mirror of https://github.com/google/gemma.cpp.git
Merge pull request #493 from ufownl:bugfix/compress_weights_le
PiperOrigin-RevId: 725585921
This commit is contained in:
commit
c495b25995
|
|
@ -257,8 +257,8 @@ void ModelWeightsStorage::CreateForType(Type weight_type,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Weight>
|
template <>
|
||||||
void LayerWeightsPtrs<Weight>::Reshape(MatStorage* storage) {
|
void LayerWeightsPtrs<NuqStream>::Reshape(MatStorage* storage) {
|
||||||
if (attn_vec_einsum_w.data() == nullptr) return;
|
if (attn_vec_einsum_w.data() == nullptr) return;
|
||||||
|
|
||||||
const size_t model_dim = layer_config.model_dim;
|
const size_t model_dim = layer_config.model_dim;
|
||||||
|
|
@ -271,7 +271,6 @@ void LayerWeightsPtrs<Weight>::Reshape(MatStorage* storage) {
|
||||||
att_weights.SetPtr(*storage);
|
att_weights.SetPtr(*storage);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hwy::IsSame<Weight, NuqStream>()) {
|
|
||||||
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
|
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<float[]> attn_vec_einsum_w_tmp =
|
hwy::AlignedFreeUniquePtr<float[]> attn_vec_einsum_w_tmp =
|
||||||
|
|
@ -301,19 +300,6 @@ void LayerWeightsPtrs<Weight>::Reshape(MatStorage* storage) {
|
||||||
/*packed_ofs=*/0, pool);
|
/*packed_ofs=*/0, pool);
|
||||||
|
|
||||||
att_weights.set_scale(attn_vec_einsum_w.scale());
|
att_weights.set_scale(attn_vec_einsum_w.scale());
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t m = 0; m < model_dim; ++m) {
|
|
||||||
Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim;
|
|
||||||
for (size_t h = 0; h < heads; ++h) {
|
|
||||||
hwy::CopyBytes(
|
|
||||||
attn_vec_einsum_w.data() + h * model_dim * qkv_dim + m * qkv_dim,
|
|
||||||
out_row + h * qkv_dim, qkv_dim * sizeof(Weight));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
att_weights.set_scale(attn_vec_einsum_w.scale());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -179,7 +179,30 @@ struct LayerWeightsPtrs {
|
||||||
// Initializes att_weights from attn_vec_einsum_w, hence this must be called
|
// Initializes att_weights from attn_vec_einsum_w, hence this must be called
|
||||||
// after loading weights via ForEachTensor.
|
// after loading weights via ForEachTensor.
|
||||||
// TODO: update compression/convert_weights to bake this in.
|
// TODO: update compression/convert_weights to bake this in.
|
||||||
void Reshape(MatStorage* storage);
|
void Reshape(MatStorage* storage) {
|
||||||
|
static_assert(!hwy::IsSame<Weight, NuqStream>());
|
||||||
|
|
||||||
|
if (attn_vec_einsum_w.data() == nullptr) return;
|
||||||
|
|
||||||
|
const size_t model_dim = layer_config.model_dim;
|
||||||
|
const size_t heads = layer_config.heads;
|
||||||
|
const size_t qkv_dim = layer_config.qkv_dim;
|
||||||
|
|
||||||
|
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
|
||||||
|
if (storage != nullptr) {
|
||||||
|
storage->Allocate();
|
||||||
|
att_weights.SetPtr(*storage);
|
||||||
|
}
|
||||||
|
for (size_t m = 0; m < model_dim; ++m) {
|
||||||
|
Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim;
|
||||||
|
for (size_t h = 0; h < heads; ++h) {
|
||||||
|
hwy::CopyBytes(
|
||||||
|
attn_vec_einsum_w.data() + h * model_dim * qkv_dim + m * qkv_dim,
|
||||||
|
out_row + h * qkv_dim, qkv_dim * sizeof(Weight));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
att_weights.set_scale(attn_vec_einsum_w.scale());
|
||||||
|
}
|
||||||
|
|
||||||
// Used by ForEachTensor for per-layer tensors.
|
// Used by ForEachTensor for per-layer tensors.
|
||||||
#define GEMMA_CALL_FUNC(member) \
|
#define GEMMA_CALL_FUNC(member) \
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue