Split W1/W2 as a load-time preprocess.

Remove kOnlyAllocate - no longer used. Rename ReadOrAllocate -> ReadFromBlobs.
Rename Reshape -> Fixup to reflect the new scope.
Remove no longer used ShrinkRows.

This simplifies gemma-inl and is a prerequisite for removing ConstMat
(whose .ofs was previously used for merged tensors)

PiperOrigin-RevId: 758214083
This commit is contained in:
Jan Wassenberg 2025-05-13 07:39:16 -07:00 committed by Copybara-Service
parent 2038dfd9cc
commit 8a312e9b89
8 changed files with 169 additions and 107 deletions

View File

@ -100,7 +100,7 @@ TEST(OptimizeTest, GradientDescent) {
}; };
gemma.MutableWeights().RandInit(1.0f, gen); gemma.MutableWeights().RandInit(1.0f, gen);
gemma.MutableWeights().Reshape(pool); gemma.MutableWeights().Fixup(pool);
printf("Initial weights:\n"); printf("Initial weights:\n");
gemma.MutableWeights().LogWeightStatsF32(); gemma.MutableWeights().LogWeightStatsF32();
@ -129,7 +129,7 @@ TEST(OptimizeTest, GradientDescent) {
CrossEntropyLossBackwardPass(prompt, *gemma.Weights().GetF32(), forward, CrossEntropyLossBackwardPass(prompt, *gemma.Weights().GetF32(), forward,
*grad.GetF32(), backward, inv_timescale, *grad.GetF32(), backward, inv_timescale,
pool); pool);
gemma.MutableWeights().Reshape(pool); gemma.MutableWeights().Fixup(pool);
num_ok += verify(prompt) ? 1 : 0; num_ok += verify(prompt) ? 1 : 0;
} }
total_loss /= kBatchSize; total_loss /= kBatchSize;

View File

