Merge pull request #493 from ufownl:bugfix/compress_weights_le

PiperOrigin-RevId: 725585921
This commit is contained in:
Copybara-Service 2025-02-11 05:10:13 -08:00
commit c495b25995
2 changed files with 46 additions and 37 deletions

View File

@ -257,8 +257,8 @@ void ModelWeightsStorage::CreateForType(Type weight_type,
} }
} }
template <class Weight> template <>
void LayerWeightsPtrs<Weight>::Reshape(MatStorage* storage) { void LayerWeightsPtrs<NuqStream>::Reshape(MatStorage* storage) {
if (attn_vec_einsum_w.data() == nullptr) return; if (attn_vec_einsum_w.data() == nullptr) return;
const size_t model_dim = layer_config.model_dim; const size_t model_dim = layer_config.model_dim;
@ -271,7 +271,6 @@ void LayerWeightsPtrs<Weight>::Reshape(MatStorage* storage) {
att_weights.SetPtr(*storage); att_weights.SetPtr(*storage);
} }
if (hwy::IsSame<Weight, NuqStream>()) {
const hwy::HWY_NAMESPACE::ScalableTag<float> df; const hwy::HWY_NAMESPACE::ScalableTag<float> df;
hwy::AlignedFreeUniquePtr<float[]> attn_vec_einsum_w_tmp = hwy::AlignedFreeUniquePtr<float[]> attn_vec_einsum_w_tmp =
@ -301,19 +300,6 @@ void LayerWeightsPtrs<Weight>::Reshape(MatStorage* storage) {
/*packed_ofs=*/0, pool); /*packed_ofs=*/0, pool);
att_weights.set_scale(attn_vec_einsum_w.scale()); att_weights.set_scale(attn_vec_einsum_w.scale());
return;
}
for (size_t m = 0; m < model_dim; ++m) {
Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim;
for (size_t h = 0; h < heads; ++h) {
hwy::CopyBytes(
attn_vec_einsum_w.data() + h * model_dim * qkv_dim + m * qkv_dim,
out_row + h * qkv_dim, qkv_dim * sizeof(Weight));
}
}
att_weights.set_scale(attn_vec_einsum_w.scale());
} }
} // namespace gcpp } // namespace gcpp

View File

@ -179,7 +179,30 @@ struct LayerWeightsPtrs {
// Initializes att_weights from attn_vec_einsum_w, hence this must be called // Initializes att_weights from attn_vec_einsum_w, hence this must be called
// after loading weights via ForEachTensor. // after loading weights via ForEachTensor.
// TODO: update compression/convert_weights to bake this in. // TODO: update compression/convert_weights to bake this in.
void Reshape(MatStorage* storage); void Reshape(MatStorage* storage) {
static_assert(!hwy::IsSame<Weight, NuqStream>());
if (attn_vec_einsum_w.data() == nullptr) return;
const size_t model_dim = layer_config.model_dim;
const size_t heads = layer_config.heads;
const size_t qkv_dim = layer_config.qkv_dim;
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
if (storage != nullptr) {
storage->Allocate();
att_weights.SetPtr(*storage);
}
for (size_t m = 0; m < model_dim; ++m) {
Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim;
for (size_t h = 0; h < heads; ++h) {
hwy::CopyBytes(
attn_vec_einsum_w.data() + h * model_dim * qkv_dim + m * qkv_dim,
out_row + h * qkv_dim, qkv_dim * sizeof(Weight));
}
}
att_weights.set_scale(attn_vec_einsum_w.scale());
}
// Used by ForEachTensor for per-layer tensors. // Used by ForEachTensor for per-layer tensors.
#define GEMMA_CALL_FUNC(member) \ #define GEMMA_CALL_FUNC(member) \