Fix gcc build error and gemma3 crash, thanks @ufownl, fixes #551

PiperOrigin-RevId: 755729478
This commit is contained in:
Jan Wassenberg 2025-05-07 00:58:45 -07:00 committed by Copybara-Service
parent c8d92948f4
commit e9ecb7794d
2 changed files with 6 additions and 2 deletions

View File

@ -310,6 +310,11 @@ struct LayerWeightsPtrs {
// after reading weights via `ForEachTensor`. // after reading weights via `ForEachTensor`.
// TODO: update compression/convert_weights to bake this in. // TODO: update compression/convert_weights to bake this in.
void Reshape() { void Reshape() {
// We only have/allocate this tensor for Gemma layers.
HWY_ASSERT(att_weights.HasPtr() ==
(layer_config.type == LayerAttentionType::kGemma));
if (!att_weights.HasPtr()) return;
// NUQ is handled by a specialization in weights.cc. // NUQ is handled by a specialization in weights.cc.
HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ); HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ);
@ -318,7 +323,6 @@ struct LayerWeightsPtrs {
const size_t qkv_dim = layer_config.qkv_dim; const size_t qkv_dim = layer_config.qkv_dim;
// Reshape [heads, model_dim, qkv_dim] to [model_dim, heads * qkv_dim]. // Reshape [heads, model_dim, qkv_dim] to [model_dim, heads * qkv_dim].
HWY_ASSERT(att_weights.HasPtr());
HWY_ASSERT(att_weights.GetType() == attn_vec_einsum_w.GetType()); HWY_ASSERT(att_weights.GetType() == attn_vec_einsum_w.GetType());
HWY_ASSERT(att_weights.Rows() == model_dim); HWY_ASSERT(att_weights.Rows() == model_dim);
HWY_ASSERT(att_weights.Cols() == heads * qkv_dim); HWY_ASSERT(att_weights.Cols() == heads * qkv_dim);

View File

@ -105,7 +105,7 @@ class BlobStore {
// Returns the end of the directory, including padding, which is also the // Returns the end of the directory, including padding, which is also the
// start of the first payload. `num_blobs` is `NumBlobs()` if the header is // start of the first payload. `num_blobs` is `NumBlobs()` if the header is
// already available, otherwise the number of blobs to be written. // already available, otherwise the number of blobs to be written.
static constexpr size_t PaddedDirEnd(size_t num_blobs) { static HWY_CXX17_CONSTEXPR size_t PaddedDirEnd(size_t num_blobs) {
HWY_ASSERT(num_blobs < kMaxBlobs); HWY_ASSERT(num_blobs < kMaxBlobs);
// Per blob, a key and offset/size. // Per blob, a key and offset/size.
return RoundUpToAlign(sizeof(Header) + 2 * kU128Bytes * num_blobs); return RoundUpToAlign(sizeof(Header) + 2 * kU128Bytes * num_blobs);