diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 811aedf..fdf73ec 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -100,7 +100,7 @@ TEST(OptimizeTest, GradientDescent) { }; gemma.MutableWeights().RandInit(1.0f, gen); - gemma.MutableWeights().Reshape(pool); + gemma.MutableWeights().Fixup(pool); printf("Initial weights:\n"); gemma.MutableWeights().LogWeightStatsF32(); @@ -129,7 +129,7 @@ TEST(OptimizeTest, GradientDescent) { CrossEntropyLossBackwardPass(prompt, *gemma.Weights().GetF32(), forward, *grad.GetF32(), backward, inv_timescale, pool); - gemma.MutableWeights().Reshape(pool); + gemma.MutableWeights().Fixup(pool); num_ok += verify(prompt) ? 1 : 0; } total_loss /= kBatchSize; diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index f658099..a861155 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -246,10 +246,6 @@ class GemmaAttention { const size_t heads = layer_config_.heads; const size_t kv_heads = layer_config_.kv_heads; - using WeightT = typename decltype(layer_weights_.qkv_einsum_w)::T; - ConstMat w_q1(layer_weights_.qkv_einsum_w.HasPtr() - ? layer_weights_.qkv_einsum_w - : layer_weights_.qkv_einsum_w1); // The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim, // model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows. // We must shrink to the actual size because MatMul verifies @@ -257,23 +253,16 @@ class GemmaAttention { // rows are used. Otherwise, `QStride() == qkv_dim` and KV will be // computed in the second MatMul. const size_t w1_rows = heads * layer_config_.QStride(); - w_q1.ShrinkRows(w1_rows); - MatMul(activations_.pre_att_rms_out, w_q1, + HWY_DASSERT(layer_weights_.qkv_einsum_w1.Rows() == w1_rows); + MatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1, /*add=*/nullptr, *activations_.env, RowPtrFromMat(activations_.q)); if (is_mha_) { // Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already. } else { - decltype(w_q1) w_q2(layer_weights_.qkv_einsum_w.HasPtr() - ? layer_weights_.qkv_einsum_w - : layer_weights_.qkv_einsum_w2); - if (layer_weights_.qkv_einsum_w.HasPtr()) { - // Skip first half of the matrix. - w_q2.ofs = w_q2.Row(w1_rows); - } // KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v). const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim; - w_q2.ShrinkRows(w_rows_kv_cols); + HWY_DASSERT(layer_weights_.qkv_einsum_w2.Rows() == w_rows_kv_cols); // Single query and no wraparound means we can use a matmul and write // directly into the KV cache with a stride of cache_pos_size_. @@ -284,7 +273,7 @@ class GemmaAttention { float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs; RowPtrF kv_rows(kv, w_rows_kv_cols); kv_rows.SetStride(cache_pos_size_); - MatMul(activations_.pre_att_rms_out, w_q2, + MatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2, /*add=*/nullptr, *activations_.env, kv_rows); } else { // Proceed row by row because there will be wraparound. @@ -299,7 +288,8 @@ class GemmaAttention { const size_t kv_offset = cache_pos * cache_pos_size_ + layer_ * cache_layer_size_; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - MatVec(w_q2, w_q2.ofs, w_rows_kv_cols, model_dim, x, kv, pool_); + MatVec(layer_weights_.qkv_einsum_w2, 0, w_rows_kv_cols, model_dim, x, + kv, pool_); } } } // !is_mha_ @@ -825,32 +815,11 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, const float* output_bias = add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr; - // Define slightly more readable names for the weights and activations. - auto hidden_activations = RowPtrFromMat(activations.C1); - auto multiplier = RowPtrFromMat(activations.C2); - auto ffw_out = RowPtrFromMat(activations.ffw_out); - - using WeightT = typename decltype(layer_weights->gating_einsum_w)::T; - - // gating_einsum_w holds two half-matrices. We plan to change the importer to - // avoid this confusion by splitting into gating_einsum_w1 and - // gating_einsum_w2. TODO: move into Reshape(). - const bool split = layer_weights->gating_einsum_w.HasPtr(); - ConstMat w1(split ? layer_weights->gating_einsum_w - : layer_weights->gating_einsum_w1); - ConstMat w2(split ? layer_weights->gating_einsum_w - : layer_weights->gating_einsum_w2); - if (split) { - w2.ofs = w2.Row(ffh_hidden_dim); - // Ensure that B.Extents().row matches C.Cols() because MatMul checks that. - w1.ShrinkRows(ffh_hidden_dim); - w2.ShrinkRows(ffh_hidden_dim); - } - // Compute the hidden layer activations. - MatMul(activations.pre_ffw_rms_out, w1, bias1, *activations.env, - hidden_activations); - MatMul(activations.pre_ffw_rms_out, w2, bias2, *activations.env, multiplier); + MatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1, bias1, + *activations.env, RowPtrFromMat(activations.C1)); + MatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2, bias2, + *activations.env, RowPtrFromMat(activations.C2)); // Activation (Gelu) and maybe multiply by gate. Store activations in act. ActivationBatched(layer_weights->layer_config.activation, activations.C1, @@ -858,7 +827,7 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, // Hidden layer -> output layer. MatMul(activations.C1, layer_weights->linear_w, output_bias, *activations.env, - ffw_out); + RowPtrFromMat(activations.ffw_out)); } // Same as FFWNoVit, but with different layer_weights members and no second diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 5222c00..ebd37d3 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -57,7 +57,7 @@ Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env) model_(*reader_, loader.tokenizer, loader.wrapping), weights_(model_.Config().weight), chat_template_(model_.Tokenizer(), model_.Config().model) { - weights_.ReadOrAllocate(model_, *reader_, env_.ctx.pools.Pool()); + weights_.ReadFromBlobs(model_, *reader_, env_.ctx.pools.Pool()); reader_.reset(); } diff --git a/gemma/weights.cc b/gemma/weights.cc index b766ad8..0ef86c7 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -37,13 +37,15 @@ #include "hwy/profiler.h" #include "hwy/stats.h" -// TODO: move into foreach_target; this is only used for NUQ Reshape. +// TODO: move into foreach_target; this is only used for NUQ Fixup. #include "compression/compress-inl.h" namespace gcpp { -template <> -void LayerWeightsPtrs::Reshape() { +static void InitAttWeightsNUQ(const LayerConfig& layer_config, + MatPtrT& attn_vec_einsum_w, + MatPtrT& att_weights, + MatOwners& mat_owners) { if (!attn_vec_einsum_w.HasPtr()) return; HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ); @@ -83,6 +85,16 @@ void LayerWeightsPtrs::Reshape() { att_weights.SetScale(attn_vec_einsum_w.Scale()); } +static void SplitW1NUQ(const LayerConfig& layer_config) { + // TODO(janwas): implement. +} + +template <> +void LayerWeightsPtrs::Fixup(MatOwners& mat_owners) { + InitAttWeightsNUQ(layer_config, attn_vec_einsum_w, att_weights, mat_owners); + SplitW1NUQ(layer_config); +} + // Aborts on error. static void MapOrRead(const std::vector& mats, BlobReader& reader, const std::vector& ranges, @@ -134,8 +146,8 @@ static void MapOrRead(const std::vector& mats, BlobReader& reader, reader.ReadAll(pool); } -void WeightsOwner::ReadOrAllocate(const ModelStore& model, BlobReader& reader, - hwy::ThreadPool& pool) { +void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader, + hwy::ThreadPool& pool) { // List of tensors to read/map, and where from. std::vector mats; std::vector ranges; @@ -148,10 +160,6 @@ void WeightsOwner::ReadOrAllocate(const ModelStore& model, BlobReader& reader, // Enumerate all weights (negligible cost). CallT([&](const auto& weights) { weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { - if (t.flags & TensorArgs::kOnlyAllocate) { - mat_owners_.AllocateFor(t.mat, padding); - return; - } size_t key_idx; if (model.FindAndUpdateMatPtr(t.mat, key_idx)) { mats.push_back(&t.mat); @@ -165,7 +173,7 @@ void WeightsOwner::ReadOrAllocate(const ModelStore& model, BlobReader& reader, MapOrRead(mats, reader, ranges, mat_owners_, padding, pool); - Reshape(pool); + Fixup(pool); } // Allocates `*_weights_`, but not yet the tensors inside. This is split out @@ -218,6 +226,7 @@ void WeightsOwner::LogWeightStatsF32() { HWY_ASSERT(weight_type_ == Type::kF32); // Only for float weights. float_weights_->ForEachTensor( nullptr, nullptr, [&total_weights](const TensorArgs& t) { + if (!t.mat.HasPtr()) return; if (t.mat.Scale() != 1.0f) { printf("[scale=%f] ", t.mat.Scale()); } @@ -238,9 +247,9 @@ void WeightsOwner::LogWeightStatsF32() { printf("%-20s %12zu\n", "Total", total_weights); } -void WeightsOwner::Reshape(hwy::ThreadPool& pool) { - PROFILER_ZONE("Startup.Reshape"); - CallT([&pool](const auto& weights) { weights->Reshape(pool); }); +void WeightsOwner::Fixup(hwy::ThreadPool& pool) { + PROFILER_ZONE("Startup.Fixup"); + CallT([&](const auto& weights) { weights->Fixup(mat_owners_, pool); }); } std::vector WeightsOwner::AddTensorDataToWriter( @@ -248,7 +257,6 @@ std::vector WeightsOwner::AddTensorDataToWriter( std::vector serialized_mat_ptrs; CallT([&](const auto& weights) { weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { - if (t.flags & TensorArgs::kOnlyAllocate) return; if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return; HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name()); writer.Add(t.mat.Name(), t.mat.Packed(), t.mat.PackedBytes()); diff --git a/gemma/weights.h b/gemma/weights.h index 5a5de9d..e0cbf0b 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -21,6 +21,7 @@ #include #include +#include // NOLINT #include #include #include @@ -43,10 +44,10 @@ struct TensorArgs { // `flags` is a combination of zero or more `Flags`. TensorArgs(MatPtr& mat, const MatPtr* other_mat1, const MatPtr* other_mat2, int flags) - : mat(mat), other_mat1(other_mat1), other_mat2(other_mat2), flags(flags) { - // Does not make sense to combine both flags. - HWY_ASSERT(flags != (kMaybeRead | kOnlyAllocate)); - } + : mat(mat), + other_mat1(other_mat1), + other_mat2(other_mat2), + flags(flags) {} MatPtr& mat; const MatPtr* other_mat1; // either/both can be nullptr. @@ -59,8 +60,6 @@ struct TensorArgs { // Not an error if the tensor is not present in the file. For example, // the _w1/_w2 tensors are not always present. kMaybeRead = 1, - // Do not attempt to read, just allocate the tensor. Used for `Reshape`. - kOnlyAllocate = 2, }; const int flags; }; @@ -87,7 +86,6 @@ struct LayerWeightsPtrs { LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config, const TensorInfoRegistry& tensors) : suffix_(LayerSuffix(layer_idx)), - attn_vec_einsum_w(Concat("att_ein", suffix_), tensors), qkv_einsum_w(Concat("qkv_ein", suffix_), tensors), qkv_einsum_w1(Concat("qkv1_w", suffix_), tensors), qkv_einsum_w2(Concat("qkv2_w", suffix_), tensors), @@ -127,9 +125,13 @@ struct LayerWeightsPtrs { post_ffw_norm_scale(Concat("post_ff_ns", suffix_), tensors), ffw_gating_biases(Concat("ffw_gat_b", suffix_), tensors), ffw_output_biases(Concat("ffw_out_b", suffix_), tensors), + + attn_vec_einsum_w(Concat("att_ein", suffix_), tensors), att_weights(Concat("att_w", suffix_), tensors), + key_norm_scale(Concat("key_norm", suffix_), tensors), query_norm_scale(Concat("query_norm", suffix_), tensors), + layer_config(config) {} ~LayerWeightsPtrs() = default; @@ -144,10 +146,8 @@ struct LayerWeightsPtrs { hwy::If(), double, hwy::If(), float, BF16>>>; - MatPtrT attn_vec_einsum_w; - // qkv_einsum_w holds 2 different matrices, which may be separated out. - // On reading, which is used depends on what is in the file. - // At inference, the one with a non-null ptr is used. + // Files either have qkv_einsum_w with 2 stacked matrices or separate + // w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h. MatPtrT qkv_einsum_w; MatPtrT qkv_einsum_w1; MatPtrT qkv_einsum_w2; @@ -185,9 +185,8 @@ struct LayerWeightsPtrs { MatPtrT layer_norm_1_scale; } vit; - // gating_einsum_w holds 2 different matrices, which may be separated out. - // On reading, which is used depends on what is in the file. - // At inference, the one with a non-null ptr is used. + // Files either have gating_einsum_w with 2 stacked matrices or separate + // w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h. MatPtrT gating_einsum_w; MatPtrT gating_einsum_w1; MatPtrT gating_einsum_w2; @@ -201,7 +200,8 @@ struct LayerWeightsPtrs { MatPtrT ffw_gating_biases; MatPtrT ffw_output_biases; - MatPtrT att_weights; // For Reshape(); kOnlyAllocate. + MatPtrT attn_vec_einsum_w; // Use att_weights instead of this. + MatPtrT att_weights; // Use this instead of attn_vec_einsum_w. MatPtrT key_norm_scale; MatPtrT query_norm_scale; @@ -234,10 +234,10 @@ struct LayerWeightsPtrs { return; } if (layer_config.type == LayerAttentionType::kGemma) { - // Not read, will be filled by Reshape() from `attn_vec_einsum_w`. - func(TENSOR_ARGS(att_weights, kOnlyAllocate)); - func(TENSOR_ARGS(attn_vec_einsum_w, kMustRead)); - func(TENSOR_ARGS(qkv_einsum_w, kMustRead)); + // Either read from file, or allocated during Fixup(). + func(TENSOR_ARGS(att_weights, kMaybeRead)); + func(TENSOR_ARGS(attn_vec_einsum_w, kMaybeRead)); + func(TENSOR_ARGS(qkv_einsum_w, kMaybeRead)); func(TENSOR_ARGS(qkv_einsum_w1, kMaybeRead)); func(TENSOR_ARGS(qkv_einsum_w2, kMaybeRead)); } else { @@ -254,7 +254,7 @@ struct LayerWeightsPtrs { func(TENSOR_ARGS(griffin.a, kMustRead)); } { - func(TENSOR_ARGS(gating_einsum_w, kMustRead)); + func(TENSOR_ARGS(gating_einsum_w, kMaybeRead)); func(TENSOR_ARGS(gating_einsum_w1, kMaybeRead)); func(TENSOR_ARGS(gating_einsum_w2, kMaybeRead)); func(TENSOR_ARGS(linear_w, kMustRead)); @@ -306,14 +306,25 @@ struct LayerWeightsPtrs { }); } - // Initializes att_weights from `attn_vec_einsum_w`, hence this must be called - // after reading weights via `ForEachTensor`. - // TODO: update compression/convert_weights to bake this in. - void Reshape() { - // We only have/allocate this tensor for Gemma layers. - HWY_ASSERT(att_weights.HasPtr() == - (layer_config.type == LayerAttentionType::kGemma)); - if (!att_weights.HasPtr()) return; + // Must be called after reading weights via `ForEachTensor`. + // TODO: exporters should bake this into the weights already. + // WARNING: called from multiple threads; `mat_owners` requires a lock. + void Fixup(MatOwners& mat_owners) { + InitAttWeights(mat_owners); + SplitW1(); + SplitAttW1(); + } + + private: + // Copies att_weights from `attn_vec_einsum_w`. + void InitAttWeights(MatOwners& mat_owners) { + // We only use this tensor for Gemma layers. + if (layer_config.type != LayerAttentionType::kGemma) return; + + // Files must have one or the other, and backprop/ allocates both. + HWY_ASSERT(attn_vec_einsum_w.HasPtr() || att_weights.HasPtr()); + // Done if we already read the transposed tensor. + if (att_weights.HasPtr() && !attn_vec_einsum_w.HasPtr()) return; // NUQ is handled by a specialization in weights.cc. HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ); @@ -326,9 +337,15 @@ struct LayerWeightsPtrs { HWY_ASSERT(att_weights.GetType() == attn_vec_einsum_w.GetType()); HWY_ASSERT(att_weights.Rows() == model_dim); HWY_ASSERT(att_weights.Cols() == heads * qkv_dim); - HWY_ASSERT(attn_vec_einsum_w.HasPtr()); HWY_ASSERT(attn_vec_einsum_w.Rows() == heads * model_dim); HWY_ASSERT(attn_vec_einsum_w.Cols() == qkv_dim); + + { + static std::mutex m; + std::lock_guard lock(m); + mat_owners.AllocateFor(att_weights, MatPadding::kOdd); + } + const size_t T_bytes = att_weights.ElementBytes(); for (size_t m = 0; m < model_dim; ++m) { uint8_t* HWY_RESTRICT out_row = @@ -340,6 +357,77 @@ struct LayerWeightsPtrs { } att_weights.SetScale(attn_vec_einsum_w.Scale()); } + + // For FFN. Fast, only updates pointers. + void SplitW1() { + // We only use this tensor for Gemma layers. + if (layer_config.type != LayerAttentionType::kGemma) return; + + // Files have both or neither of w1 and w2, and backprop/ allocates both. + HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr()); + // w is mutually exclusive with w1 and w2 in the file, but backprop/ + // allocates both, so we can only rule out both being null. + HWY_ASSERT(gating_einsum_w.HasPtr() || gating_einsum_w1.HasPtr()); + // Done if we already read split tensors. Note that they are not + // necessarily the same type. + if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return; + + const size_t ff_hidden_dim = layer_config.ff_hidden_dim; + HWY_ASSERT(gating_einsum_w.Rows() == 2 * ff_hidden_dim); + HWY_ASSERT(gating_einsum_w1.Rows() == ff_hidden_dim); + HWY_ASSERT(gating_einsum_w2.Rows() == ff_hidden_dim); + // Cols are the model_dim but we don't have ModelConfig here. + HWY_ASSERT(gating_einsum_w1.Cols() == gating_einsum_w.Cols()); + HWY_ASSERT(gating_einsum_w2.Cols() == gating_einsum_w.Cols()); + + const size_t stride = gating_einsum_w.Stride(); + gating_einsum_w1.SetPtr(gating_einsum_w.Row(0), stride); + gating_einsum_w2.SetPtr(gating_einsum_w.Row(ff_hidden_dim), stride); + gating_einsum_w1.SetType(gating_einsum_w.GetType()); + gating_einsum_w2.SetType(gating_einsum_w.GetType()); + gating_einsum_w1.SetScale(gating_einsum_w.Scale()); + gating_einsum_w2.SetScale(gating_einsum_w.Scale()); + // Do not invalidate gating_einsum_w: backprop/ calls this repeatedly. + } + + // For attention, which might not have a w2. Fast, only updates pointers. + void SplitAttW1() { + // We only use this tensor for Gemma layers. + if (layer_config.type != LayerAttentionType::kGemma) return; + + // w is mutually exclusive with w1 in the file, but backprop/ allocates + // both, so we can only rule out both being null. + HWY_ASSERT(qkv_einsum_w.HasPtr() || qkv_einsum_w1.HasPtr()); + // Done if we already read split tensors. Note that w2 does not exist for + // MHA, and otherwise might not be the same type. + if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return; + + const size_t w1_rows = layer_config.heads * layer_config.QStride(); + + if (layer_config.IsMHA()) { // MHA only requires w1. + qkv_einsum_w1 = qkv_einsum_w; + HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows); + return; + } + + const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim; + + HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows); + HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows); + HWY_ASSERT(qkv_einsum_w2.Rows() == w2_rows); + // Cols are the model_dim but we don't have ModelConfig here. + HWY_ASSERT(qkv_einsum_w1.Cols() == qkv_einsum_w.Cols()); + HWY_ASSERT(qkv_einsum_w2.Cols() == qkv_einsum_w.Cols()); + + const size_t stride = qkv_einsum_w.Stride(); + qkv_einsum_w1.SetPtr(qkv_einsum_w.Row(0), stride); + qkv_einsum_w2.SetPtr(qkv_einsum_w.Row(w1_rows), stride); + qkv_einsum_w1.SetType(qkv_einsum_w.GetType()); + qkv_einsum_w2.SetType(qkv_einsum_w.GetType()); + qkv_einsum_w1.SetScale(qkv_einsum_w.Scale()); + qkv_einsum_w2.SetScale(qkv_einsum_w.Scale()); + // Do not invalidate qkv_einsum_w: backprop/ calls this repeatedly. + } }; // Holds layer-independent weight metadata and pointers plus per-layer @@ -502,13 +590,13 @@ struct ModelWeightsPtrs { // For reshaping file tensors to the shape expected by the code. This would // ideally already happen in the importer. Must be called after reading and // updating the attention weights. - void Reshape(hwy::ThreadPool& pool) { - pool.Run(0, c_layers.size(), [this](uint64_t layer, size_t /*thread*/) { - GetLayer(layer)->Reshape(); + void Fixup(MatOwners& mat_owners, hwy::ThreadPool& pool) { + pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) { + GetLayer(layer)->Fixup(mat_owners); }); - pool.Run(0, vit_layers.size(), [this](uint64_t layer, size_t /*thread*/) { - VitLayer(layer)->Reshape(); + pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) { + VitLayer(layer)->Fixup(mat_owners); }); } }; // `WeightsPtrs` @@ -521,10 +609,9 @@ class WeightsOwner { // `weight_type` is obtained from `ModelConfig` in `ModelStore`. WeightsOwner(Type weight_type) : weight_type_(weight_type) {} - // Reads tensor data from `BlobStore`, or for tensors marked `kOnlyAllocate`, - // allocates memory and reshapes. Aborts on error. - void ReadOrAllocate(const ModelStore& model, BlobReader& reader, - hwy::ThreadPool& pool); + // Reads tensor data from `BlobStore` or aborts on error. + void ReadFromBlobs(const ModelStore& model, BlobReader& reader, + hwy::ThreadPool& pool); // Calls `func(std::unique_ptr>&, args)`. `func` typically // calls `ForEachTensor`. @@ -562,7 +649,7 @@ class WeightsOwner { // Usually taken care of by `ReadOrAllocate`, but must also be called by // `optimize_test, which updates the attention weights from which this copies. - void Reshape(hwy::ThreadPool& pool); + void Fixup(hwy::ThreadPool& pool); private: Type weight_type_; diff --git a/ops/matmul.h b/ops/matmul.h index 773322f..e00a6ec 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -714,13 +714,6 @@ struct ConstMat { size_t Rows() const { return extents.rows; } size_t Cols() const { return extents.cols; } - // Shrinks the row-extent of this matrix view, i.e. reduces the view to a - // subrange of the original rows starting at row 0. - void ShrinkRows(size_t rows) { - HWY_ASSERT(rows <= extents.rows); - extents.rows = rows; - } - const T* HWY_RESTRICT ptr; Extents2D extents; size_t stride; diff --git a/ops/matvec-inl.h b/ops/matvec-inl.h index d838a18..83e70c8 100644 --- a/ops/matvec-inl.h +++ b/ops/matvec-inl.h @@ -57,12 +57,13 @@ HWY_INLINE float Dot(const ConstMat& w, size_t w_ofs, const VT* vec_aligned, const auto span = MakeSpan(w.ptr, w_ofs + w.extents.rows * w.Stride()); return w.Scale() * Dot(d, span, w_ofs, vec_aligned, num); } -// For callers that pass `MatPtrT`. +// For callers that pass `MatPtrT`, which is not necessarily packed - callers +// should use Stride() to compute `w_ofs`. template HWY_INLINE float Dot(const MatPtrT& w, size_t w_ofs, const VT* vec_aligned, size_t num) { const hn::ScalableTag d; - return w.Scale() * Dot(d, w.Span(), w_ofs, vec_aligned, num); + return w.Scale() * Dot(d, w.PaddedSpan(), w_ofs, vec_aligned, num); } // ArrayT is either MatPtrT or ConstMat. diff --git a/util/mat.h b/util/mat.h index d039bdf..ce134d1 100644 --- a/util/mat.h +++ b/util/mat.h @@ -257,6 +257,10 @@ class MatPtrT : public MatPtr { const MatT* Row(size_t row) const { return this->RowT(row); } MatT* Row(size_t row) { return this->RowT(row); } + PackedSpan PaddedSpan() const { + return MakeConstSpan(Row(0), Rows() * Stride()); + } + // For `compress-inl.h` functions, which assume contiguous streams and thus // require packed layout. PackedSpan Span() const {