mirror of https://github.com/google/gemma.cpp.git
Merge pull request #493 from ufownl:bugfix/compress_weights_le
PiperOrigin-RevId: 725585921
This commit is contained in:
commit
c495b25995
|
|
@ -257,8 +257,8 @@ void ModelWeightsStorage::CreateForType(Type weight_type,
|
|||
}
|
||||
}
|
||||
|
||||
template <class Weight>
|
||||
void LayerWeightsPtrs<Weight>::Reshape(MatStorage* storage) {
|
||||
template <>
|
||||
void LayerWeightsPtrs<NuqStream>::Reshape(MatStorage* storage) {
|
||||
if (attn_vec_einsum_w.data() == nullptr) return;
|
||||
|
||||
const size_t model_dim = layer_config.model_dim;
|
||||
|
|
@ -271,7 +271,6 @@ void LayerWeightsPtrs<Weight>::Reshape(MatStorage* storage) {
|
|||
att_weights.SetPtr(*storage);
|
||||
}
|
||||
|
||||
if (hwy::IsSame<Weight, NuqStream>()) {
|
||||
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> attn_vec_einsum_w_tmp =
|
||||
|
|
@ -301,19 +300,6 @@ void LayerWeightsPtrs<Weight>::Reshape(MatStorage* storage) {
|
|||
/*packed_ofs=*/0, pool);
|
||||
|
||||
att_weights.set_scale(attn_vec_einsum_w.scale());
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t m = 0; m < model_dim; ++m) {
|
||||
Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim;
|
||||
for (size_t h = 0; h < heads; ++h) {
|
||||
hwy::CopyBytes(
|
||||
attn_vec_einsum_w.data() + h * model_dim * qkv_dim + m * qkv_dim,
|
||||
out_row + h * qkv_dim, qkv_dim * sizeof(Weight));
|
||||
}
|
||||
}
|
||||
att_weights.set_scale(attn_vec_einsum_w.scale());
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -179,7 +179,30 @@ struct LayerWeightsPtrs {
|
|||
// Initializes att_weights from attn_vec_einsum_w, hence this must be called
|
||||
// after loading weights via ForEachTensor.
|
||||
// TODO: update compression/convert_weights to bake this in.
|
||||
void Reshape(MatStorage* storage);
|
||||
void Reshape(MatStorage* storage) {
|
||||
static_assert(!hwy::IsSame<Weight, NuqStream>());
|
||||
|
||||
if (attn_vec_einsum_w.data() == nullptr) return;
|
||||
|
||||
const size_t model_dim = layer_config.model_dim;
|
||||
const size_t heads = layer_config.heads;
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
|
||||
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
|
||||
if (storage != nullptr) {
|
||||
storage->Allocate();
|
||||
att_weights.SetPtr(*storage);
|
||||
}
|
||||
for (size_t m = 0; m < model_dim; ++m) {
|
||||
Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim;
|
||||
for (size_t h = 0; h < heads; ++h) {
|
||||
hwy::CopyBytes(
|
||||
attn_vec_einsum_w.data() + h * model_dim * qkv_dim + m * qkv_dim,
|
||||
out_row + h * qkv_dim, qkv_dim * sizeof(Weight));
|
||||
}
|
||||
}
|
||||
att_weights.set_scale(attn_vec_einsum_w.scale());
|
||||
}
|
||||
|
||||
// Used by ForEachTensor for per-layer tensors.
|
||||
#define GEMMA_CALL_FUNC(member) \
|
||||
|
|
|
|||
Loading…
Reference in New Issue