From 421a2ab8acca2dcfee31cbf97c90070ec97cac1d Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 26 May 2025 03:03:05 -0700 Subject: [PATCH] Add comments explaining non-padded tensors, kNoPad -> kPacked PiperOrigin-RevId: 763352173 --- gemma/weights.cc | 4 ++-- gemma/weights.h | 12 ++++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/gemma/weights.cc b/gemma/weights.cc index 4e8325d..509a0ee 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -100,7 +100,7 @@ void LayerWeightsPtrs::Fixup(std::vector& mat_owners) { struct TensorToRead { MatPtr* mat; BlobRange range; - // Some tensors opt out of padding via kNoPad flags. + // Some tensors opt out of padding via kPacked flags. MatPadding padding; }; @@ -266,7 +266,7 @@ void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader, // Enumerate all weights (negligible cost). CallT([&](const auto& weights) { weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { - const MatPadding padding = (t.flags & TensorArgs::kNoPad) + const MatPadding padding = (t.flags & TensorArgs::kPacked) ? MatPadding::kPacked : MatPadding::kOdd; size_t key_idx; diff --git a/gemma/weights.h b/gemma/weights.h index 80341d0..7f6a12c 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -63,7 +63,7 @@ struct TensorArgs { // Avoid padding tensor rows when reading. Used for some Griffin tensors // whose index computations do not use Row() accessors. - kNoPad = 2, + kPacked = 2, }; const int flags; }; @@ -224,7 +224,9 @@ struct LayerWeightsPtrs { func(TENSOR_ARGS(vit.attn_out_w, kMustRead)); func(TENSOR_ARGS(vit.attn_out_b, kMustRead)); func(TENSOR_ARGS(vit.qkv_einsum_w, kMustRead)); - func(TENSOR_ARGS(vit.qkv_einsum_b, kMustRead | TensorArgs::kNoPad)); + // Used as 1D MatMul bias, but has `heads + 2 * kv_heads` rows, hence + // must not be padded. + func(TENSOR_ARGS(vit.qkv_einsum_b, kMustRead | TensorArgs::kPacked)); // MlpBlock. func(TENSOR_ARGS(vit.linear_0_w, kMustRead)); func(TENSOR_ARGS(vit.linear_0_b, kMustRead)); @@ -251,9 +253,11 @@ struct LayerWeightsPtrs { func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead)); func(TENSOR_ARGS(griffin.linear_out_w, kMustRead)); func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead)); - func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kNoPad)); + // conv_w and gate_w are not accessed via Row(), hence must not be padded. + // Note that *biases are 1D, hence packing/padding does not matter. + func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kPacked)); func(TENSOR_ARGS(griffin.conv_biases, kMustRead)); - func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kNoPad)); + func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kPacked)); func(TENSOR_ARGS(griffin.gate_biases, kMustRead)); func(TENSOR_ARGS(griffin.a, kMustRead)); }