@ -246,10 +246,6 @@ class GemmaAttention {
const size_t heads = layer_config_.heads; const size_t heads = layer_config_.heads;
const size_t kv_heads = layer_config_.kv_heads; const size_t kv_heads = layer_config_.kv_heads;
using WeightT = typename decltype(layer_weights_.qkv_einsum_w)::T;
ConstMat<WeightT> w_q1(layer_weights_.qkv_einsum_w.HasPtr()
? layer_weights_.qkv_einsum_w
: layer_weights_.qkv_einsum_w1);
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim, // The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim,
// model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows. // model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows.
// We must shrink to the actual size because MatMul verifies // We must shrink to the actual size because MatMul verifies
@ -257,23 +253,16 @@ class GemmaAttention {
// rows are used. Otherwise, `QStride() == qkv_dim` and KV will be // rows are used. Otherwise, `QStride() == qkv_dim` and KV will be
// computed in the second MatMul. // computed in the second MatMul.
const size_t w1_rows = heads * layer_config_.QStride(); const size_t w1_rows = heads * layer_config_.QStride();
w_q1.ShrinkRows(w1_rows); HWY_DASSERT(layer_weights_.qkv_einsum_w1.Rows() == w1_rows);
MatMul(activations_.pre_att_rms_out, w_q1, MatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1,
/*add=*/nullptr, *activations_.env, RowPtrFromMat(activations_.q)); /*add=*/nullptr, *activations_.env, RowPtrFromMat(activations_.q));
if (is_mha_) { if (is_mha_) {
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already. // Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
} else { } else {
decltype(w_q1) w_q2(layer_weights_.qkv_einsum_w.HasPtr()
? layer_weights_.qkv_einsum_w
: layer_weights_.qkv_einsum_w2);
if (layer_weights_.qkv_einsum_w.HasPtr()) {
// Skip first half of the matrix.
w_q2.ofs = w_q2.Row(w1_rows);
}
// KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v). // KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v).
const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim; const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim;
w_q2.ShrinkRows(w_rows_kv_cols); HWY_DASSERT(layer_weights_.qkv_einsum_w2.Rows() == w_rows_kv_cols);
// Single query and no wraparound means we can use a matmul and write // Single query and no wraparound means we can use a matmul and write
// directly into the KV cache with a stride of cache_pos_size_. // directly into the KV cache with a stride of cache_pos_size_.
@ -284,7 +273,7 @@ class GemmaAttention {
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs; float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
RowPtrF kv_rows(kv, w_rows_kv_cols); RowPtrF kv_rows(kv, w_rows_kv_cols);
kv_rows.SetStride(cache_pos_size_); kv_rows.SetStride(cache_pos_size_);
MatMul(activations_.pre_att_rms_out, w_q2, MatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2,
/*add=*/nullptr, *activations_.env, kv_rows); /*add=*/nullptr, *activations_.env, kv_rows);
} else { } else {
// Proceed row by row because there will be wraparound. // Proceed row by row because there will be wraparound.
@ -299,7 +288,8 @@ class GemmaAttention {
const size_t kv_offset = const size_t kv_offset =
cache_pos * cache_pos_size_ + layer_ * cache_layer_size_; cache_pos * cache_pos_size_ + layer_ * cache_layer_size_;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
MatVec(w_q2, w_q2.ofs, w_rows_kv_cols, model_dim, x, kv, pool_); MatVec(layer_weights_.qkv_einsum_w2, 0, w_rows_kv_cols, model_dim, x,
kv, pool_);
} }
} }
} // !is_mha_ } // !is_mha_
@ -825,32 +815,11 @@ HWY_NOINLINE void FFWNoVit(Activations& activations,
const float* output_bias = const float* output_bias =
add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr; add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr;
// Define slightly more readable names for the weights and activations.
auto hidden_activations = RowPtrFromMat(activations.C1);
auto multiplier = RowPtrFromMat(activations.C2);
auto ffw_out = RowPtrFromMat(activations.ffw_out);
using WeightT = typename decltype(layer_weights->gating_einsum_w)::T;
// gating_einsum_w holds two half-matrices. We plan to change the importer to
// avoid this confusion by splitting into gating_einsum_w1 and
// gating_einsum_w2. TODO: move into Reshape().
const bool split = layer_weights->gating_einsum_w.HasPtr();
ConstMat<WeightT> w1(split ? layer_weights->gating_einsum_w
: layer_weights->gating_einsum_w1);
ConstMat<WeightT> w2(split ? layer_weights->gating_einsum_w
: layer_weights->gating_einsum_w2);
if (split) {
w2.ofs = w2.Row(ffh_hidden_dim);
// Ensure that B.Extents().row matches C.Cols() because MatMul checks that.
w1.ShrinkRows(ffh_hidden_dim);
w2.ShrinkRows(ffh_hidden_dim);
}
// Compute the hidden layer activations. // Compute the hidden layer activations.
MatMul(activations.pre_ffw_rms_out, w1, bias1, *activations.env, MatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1, bias1,
hidden_activations); *activations.env, RowPtrFromMat(activations.C1));
MatMul(activations.pre_ffw_rms_out, w2, bias2, *activations.env, multiplier); MatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2, bias2,
*activations.env, RowPtrFromMat(activations.C2));
// Activation (Gelu) and maybe multiply by gate. Store activations in act. // Activation (Gelu) and maybe multiply by gate. Store activations in act.
ActivationBatched(layer_weights->layer_config.activation, activations.C1, ActivationBatched(layer_weights->layer_config.activation, activations.C1,
@ -858,7 +827,7 @@ HWY_NOINLINE void FFWNoVit(Activations& activations,
// Hidden layer -> output layer. // Hidden layer -> output layer.
MatMul(activations.C1, layer_weights->linear_w, output_bias, *activations.env, MatMul(activations.C1, layer_weights->linear_w, output_bias, *activations.env,
ffw_out); RowPtrFromMat(activations.ffw_out));
} }
// Same as FFWNoVit, but with different layer_weights members and no second // Same as FFWNoVit, but with different layer_weights members and no second

View File

@ -57,7 +57,7 @@ Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env)
model_(*reader_, loader.tokenizer, loader.wrapping), model_(*reader_, loader.tokenizer, loader.wrapping),
weights_(model_.Config().weight), weights_(model_.Config().weight),
chat_template_(model_.Tokenizer(), model_.Config().model) { chat_template_(model_.Tokenizer(), model_.Config().model) {
weights_.ReadOrAllocate(model_, *reader_, env_.ctx.pools.Pool()); weights_.ReadFromBlobs(model_, *reader_, env_.ctx.pools.Pool());
reader_.reset(); reader_.reset();
} }

