mirror of https://github.com/google/gemma.cpp.git
Split W1/W2 as a load-time preprocess.
Remove kOnlyAllocate - no longer used. Rename ReadOrAllocate -> ReadFromBlobs. Rename Reshape -> Fixup to reflect the new scope. Remove no longer used ShrinkRows. This simplifies gemma-inl and is a prerequisite for removing ConstMat (whose .ofs was previously used for merged tensors) PiperOrigin-RevId: 758214083
This commit is contained in:
parent
2038dfd9cc
commit
8a312e9b89
|
|
@ -100,7 +100,7 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
};
|
};
|
||||||
|
|
||||||
gemma.MutableWeights().RandInit(1.0f, gen);
|
gemma.MutableWeights().RandInit(1.0f, gen);
|
||||||
gemma.MutableWeights().Reshape(pool);
|
gemma.MutableWeights().Fixup(pool);
|
||||||
|
|
||||||
printf("Initial weights:\n");
|
printf("Initial weights:\n");
|
||||||
gemma.MutableWeights().LogWeightStatsF32();
|
gemma.MutableWeights().LogWeightStatsF32();
|
||||||
|
|
@ -129,7 +129,7 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
CrossEntropyLossBackwardPass(prompt, *gemma.Weights().GetF32(), forward,
|
CrossEntropyLossBackwardPass(prompt, *gemma.Weights().GetF32(), forward,
|
||||||
*grad.GetF32(), backward, inv_timescale,
|
*grad.GetF32(), backward, inv_timescale,
|
||||||
pool);
|
pool);
|
||||||
gemma.MutableWeights().Reshape(pool);
|
gemma.MutableWeights().Fixup(pool);
|
||||||
num_ok += verify(prompt) ? 1 : 0;
|
num_ok += verify(prompt) ? 1 : 0;
|
||||||
}
|
}
|
||||||
total_loss /= kBatchSize;
|
total_loss /= kBatchSize;
|
||||||
|
|
|
||||||
|
|
@ -246,10 +246,6 @@ class GemmaAttention {
|
||||||
const size_t heads = layer_config_.heads;
|
const size_t heads = layer_config_.heads;
|
||||||
const size_t kv_heads = layer_config_.kv_heads;
|
const size_t kv_heads = layer_config_.kv_heads;
|
||||||
|
|
||||||
using WeightT = typename decltype(layer_weights_.qkv_einsum_w)::T;
|
|
||||||
ConstMat<WeightT> w_q1(layer_weights_.qkv_einsum_w.HasPtr()
|
|
||||||
? layer_weights_.qkv_einsum_w
|
|
||||||
: layer_weights_.qkv_einsum_w1);
|
|
||||||
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim,
|
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim,
|
||||||
// model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows.
|
// model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows.
|
||||||
// We must shrink to the actual size because MatMul verifies
|
// We must shrink to the actual size because MatMul verifies
|
||||||
|
|
@ -257,23 +253,16 @@ class GemmaAttention {
|
||||||
// rows are used. Otherwise, `QStride() == qkv_dim` and KV will be
|
// rows are used. Otherwise, `QStride() == qkv_dim` and KV will be
|
||||||
// computed in the second MatMul.
|
// computed in the second MatMul.
|
||||||
const size_t w1_rows = heads * layer_config_.QStride();
|
const size_t w1_rows = heads * layer_config_.QStride();
|
||||||
w_q1.ShrinkRows(w1_rows);
|
HWY_DASSERT(layer_weights_.qkv_einsum_w1.Rows() == w1_rows);
|
||||||
MatMul(activations_.pre_att_rms_out, w_q1,
|
MatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1,
|
||||||
/*add=*/nullptr, *activations_.env, RowPtrFromMat(activations_.q));
|
/*add=*/nullptr, *activations_.env, RowPtrFromMat(activations_.q));
|
||||||
|
|
||||||
if (is_mha_) {
|
if (is_mha_) {
|
||||||
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
|
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
|
||||||
} else {
|
} else {
|
||||||
decltype(w_q1) w_q2(layer_weights_.qkv_einsum_w.HasPtr()
|
|
||||||
? layer_weights_.qkv_einsum_w
|
|
||||||
: layer_weights_.qkv_einsum_w2);
|
|
||||||
if (layer_weights_.qkv_einsum_w.HasPtr()) {
|
|
||||||
// Skip first half of the matrix.
|
|
||||||
w_q2.ofs = w_q2.Row(w1_rows);
|
|
||||||
}
|
|
||||||
// KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v).
|
// KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v).
|
||||||
const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim;
|
const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim;
|
||||||
w_q2.ShrinkRows(w_rows_kv_cols);
|
HWY_DASSERT(layer_weights_.qkv_einsum_w2.Rows() == w_rows_kv_cols);
|
||||||
|
|
||||||
// Single query and no wraparound means we can use a matmul and write
|
// Single query and no wraparound means we can use a matmul and write
|
||||||
// directly into the KV cache with a stride of cache_pos_size_.
|
// directly into the KV cache with a stride of cache_pos_size_.
|
||||||
|
|
@ -284,7 +273,7 @@ class GemmaAttention {
|
||||||
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
|
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
|
||||||
RowPtrF kv_rows(kv, w_rows_kv_cols);
|
RowPtrF kv_rows(kv, w_rows_kv_cols);
|
||||||
kv_rows.SetStride(cache_pos_size_);
|
kv_rows.SetStride(cache_pos_size_);
|
||||||
MatMul(activations_.pre_att_rms_out, w_q2,
|
MatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2,
|
||||||
/*add=*/nullptr, *activations_.env, kv_rows);
|
/*add=*/nullptr, *activations_.env, kv_rows);
|
||||||
} else {
|
} else {
|
||||||
// Proceed row by row because there will be wraparound.
|
// Proceed row by row because there will be wraparound.
|
||||||
|
|
@ -299,7 +288,8 @@ class GemmaAttention {
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
cache_pos * cache_pos_size_ + layer_ * cache_layer_size_;
|
cache_pos * cache_pos_size_ + layer_ * cache_layer_size_;
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
MatVec(w_q2, w_q2.ofs, w_rows_kv_cols, model_dim, x, kv, pool_);
|
MatVec(layer_weights_.qkv_einsum_w2, 0, w_rows_kv_cols, model_dim, x,
|
||||||
|
kv, pool_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // !is_mha_
|
} // !is_mha_
|
||||||
|
|
@ -825,32 +815,11 @@ HWY_NOINLINE void FFWNoVit(Activations& activations,
|
||||||
const float* output_bias =
|
const float* output_bias =
|
||||||
add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr;
|
add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr;
|
||||||
|
|
||||||
// Define slightly more readable names for the weights and activations.
|
|
||||||
auto hidden_activations = RowPtrFromMat(activations.C1);
|
|
||||||
auto multiplier = RowPtrFromMat(activations.C2);
|
|
||||||
auto ffw_out = RowPtrFromMat(activations.ffw_out);
|
|
||||||
|
|
||||||
using WeightT = typename decltype(layer_weights->gating_einsum_w)::T;
|
|
||||||
|
|
||||||
// gating_einsum_w holds two half-matrices. We plan to change the importer to
|
|
||||||
// avoid this confusion by splitting into gating_einsum_w1 and
|
|
||||||
// gating_einsum_w2. TODO: move into Reshape().
|
|
||||||
const bool split = layer_weights->gating_einsum_w.HasPtr();
|
|
||||||
ConstMat<WeightT> w1(split ? layer_weights->gating_einsum_w
|
|
||||||
: layer_weights->gating_einsum_w1);
|
|
||||||
ConstMat<WeightT> w2(split ? layer_weights->gating_einsum_w
|
|
||||||
: layer_weights->gating_einsum_w2);
|
|
||||||
if (split) {
|
|
||||||
w2.ofs = w2.Row(ffh_hidden_dim);
|
|
||||||
// Ensure that B.Extents().row matches C.Cols() because MatMul checks that.
|
|
||||||
w1.ShrinkRows(ffh_hidden_dim);
|
|
||||||
w2.ShrinkRows(ffh_hidden_dim);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the hidden layer activations.
|
// Compute the hidden layer activations.
|
||||||
MatMul(activations.pre_ffw_rms_out, w1, bias1, *activations.env,
|
MatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1, bias1,
|
||||||
hidden_activations);
|
*activations.env, RowPtrFromMat(activations.C1));
|
||||||
MatMul(activations.pre_ffw_rms_out, w2, bias2, *activations.env, multiplier);
|
MatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2, bias2,
|
||||||
|
*activations.env, RowPtrFromMat(activations.C2));
|
||||||
|
|
||||||
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
|
||||||
ActivationBatched(layer_weights->layer_config.activation, activations.C1,
|
ActivationBatched(layer_weights->layer_config.activation, activations.C1,
|
||||||
|
|
@ -858,7 +827,7 @@ HWY_NOINLINE void FFWNoVit(Activations& activations,
|
||||||
|
|
||||||
// Hidden layer -> output layer.
|
// Hidden layer -> output layer.
|
||||||
MatMul(activations.C1, layer_weights->linear_w, output_bias, *activations.env,
|
MatMul(activations.C1, layer_weights->linear_w, output_bias, *activations.env,
|
||||||
ffw_out);
|
RowPtrFromMat(activations.ffw_out));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as FFWNoVit, but with different layer_weights members and no second
|
// Same as FFWNoVit, but with different layer_weights members and no second
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,7 @@ Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env)
|
||||||
model_(*reader_, loader.tokenizer, loader.wrapping),
|
model_(*reader_, loader.tokenizer, loader.wrapping),
|
||||||
weights_(model_.Config().weight),
|
weights_(model_.Config().weight),
|
||||||
chat_template_(model_.Tokenizer(), model_.Config().model) {
|
chat_template_(model_.Tokenizer(), model_.Config().model) {
|
||||||
weights_.ReadOrAllocate(model_, *reader_, env_.ctx.pools.Pool());
|
weights_.ReadFromBlobs(model_, *reader_, env_.ctx.pools.Pool());
|
||||||
reader_.reset();
|
reader_.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,13 +37,15 @@
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/stats.h"
|
#include "hwy/stats.h"
|
||||||
|
|
||||||
// TODO: move into foreach_target; this is only used for NUQ Reshape.
|
// TODO: move into foreach_target; this is only used for NUQ Fixup.
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
template <>
|
static void InitAttWeightsNUQ(const LayerConfig& layer_config,
|
||||||
void LayerWeightsPtrs<NuqStream>::Reshape() {
|
MatPtrT<NuqStream>& attn_vec_einsum_w,
|
||||||
|
MatPtrT<NuqStream>& att_weights,
|
||||||
|
MatOwners& mat_owners) {
|
||||||
if (!attn_vec_einsum_w.HasPtr()) return;
|
if (!attn_vec_einsum_w.HasPtr()) return;
|
||||||
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ);
|
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ);
|
||||||
|
|
||||||
|
|
@ -83,6 +85,16 @@ void LayerWeightsPtrs<NuqStream>::Reshape() {
|
||||||
att_weights.SetScale(attn_vec_einsum_w.Scale());
|
att_weights.SetScale(attn_vec_einsum_w.Scale());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void SplitW1NUQ(const LayerConfig& layer_config) {
|
||||||
|
// TODO(janwas): implement.
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void LayerWeightsPtrs<NuqStream>::Fixup(MatOwners& mat_owners) {
|
||||||
|
InitAttWeightsNUQ(layer_config, attn_vec_einsum_w, att_weights, mat_owners);
|
||||||
|
SplitW1NUQ(layer_config);
|
||||||
|
}
|
||||||
|
|
||||||
// Aborts on error.
|
// Aborts on error.
|
||||||
static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
|
static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
|
||||||
const std::vector<BlobRange>& ranges,
|
const std::vector<BlobRange>& ranges,
|
||||||
|
|
@ -134,8 +146,8 @@ static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
|
||||||
reader.ReadAll(pool);
|
reader.ReadAll(pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
void WeightsOwner::ReadOrAllocate(const ModelStore& model, BlobReader& reader,
|
void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
// List of tensors to read/map, and where from.
|
// List of tensors to read/map, and where from.
|
||||||
std::vector<MatPtr*> mats;
|
std::vector<MatPtr*> mats;
|
||||||
std::vector<BlobRange> ranges;
|
std::vector<BlobRange> ranges;
|
||||||
|
|
@ -148,10 +160,6 @@ void WeightsOwner::ReadOrAllocate(const ModelStore& model, BlobReader& reader,
|
||||||
// Enumerate all weights (negligible cost).
|
// Enumerate all weights (negligible cost).
|
||||||
CallT([&](const auto& weights) {
|
CallT([&](const auto& weights) {
|
||||||
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
||||||
if (t.flags & TensorArgs::kOnlyAllocate) {
|
|
||||||
mat_owners_.AllocateFor(t.mat, padding);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
size_t key_idx;
|
size_t key_idx;
|
||||||
if (model.FindAndUpdateMatPtr(t.mat, key_idx)) {
|
if (model.FindAndUpdateMatPtr(t.mat, key_idx)) {
|
||||||
mats.push_back(&t.mat);
|
mats.push_back(&t.mat);
|
||||||
|
|
@ -165,7 +173,7 @@ void WeightsOwner::ReadOrAllocate(const ModelStore& model, BlobReader& reader,
|
||||||
|
|
||||||
MapOrRead(mats, reader, ranges, mat_owners_, padding, pool);
|
MapOrRead(mats, reader, ranges, mat_owners_, padding, pool);
|
||||||
|
|
||||||
Reshape(pool);
|
Fixup(pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocates `*_weights_`, but not yet the tensors inside. This is split out
|
// Allocates `*_weights_`, but not yet the tensors inside. This is split out
|
||||||
|
|
@ -218,6 +226,7 @@ void WeightsOwner::LogWeightStatsF32() {
|
||||||
HWY_ASSERT(weight_type_ == Type::kF32); // Only for float weights.
|
HWY_ASSERT(weight_type_ == Type::kF32); // Only for float weights.
|
||||||
float_weights_->ForEachTensor(
|
float_weights_->ForEachTensor(
|
||||||
nullptr, nullptr, [&total_weights](const TensorArgs& t) {
|
nullptr, nullptr, [&total_weights](const TensorArgs& t) {
|
||||||
|
if (!t.mat.HasPtr()) return;
|
||||||
if (t.mat.Scale() != 1.0f) {
|
if (t.mat.Scale() != 1.0f) {
|
||||||
printf("[scale=%f] ", t.mat.Scale());
|
printf("[scale=%f] ", t.mat.Scale());
|
||||||
}
|
}
|
||||||
|
|
@ -238,9 +247,9 @@ void WeightsOwner::LogWeightStatsF32() {
|
||||||
printf("%-20s %12zu\n", "Total", total_weights);
|
printf("%-20s %12zu\n", "Total", total_weights);
|
||||||
}
|
}
|
||||||
|
|
||||||
void WeightsOwner::Reshape(hwy::ThreadPool& pool) {
|
void WeightsOwner::Fixup(hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Startup.Reshape");
|
PROFILER_ZONE("Startup.Fixup");
|
||||||
CallT([&pool](const auto& weights) { weights->Reshape(pool); });
|
CallT([&](const auto& weights) { weights->Fixup(mat_owners_, pool); });
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint32_t> WeightsOwner::AddTensorDataToWriter(
|
std::vector<uint32_t> WeightsOwner::AddTensorDataToWriter(
|
||||||
|
|
@ -248,7 +257,6 @@ std::vector<uint32_t> WeightsOwner::AddTensorDataToWriter(
|
||||||
std::vector<uint32_t> serialized_mat_ptrs;
|
std::vector<uint32_t> serialized_mat_ptrs;
|
||||||
CallT([&](const auto& weights) {
|
CallT([&](const auto& weights) {
|
||||||
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
||||||
if (t.flags & TensorArgs::kOnlyAllocate) return;
|
|
||||||
if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return;
|
if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return;
|
||||||
HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name());
|
HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name());
|
||||||
writer.Add(t.mat.Name(), t.mat.Packed(), t.mat.PackedBytes());
|
writer.Add(t.mat.Name(), t.mat.Packed(), t.mat.PackedBytes());
|
||||||
|
|
|
||||||
165
gemma/weights.h
165
gemma/weights.h
|
|
@ -21,6 +21,7 @@
|
||||||
|
|
||||||
#include <complex>
|
#include <complex>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <mutex> // NOLINT
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
@ -43,10 +44,10 @@ struct TensorArgs {
|
||||||
// `flags` is a combination of zero or more `Flags`.
|
// `flags` is a combination of zero or more `Flags`.
|
||||||
TensorArgs(MatPtr& mat, const MatPtr* other_mat1, const MatPtr* other_mat2,
|
TensorArgs(MatPtr& mat, const MatPtr* other_mat1, const MatPtr* other_mat2,
|
||||||
int flags)
|
int flags)
|
||||||
: mat(mat), other_mat1(other_mat1), other_mat2(other_mat2), flags(flags) {
|
: mat(mat),
|
||||||
// Does not make sense to combine both flags.
|
other_mat1(other_mat1),
|
||||||
HWY_ASSERT(flags != (kMaybeRead | kOnlyAllocate));
|
other_mat2(other_mat2),
|
||||||
}
|
flags(flags) {}
|
||||||
|
|
||||||
MatPtr& mat;
|
MatPtr& mat;
|
||||||
const MatPtr* other_mat1; // either/both can be nullptr.
|
const MatPtr* other_mat1; // either/both can be nullptr.
|
||||||
|
|
@ -59,8 +60,6 @@ struct TensorArgs {
|
||||||
// Not an error if the tensor is not present in the file. For example,
|
// Not an error if the tensor is not present in the file. For example,
|
||||||
// the _w1/_w2 tensors are not always present.
|
// the _w1/_w2 tensors are not always present.
|
||||||
kMaybeRead = 1,
|
kMaybeRead = 1,
|
||||||
// Do not attempt to read, just allocate the tensor. Used for `Reshape`.
|
|
||||||
kOnlyAllocate = 2,
|
|
||||||
};
|
};
|
||||||
const int flags;
|
const int flags;
|
||||||
};
|
};
|
||||||
|
|
@ -87,7 +86,6 @@ struct LayerWeightsPtrs {
|
||||||
LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config,
|
LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config,
|
||||||
const TensorInfoRegistry& tensors)
|
const TensorInfoRegistry& tensors)
|
||||||
: suffix_(LayerSuffix(layer_idx)),
|
: suffix_(LayerSuffix(layer_idx)),
|
||||||
attn_vec_einsum_w(Concat("att_ein", suffix_), tensors),
|
|
||||||
qkv_einsum_w(Concat("qkv_ein", suffix_), tensors),
|
qkv_einsum_w(Concat("qkv_ein", suffix_), tensors),
|
||||||
qkv_einsum_w1(Concat("qkv1_w", suffix_), tensors),
|
qkv_einsum_w1(Concat("qkv1_w", suffix_), tensors),
|
||||||
qkv_einsum_w2(Concat("qkv2_w", suffix_), tensors),
|
qkv_einsum_w2(Concat("qkv2_w", suffix_), tensors),
|
||||||
|
|
@ -127,9 +125,13 @@ struct LayerWeightsPtrs {
|
||||||
post_ffw_norm_scale(Concat("post_ff_ns", suffix_), tensors),
|
post_ffw_norm_scale(Concat("post_ff_ns", suffix_), tensors),
|
||||||
ffw_gating_biases(Concat("ffw_gat_b", suffix_), tensors),
|
ffw_gating_biases(Concat("ffw_gat_b", suffix_), tensors),
|
||||||
ffw_output_biases(Concat("ffw_out_b", suffix_), tensors),
|
ffw_output_biases(Concat("ffw_out_b", suffix_), tensors),
|
||||||
|
|
||||||
|
attn_vec_einsum_w(Concat("att_ein", suffix_), tensors),
|
||||||
att_weights(Concat("att_w", suffix_), tensors),
|
att_weights(Concat("att_w", suffix_), tensors),
|
||||||
|
|
||||||
key_norm_scale(Concat("key_norm", suffix_), tensors),
|
key_norm_scale(Concat("key_norm", suffix_), tensors),
|
||||||
query_norm_scale(Concat("query_norm", suffix_), tensors),
|
query_norm_scale(Concat("query_norm", suffix_), tensors),
|
||||||
|
|
||||||
layer_config(config) {}
|
layer_config(config) {}
|
||||||
~LayerWeightsPtrs() = default;
|
~LayerWeightsPtrs() = default;
|
||||||
|
|
||||||
|
|
@ -144,10 +146,8 @@ struct LayerWeightsPtrs {
|
||||||
hwy::If<hwy::IsSame<Weight, double>(), double,
|
hwy::If<hwy::IsSame<Weight, double>(), double,
|
||||||
hwy::If<IsF32<Weight>(), float, BF16>>>;
|
hwy::If<IsF32<Weight>(), float, BF16>>>;
|
||||||
|
|
||||||
MatPtrT<Weight> attn_vec_einsum_w;
|
// Files either have qkv_einsum_w with 2 stacked matrices or separate
|
||||||
// qkv_einsum_w holds 2 different matrices, which may be separated out.
|
// w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h.
|
||||||
// On reading, which is used depends on what is in the file.
|
|
||||||
// At inference, the one with a non-null ptr is used.
|
|
||||||
MatPtrT<Weight> qkv_einsum_w;
|
MatPtrT<Weight> qkv_einsum_w;
|
||||||
MatPtrT<Weight> qkv_einsum_w1;
|
MatPtrT<Weight> qkv_einsum_w1;
|
||||||
MatPtrT<Weight> qkv_einsum_w2;
|
MatPtrT<Weight> qkv_einsum_w2;
|
||||||
|
|
@ -185,9 +185,8 @@ struct LayerWeightsPtrs {
|
||||||
MatPtrT<WeightF32OrBF16> layer_norm_1_scale;
|
MatPtrT<WeightF32OrBF16> layer_norm_1_scale;
|
||||||
} vit;
|
} vit;
|
||||||
|
|
||||||
// gating_einsum_w holds 2 different matrices, which may be separated out.
|
// Files either have gating_einsum_w with 2 stacked matrices or separate
|
||||||
// On reading, which is used depends on what is in the file.
|
// w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h.
|
||||||
// At inference, the one with a non-null ptr is used.
|
|
||||||
MatPtrT<Weight> gating_einsum_w;
|
MatPtrT<Weight> gating_einsum_w;
|
||||||
MatPtrT<Weight> gating_einsum_w1;
|
MatPtrT<Weight> gating_einsum_w1;
|
||||||
MatPtrT<Weight> gating_einsum_w2;
|
MatPtrT<Weight> gating_einsum_w2;
|
||||||
|
|
@ -201,7 +200,8 @@ struct LayerWeightsPtrs {
|
||||||
MatPtrT<float> ffw_gating_biases;
|
MatPtrT<float> ffw_gating_biases;
|
||||||
MatPtrT<float> ffw_output_biases;
|
MatPtrT<float> ffw_output_biases;
|
||||||
|
|
||||||
MatPtrT<Weight> att_weights; // For Reshape(); kOnlyAllocate.
|
MatPtrT<Weight> attn_vec_einsum_w; // Use att_weights instead of this.
|
||||||
|
MatPtrT<Weight> att_weights; // Use this instead of attn_vec_einsum_w.
|
||||||
|
|
||||||
MatPtrT<WeightF32OrBF16> key_norm_scale;
|
MatPtrT<WeightF32OrBF16> key_norm_scale;
|
||||||
MatPtrT<WeightF32OrBF16> query_norm_scale;
|
MatPtrT<WeightF32OrBF16> query_norm_scale;
|
||||||
|
|
@ -234,10 +234,10 @@ struct LayerWeightsPtrs {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (layer_config.type == LayerAttentionType::kGemma) {
|
if (layer_config.type == LayerAttentionType::kGemma) {
|
||||||
// Not read, will be filled by Reshape() from `attn_vec_einsum_w`.
|
// Either read from file, or allocated during Fixup().
|
||||||
func(TENSOR_ARGS(att_weights, kOnlyAllocate));
|
func(TENSOR_ARGS(att_weights, kMaybeRead));
|
||||||
func(TENSOR_ARGS(attn_vec_einsum_w, kMustRead));
|
func(TENSOR_ARGS(attn_vec_einsum_w, kMaybeRead));
|
||||||
func(TENSOR_ARGS(qkv_einsum_w, kMustRead));
|
func(TENSOR_ARGS(qkv_einsum_w, kMaybeRead));
|
||||||
func(TENSOR_ARGS(qkv_einsum_w1, kMaybeRead));
|
func(TENSOR_ARGS(qkv_einsum_w1, kMaybeRead));
|
||||||
func(TENSOR_ARGS(qkv_einsum_w2, kMaybeRead));
|
func(TENSOR_ARGS(qkv_einsum_w2, kMaybeRead));
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -254,7 +254,7 @@ struct LayerWeightsPtrs {
|
||||||
func(TENSOR_ARGS(griffin.a, kMustRead));
|
func(TENSOR_ARGS(griffin.a, kMustRead));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
func(TENSOR_ARGS(gating_einsum_w, kMustRead));
|
func(TENSOR_ARGS(gating_einsum_w, kMaybeRead));
|
||||||
func(TENSOR_ARGS(gating_einsum_w1, kMaybeRead));
|
func(TENSOR_ARGS(gating_einsum_w1, kMaybeRead));
|
||||||
func(TENSOR_ARGS(gating_einsum_w2, kMaybeRead));
|
func(TENSOR_ARGS(gating_einsum_w2, kMaybeRead));
|
||||||
func(TENSOR_ARGS(linear_w, kMustRead));
|
func(TENSOR_ARGS(linear_w, kMustRead));
|
||||||
|
|
@ -306,14 +306,25 @@ struct LayerWeightsPtrs {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initializes att_weights from `attn_vec_einsum_w`, hence this must be called
|
// Must be called after reading weights via `ForEachTensor`.
|
||||||
// after reading weights via `ForEachTensor`.
|
// TODO: exporters should bake this into the weights already.
|
||||||
// TODO: update compression/convert_weights to bake this in.
|
// WARNING: called from multiple threads; `mat_owners` requires a lock.
|
||||||
void Reshape() {
|
void Fixup(MatOwners& mat_owners) {
|
||||||
// We only have/allocate this tensor for Gemma layers.
|
InitAttWeights(mat_owners);
|
||||||
HWY_ASSERT(att_weights.HasPtr() ==
|
SplitW1();
|
||||||
(layer_config.type == LayerAttentionType::kGemma));
|
SplitAttW1();
|
||||||
if (!att_weights.HasPtr()) return;
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Copies att_weights from `attn_vec_einsum_w`.
|
||||||
|
void InitAttWeights(MatOwners& mat_owners) {
|
||||||
|
// We only use this tensor for Gemma layers.
|
||||||
|
if (layer_config.type != LayerAttentionType::kGemma) return;
|
||||||
|
|
||||||
|
// Files must have one or the other, and backprop/ allocates both.
|
||||||
|
HWY_ASSERT(attn_vec_einsum_w.HasPtr() || att_weights.HasPtr());
|
||||||
|
// Done if we already read the transposed tensor.
|
||||||
|
if (att_weights.HasPtr() && !attn_vec_einsum_w.HasPtr()) return;
|
||||||
|
|
||||||
// NUQ is handled by a specialization in weights.cc.
|
// NUQ is handled by a specialization in weights.cc.
|
||||||
HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ);
|
HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ);
|
||||||
|
|
@ -326,9 +337,15 @@ struct LayerWeightsPtrs {
|
||||||
HWY_ASSERT(att_weights.GetType() == attn_vec_einsum_w.GetType());
|
HWY_ASSERT(att_weights.GetType() == attn_vec_einsum_w.GetType());
|
||||||
HWY_ASSERT(att_weights.Rows() == model_dim);
|
HWY_ASSERT(att_weights.Rows() == model_dim);
|
||||||
HWY_ASSERT(att_weights.Cols() == heads * qkv_dim);
|
HWY_ASSERT(att_weights.Cols() == heads * qkv_dim);
|
||||||
HWY_ASSERT(attn_vec_einsum_w.HasPtr());
|
|
||||||
HWY_ASSERT(attn_vec_einsum_w.Rows() == heads * model_dim);
|
HWY_ASSERT(attn_vec_einsum_w.Rows() == heads * model_dim);
|
||||||
HWY_ASSERT(attn_vec_einsum_w.Cols() == qkv_dim);
|
HWY_ASSERT(attn_vec_einsum_w.Cols() == qkv_dim);
|
||||||
|
|
||||||
|
{
|
||||||
|
static std::mutex m;
|
||||||
|
std::lock_guard<std::mutex> lock(m);
|
||||||
|
mat_owners.AllocateFor(att_weights, MatPadding::kOdd);
|
||||||
|
}
|
||||||
|
|
||||||
const size_t T_bytes = att_weights.ElementBytes();
|
const size_t T_bytes = att_weights.ElementBytes();
|
||||||
for (size_t m = 0; m < model_dim; ++m) {
|
for (size_t m = 0; m < model_dim; ++m) {
|
||||||
uint8_t* HWY_RESTRICT out_row =
|
uint8_t* HWY_RESTRICT out_row =
|
||||||
|
|
@ -340,6 +357,77 @@ struct LayerWeightsPtrs {
|
||||||
}
|
}
|
||||||
att_weights.SetScale(attn_vec_einsum_w.Scale());
|
att_weights.SetScale(attn_vec_einsum_w.Scale());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For FFN. Fast, only updates pointers.
|
||||||
|
void SplitW1() {
|
||||||
|
// We only use this tensor for Gemma layers.
|
||||||
|
if (layer_config.type != LayerAttentionType::kGemma) return;
|
||||||
|
|
||||||
|
// Files have both or neither of w1 and w2, and backprop/ allocates both.
|
||||||
|
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
|
||||||
|
// w is mutually exclusive with w1 and w2 in the file, but backprop/
|
||||||
|
// allocates both, so we can only rule out both being null.
|
||||||
|
HWY_ASSERT(gating_einsum_w.HasPtr() || gating_einsum_w1.HasPtr());
|
||||||
|
// Done if we already read split tensors. Note that they are not
|
||||||
|
// necessarily the same type.
|
||||||
|
if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return;
|
||||||
|
|
||||||
|
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||||
|
HWY_ASSERT(gating_einsum_w.Rows() == 2 * ff_hidden_dim);
|
||||||
|
HWY_ASSERT(gating_einsum_w1.Rows() == ff_hidden_dim);
|
||||||
|
HWY_ASSERT(gating_einsum_w2.Rows() == ff_hidden_dim);
|
||||||
|
// Cols are the model_dim but we don't have ModelConfig here.
|
||||||
|
HWY_ASSERT(gating_einsum_w1.Cols() == gating_einsum_w.Cols());
|
||||||
|
HWY_ASSERT(gating_einsum_w2.Cols() == gating_einsum_w.Cols());
|
||||||
|
|
||||||
|
const size_t stride = gating_einsum_w.Stride();
|
||||||
|
gating_einsum_w1.SetPtr(gating_einsum_w.Row(0), stride);
|
||||||
|
gating_einsum_w2.SetPtr(gating_einsum_w.Row(ff_hidden_dim), stride);
|
||||||
|
gating_einsum_w1.SetType(gating_einsum_w.GetType());
|
||||||
|
gating_einsum_w2.SetType(gating_einsum_w.GetType());
|
||||||
|
gating_einsum_w1.SetScale(gating_einsum_w.Scale());
|
||||||
|
gating_einsum_w2.SetScale(gating_einsum_w.Scale());
|
||||||
|
// Do not invalidate gating_einsum_w: backprop/ calls this repeatedly.
|
||||||
|
}
|
||||||
|
|
||||||
|
// For attention, which might not have a w2. Fast, only updates pointers.
|
||||||
|
void SplitAttW1() {
|
||||||
|
// We only use this tensor for Gemma layers.
|
||||||
|
if (layer_config.type != LayerAttentionType::kGemma) return;
|
||||||
|
|
||||||
|
// w is mutually exclusive with w1 in the file, but backprop/ allocates
|
||||||
|
// both, so we can only rule out both being null.
|
||||||
|
HWY_ASSERT(qkv_einsum_w.HasPtr() || qkv_einsum_w1.HasPtr());
|
||||||
|
// Done if we already read split tensors. Note that w2 does not exist for
|
||||||
|
// MHA, and otherwise might not be the same type.
|
||||||
|
if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return;
|
||||||
|
|
||||||
|
const size_t w1_rows = layer_config.heads * layer_config.QStride();
|
||||||
|
|
||||||
|
if (layer_config.IsMHA()) { // MHA only requires w1.
|
||||||
|
qkv_einsum_w1 = qkv_einsum_w;
|
||||||
|
HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim;
|
||||||
|
|
||||||
|
HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows);
|
||||||
|
HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows);
|
||||||
|
HWY_ASSERT(qkv_einsum_w2.Rows() == w2_rows);
|
||||||
|
// Cols are the model_dim but we don't have ModelConfig here.
|
||||||
|
HWY_ASSERT(qkv_einsum_w1.Cols() == qkv_einsum_w.Cols());
|
||||||
|
HWY_ASSERT(qkv_einsum_w2.Cols() == qkv_einsum_w.Cols());
|
||||||
|
|
||||||
|
const size_t stride = qkv_einsum_w.Stride();
|
||||||
|
qkv_einsum_w1.SetPtr(qkv_einsum_w.Row(0), stride);
|
||||||
|
qkv_einsum_w2.SetPtr(qkv_einsum_w.Row(w1_rows), stride);
|
||||||
|
qkv_einsum_w1.SetType(qkv_einsum_w.GetType());
|
||||||
|
qkv_einsum_w2.SetType(qkv_einsum_w.GetType());
|
||||||
|
qkv_einsum_w1.SetScale(qkv_einsum_w.Scale());
|
||||||
|
qkv_einsum_w2.SetScale(qkv_einsum_w.Scale());
|
||||||
|
// Do not invalidate qkv_einsum_w: backprop/ calls this repeatedly.
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Holds layer-independent weight metadata and pointers plus per-layer
|
// Holds layer-independent weight metadata and pointers plus per-layer
|
||||||
|
|
@ -502,13 +590,13 @@ struct ModelWeightsPtrs {
|
||||||
// For reshaping file tensors to the shape expected by the code. This would
|
// For reshaping file tensors to the shape expected by the code. This would
|
||||||
// ideally already happen in the importer. Must be called after reading and
|
// ideally already happen in the importer. Must be called after reading and
|
||||||
// updating the attention weights.
|
// updating the attention weights.
|
||||||
void Reshape(hwy::ThreadPool& pool) {
|
void Fixup(MatOwners& mat_owners, hwy::ThreadPool& pool) {
|
||||||
pool.Run(0, c_layers.size(), [this](uint64_t layer, size_t /*thread*/) {
|
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
||||||
GetLayer(layer)->Reshape();
|
GetLayer(layer)->Fixup(mat_owners);
|
||||||
});
|
});
|
||||||
|
|
||||||
pool.Run(0, vit_layers.size(), [this](uint64_t layer, size_t /*thread*/) {
|
pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
|
||||||
VitLayer(layer)->Reshape();
|
VitLayer(layer)->Fixup(mat_owners);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}; // `WeightsPtrs`
|
}; // `WeightsPtrs`
|
||||||
|
|
@ -521,10 +609,9 @@ class WeightsOwner {
|
||||||
// `weight_type` is obtained from `ModelConfig` in `ModelStore`.
|
// `weight_type` is obtained from `ModelConfig` in `ModelStore`.
|
||||||
WeightsOwner(Type weight_type) : weight_type_(weight_type) {}
|
WeightsOwner(Type weight_type) : weight_type_(weight_type) {}
|
||||||
|
|
||||||
// Reads tensor data from `BlobStore`, or for tensors marked `kOnlyAllocate`,
|
// Reads tensor data from `BlobStore` or aborts on error.
|
||||||
// allocates memory and reshapes. Aborts on error.
|
void ReadFromBlobs(const ModelStore& model, BlobReader& reader,
|
||||||
void ReadOrAllocate(const ModelStore& model, BlobReader& reader,
|
hwy::ThreadPool& pool);
|
||||||
hwy::ThreadPool& pool);
|
|
||||||
|
|
||||||
// Calls `func(std::unique_ptr<WeightsPtrs<T>>&, args)`. `func` typically
|
// Calls `func(std::unique_ptr<WeightsPtrs<T>>&, args)`. `func` typically
|
||||||
// calls `ForEachTensor`.
|
// calls `ForEachTensor`.
|
||||||
|
|
@ -562,7 +649,7 @@ class WeightsOwner {
|
||||||
|
|
||||||
// Usually taken care of by `ReadOrAllocate`, but must also be called by
|
// Usually taken care of by `ReadOrAllocate`, but must also be called by
|
||||||
// `optimize_test, which updates the attention weights from which this copies.
|
// `optimize_test, which updates the attention weights from which this copies.
|
||||||
void Reshape(hwy::ThreadPool& pool);
|
void Fixup(hwy::ThreadPool& pool);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Type weight_type_;
|
Type weight_type_;
|
||||||
|
|
|
||||||
|
|
@ -714,13 +714,6 @@ struct ConstMat {
|
||||||
size_t Rows() const { return extents.rows; }
|
size_t Rows() const { return extents.rows; }
|
||||||
size_t Cols() const { return extents.cols; }
|
size_t Cols() const { return extents.cols; }
|
||||||
|
|
||||||
// Shrinks the row-extent of this matrix view, i.e. reduces the view to a
|
|
||||||
// subrange of the original rows starting at row 0.
|
|
||||||
void ShrinkRows(size_t rows) {
|
|
||||||
HWY_ASSERT(rows <= extents.rows);
|
|
||||||
extents.rows = rows;
|
|
||||||
}
|
|
||||||
|
|
||||||
const T* HWY_RESTRICT ptr;
|
const T* HWY_RESTRICT ptr;
|
||||||
Extents2D extents;
|
Extents2D extents;
|
||||||
size_t stride;
|
size_t stride;
|
||||||
|
|
|
||||||
|
|
@ -57,12 +57,13 @@ HWY_INLINE float Dot(const ConstMat<WT>& w, size_t w_ofs, const VT* vec_aligned,
|
||||||
const auto span = MakeSpan(w.ptr, w_ofs + w.extents.rows * w.Stride());
|
const auto span = MakeSpan(w.ptr, w_ofs + w.extents.rows * w.Stride());
|
||||||
return w.Scale() * Dot(d, span, w_ofs, vec_aligned, num);
|
return w.Scale() * Dot(d, span, w_ofs, vec_aligned, num);
|
||||||
}
|
}
|
||||||
// For callers that pass `MatPtrT`.
|
// For callers that pass `MatPtrT`, which is not necessarily packed - callers
|
||||||
|
// should use Stride() to compute `w_ofs`.
|
||||||
template <typename WT, typename VT>
|
template <typename WT, typename VT>
|
||||||
HWY_INLINE float Dot(const MatPtrT<WT>& w, size_t w_ofs, const VT* vec_aligned,
|
HWY_INLINE float Dot(const MatPtrT<WT>& w, size_t w_ofs, const VT* vec_aligned,
|
||||||
size_t num) {
|
size_t num) {
|
||||||
const hn::ScalableTag<VT> d;
|
const hn::ScalableTag<VT> d;
|
||||||
return w.Scale() * Dot(d, w.Span(), w_ofs, vec_aligned, num);
|
return w.Scale() * Dot(d, w.PaddedSpan(), w_ofs, vec_aligned, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ArrayT is either MatPtrT or ConstMat.
|
// ArrayT is either MatPtrT or ConstMat.
|
||||||
|
|
|
||||||
|
|
@ -257,6 +257,10 @@ class MatPtrT : public MatPtr {
|
||||||
const MatT* Row(size_t row) const { return this->RowT<MatT>(row); }
|
const MatT* Row(size_t row) const { return this->RowT<MatT>(row); }
|
||||||
MatT* Row(size_t row) { return this->RowT<MatT>(row); }
|
MatT* Row(size_t row) { return this->RowT<MatT>(row); }
|
||||||
|
|
||||||
|
PackedSpan<const MatT> PaddedSpan() const {
|
||||||
|
return MakeConstSpan(Row(0), Rows() * Stride());
|
||||||
|
}
|
||||||
|
|
||||||
// For `compress-inl.h` functions, which assume contiguous streams and thus
|
// For `compress-inl.h` functions, which assume contiguous streams and thus
|
||||||
// require packed layout.
|
// require packed layout.
|
||||||
PackedSpan<const MatT> Span() const {
|
PackedSpan<const MatT> Span() const {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue