Split W1/W2 as a load-time preprocess.

Remove kOnlyAllocate - no longer used. Rename ReadOrAllocate -> ReadFromBlobs.
Rename Reshape -> Fixup to reflect the new scope.
Remove no longer used ShrinkRows.

This simplifies gemma-inl and is a prerequisite for removing ConstMat
(whose .ofs was previously used for merged tensors)

PiperOrigin-RevId: 758214083
This commit is contained in:
Jan Wassenberg 2025-05-13 07:39:16 -07:00 committed by Copybara-Service
parent 2038dfd9cc
commit 8a312e9b89
8 changed files with 169 additions and 107 deletions

View File

@ -100,7 +100,7 @@ TEST(OptimizeTest, GradientDescent) {
};
gemma.MutableWeights().RandInit(1.0f, gen);
gemma.MutableWeights().Reshape(pool);
gemma.MutableWeights().Fixup(pool);
printf("Initial weights:\n");
gemma.MutableWeights().LogWeightStatsF32();
@ -129,7 +129,7 @@ TEST(OptimizeTest, GradientDescent) {
CrossEntropyLossBackwardPass(prompt, *gemma.Weights().GetF32(), forward,
*grad.GetF32(), backward, inv_timescale,
pool);
gemma.MutableWeights().Reshape(pool);
gemma.MutableWeights().Fixup(pool);
num_ok += verify(prompt) ? 1 : 0;
}
total_loss /= kBatchSize;

View File

@ -246,10 +246,6 @@ class GemmaAttention {
const size_t heads = layer_config_.heads;
const size_t kv_heads = layer_config_.kv_heads;
using WeightT = typename decltype(layer_weights_.qkv_einsum_w)::T;
ConstMat<WeightT> w_q1(layer_weights_.qkv_einsum_w.HasPtr()
? layer_weights_.qkv_einsum_w
: layer_weights_.qkv_einsum_w1);
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim,
// model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows.
// We must shrink to the actual size because MatMul verifies
@ -257,23 +253,16 @@ class GemmaAttention {
// rows are used. Otherwise, `QStride() == qkv_dim` and KV will be
// computed in the second MatMul.
const size_t w1_rows = heads * layer_config_.QStride();
w_q1.ShrinkRows(w1_rows);
MatMul(activations_.pre_att_rms_out, w_q1,
HWY_DASSERT(layer_weights_.qkv_einsum_w1.Rows() == w1_rows);
MatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1,
/*add=*/nullptr, *activations_.env, RowPtrFromMat(activations_.q));
if (is_mha_) {
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
} else {
decltype(w_q1) w_q2(layer_weights_.qkv_einsum_w.HasPtr()
? layer_weights_.qkv_einsum_w
: layer_weights_.qkv_einsum_w2);
if (layer_weights_.qkv_einsum_w.HasPtr()) {
// Skip first half of the matrix.
w_q2.ofs = w_q2.Row(w1_rows);
}
// KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v).
const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim;
w_q2.ShrinkRows(w_rows_kv_cols);
HWY_DASSERT(layer_weights_.qkv_einsum_w2.Rows() == w_rows_kv_cols);
// Single query and no wraparound means we can use a matmul and write
// directly into the KV cache with a stride of cache_pos_size_.
@ -284,7 +273,7 @@ class GemmaAttention {
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
RowPtrF kv_rows(kv, w_rows_kv_cols);
kv_rows.SetStride(cache_pos_size_);
MatMul(activations_.pre_att_rms_out, w_q2,
MatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2,
/*add=*/nullptr, *activations_.env, kv_rows);
} else {
// Proceed row by row because there will be wraparound.
@ -299,7 +288,8 @@ class GemmaAttention {
const size_t kv_offset =
cache_pos * cache_pos_size_ + layer_ * cache_layer_size_;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
MatVec(w_q2, w_q2.ofs, w_rows_kv_cols, model_dim, x, kv, pool_);
MatVec(layer_weights_.qkv_einsum_w2, 0, w_rows_kv_cols, model_dim, x,
kv, pool_);
}
}
} // !is_mha_
@ -825,32 +815,11 @@ HWY_NOINLINE void FFWNoVit(Activations& activations,
const float* output_bias =
add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr;
// Define slightly more readable names for the weights and activations.
auto hidden_activations = RowPtrFromMat(activations.C1);
auto multiplier = RowPtrFromMat(activations.C2);
auto ffw_out = RowPtrFromMat(activations.ffw_out);
using WeightT = typename decltype(layer_weights->gating_einsum_w)::T;
// gating_einsum_w holds two half-matrices. We plan to change the importer to
// avoid this confusion by splitting into gating_einsum_w1 and
// gating_einsum_w2. TODO: move into Reshape().
const bool split = layer_weights->gating_einsum_w.HasPtr();
ConstMat<WeightT> w1(split ? layer_weights->gating_einsum_w
: layer_weights->gating_einsum_w1);
ConstMat<WeightT> w2(split ? layer_weights->gating_einsum_w
: layer_weights->gating_einsum_w2);
if (split) {
w2.ofs = w2.Row(ffh_hidden_dim);
// Ensure that B.Extents().row matches C.Cols() because MatMul checks that.
w1.ShrinkRows(ffh_hidden_dim);
w2.ShrinkRows(ffh_hidden_dim);
}
// Compute the hidden layer activations.
MatMul(activations.pre_ffw_rms_out, w1, bias1, *activations.env,
hidden_activations);
MatMul(activations.pre_ffw_rms_out, w2, bias2, *activations.env, multiplier);
MatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1, bias1,
*activations.env, RowPtrFromMat(activations.C1));
MatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2, bias2,
*activations.env, RowPtrFromMat(activations.C2));
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
ActivationBatched(layer_weights->layer_config.activation, activations.C1,
@ -858,7 +827,7 @@ HWY_NOINLINE void FFWNoVit(Activations& activations,
// Hidden layer -> output layer.
MatMul(activations.C1, layer_weights->linear_w, output_bias, *activations.env,
ffw_out);
RowPtrFromMat(activations.ffw_out));
}
// Same as FFWNoVit, but with different layer_weights members and no second

View File

@ -57,7 +57,7 @@ Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env)
model_(*reader_, loader.tokenizer, loader.wrapping),
weights_(model_.Config().weight),
chat_template_(model_.Tokenizer(), model_.Config().model) {
weights_.ReadOrAllocate(model_, *reader_, env_.ctx.pools.Pool());
weights_.ReadFromBlobs(model_, *reader_, env_.ctx.pools.Pool());
reader_.reset();
}

View File

@ -37,13 +37,15 @@
#include "hwy/profiler.h"
#include "hwy/stats.h"
// TODO: move into foreach_target; this is only used for NUQ Reshape.
// TODO: move into foreach_target; this is only used for NUQ Fixup.
#include "compression/compress-inl.h"
namespace gcpp {
template <>
void LayerWeightsPtrs<NuqStream>::Reshape() {
static void InitAttWeightsNUQ(const LayerConfig& layer_config,
MatPtrT<NuqStream>& attn_vec_einsum_w,
MatPtrT<NuqStream>& att_weights,
MatOwners& mat_owners) {
if (!attn_vec_einsum_w.HasPtr()) return;
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ);
@ -83,6 +85,16 @@ void LayerWeightsPtrs<NuqStream>::Reshape() {
att_weights.SetScale(attn_vec_einsum_w.Scale());
}
static void SplitW1NUQ(const LayerConfig& layer_config) {
// TODO(janwas): implement.
}
template <>
void LayerWeightsPtrs<NuqStream>::Fixup(MatOwners& mat_owners) {
InitAttWeightsNUQ(layer_config, attn_vec_einsum_w, att_weights, mat_owners);
SplitW1NUQ(layer_config);
}
// Aborts on error.
static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
const std::vector<BlobRange>& ranges,
@ -134,8 +146,8 @@ static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
reader.ReadAll(pool);
}
void WeightsOwner::ReadOrAllocate(const ModelStore& model, BlobReader& reader,
hwy::ThreadPool& pool) {
void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
hwy::ThreadPool& pool) {
// List of tensors to read/map, and where from.
std::vector<MatPtr*> mats;
std::vector<BlobRange> ranges;
@ -148,10 +160,6 @@ void WeightsOwner::ReadOrAllocate(const ModelStore& model, BlobReader& reader,
// Enumerate all weights (negligible cost).
CallT([&](const auto& weights) {
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
if (t.flags & TensorArgs::kOnlyAllocate) {
mat_owners_.AllocateFor(t.mat, padding);
return;
}
size_t key_idx;
if (model.FindAndUpdateMatPtr(t.mat, key_idx)) {
mats.push_back(&t.mat);
@ -165,7 +173,7 @@ void WeightsOwner::ReadOrAllocate(const ModelStore& model, BlobReader& reader,
MapOrRead(mats, reader, ranges, mat_owners_, padding, pool);
Reshape(pool);
Fixup(pool);
}
// Allocates `*_weights_`, but not yet the tensors inside. This is split out
@ -218,6 +226,7 @@ void WeightsOwner::LogWeightStatsF32() {
HWY_ASSERT(weight_type_ == Type::kF32); // Only for float weights.
float_weights_->ForEachTensor(
nullptr, nullptr, [&total_weights](const TensorArgs& t) {
if (!t.mat.HasPtr()) return;
if (t.mat.Scale() != 1.0f) {
printf("[scale=%f] ", t.mat.Scale());
}
@ -238,9 +247,9 @@ void WeightsOwner::LogWeightStatsF32() {
printf("%-20s %12zu\n", "Total", total_weights);
}
void WeightsOwner::Reshape(hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.Reshape");
CallT([&pool](const auto& weights) { weights->Reshape(pool); });
void WeightsOwner::Fixup(hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.Fixup");
CallT([&](const auto& weights) { weights->Fixup(mat_owners_, pool); });
}
std::vector<uint32_t> WeightsOwner::AddTensorDataToWriter(
@ -248,7 +257,6 @@ std::vector<uint32_t> WeightsOwner::AddTensorDataToWriter(
std::vector<uint32_t> serialized_mat_ptrs;
CallT([&](const auto& weights) {
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
if (t.flags & TensorArgs::kOnlyAllocate) return;
if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return;
HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name());
writer.Add(t.mat.Name(), t.mat.Packed(), t.mat.PackedBytes());

View File

@ -21,6 +21,7 @@
#include <complex>
#include <memory>
#include <mutex> // NOLINT
#include <random>
#include <string>
#include <vector>
@ -43,10 +44,10 @@ struct TensorArgs {
// `flags` is a combination of zero or more `Flags`.
TensorArgs(MatPtr& mat, const MatPtr* other_mat1, const MatPtr* other_mat2,
int flags)
: mat(mat), other_mat1(other_mat1), other_mat2(other_mat2), flags(flags) {
// Does not make sense to combine both flags.
HWY_ASSERT(flags != (kMaybeRead | kOnlyAllocate));
}
: mat(mat),
other_mat1(other_mat1),
other_mat2(other_mat2),
flags(flags) {}
MatPtr& mat;
const MatPtr* other_mat1; // either/both can be nullptr.
@ -59,8 +60,6 @@ struct TensorArgs {
// Not an error if the tensor is not present in the file. For example,
// the _w1/_w2 tensors are not always present.
kMaybeRead = 1,
// Do not attempt to read, just allocate the tensor. Used for `Reshape`.
kOnlyAllocate = 2,
};
const int flags;
};
@ -87,7 +86,6 @@ struct LayerWeightsPtrs {
LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config,
const TensorInfoRegistry& tensors)
: suffix_(LayerSuffix(layer_idx)),
attn_vec_einsum_w(Concat("att_ein", suffix_), tensors),
qkv_einsum_w(Concat("qkv_ein", suffix_), tensors),
qkv_einsum_w1(Concat("qkv1_w", suffix_), tensors),
qkv_einsum_w2(Concat("qkv2_w", suffix_), tensors),
@ -127,9 +125,13 @@ struct LayerWeightsPtrs {
post_ffw_norm_scale(Concat("post_ff_ns", suffix_), tensors),
ffw_gating_biases(Concat("ffw_gat_b", suffix_), tensors),
ffw_output_biases(Concat("ffw_out_b", suffix_), tensors),
attn_vec_einsum_w(Concat("att_ein", suffix_), tensors),
att_weights(Concat("att_w", suffix_), tensors),
key_norm_scale(Concat("key_norm", suffix_), tensors),
query_norm_scale(Concat("query_norm", suffix_), tensors),
layer_config(config) {}
~LayerWeightsPtrs() = default;
@ -144,10 +146,8 @@ struct LayerWeightsPtrs {
hwy::If<hwy::IsSame<Weight, double>(), double,
hwy::If<IsF32<Weight>(), float, BF16>>>;
MatPtrT<Weight> attn_vec_einsum_w;
// qkv_einsum_w holds 2 different matrices, which may be separated out.
// On reading, which is used depends on what is in the file.
// At inference, the one with a non-null ptr is used.
// Files either have qkv_einsum_w with 2 stacked matrices or separate
// w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h.
MatPtrT<Weight> qkv_einsum_w;
MatPtrT<Weight> qkv_einsum_w1;
MatPtrT<Weight> qkv_einsum_w2;
@ -185,9 +185,8 @@ struct LayerWeightsPtrs {
MatPtrT<WeightF32OrBF16> layer_norm_1_scale;
} vit;
// gating_einsum_w holds 2 different matrices, which may be separated out.
// On reading, which is used depends on what is in the file.
// At inference, the one with a non-null ptr is used.
// Files either have gating_einsum_w with 2 stacked matrices or separate
// w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h.
MatPtrT<Weight> gating_einsum_w;
MatPtrT<Weight> gating_einsum_w1;
MatPtrT<Weight> gating_einsum_w2;
@ -201,7 +200,8 @@ struct LayerWeightsPtrs {
MatPtrT<float> ffw_gating_biases;
MatPtrT<float> ffw_output_biases;
MatPtrT<Weight> att_weights; // For Reshape(); kOnlyAllocate.
MatPtrT<Weight> attn_vec_einsum_w; // Use att_weights instead of this.
MatPtrT<Weight> att_weights; // Use this instead of attn_vec_einsum_w.
MatPtrT<WeightF32OrBF16> key_norm_scale;
MatPtrT<WeightF32OrBF16> query_norm_scale;
@ -234,10 +234,10 @@ struct LayerWeightsPtrs {
return;
}
if (layer_config.type == LayerAttentionType::kGemma) {
// Not read, will be filled by Reshape() from `attn_vec_einsum_w`.
func(TENSOR_ARGS(att_weights, kOnlyAllocate));
func(TENSOR_ARGS(attn_vec_einsum_w, kMustRead));
func(TENSOR_ARGS(qkv_einsum_w, kMustRead));
// Either read from file, or allocated during Fixup().
func(TENSOR_ARGS(att_weights, kMaybeRead));
func(TENSOR_ARGS(attn_vec_einsum_w, kMaybeRead));
func(TENSOR_ARGS(qkv_einsum_w, kMaybeRead));
func(TENSOR_ARGS(qkv_einsum_w1, kMaybeRead));
func(TENSOR_ARGS(qkv_einsum_w2, kMaybeRead));
} else {
@ -254,7 +254,7 @@ struct LayerWeightsPtrs {
func(TENSOR_ARGS(griffin.a, kMustRead));
}
{
func(TENSOR_ARGS(gating_einsum_w, kMustRead));
func(TENSOR_ARGS(gating_einsum_w, kMaybeRead));
func(TENSOR_ARGS(gating_einsum_w1, kMaybeRead));
func(TENSOR_ARGS(gating_einsum_w2, kMaybeRead));
func(TENSOR_ARGS(linear_w, kMustRead));
@ -306,14 +306,25 @@ struct LayerWeightsPtrs {
});
}
// Initializes att_weights from `attn_vec_einsum_w`, hence this must be called
// after reading weights via `ForEachTensor`.
// TODO: update compression/convert_weights to bake this in.
void Reshape() {
// We only have/allocate this tensor for Gemma layers.
HWY_ASSERT(att_weights.HasPtr() ==
(layer_config.type == LayerAttentionType::kGemma));
if (!att_weights.HasPtr()) return;
// Must be called after reading weights via `ForEachTensor`.
// TODO: exporters should bake this into the weights already.
// WARNING: called from multiple threads; `mat_owners` requires a lock.
void Fixup(MatOwners& mat_owners) {
InitAttWeights(mat_owners);
SplitW1();
SplitAttW1();
}
private:
// Copies att_weights from `attn_vec_einsum_w`.
void InitAttWeights(MatOwners& mat_owners) {
// We only use this tensor for Gemma layers.
if (layer_config.type != LayerAttentionType::kGemma) return;
// Files must have one or the other, and backprop/ allocates both.
HWY_ASSERT(attn_vec_einsum_w.HasPtr() || att_weights.HasPtr());
// Done if we already read the transposed tensor.
if (att_weights.HasPtr() && !attn_vec_einsum_w.HasPtr()) return;
// NUQ is handled by a specialization in weights.cc.
HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ);
@ -326,9 +337,15 @@ struct LayerWeightsPtrs {
HWY_ASSERT(att_weights.GetType() == attn_vec_einsum_w.GetType());
HWY_ASSERT(att_weights.Rows() == model_dim);
HWY_ASSERT(att_weights.Cols() == heads * qkv_dim);
HWY_ASSERT(attn_vec_einsum_w.HasPtr());
HWY_ASSERT(attn_vec_einsum_w.Rows() == heads * model_dim);
HWY_ASSERT(attn_vec_einsum_w.Cols() == qkv_dim);
{
static std::mutex m;
std::lock_guard<std::mutex> lock(m);
mat_owners.AllocateFor(att_weights, MatPadding::kOdd);
}
const size_t T_bytes = att_weights.ElementBytes();
for (size_t m = 0; m < model_dim; ++m) {
uint8_t* HWY_RESTRICT out_row =
@ -340,6 +357,77 @@ struct LayerWeightsPtrs {
}
att_weights.SetScale(attn_vec_einsum_w.Scale());
}
// For FFN. Fast, only updates pointers.
void SplitW1() {
// We only use this tensor for Gemma layers.
if (layer_config.type != LayerAttentionType::kGemma) return;
// Files have both or neither of w1 and w2, and backprop/ allocates both.
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
// w is mutually exclusive with w1 and w2 in the file, but backprop/
// allocates both, so we can only rule out both being null.
HWY_ASSERT(gating_einsum_w.HasPtr() || gating_einsum_w1.HasPtr());
// Done if we already read split tensors. Note that they are not
// necessarily the same type.
if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return;
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
HWY_ASSERT(gating_einsum_w.Rows() == 2 * ff_hidden_dim);
HWY_ASSERT(gating_einsum_w1.Rows() == ff_hidden_dim);
HWY_ASSERT(gating_einsum_w2.Rows() == ff_hidden_dim);
// Cols are the model_dim but we don't have ModelConfig here.
HWY_ASSERT(gating_einsum_w1.Cols() == gating_einsum_w.Cols());
HWY_ASSERT(gating_einsum_w2.Cols() == gating_einsum_w.Cols());
const size_t stride = gating_einsum_w.Stride();
gating_einsum_w1.SetPtr(gating_einsum_w.Row(0), stride);
gating_einsum_w2.SetPtr(gating_einsum_w.Row(ff_hidden_dim), stride);
gating_einsum_w1.SetType(gating_einsum_w.GetType());
gating_einsum_w2.SetType(gating_einsum_w.GetType());
gating_einsum_w1.SetScale(gating_einsum_w.Scale());
gating_einsum_w2.SetScale(gating_einsum_w.Scale());
// Do not invalidate gating_einsum_w: backprop/ calls this repeatedly.
}
// For attention, which might not have a w2. Fast, only updates pointers.
void SplitAttW1() {
// We only use this tensor for Gemma layers.
if (layer_config.type != LayerAttentionType::kGemma) return;
// w is mutually exclusive with w1 in the file, but backprop/ allocates
// both, so we can only rule out both being null.
HWY_ASSERT(qkv_einsum_w.HasPtr() || qkv_einsum_w1.HasPtr());
// Done if we already read split tensors. Note that w2 does not exist for
// MHA, and otherwise might not be the same type.
if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return;
const size_t w1_rows = layer_config.heads * layer_config.QStride();
if (layer_config.IsMHA()) { // MHA only requires w1.
qkv_einsum_w1 = qkv_einsum_w;
HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows);
return;
}
const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim;
HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows);
HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows);
HWY_ASSERT(qkv_einsum_w2.Rows() == w2_rows);
// Cols are the model_dim but we don't have ModelConfig here.
HWY_ASSERT(qkv_einsum_w1.Cols() == qkv_einsum_w.Cols());
HWY_ASSERT(qkv_einsum_w2.Cols() == qkv_einsum_w.Cols());
const size_t stride = qkv_einsum_w.Stride();
qkv_einsum_w1.SetPtr(qkv_einsum_w.Row(0), stride);
qkv_einsum_w2.SetPtr(qkv_einsum_w.Row(w1_rows), stride);
qkv_einsum_w1.SetType(qkv_einsum_w.GetType());
qkv_einsum_w2.SetType(qkv_einsum_w.GetType());
qkv_einsum_w1.SetScale(qkv_einsum_w.Scale());
qkv_einsum_w2.SetScale(qkv_einsum_w.Scale());
// Do not invalidate qkv_einsum_w: backprop/ calls this repeatedly.
}
};
// Holds layer-independent weight metadata and pointers plus per-layer
@ -502,13 +590,13 @@ struct ModelWeightsPtrs {
// For reshaping file tensors to the shape expected by the code. This would
// ideally already happen in the importer. Must be called after reading and
// updating the attention weights.
void Reshape(hwy::ThreadPool& pool) {
pool.Run(0, c_layers.size(), [this](uint64_t layer, size_t /*thread*/) {
GetLayer(layer)->Reshape();
void Fixup(MatOwners& mat_owners, hwy::ThreadPool& pool) {
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
GetLayer(layer)->Fixup(mat_owners);
});
pool.Run(0, vit_layers.size(), [this](uint64_t layer, size_t /*thread*/) {
VitLayer(layer)->Reshape();
pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
VitLayer(layer)->Fixup(mat_owners);
});
}
}; // `WeightsPtrs`
@ -521,10 +609,9 @@ class WeightsOwner {
// `weight_type` is obtained from `ModelConfig` in `ModelStore`.
WeightsOwner(Type weight_type) : weight_type_(weight_type) {}
// Reads tensor data from `BlobStore`, or for tensors marked `kOnlyAllocate`,
// allocates memory and reshapes. Aborts on error.
void ReadOrAllocate(const ModelStore& model, BlobReader& reader,
hwy::ThreadPool& pool);
// Reads tensor data from `BlobStore` or aborts on error.
void ReadFromBlobs(const ModelStore& model, BlobReader& reader,
hwy::ThreadPool& pool);
// Calls `func(std::unique_ptr<WeightsPtrs<T>>&, args)`. `func` typically
// calls `ForEachTensor`.
@ -562,7 +649,7 @@ class WeightsOwner {
// Usually taken care of by `ReadOrAllocate`, but must also be called by
// `optimize_test, which updates the attention weights from which this copies.
void Reshape(hwy::ThreadPool& pool);
void Fixup(hwy::ThreadPool& pool);
private:
Type weight_type_;

View File

@ -714,13 +714,6 @@ struct ConstMat {
size_t Rows() const { return extents.rows; }
size_t Cols() const { return extents.cols; }
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a
// subrange of the original rows starting at row 0.
void ShrinkRows(size_t rows) {
HWY_ASSERT(rows <= extents.rows);
extents.rows = rows;
}
const T* HWY_RESTRICT ptr;
Extents2D extents;
size_t stride;

View File

@ -57,12 +57,13 @@ HWY_INLINE float Dot(const ConstMat<WT>& w, size_t w_ofs, const VT* vec_aligned,
const auto span = MakeSpan(w.ptr, w_ofs + w.extents.rows * w.Stride());
return w.Scale() * Dot(d, span, w_ofs, vec_aligned, num);
}
// For callers that pass `MatPtrT`.
// For callers that pass `MatPtrT`, which is not necessarily packed - callers
// should use Stride() to compute `w_ofs`.
template <typename WT, typename VT>
HWY_INLINE float Dot(const MatPtrT<WT>& w, size_t w_ofs, const VT* vec_aligned,
size_t num) {
const hn::ScalableTag<VT> d;
return w.Scale() * Dot(d, w.Span(), w_ofs, vec_aligned, num);
return w.Scale() * Dot(d, w.PaddedSpan(), w_ofs, vec_aligned, num);
}
// ArrayT is either MatPtrT or ConstMat.

View File

@ -257,6 +257,10 @@ class MatPtrT : public MatPtr {
const MatT* Row(size_t row) const { return this->RowT<MatT>(row); }
MatT* Row(size_t row) { return this->RowT<MatT>(row); }
PackedSpan<const MatT> PaddedSpan() const {
return MakeConstSpan(Row(0), Rows() * Stride());
}
// For `compress-inl.h` functions, which assume contiguous streams and thus
// require packed layout.
PackedSpan<const MatT> Span() const {