View File

@ -37,13 +37,15 @@
#include "hwy/profiler.h" #include "hwy/profiler.h"
#include "hwy/stats.h" #include "hwy/stats.h"
// TODO: move into foreach_target; this is only used for NUQ Reshape. // TODO: move into foreach_target; this is only used for NUQ Fixup.
#include "compression/compress-inl.h" #include "compression/compress-inl.h"
namespace gcpp { namespace gcpp {
template <> static void InitAttWeightsNUQ(const LayerConfig& layer_config,
void LayerWeightsPtrs<NuqStream>::Reshape() { MatPtrT<NuqStream>& attn_vec_einsum_w,
MatPtrT<NuqStream>& att_weights,
MatOwners& mat_owners) {
if (!attn_vec_einsum_w.HasPtr()) return; if (!attn_vec_einsum_w.HasPtr()) return;
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ); HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ);
@ -83,6 +85,16 @@ void LayerWeightsPtrs<NuqStream>::Reshape() {
att_weights.SetScale(attn_vec_einsum_w.Scale()); att_weights.SetScale(attn_vec_einsum_w.Scale());
} }
static void SplitW1NUQ(const LayerConfig& layer_config) {
// TODO(janwas): implement.
}
template <>
void LayerWeightsPtrs<NuqStream>::Fixup(MatOwners& mat_owners) {
InitAttWeightsNUQ(layer_config, attn_vec_einsum_w, att_weights, mat_owners);
SplitW1NUQ(layer_config);
}
// Aborts on error. // Aborts on error.
static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader, static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
const std::vector<BlobRange>& ranges, const std::vector<BlobRange>& ranges,
@ -134,8 +146,8 @@ static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
reader.ReadAll(pool); reader.ReadAll(pool);
} }
void WeightsOwner::ReadOrAllocate(const ModelStore& model, BlobReader& reader, void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
// List of tensors to read/map, and where from. // List of tensors to read/map, and where from.
std::vector<MatPtr*> mats; std::vector<MatPtr*> mats;
std::vector<BlobRange> ranges; std::vector<BlobRange> ranges;
@ -148,10 +160,6 @@ void WeightsOwner::ReadOrAllocate(const ModelStore& model, BlobReader& reader,
// Enumerate all weights (negligible cost). // Enumerate all weights (negligible cost).
CallT([&](const auto& weights) { CallT([&](const auto& weights) {
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
if (t.flags & TensorArgs::kOnlyAllocate) {
mat_owners_.AllocateFor(t.mat, padding);
return;
}
size_t key_idx; size_t key_idx;
if (model.FindAndUpdateMatPtr(t.mat, key_idx)) { if (model.FindAndUpdateMatPtr(t.mat, key_idx)) {
mats.push_back(&t.mat); mats.push_back(&t.mat);
@ -165,7 +173,7 @@ void WeightsOwner::ReadOrAllocate(const ModelStore& model, BlobReader& reader,
MapOrRead(mats, reader, ranges, mat_owners_, padding, pool); MapOrRead(mats, reader, ranges, mat_owners_, padding, pool);
Reshape(pool); Fixup(pool);
} }
// Allocates `*_weights_`, but not yet the tensors inside. This is split out // Allocates `*_weights_`, but not yet the tensors inside. This is split out
@ -218,6 +226,7 @@ void WeightsOwner::LogWeightStatsF32() {
HWY_ASSERT(weight_type_ == Type::kF32); // Only for float weights. HWY_ASSERT(weight_type_ == Type::kF32); // Only for float weights.
float_weights_->ForEachTensor( float_weights_->ForEachTensor(
nullptr, nullptr, [&total_weights](const TensorArgs& t) { nullptr, nullptr, [&total_weights](const TensorArgs& t) {
if (!t.mat.HasPtr()) return;
if (t.mat.Scale() != 1.0f) { if (t.mat.Scale() != 1.0f) {
printf("[scale=%f] ", t.mat.Scale()); printf("[scale=%f] ", t.mat.Scale());
} }
@ -238,9 +247,9 @@ void WeightsOwner::LogWeightStatsF32() {
printf("%-20s %12zu\n", "Total", total_weights); printf("%-20s %12zu\n", "Total", total_weights);
} }
void WeightsOwner::Reshape(hwy::ThreadPool& pool) { void WeightsOwner::Fixup(hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.Reshape"); PROFILER_ZONE("Startup.Fixup");
CallT([&pool](const auto& weights) { weights->Reshape(pool); }); CallT([&](const auto& weights) { weights->Fixup(mat_owners_, pool); });
} }
std::vector<uint32_t> WeightsOwner::AddTensorDataToWriter( std::vector<uint32_t> WeightsOwner::AddTensorDataToWriter(
@ -248,7 +257,6 @@ std::vector<uint32_t> WeightsOwner::AddTensorDataToWriter(
std::vector<uint32_t> serialized_mat_ptrs; std::vector<uint32_t> serialized_mat_ptrs;
CallT([&](const auto& weights) { CallT([&](const auto& weights) {
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
if (t.flags & TensorArgs::kOnlyAllocate) return;
if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return; if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return;
HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name()); HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name());
writer.Add(t.mat.Name(), t.mat.Packed(), t.mat.PackedBytes()); writer.Add(t.mat.Name(), t.mat.Packed(), t.mat.PackedBytes());

View File

@ -21,6 +21,7 @@
#include <complex> #include <complex>
#include <memory> #include <memory>
#include <mutex> // NOLINT
#include <random> #include <random>
#include <string> #include <string>
#include <vector> #include <vector>
@ -43,10 +44,10 @@ struct TensorArgs {
// `flags` is a combination of zero or more `Flags`. // `flags` is a combination of zero or more `Flags`.
TensorArgs(MatPtr& mat, const MatPtr* other_mat1, const MatPtr* other_mat2, TensorArgs(MatPtr& mat, const MatPtr* other_mat1, const MatPtr* other_mat2,
int flags) int flags)
: mat(mat), other_mat1(other_mat1), other_mat2(other_mat2), flags(flags) { : mat(mat),
// Does not make sense to combine both flags. other_mat1(other_mat1),
HWY_ASSERT(flags != (kMaybeRead | kOnlyAllocate)); other_mat2(other_mat2),
} flags(flags) {}
MatPtr& mat; MatPtr& mat;
const MatPtr* other_mat1; // either/both can be nullptr. const MatPtr* other_mat1; // either/both can be nullptr.
@ -59,8 +60,6 @@ struct TensorArgs {
// Not an error if the tensor is not present in the file. For example, // Not an error if the tensor is not present in the file. For example,
// the _w1/_w2 tensors are not always present. // the _w1/_w2 tensors are not always present.
kMaybeRead = 1, kMaybeRead = 1,
// Do not attempt to read, just allocate the tensor. Used for `Reshape`.
kOnlyAllocate = 2,
}; };
const int flags; const int flags;
}; };
@ -87,7 +86,6 @@ struct LayerWeightsPtrs {
LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config, LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config,
const TensorInfoRegistry& tensors) const TensorInfoRegistry& tensors)
: suffix_(LayerSuffix(layer_idx)), : suffix_(LayerSuffix(layer_idx)),
attn_vec_einsum_w(Concat("att_ein", suffix_), tensors),
qkv_einsum_w(Concat("qkv_ein", suffix_), tensors), qkv_einsum_w(Concat("qkv_ein", suffix_), tensors),
qkv_einsum_w1(Concat("qkv1_w", suffix_), tensors), qkv_einsum_w1(Concat("qkv1_w", suffix_), tensors),
qkv_einsum_w2(Concat("qkv2_w", suffix_), tensors), qkv_einsum_w2(Concat("qkv2_w", suffix_), tensors),
@ -127,9 +125,13 @@ struct LayerWeightsPtrs {
post_ffw_norm_scale(Concat("post_ff_ns", suffix_), tensors), post_ffw_norm_scale(Concat("post_ff_ns", suffix_), tensors),
ffw_gating_biases(Concat("ffw_gat_b", suffix_), tensors), ffw_gating_biases(Concat("ffw_gat_b", suffix_), tensors),
ffw_output_biases(Concat("ffw_out_b", suffix_), tensors), ffw_output_biases(Concat("ffw_out_b", suffix_), tensors),
attn_vec_einsum_w(Concat("att_ein", suffix_), tensors),
att_weights(Concat("att_w", suffix_), tensors), att_weights(Concat("att_w", suffix_), tensors),
key_norm_scale(Concat("key_norm", suffix_), tensors), key_norm_scale(Concat("key_norm", suffix_), tensors),
query_norm_scale(Concat("query_norm", suffix_), tensors), query_norm_scale(Concat("query_norm", suffix_), tensors),
layer_config(config) {} layer_config(config) {}
~LayerWeightsPtrs() = default; ~LayerWeightsPtrs() = default;
@ -144,10 +146,8 @@ struct LayerWeightsPtrs {
hwy::If<hwy::IsSame<Weight, double>(), double, hwy::If<hwy::IsSame<Weight, double>(), double,
hwy::If<IsF32<Weight>(), float, BF16>>>; hwy::If<IsF32<Weight>(), float, BF16>>>;
MatPtrT<Weight> attn_vec_einsum_w; // Files either have qkv_einsum_w with 2 stacked matrices or separate
// qkv_einsum_w holds 2 different matrices, which may be separated out. // w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h.
// On reading, which is used depends on what is in the file.
// At inference, the one with a non-null ptr is used.
MatPtrT<Weight> qkv_einsum_w; MatPtrT<Weight> qkv_einsum_w;
MatPtrT<Weight> qkv_einsum_w1; MatPtrT<Weight> qkv_einsum_w1;
MatPtrT<Weight> qkv_einsum_w2; MatPtrT<Weight> qkv_einsum_w2;
@ -185,9 +185,8 @@ struct LayerWeightsPtrs {
MatPtrT<WeightF32OrBF16> layer_norm_1_scale; MatPtrT<WeightF32OrBF16> layer_norm_1_scale;
} vit; } vit;
// gating_einsum_w holds 2 different matrices, which may be separated out. // Files either have gating_einsum_w with 2 stacked matrices or separate
// On reading, which is used depends on what is in the file. // w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h.
// At inference, the one with a non-null ptr is used.
MatPtrT<Weight> gating_einsum_w; MatPtrT<Weight> gating_einsum_w;
MatPtrT<Weight> gating_einsum_w1; MatPtrT<Weight> gating_einsum_w1;
MatPtrT<Weight> gating_einsum_w2; MatPtrT<Weight> gating_einsum_w2;
@ -201,7 +200,8 @@ struct LayerWeightsPtrs {
MatPtrT<float> ffw_gating_biases; MatPtrT<float> ffw_gating_biases;
MatPtrT<float> ffw_output_biases; MatPtrT<float> ffw_output_biases;
MatPtrT<Weight> att_weights; // For Reshape(); kOnlyAllocate. MatPtrT<Weight> attn_vec_einsum_w; // Use att_weights instead of this.
MatPtrT<Weight> att_weights; // Use this instead of attn_vec_einsum_w.
MatPtrT<WeightF32OrBF16> key_norm_scale; MatPtrT<WeightF32OrBF16> key_norm_scale;
MatPtrT<WeightF32OrBF16> query_norm_scale; MatPtrT<WeightF32OrBF16> query_norm_scale;
@ -234,10 +234,10 @@ struct LayerWeightsPtrs {
return; return;
} }
if (layer_config.type == LayerAttentionType::kGemma) { if (layer_config.type == LayerAttentionType::kGemma) {
// Not read, will be filled by Reshape() from `attn_vec_einsum_w`. // Either read from file, or allocated during Fixup().
func(TENSOR_ARGS(att_weights, kOnlyAllocate)); func(TENSOR_ARGS(att_weights, kMaybeRead));
func(TENSOR_ARGS(attn_vec_einsum_w, kMustRead)); func(TENSOR_ARGS(attn_vec_einsum_w, kMaybeRead));
func(TENSOR_ARGS(qkv_einsum_w, kMustRead)); func(TENSOR_ARGS(qkv_einsum_w, kMaybeRead));
func(TENSOR_ARGS(qkv_einsum_w1, kMaybeRead)); func(TENSOR_ARGS(qkv_einsum_w1, kMaybeRead));
func(TENSOR_ARGS(qkv_einsum_w2, kMaybeRead)); func(TENSOR_ARGS(qkv_einsum_w2, kMaybeRead));
} else { } else {
@ -254,7 +254,7 @@ struct LayerWeightsPtrs {
func(TENSOR_ARGS(griffin.a, kMustRead)); func(TENSOR_ARGS(griffin.a, kMustRead));
} }
{ {
func(TENSOR_ARGS(gating_einsum_w, kMustRead)); func(TENSOR_ARGS(gating_einsum_w, kMaybeRead));
func(TENSOR_ARGS(gating_einsum_w1, kMaybeRead)); func(TENSOR_ARGS(gating_einsum_w1, kMaybeRead));
func(TENSOR_ARGS(gating_einsum_w2, kMaybeRead)); func(TENSOR_ARGS(gating_einsum_w2, kMaybeRead));
func(TENSOR_ARGS(linear_w, kMustRead)); func(TENSOR_ARGS(linear_w, kMustRead));
@ -306,14 +306,25 @@ struct LayerWeightsPtrs {
}); });
} }
// Initializes att_weights from `attn_vec_einsum_w`, hence this must be called // Must be called after reading weights via `ForEachTensor`.
// after reading weights via `ForEachTensor`. // TODO: exporters should bake this into the weights already.
// TODO: update compression/convert_weights to bake this in. // WARNING: called from multiple threads; `mat_owners` requires a lock.
void Reshape() { void Fixup(MatOwners& mat_owners) {
// We only have/allocate this tensor for Gemma layers. InitAttWeights(mat_owners);
HWY_ASSERT(att_weights.HasPtr() == SplitW1();
(layer_config.type == LayerAttentionType::kGemma)); SplitAttW1();
if (!att_weights.HasPtr()) return; }
private:
// Copies att_weights from `attn_vec_einsum_w`.
void InitAttWeights(MatOwners& mat_owners) {
// We only use this tensor for Gemma layers.
if (layer_config.type != LayerAttentionType::kGemma) return;
// Files must have one or the other, and backprop/ allocates both.
HWY_ASSERT(attn_vec_einsum_w.HasPtr() || att_weights.HasPtr());
// Done if we already read the transposed tensor.
if (att_weights.HasPtr() && !attn_vec_einsum_w.HasPtr()) return;
// NUQ is handled by a specialization in weights.cc. // NUQ is handled by a specialization in weights.cc.
HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ); HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ);
@ -326,9 +337,15 @@ struct LayerWeightsPtrs {
HWY_ASSERT(att_weights.GetType() == attn_vec_einsum_w.GetType()); HWY_ASSERT(att_weights.GetType() == attn_vec_einsum_w.GetType());
HWY_ASSERT(att_weights.Rows() == model_dim); HWY_ASSERT(att_weights.Rows() == model_dim);
HWY_ASSERT(att_weights.Cols() == heads * qkv_dim); HWY_ASSERT(att_weights.Cols() == heads * qkv_dim);
HWY_ASSERT(attn_vec_einsum_w.HasPtr());
HWY_ASSERT(attn_vec_einsum_w.Rows() == heads * model_dim); HWY_ASSERT(attn_vec_einsum_w.Rows() == heads * model_dim);
HWY_ASSERT(attn_vec_einsum_w.Cols() == qkv_dim); HWY_ASSERT(attn_vec_einsum_w.Cols() == qkv_dim);
{
static std::mutex m;
std::lock_guard<std::mutex> lock(m);
mat_owners.AllocateFor(att_weights, MatPadding::kOdd);
}
const size_t T_bytes = att_weights.ElementBytes(); const size_t T_bytes = att_weights.ElementBytes();
for (size_t m = 0; m < model_dim; ++m) { for (size_t m = 0; m < model_dim; ++m) {
uint8_t* HWY_RESTRICT out_row = uint8_t* HWY_RESTRICT out_row =
@ -340,6 +357,77 @@ struct LayerWeightsPtrs {
} }
att_weights.SetScale(attn_vec_einsum_w.Scale()); att_weights.SetScale(attn_vec_einsum_w.Scale());
} }
// For FFN. Fast, only updates pointers.
void SplitW1() {
// We only use this tensor for Gemma layers.
if (layer_config.type != LayerAttentionType::kGemma) return;
// Files have both or neither of w1 and w2, and backprop/ allocates both.
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
// w is mutually exclusive with w1 and w2 in the file, but backprop/
// allocates both, so we can only rule out both being null.
HWY_ASSERT(gating_einsum_w.HasPtr() || gating_einsum_w1.HasPtr());
// Done if we already read split tensors. Note that they are not
// necessarily the same type.
if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return;
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
HWY_ASSERT(gating_einsum_w.Rows() == 2 * ff_hidden_dim);
HWY_ASSERT(gating_einsum_w1.Rows() == ff_hidden_dim);
HWY_ASSERT(gating_einsum_w2.Rows() == ff_hidden_dim);
// Cols are the model_dim but we don't have ModelConfig here.
HWY_ASSERT(gating_einsum_w1.Cols() == gating_einsum_w.Cols());
HWY_ASSERT(gating_einsum_w2.Cols() == gating_einsum_w.Cols());
const size_t stride = gating_einsum_w.Stride();
gating_einsum_w1.SetPtr(gating_einsum_w.Row(0), stride);
gating_einsum_w2.SetPtr(gating_einsum_w.Row(ff_hidden_dim), stride);
gating_einsum_w1.SetType(gating_einsum_w.GetType());
gating_einsum_w2.SetType(gating_einsum_w.GetType());
gating_einsum_w1.SetScale(gating_einsum_w.Scale());
gating_einsum_w2.SetScale(gating_einsum_w.Scale());
// Do not invalidate gating_einsum_w: backprop/ calls this repeatedly.
}
// For attention, which might not have a w2. Fast, only updates pointers.
void SplitAttW1() {
// We only use this tensor for Gemma layers.
if (layer_config.type != LayerAttentionType::kGemma) return;
// w is mutually exclusive with w1 in the file, but backprop/ allocates
// both, so we can only rule out both being null.
HWY_ASSERT(qkv_einsum_w.HasPtr() || qkv_einsum_w1.HasPtr());
// Done if we already read split tensors. Note that w2 does not exist for
// MHA, and otherwise might not be the same type.
if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return;
const size_t w1_rows = layer_config.heads * layer_config.QStride();
if (layer_config.IsMHA()) { // MHA only requires w1.
qkv_einsum_w1 = qkv_einsum_w;
HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows);
return;
}
const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim;
HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows);
HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows);
HWY_ASSERT(qkv_einsum_w2.Rows() == w2_rows);
// Cols are the model_dim but we don't have ModelConfig here.
HWY_ASSERT(qkv_einsum_w1.Cols() == qkv_einsum_w.Cols());
HWY_ASSERT(qkv_einsum_w2.Cols() == qkv_einsum_w.Cols());
const size_t stride = qkv_einsum_w.Stride();
qkv_einsum_w1.SetPtr(qkv_einsum_w.Row(0), stride);
qkv_einsum_w2.SetPtr(qkv_einsum_w.Row(w1_rows), stride);
qkv_einsum_w1.SetType(qkv_einsum_w.GetType());
qkv_einsum_w2.SetType(qkv_einsum_w.GetType());
qkv_einsum_w1.SetScale(qkv_einsum_w.Scale());
qkv_einsum_w2.SetScale(qkv_einsum_w.Scale());
// Do not invalidate qkv_einsum_w: backprop/ calls this repeatedly.
}
}; };
// Holds layer-independent weight metadata and pointers plus per-layer // Holds layer-independent weight metadata and pointers plus per-layer
@ -502,13 +590,13 @@ struct ModelWeightsPtrs {
// For reshaping file tensors to the shape expected by the code. This would // For reshaping file tensors to the shape expected by the code. This would
// ideally already happen in the importer. Must be called after reading and // ideally already happen in the importer. Must be called after reading and
// updating the attention weights. // updating the attention weights.
void Reshape(hwy::ThreadPool& pool) { void Fixup(MatOwners& mat_owners, hwy::ThreadPool& pool) {
pool.Run(0, c_layers.size(), [this](uint64_t layer, size_t /*thread*/) { pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
GetLayer(layer)->Reshape(); GetLayer(layer)->Fixup(mat_owners);
}); });
pool.Run(0, vit_layers.size(), [this](uint64_t layer, size_t /*thread*/) { pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
VitLayer(layer)->Reshape(); VitLayer(layer)->Fixup(mat_owners);
}); });
} }
}; // `WeightsPtrs` }; // `WeightsPtrs`
@ -521,10 +609,9 @@ class WeightsOwner {
// `weight_type` is obtained from `ModelConfig` in `ModelStore`. // `weight_type` is obtained from `ModelConfig` in `ModelStore`.
WeightsOwner(Type weight_type) : weight_type_(weight_type) {} WeightsOwner(Type weight_type) : weight_type_(weight_type) {}
// Reads tensor data from `BlobStore`, or for tensors marked `kOnlyAllocate`, // Reads tensor data from `BlobStore` or aborts on error.
// allocates memory and reshapes. Aborts on error. void ReadFromBlobs(const ModelStore& model, BlobReader& reader,
void ReadOrAllocate(const ModelStore& model, BlobReader& reader, hwy::ThreadPool& pool);
hwy::ThreadPool& pool);
// Calls `func(std::unique_ptr<WeightsPtrs<T>>&, args)`. `func` typically // Calls `func(std::unique_ptr<WeightsPtrs<T>>&, args)`. `func` typically
// calls `ForEachTensor`. // calls `ForEachTensor`.
@ -562,7 +649,7 @@ class WeightsOwner {
// Usually taken care of by `ReadOrAllocate`, but must also be called by // Usually taken care of by `ReadOrAllocate`, but must also be called by
// `optimize_test, which updates the attention weights from which this copies. // `optimize_test, which updates the attention weights from which this copies.
void Reshape(hwy::ThreadPool& pool); void Fixup(hwy::ThreadPool& pool);
private: private:
Type weight_type_; Type weight_type_;

View File

@ -714,13 +714,6 @@ struct ConstMat {
size_t Rows() const { return extents.rows; } size_t Rows() const { return extents.rows; }
size_t Cols() const { return extents.cols; } size_t Cols() const { return extents.cols; }
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a
// subrange of the original rows starting at row 0.
void ShrinkRows(size_t rows) {
HWY_ASSERT(rows <= extents.rows);
extents.rows = rows;
}
const T* HWY_RESTRICT ptr; const T* HWY_RESTRICT ptr;
Extents2D extents; Extents2D extents;
size_t stride; size_t stride;

View File

@ -57,12 +57,13 @@ HWY_INLINE float Dot(const ConstMat<WT>& w, size_t w_ofs, const VT* vec_aligned,
const auto span = MakeSpan(w.ptr, w_ofs + w.extents.rows * w.Stride()); const auto span = MakeSpan(w.ptr, w_ofs + w.extents.rows * w.Stride());
return w.Scale() * Dot(d, span, w_ofs, vec_aligned, num); return w.Scale() * Dot(d, span, w_ofs, vec_aligned, num);
} }
// For callers that pass `MatPtrT`. // For callers that pass `MatPtrT`, which is not necessarily packed - callers
// should use Stride() to compute `w_ofs`.
template <typename WT, typename VT> template <typename WT, typename VT>
HWY_INLINE float Dot(const MatPtrT<WT>& w, size_t w_ofs, const VT* vec_aligned, HWY_INLINE float Dot(const MatPtrT<WT>& w, size_t w_ofs, const VT* vec_aligned,
size_t num) { size_t num) {
const hn::ScalableTag<VT> d; const hn::ScalableTag<VT> d;
return w.Scale() * Dot(d, w.Span(), w_ofs, vec_aligned, num); return w.Scale() * Dot(d, w.PaddedSpan(), w_ofs, vec_aligned, num);
} }
// ArrayT is either MatPtrT or ConstMat. // ArrayT is either MatPtrT or ConstMat.

View File

@ -257,6 +257,10 @@ class MatPtrT : public MatPtr {
const MatT* Row(size_t row) const { return this->RowT<MatT>(row); } const MatT* Row(size_t row) const { return this->RowT<MatT>(row); }
MatT* Row(size_t row) { return this->RowT<MatT>(row); } MatT* Row(size_t row) { return this->RowT<MatT>(row); }
PackedSpan<const MatT> PaddedSpan() const {
return MakeConstSpan(Row(0), Rows() * Stride());
}
// For `compress-inl.h` functions, which assume contiguous streams and thus // For `compress-inl.h` functions, which assume contiguous streams and thus
// require packed layout. // require packed layout.
PackedSpan<const MatT> Span() const { PackedSpan<const MatT> Span() const {