Add comments explaining non-padded tensors, kNoPad -> kPacked

PiperOrigin-RevId: 763352173
This commit is contained in:
Jan Wassenberg 2025-05-26 03:03:05 -07:00 committed by Copybara-Service
parent eb8a463038
commit 421a2ab8ac
2 changed files with 10 additions and 6 deletions

View File

@ -100,7 +100,7 @@ void LayerWeightsPtrs<NuqStream>::Fixup(std::vector<MatOwner>& mat_owners) {
struct TensorToRead { struct TensorToRead {
MatPtr* mat; MatPtr* mat;
BlobRange range; BlobRange range;
// Some tensors opt out of padding via kNoPad flags. // Some tensors opt out of padding via kPacked flags.
MatPadding padding; MatPadding padding;
}; };
@ -266,7 +266,7 @@ void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
// Enumerate all weights (negligible cost). // Enumerate all weights (negligible cost).
CallT([&](const auto& weights) { CallT([&](const auto& weights) {
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
const MatPadding padding = (t.flags & TensorArgs::kNoPad) const MatPadding padding = (t.flags & TensorArgs::kPacked)
? MatPadding::kPacked ? MatPadding::kPacked
: MatPadding::kOdd; : MatPadding::kOdd;
size_t key_idx; size_t key_idx;

View File

@ -63,7 +63,7 @@ struct TensorArgs {
// Avoid padding tensor rows when reading. Used for some Griffin tensors // Avoid padding tensor rows when reading. Used for some Griffin tensors
// whose index computations do not use Row() accessors. // whose index computations do not use Row() accessors.
kNoPad = 2, kPacked = 2,
}; };
const int flags; const int flags;
}; };
@ -224,7 +224,9 @@ struct LayerWeightsPtrs {
func(TENSOR_ARGS(vit.attn_out_w, kMustRead)); func(TENSOR_ARGS(vit.attn_out_w, kMustRead));
func(TENSOR_ARGS(vit.attn_out_b, kMustRead)); func(TENSOR_ARGS(vit.attn_out_b, kMustRead));
func(TENSOR_ARGS(vit.qkv_einsum_w, kMustRead)); func(TENSOR_ARGS(vit.qkv_einsum_w, kMustRead));
func(TENSOR_ARGS(vit.qkv_einsum_b, kMustRead | TensorArgs::kNoPad)); // Used as 1D MatMul bias, but has `heads + 2 * kv_heads` rows, hence
// must not be padded.
func(TENSOR_ARGS(vit.qkv_einsum_b, kMustRead | TensorArgs::kPacked));
// MlpBlock. // MlpBlock.
func(TENSOR_ARGS(vit.linear_0_w, kMustRead)); func(TENSOR_ARGS(vit.linear_0_w, kMustRead));
func(TENSOR_ARGS(vit.linear_0_b, kMustRead)); func(TENSOR_ARGS(vit.linear_0_b, kMustRead));
@ -251,9 +253,11 @@ struct LayerWeightsPtrs {
func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead)); func(TENSOR_ARGS(griffin.linear_y_biases, kMustRead));
func(TENSOR_ARGS(griffin.linear_out_w, kMustRead)); func(TENSOR_ARGS(griffin.linear_out_w, kMustRead));
func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead)); func(TENSOR_ARGS(griffin.linear_out_biases, kMustRead));
func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kNoPad)); // conv_w and gate_w are not accessed via Row(), hence must not be padded.
// Note that *biases are 1D, hence packing/padding does not matter.
func(TENSOR_ARGS(griffin.conv_w, kMustRead | TensorArgs::kPacked));
func(TENSOR_ARGS(griffin.conv_biases, kMustRead)); func(TENSOR_ARGS(griffin.conv_biases, kMustRead));
func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kNoPad)); func(TENSOR_ARGS(griffin.gate_w, kMustRead | TensorArgs::kPacked));
func(TENSOR_ARGS(griffin.gate_biases, kMustRead)); func(TENSOR_ARGS(griffin.gate_biases, kMustRead));
func(TENSOR_ARGS(griffin.a, kMustRead)); func(TENSOR_ARGS(griffin.a, kMustRead));
} }