Add comments explaining non-padded tensors, kNoPad -> kPacked

PiperOrigin-RevId: 763352173
This commit is contained in:
Jan Wassenberg 2025-05-26 03:03:05 -07:00 committed by Copybara-Service
parent eb8a463038
commit 421a2ab8ac
2 changed files with 10 additions and 6 deletions

View File

@ -100,7 +100,7 @@ void LayerWeightsPtrs<NuqStream>::Fixup(std::vector<MatOwner>& mat_owners) {
struct TensorToRead {
MatPtr* mat;
BlobRange range;
// Some tensors opt out of padding via kNoPad flags.
// Some tensors opt out of padding via kPacked flags.
MatPadding padding;
};
@ -266,7 +266,7 @@ void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
// Enumerate all weights (negligible cost).
CallT([&](const auto& weights) {
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
const MatPadding padding = (t.flags & TensorArgs::kNoPad)
const MatPadding padding = (t.flags & TensorArgs::kPacked)
? MatPadding::kPacked
: MatPadding::kOdd;
size_t key_idx;

View File

@ -63,7 +63,7 @@ struct TensorArgs {
// Avoid padding tensor rows when reading. Used for some Griffin tensors
// whose index computations do not use Row() accessors.
kNoPad = 2,
kPacked = 2,
};
const int flags;
};
@ -224,7 +224,9 @@ struct LayerWeightsPtrs {
func(TENSOR_ARGS(vit.attn_out_w, kMustRead));
func(TENSOR_ARGS(vit.attn_out_b, kMustRead));
func(TENSOR_ARGS(vit.qkv_einsum_w, kMustRead));
func(TENSOR_ARGS(vit.qkv_einsum_b, kMustRead | TensorArgs::kNoPad));
// Used as 1D MatMul bias, but has `heads + 2 * kv_heads` rows, hence
// must not be padded.
func(TENSOR_ARGS(vit.qkv_einsum_b, kMustRead | TensorArgs::kPacked));
// MlpBlock.
func(TENSOR_ARGS(vit.linear_0_w, kMustRead));
func(TENSOR_ARGS(vit.linear_0_b, kMustRead));
@ -251,9 +253,11 @@ struct LayerWeightsPtrs {
func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead));
func(TENSOR_ARGS(griffin.linear_out_w, kMustRead));
func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead));
func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kNoPad));
// conv_w and gate_w are not accessed via Row(), hence must not be padded.
// Note that *biases are 1D, hence packing/padding does not matter.
func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kPacked));
func(TENSOR_ARGS(griffin.conv_biases, kMustRead));
func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kNoPad));
func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kPacked));
func(TENSOR_ARGS(griffin.gate_biases, kMustRead));
func(TENSOR_ARGS(griffin.a, kMustRead));
}