Major refactor to de-templatize gemma-inl and weights

This replaces per-weight instantiations of all code with only per-MatMul/norm.
Reduces binary size by 133KiB.

WeightsOwner is no longer required for type erasing, hence it is replaced with ModelWeightsPtrs.
Also remove unused EmbedToken, replaced with EmbedMMToken.

PiperOrigin-RevId: 766497657
This commit is contained in:
Jan Wassenberg 2025-06-02 23:00:47 -07:00 committed by Copybara-Service
parent cf4d7ceb82
commit 794a21a4e6
18 changed files with 676 additions and 877 deletions

View File

@ -469,10 +469,6 @@ cc_library(
name = "gemma_lib",
srcs = [
"gemma/gemma.cc",
"gemma/instantiations/bf16.cc",
"gemma/instantiations/f32.cc",
"gemma/instantiations/nuq.cc",
"gemma/instantiations/sfp.cc",
],
hdrs = [
"gemma/activations.h",

View File

@ -61,10 +61,6 @@ set(SOURCES
gemma/gemma-inl.h
gemma/gemma.cc
gemma/gemma.h
gemma/instantiations/bf16.cc
gemma/instantiations/f32.cc
gemma/instantiations/nuq.cc
gemma/instantiations/sfp.cc
gemma/kv_cache.cc
gemma/kv_cache.h
gemma/model_store.cc

View File

@ -556,7 +556,7 @@ HWY_INLINE void Decompress2(DRaw d, const PackedSpan<Packed>& packed,
// Decompresses from any type of `packed`, starting at (any) `packed_ofs`, to
// (any) `num` elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as
// required to round `num` up to one vector, if it is not already. The caller is
// responsible for scaling `raw` to the original range because `EmbedToken`
// responsible for scaling `raw` to the original range because `EmbedMMToken`
// also wants to scale the decompressed elements.
// `TRaw` can be `float/BF16`, or `double` if `Packed` is `float`.
template <class DRaw, typename Packed, typename TRaw = hn::TFromD<DRaw>>

View File

@ -83,6 +83,7 @@ struct Activations {
// For MatMul outputs, precompute their row pointers.
const auto init_row_ptrs = [&](MatPtrT<float>& mat) {
if (!mat.HasPtr()) return;
row_ptrs.push_back(hwy::AllocateAligned<uint8_t*>(mat.Rows()));
uint8_t** ptrs = row_ptrs.back().get();
for (size_t r = 0; r < mat.Rows(); ++r) {

View File

@ -53,31 +53,30 @@
#include "ops/matvec-inl.h"
#include "ops/ops-inl.h"
#ifndef GEMMA_TYPE
#if HWY_IDE
// Provide a definition so the IDE does not complain.
#define GEMMA_TYPE float
#else
#error "Only include from instantiations/*.cc, which must define GEMMA_TYPE"
#endif // HWY_IDE
#endif // GEMMA_TYPE
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
template <typename TA, typename TC>
MMPerKey* CallMatMul(const MatPtrT<TA>& A, const MatPtr& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
MatPtrT<TC>& C) {
return CallUpcasted(
&B, [&](const auto* B_t) { return MatMulStatic(A, *B_t, add, env, C); });
}
// Different functions use different naming conventions for the number of
// tokens. Functions that are query-independent, such as RMSNorm*, call the
// count `num_interleaved`. Functions that are query-dependent, such as
// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the
// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size.
template <typename T>
HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
size_t num_tokens, size_t griffin_layer,
Activations& activations,
const LayerWeightsPtrs<T>* layer_weights,
const KVCaches& kv_caches) {
static HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
size_t num_tokens,
size_t griffin_layer,
Activations& activations,
const LayerWeightsPtrs* layer_weights,
const KVCaches& kv_caches) {
PROFILER_ZONE("Gen.Griffin");
hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0);
namespace hn = hwy::HWY_NAMESPACE;
@ -101,17 +100,21 @@ HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
// TODO: MatMul
HWY_DASSERT(activations.griffin_y.Rows() == activations.griffin_x.Rows());
HWY_DASSERT(num_interleaved == activations.griffin_y.Rows());
for (size_t r = 0; r < num_interleaved; ++r) {
float* HWY_RESTRICT y = activations.griffin_y.Row(r);
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
TwoMatVecAdd(layer_weights->griffin.linear_x_w,
layer_weights->griffin.linear_y_w, 0, model_dim, model_dim,
activations.pre_att_rms_out.Row(r),
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
/*out0=*/x, /*out1=*/y, pool);
Gelu(y, model_dim);
}
CallUpcastedSame(
&layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w,
[&](const auto* wx, const auto* wy) {
for (size_t r = 0; r < num_interleaved; ++r) {
float* HWY_RESTRICT y = activations.griffin_y.Row(r);
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
TwoMatVecAdd(
*wx, *wy, 0, model_dim, model_dim,
activations.pre_att_rms_out.Row(r),
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
/*out0=*/x, /*out1=*/y, pool);
Gelu(y, model_dim);
}
});
// Conv1D.
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
@ -165,14 +168,16 @@ HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
size_t head_offset = head * kHeadDim;
TwoOfsMatVecAddLoop(
layer_weights->griffin.gate_w, kMatrixSize * head,
kMatrixSize * (heads + head), kHeadDim, kHeadDim, x + head_offset,
/*add0=*/layer_weights->griffin.gate_biases.PackedScale1() +
head_offset,
/*add1=*/layer_weights->griffin.gate_biases.PackedScale1() +
model_dim + head_offset,
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
CallUpcasted(&layer_weights->griffin.gate_w, [&](const auto* gate_w) {
TwoOfsMatVecAddLoop(
*gate_w, kMatrixSize * head, kMatrixSize * (heads + head), kHeadDim,
kHeadDim, x + head_offset,
/*add0=*/layer_weights->griffin.gate_biases.PackedScale1() +
head_offset,
/*add1=*/layer_weights->griffin.gate_biases.PackedScale1() +
model_dim + head_offset,
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
});
Sigmoid(gate_x + head_offset, kHeadDim);
Sigmoid(a + head_offset, kHeadDim);
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
@ -206,17 +211,19 @@ HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos,
// Final linear layer.
// TODO: MatMul
for (size_t r = 0; r < num_interleaved; ++r) {
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
float* out_ptr = activations.att_sums.Row(r);
MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x,
layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr,
pool);
}
CallUpcasted(
&layer_weights->griffin.linear_out_w, [&](const auto* weights_t) {
for (size_t r = 0; r < num_interleaved; ++r) {
float* HWY_RESTRICT x = activations.griffin_x.Row(r);
float* out_ptr = activations.att_sums.Row(r);
MatVecAdd(*weights_t, 0, model_dim, model_dim, x,
layer_weights->griffin.linear_out_biases.PackedScale1(),
out_ptr, pool);
}
});
} // GriffinRecurrent
// Wrapper class; holds arguments in member variables to shorten call sites.
template <typename T>
class GemmaAttention {
// The attention window usually starts at 0 unless `pos` is larger than
// the attention window size, then it is `pos` - window_size + 1.
@ -257,8 +264,8 @@ class GemmaAttention {
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1,
/*add=*/nullptr, *activations_.env, activations_.q);
CallMatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1,
/*add=*/nullptr, *activations_.env, activations_.q);
// Set up MatMul row pointers for writing to KV, which consists of
// `kv_heads` pairs of (k, v) vectors. This safely handles wraparound
@ -279,8 +286,8 @@ class GemmaAttention {
kv_offset);
}
kv_rows.AttachRowPtrs(&activations_.env->storage.OutRow(0));
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2,
/*add=*/nullptr, *activations_.env, kv_rows);
CallMatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2,
/*add=*/nullptr, *activations_.env, kv_rows);
// Apply positional encodings for K (and copy KV to cache if MHA).
pool_.Run(0, kv_heads * num_interleaved,
@ -299,8 +306,11 @@ class GemmaAttention {
// Apply further processing to K.
if (layer_weights_.key_norm_scale.HasPtr()) {
RMSNormInplace(layer_weights_.key_norm_scale.PackedScale1(),
0, kv, qkv_dim);
CallUpcasted(&layer_weights_.key_norm_scale,
[&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0,
kv, qkv_dim);
});
}
PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f);
});
@ -367,8 +377,10 @@ class GemmaAttention {
// Apply rope and scaling to Q.
if (layer_weights_.query_norm_scale.HasPtr()) {
RMSNormInplace(layer_weights_.query_norm_scale.PackedScale1(), 0, q,
qkv_dim);
CallUpcasted(&layer_weights_.query_norm_scale,
[&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, q, qkv_dim);
});
}
PositionalEncodingQK(q, pos, layer_, query_scale);
@ -456,8 +468,8 @@ class GemmaAttention {
layer_weights_.layer_config.softmax_attn_output_biases
? layer_weights_.attention_output_biases.PackedScale1()
: nullptr;
MatMulStatic(activations_.att_out, layer_weights_.att_weights, add,
*activations_.env, activations_.att_sums);
CallMatMul(activations_.att_out, layer_weights_.att_weights, add,
*activations_.env, activations_.att_sums);
}
public:
@ -467,7 +479,7 @@ class GemmaAttention {
GemmaAttention(const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, size_t num_tokens,
size_t layer, Activations& activations,
const LayerWeightsPtrs<T>* layer_weights,
const LayerWeightsPtrs* layer_weights,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches)
: GemmaAttention(queries_pos, &queries_prefix_end, num_tokens, layer,
activations, layer_weights, div_seq_len, kv_caches,
@ -475,7 +487,7 @@ class GemmaAttention {
// Constructor with default initialization to 0 for queries_prefix_end.
GemmaAttention(const QueriesPos& queries_pos, size_t num_tokens, size_t layer,
Activations& activations,
const LayerWeightsPtrs<T>* layer_weights,
const LayerWeightsPtrs* layer_weights,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches)
: GemmaAttention(queries_pos, nullptr, num_tokens, layer, activations,
layer_weights, div_seq_len, kv_caches,
@ -487,7 +499,7 @@ class GemmaAttention {
GemmaAttention(const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, size_t num_tokens,
size_t layer, Activations& activations,
const LayerWeightsPtrs<T>* layer_weights,
const LayerWeightsPtrs* layer_weights,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
ThreadingContext& ctx)
: GemmaAttention(queries_pos, &queries_prefix_end, num_tokens, layer,
@ -507,7 +519,7 @@ class GemmaAttention {
GemmaAttention(const QueriesPos& queries_pos,
const QueriesPos* queries_prefix_end, size_t num_tokens,
size_t layer, Activations& activations,
const LayerWeightsPtrs<T>* layer_weights,
const LayerWeightsPtrs* layer_weights,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
ThreadingContext& ctx)
: queries_pos_(queries_pos),
@ -546,21 +558,20 @@ class GemmaAttention {
const size_t cache_pos_size_ = 0;
Activations& activations_;
const LayerWeightsPtrs<T>& layer_weights_;
const LayerWeightsPtrs& layer_weights_;
const hwy::Divisor& div_seq_len_;
const KVCaches& kv_caches_;
hwy::ThreadPool& pool_;
};
template <typename T>
HWY_NOINLINE void Attention(
static HWY_NOINLINE void Attention(
LayerAttentionType type, const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, size_t num_tokens, size_t layer,
Activations& activations, const LayerWeightsPtrs<T>* layer_weights,
Activations& activations, const LayerWeightsPtrs* layer_weights,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) {
if (type == LayerAttentionType::kGemma) {
GemmaAttention<T>(queries_pos, queries_prefix_end, num_tokens, layer,
activations, layer_weights, div_seq_len, kv_caches)();
GemmaAttention(queries_pos, queries_prefix_end, num_tokens, layer,
activations, layer_weights, div_seq_len, kv_caches)();
} else {
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
// KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer,
@ -581,7 +592,6 @@ HWY_NOINLINE void Attention(
// This results in a much simpler implementation. However, to avoid duplicating
// code, we should still consider merging the two classes.
// TODO(keysers): Refactor to share code with GemmaAttention.
template <typename T>
class VitAttention {
// Computes Q, K, V for all heads, stored in activations_.q.
HWY_NOINLINE void ComputeQKV() {
@ -589,9 +599,9 @@ class VitAttention {
auto& qkv = activations_.q;
HWY_ASSERT(qkv.Rows() == num_tokens_);
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
MatMulStatic(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w,
layer_weights_.vit.qkv_einsum_b.PackedScale1(),
*activations_.env, qkv);
CallMatMul(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w,
layer_weights_.vit.qkv_einsum_b.PackedScale1(),
*activations_.env, qkv);
}
// TODO(philculliton): transition fully to MatMul.
@ -631,7 +641,7 @@ class VitAttention {
});
// this produces C, a (num_tokens_, seq_len) matrix of dot products
MatMulStatic(Q, K, nullptr, *activations_.env, C);
CallMatMul(Q, K, nullptr, *activations_.env, C);
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
float* HWY_RESTRICT c = C.Row(task);
@ -697,15 +707,14 @@ class VitAttention {
// att_weights and att_out are concatenated heads, each of length
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
// matmul output is the sum over heads.
MatMulStatic(activations_.att_out, layer_weights_.vit.attn_out_w, bias,
*activations_.env, activations_.att_sums);
CallMatMul(activations_.att_out, layer_weights_.vit.attn_out_w, bias,
*activations_.env, activations_.att_sums);
}
public:
VitAttention(size_t num_tokens, size_t layer, Activations& activations,
const LayerWeightsPtrs<T>* layer_weights)
const LayerWeightsPtrs* layer_weights)
: num_tokens_(num_tokens),
layer_(layer),
activations_(activations),
layer_weights_(*layer_weights),
layer_config_(layer_weights->layer_config),
@ -723,9 +732,8 @@ class VitAttention {
private:
const size_t num_tokens_;
const size_t layer_;
Activations& activations_;
const LayerWeightsPtrs<T>& layer_weights_;
const LayerWeightsPtrs& layer_weights_;
const LayerConfig& layer_config_;
hwy::ThreadPool& pool_;
};
@ -775,9 +783,8 @@ void ActivationBatched(ActivationType activation, Mat& c1, const Mat* c2) {
}
}
template <typename T>
HWY_NOINLINE void FFWNoVit(Activations& activations,
const LayerWeightsPtrs<T>* layer_weights) {
static HWY_NOINLINE void FFWNoVit(Activations& activations,
const LayerWeightsPtrs* layer_weights) {
PROFILER_ZONE("Gen.FFW");
const size_t ffh_hidden_dim = layer_weights->layer_config.ff_hidden_dim;
@ -789,25 +796,24 @@ HWY_NOINLINE void FFWNoVit(Activations& activations,
add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr;
// Compute the hidden layer activations.
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1,
bias1, *activations.env, activations.C1);
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2,
bias2, *activations.env, activations.C2);
CallMatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1,
bias1, *activations.env, activations.C1);
CallMatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2,
bias2, *activations.env, activations.C2);
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
ActivationBatched(layer_weights->layer_config.activation, activations.C1,
&activations.C2);
// Hidden layer -> output layer.
MatMulStatic(activations.C1, layer_weights->linear_w, output_bias,
*activations.env, activations.ffw_out);
CallMatMul(activations.C1, layer_weights->linear_w, output_bias,
*activations.env, activations.ffw_out);
}
// Same as FFWNoVit, but with different layer_weights members and no second
// gating matrix.
template <typename T>
HWY_NOINLINE void FFWVit(Activations& activations,
const LayerWeightsPtrs<T>* layer_weights) {
static HWY_NOINLINE void FFWVit(Activations& activations,
const LayerWeightsPtrs* layer_weights) {
PROFILER_ZONE("Gen.FFW.ViT");
const bool add_bias = layer_weights->layer_config.ff_biases;
@ -817,15 +823,15 @@ HWY_NOINLINE void FFWVit(Activations& activations,
add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr;
// Compute the hidden layer activations.
MatMulStatic(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w,
bias1, *activations.env, activations.C1);
CallMatMul(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, bias1,
*activations.env, activations.C1);
// Activation (Gelu), store in C1.
ActivationBatched(layer_weights->layer_config.activation, activations.C1);
// Hidden layer -> output layer.
MatMulStatic(activations.C1, layer_weights->vit.linear_1_w, output_bias,
*activations.env, activations.ffw_out);
CallMatMul(activations.C1, layer_weights->vit.linear_1_w, output_bias,
*activations.env, activations.ffw_out);
}
// `batch_idx` indicates which row of `x` to write to.
@ -836,67 +842,52 @@ HWY_NOINLINE void FFWVit(Activations& activations,
// spec) until we run out of image tokens. This allows for a multi-image prompt
// if -2 locations with appropriate begin/end image tokens are created by the
// calling application.
template <typename T>
HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos,
size_t pos_in_prompt,
const ModelWeightsPtrs<T>& weights,
MatStorageT<float>& x,
const ImageTokens* image_tokens,
size_t& image_token_position) {
// Returns new image_token_position.
static HWY_NOINLINE size_t
EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt,
const ModelConfig& model_config, const ModelWeightsPtrs& weights,
MatStorageT<float>& x, const ImageTokens* image_tokens = nullptr,
size_t image_token_position = 0) {
// Image tokens just need to be copied.
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM &&
if (model_config.wrapping == PromptWrapping::GEMMA_VLM &&
image_tokens != nullptr && token == -2 &&
image_token_position < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(batch_idx),
x.Cols() * x.ElementBytes());
image_token_position++;
return;
return image_token_position + 1;
}
if (weights.weights_config.wrapping == PromptWrapping::PALIGEMMA &&
if (model_config.wrapping == PromptWrapping::PALIGEMMA &&
image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(batch_idx),
x.Cols() * x.ElementBytes());
return;
return image_token_position;
}
const size_t model_dim = weights.weights_config.model_dim;
const size_t model_dim = model_config.model_dim;
const float emb_scaling = EmbeddingScaling(model_dim);
HWY_DASSERT(token >= 0);
HWY_DASSERT(token < static_cast<int>(weights.weights_config.vocab_size));
HWY_DASSERT(token < static_cast<int>(model_config.vocab_size));
const hn::ScalableTag<float> df;
// Using `Stride` to compute the offset works for both NUQ (because we use an
// offset and NUQ is never padded) and padded, because non-NUQ types are
// seekable, hence the offset can also skip any padding.
const size_t embedding_ofs =
token * weights.embedder_input_embedding.Stride();
HWY_ASSERT(weights.embedder_input_embedding.Cols() == model_dim);
const auto embedding_span = MakeSpan(weights.embedder_input_embedding.Row(0),
embedding_ofs + model_dim);
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(batch_idx),
model_dim);
MulByConst(emb_scaling * weights.embedder_input_embedding.Scale(),
x.Row(batch_idx), model_dim);
if (weights.weights_config.absolute_pe) {
CallUpcasted(&weights.embedder_input_embedding, [&](const auto* weights_t) {
// Using `Stride` to compute the offset works for both NUQ (because we use
// an offset and NUQ is never padded) and padded, because non-NUQ types are
// seekable, hence the offset can also skip any padding.
const size_t embedding_ofs = token * weights_t->Stride();
HWY_ASSERT(weights_t->Cols() == model_dim);
const auto embedding_span =
MakeSpan(weights_t->Row(0), embedding_ofs + model_dim);
const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(batch_idx),
model_dim);
MulByConst(emb_scaling * weights_t->Scale(), x.Row(batch_idx), model_dim);
});
if (model_config.absolute_pe) {
AddAbsolutePositionalEmbeddings(x.Row(batch_idx), model_dim, pos);
}
}
// `batch_idx` indicates which row of `x` to write to.
// `pos` is the *token*'s position, not the start of the batch, because this is
// called for batches of tokens in prefill, but batches of queries in decode.
// This version of the function doesn't track internal image token position.
template <typename T>
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
size_t pos_in_prompt,
const ModelWeightsPtrs<T>& weights,
MatStorageT<float>& x,
const ImageTokens* image_tokens) {
size_t image_token_position = 0;
EmbedMMToken<T>(token, batch_idx, pos, pos_in_prompt, weights, x,
image_tokens, image_token_position);
return image_token_position;
}
template <typename T2, class LayerWeights>
@ -908,8 +899,8 @@ HWY_NOINLINE void ResidualConnection(const MatPtrT<T2>& other,
AddFromBatched(other, x);
}
template <typename WeightT, typename InOutT>
void PostNorm(PostNormType post_norm, const MatPtrT<WeightT>& weights,
template <typename InOutT>
void PostNorm(PostNormType post_norm, const MatPtr& weights,
MatPtrT<InOutT>& inout) {
HWY_DASSERT(weights.Rows() == 1);
if (post_norm == PostNormType::Scale) {
@ -917,14 +908,11 @@ void PostNorm(PostNormType post_norm, const MatPtrT<WeightT>& weights,
}
}
template <typename T>
HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
size_t num_tokens, size_t cache_layer_idx,
const LayerWeightsPtrs<T>* layer_weights,
Activations& activations,
const hwy::Divisor& div_seq_len,
const KVCaches& kv_caches) {
static HWY_NOINLINE void TransformerLayer(
const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end,
size_t num_tokens, size_t cache_layer_idx,
const LayerWeightsPtrs* layer_weights, Activations& activations,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) {
auto type = layer_weights->layer_config.type;
RMSNormBatched(activations.x, layer_weights->pre_attention_norm_scale,
@ -960,10 +948,9 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos,
// github.com/google-research/big_vision/blob/main/big_vision/models/vit.py
// TODO(keysers): consider adding a wrapper for both LayerNorm with RMSNorm and
// try merging this with TransformerLayer.
template <typename T>
HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
const LayerWeightsPtrs<T>* layer_weights,
Activations& activations) {
static HWY_NOINLINE void VitTransformerLayer(
size_t num_tokens, size_t layer, const LayerWeightsPtrs* layer_weights,
Activations& activations) {
const size_t model_dim = activations.weights_config.model_dim;
auto type = layer_weights->layer_config.type;
HWY_DASSERT(type == LayerAttentionType::kVit);
@ -982,7 +969,7 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
// y ~ att_sums
VitAttention<T>(num_tokens, layer, activations, layer_weights)();
VitAttention(num_tokens, layer, activations, layer_weights)();
// x = out["+sa"] = x + y
AddFromBatched(activations.att_sums, x);
@ -1005,13 +992,13 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
using QueriesMutablePos = hwy::Span<size_t>;
// Populates KV cache for batches of tokens from one query at a time.
template <typename T>
HWY_NOINLINE void Prefill(
static HWY_NOINLINE void Prefill(
const QueriesPromptTokens& queries_prompt,
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
const size_t query_idx_start, const ModelWeightsPtrs<T>& weights,
Activations& activations, const RuntimeConfig& runtime_config,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) {
const size_t query_idx_start, const ModelConfig& config,
const ModelWeightsPtrs& weights, Activations& activations,
const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len,
const KVCaches& kv_caches) {
PROFILER_ZONE("Gen.Prefill");
const size_t num_queries = queries_prompt.size();
HWY_DASSERT(queries_pos.size() == num_queries);
@ -1072,14 +1059,14 @@ HWY_NOINLINE void Prefill(
const size_t pos = queries_pos[qi] + ti;
const size_t pos_in_prompt = tbatch_start + ti;
const int token = queries_prompt[qi][pos_in_prompt];
EmbedMMToken(token, ti, pos, pos_in_prompt, weights, activations.x,
runtime_config.image_tokens, image_token_position);
image_token_position = EmbedMMToken(
token, ti, pos, pos_in_prompt, config, weights, activations.x,
runtime_config.image_tokens, image_token_position);
}
// Transformer with one batch of tokens from a single query.
for (size_t layer = 0;
layer < weights.weights_config.layer_configs.size(); ++layer) {
const LayerWeightsPtrs<T>* layer_weights = weights.GetLayer(layer);
for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) {
const LayerWeightsPtrs* layer_weights = weights.GetLayer(layer);
TransformerLayer(single_query_pos, single_query_prefix_end, tbatch_size,
layer, layer_weights, activations, div_seq_len,
single_kv_cache);
@ -1115,13 +1102,13 @@ HWY_NOINLINE void Prefill(
// Gets the patches of the image and embeds them with the image embedding
// kernel. The result is stored in activations.x.
template <typename T>
HWY_NOINLINE void EmbedImagePatches(const Image& image,
const ModelWeightsPtrs<T>& weights,
Activations& activations) {
const size_t model_dim = weights.weights_config.vit_config.model_dim;
const size_t patch_width = weights.weights_config.vit_config.patch_width;
const size_t seq_len = weights.weights_config.vit_config.seq_len;
static HWY_NOINLINE void EmbedImagePatches(const Image& image,
const ModelConfig& model_config,
const ModelWeightsPtrs& weights,
Activations& activations) {
const size_t model_dim = model_config.vit_config.model_dim;
const size_t patch_width = model_config.vit_config.patch_width;
const size_t seq_len = model_config.vit_config.seq_len;
const size_t patch_size = patch_width * patch_width * 3;
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
@ -1138,7 +1125,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
// MatStorageT<float> image_patches("patches", Extents2D(kSeqLen,
// kPatchSize), MatPadding::kPacked);
// [Get patches]
// MatMulStatic(
// CallMatMul(
// image_patches,
// weights.vit_img_embedding_kernel,
// weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
@ -1147,61 +1134,64 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
// which is not the case here. We should relax that requirement on MatMul and
// then use the above. For now, we rely on MatVecAdd instead.
for (size_t i = 0; i < seq_len; ++i) {
MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
image_patches[i].get(),
weights.vit_img_embedding_bias.PackedScale1(),
activations.x.Row(i), activations.env->ctx.pools.Pool(0));
}
CallUpcasted(&weights.vit_img_embedding_kernel, [&](const auto* embedding_t) {
for (size_t i = 0; i < seq_len; ++i) {
MatVecAdd(*embedding_t, 0, model_dim, patch_size, image_patches[i].get(),
weights.vit_img_embedding_bias.PackedScale1(),
activations.x.Row(i), activations.env->ctx.pools.Pool(0));
}
});
// Add position embeddings.
AddFromBatched(weights.vit_img_pos_embedding, activations.x);
}
// Prefills the image tokens with the ViT encoder.
template <typename T>
HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens,
Activations& activations) {
static HWY_NOINLINE void PrefillVit(const ModelConfig& model_config,
const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config,
const Image& image,
ImageTokens& image_tokens,
Activations& activations) {
PROFILER_ZONE("Gen.PrefillVit");
const size_t num_tokens = weights.weights_config.vit_config.seq_len;
const size_t vit_model_dim = weights.weights_config.vit_config.model_dim;
const size_t num_tokens = model_config.vit_config.seq_len;
const size_t vit_model_dim = model_config.vit_config.model_dim;
HWY_ASSERT(num_tokens == activations.x.Rows());
// Embed the image patches.
EmbedImagePatches(image, weights, activations);
EmbedImagePatches(image, model_config, weights, activations);
// Go through all layers.
for (size_t layer = 0;
layer < weights.weights_config.vit_config.layer_configs.size();
for (size_t layer = 0; layer < model_config.vit_config.layer_configs.size();
++layer) {
const LayerWeightsPtrs<T>* layer_weights = weights.VitLayer(layer);
const LayerWeightsPtrs* layer_weights = weights.VitLayer(layer);
VitTransformerLayer(num_tokens, layer, layer_weights, activations);
}
// Final Layernorm.
LayerNormBatched(activations.x, weights.vit_encoder_norm_scale,
weights.vit_encoder_norm_bias, activations.x);
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
if (model_config.wrapping == PromptWrapping::GEMMA_VLM) {
activations.x = AvgPool4x4(activations.x);
// Apply soft embedding norm before input projection.
RMSNormInplace(weights.mm_embed_norm.PackedScale1(), 0,
activations.x.Row(0), vit_model_dim);
CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) {
RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0),
vit_model_dim);
});
}
// Apply head embedding into image_tokens of size of the LLM kModelDim.
MatMulStatic(activations.x, weights.vit_img_head_kernel,
weights.vit_img_head_bias.PackedScale1(), *activations.env,
image_tokens);
CallMatMul(activations.x, weights.vit_img_head_kernel,
weights.vit_img_head_bias.PackedScale1(), *activations.env,
image_tokens);
}
// Generates one token for each query. `queries_token` is the previous token
// from each query, and `queries_pos` are their position in the sequence.
template <typename T>
HWY_NOINLINE void Transformer(
static HWY_NOINLINE void Transformer(
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
const QueriesPos& queries_prefix_end, const ModelWeightsPtrs<T>& weights,
Activations& activations, const hwy::Divisor& div_seq_len,
const KVCaches& kv_caches, const LayersOutputFunc& layers_output,
const QueriesPos& queries_prefix_end, const ModelConfig& config,
const ModelWeightsPtrs& weights, Activations& activations,
const hwy::Divisor& div_seq_len, const KVCaches& kv_caches,
const LayersOutputFunc& layers_output,
const ActivationsObserverFunc& activations_observer) {
const size_t num_queries = queries_token.size();
HWY_DASSERT(queries_pos.size() == num_queries);
@ -1215,15 +1205,13 @@ HWY_NOINLINE void Transformer(
}
}
size_t image_token_position = 0;
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
EmbedMMToken(queries_token[query_idx], query_idx, queries_pos[query_idx],
/*pos_in_prompt=*/0, weights, activations.x,
/*image_tokens=*/nullptr, image_token_position);
/*pos_in_prompt=*/0, config, weights, activations.x);
}
for (size_t layer = 0; layer < weights.c_layers.size(); ++layer) {
const LayerWeightsPtrs<T>* layer_weights = weights.GetLayer(layer);
const LayerWeightsPtrs* layer_weights = weights.GetLayer(layer);
TransformerLayer(queries_pos, queries_prefix_end, /*num_tokens=*/1, layer,
layer_weights, activations, div_seq_len, kv_caches);
@ -1302,25 +1290,25 @@ HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) {
};
}
template <typename T>
// Runs one decode step for all the queries in the batch. Returns true if all
// queries are at <end_of_sentence>.
bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const size_t query_idx_start, const KVCaches& kv_caches,
const QueriesPos& queries_prefix_end,
const hwy::Divisor div_seq_len, const size_t vocab_size,
const SampleFunc& sample_token, Activations& activations,
TokenStreamer& token_streamer, std::vector<int>& gen_tokens,
TimingInfo& timing_info,
const QueriesMutablePos& queries_mutable_pos) {
static bool DecodeStepT(const ModelConfig& config,
const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const size_t query_idx_start, const KVCaches& kv_caches,
const QueriesPos& queries_prefix_end,
const hwy::Divisor div_seq_len, const size_t vocab_size,
const SampleFunc& sample_token,
Activations& activations, TokenStreamer& token_streamer,
std::vector<int>& gen_tokens, TimingInfo& timing_info,
const QueriesMutablePos& queries_mutable_pos) {
const size_t num_queries = queries_prompt.size();
// Decode generates one token per query and increments
// queries_mutable_pos.
Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos,
queries_prefix_end, weights, activations, div_seq_len, kv_caches,
runtime_config.layers_output,
queries_prefix_end, config, weights, activations, div_seq_len,
kv_caches, runtime_config.layers_output,
runtime_config.activations_observer);
// queries_pos are incremented by Transformer.
@ -1329,13 +1317,13 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
{
PROFILER_ZONE("Gen.EmbeddingMatmul");
// Compute logits from last layer activations.
MatMulStatic(activations.x, weights.embedder_input_embedding,
/*add=*/nullptr, *activations.env, activations.logits);
CallMatMul(activations.x, weights.embedder_input_embedding,
/*add=*/nullptr, *activations.env, activations.logits);
}
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
float* HWY_RESTRICT logits = activations.logits.Row(query_idx);
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size);
MaybeLogitsSoftCap(config.final_cap, logits, vocab_size);
const TokenAndProb tp = sample_token(logits, vocab_size);
timing_info.NotifyGenerated();
@ -1359,15 +1347,16 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
// `StreamFunc` gets the global query index, not relative to the batch.
//
// `kv_caches` is for the batch, size must match `queries_prompt`.
template <typename T>
void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
Activations& activations, const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos_in,
const QueriesPos& queries_prefix_end,
const size_t query_idx_start, const KVCaches& kv_caches,
TimingInfo& timing_info) {
static void GenerateT(const ModelConfig& config,
const ModelWeightsPtrs& weights, Activations& activations,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos_in,
const QueriesPos& queries_prefix_end,
const size_t query_idx_start, const KVCaches& kv_caches,
TimingInfo& timing_info) {
HWY_ASSERT(queries_pos_in.size() == kv_caches.size());
// Griffin assumes that the recurrent block cache is zero-initialized.
for (size_t i = 0; i < kv_caches.size(); ++i) {
if (queries_pos_in[i] == 0) {
@ -1384,7 +1373,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
// Sanity check: prompts should not be empty, nor start with EOS.
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
const PromptTokens& prompt = queries_prompt[query_idx];
HWY_ASSERT(prompt.size() != 0 && prompt[0] != model.Config().eos_id);
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
}
const size_t num_queries = queries_prompt.size();
@ -1395,7 +1384,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
size_t max_prompt_size = MaxQueryLength(queries_prompt);
size_t max_generated_tokens = runtime_config.max_generated_tokens;
RangeChecks(weights.weights_config, max_generated_tokens, max_prompt_size);
RangeChecks(config, max_generated_tokens, max_prompt_size);
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
// Prefill stops before min_prompt_size - 1 because the last prompt
@ -1403,8 +1392,8 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
timing_info.prefill_start = hwy::platform::Now();
// Note that Prefill calls activations.SetBatchSize, so we reset it below.
Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end,
query_idx_start, weights, activations, runtime_config, div_seq_len,
kv_caches);
query_idx_start, config, weights, activations, runtime_config,
div_seq_len, kv_caches);
// Compute the number of tokens that were prefilled and notify timing_info.
size_t prefilled_tokens = 0;
for (size_t qi = 0; qi < num_queries; ++qi) {
@ -1419,7 +1408,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
std::vector<int> gen_tokens(num_queries);
// Stream the last prompt token from each query and fill gen_tokens.
TokenStreamer token_streamer(runtime_config, model.Config());
TokenStreamer token_streamer(runtime_config, config);
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
size_t last_token_pos_in_prompt =
queries_mutable_pos[query_idx] - queries_pos_in[query_idx];
@ -1430,53 +1419,49 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
}
{
const size_t vocab_size = model.Config().vocab_size;
const size_t vocab_size = config.vocab_size;
timing_info.generate_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
bool all_queries_eos = DecodeStepT<T>(
model.Config(), weights, runtime_config, queries_prompt,
query_idx_start, kv_caches, queries_prefix_end, div_seq_len,
vocab_size, sample_token, activations, token_streamer, gen_tokens,
timing_info, queries_mutable_pos);
bool all_queries_eos = DecodeStepT(
config, weights, runtime_config, queries_prompt, query_idx_start,
kv_caches, queries_prefix_end, div_seq_len, vocab_size, sample_token,
activations, token_streamer, gen_tokens, timing_info,
queries_mutable_pos);
if (all_queries_eos) break;
} // foreach token to generate
timing_info.NotifyGenerateDone();
}
}
template <typename T>
void GenerateSingleT(const ModelStore& model,
const ModelWeightsPtrs<T>& weights,
const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, size_t prefix_end,
KVCache& kv_cache, MatMulEnv* env,
TimingInfo& timing_info) {
static HWY_MAYBE_UNUSED void GenerateSingleT(
const ModelConfig& config, const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
size_t prefix_end, KVCache& kv_cache, MatMulEnv* env,
TimingInfo& timing_info) {
constexpr size_t kNumQueries = 1;
const size_t qbatch_start = 0;
const size_t max_batch_size =
HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size);
// TODO: move into Gemma?
Activations activations(model.Config(), max_batch_size, env);
Activations activations(config, max_batch_size, env);
const QueriesPromptTokens queries_prompt(&prompt, kNumQueries);
QueriesPos queries_pos(&pos, kNumQueries);
const QueriesPos queries_prefix_end(&prefix_end, kNumQueries);
const KVCaches kv_caches{&kv_cache, kNumQueries};
GenerateT<T>(model, weights, activations, runtime_config, queries_prompt,
queries_pos, queries_prefix_end, qbatch_start, kv_caches,
timing_info);
GenerateT(config, weights, activations, runtime_config, queries_prompt,
queries_pos, queries_prefix_end, qbatch_start, kv_caches,
timing_info);
}
template <typename T>
void GenerateBatchT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches, MatMulEnv* env,
TimingInfo& timing_info) {
static HWY_MAYBE_UNUSED void GenerateBatchT(
const ModelConfig& config, const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, const KVCaches& kv_caches,
MatMulEnv* env, TimingInfo& timing_info) {
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(queries_pos.size() == num_queries);
HWY_ASSERT(kv_caches.size() >= num_queries);
@ -1484,7 +1469,7 @@ void GenerateBatchT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
const size_t max_batch_size =
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
Activations activations(model.Config(), max_batch_size, env);
Activations activations(config, max_batch_size, env);
for (size_t qbatch_start = 0; qbatch_start < num_queries;
qbatch_start += max_qbatch_size) {
@ -1497,68 +1482,31 @@ void GenerateBatchT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
qbatch_size);
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
GenerateT<T>(model, weights, activations, runtime_config, qbatch_prompts,
qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv,
timing_info);
GenerateT(config, weights, activations, runtime_config, qbatch_prompts,
qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv,
timing_info);
}
}
template <typename T>
void GenerateImageTokensT(const ModelStore& model,
const ModelWeightsPtrs<T>& weights,
const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens,
MatMulEnv* env) {
if (model.Config().vit_config.layer_configs.empty()) {
static HWY_MAYBE_UNUSED void GenerateImageTokensT(
const ModelConfig& config, const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config, const Image& image,
ImageTokens& image_tokens, MatMulEnv* env) {
if (config.vit_config.layer_configs.empty()) {
HWY_ABORT("Model does not support generating image tokens.");
}
RuntimeConfig prefill_runtime_config = runtime_config;
ModelConfig vit_config = GetVitConfig(model.Config());
ModelConfig vit_config = GetVitConfig(config);
prefill_runtime_config.prefill_tbatch_size =
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config, vit_config.seq_len, env);
// Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(weights, prefill_runtime_config, image, image_tokens,
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
prefill_activations);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
#if HWY_ONCE
// These are extern functions defined by instantiations/*.cc, which include this
// 'header' after defining `GEMMA_TYPE`.
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
const ModelStore& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
size_t prefix_end, KVCache& kv_cache, MatMulEnv* env,
TimingInfo& timing_info) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT<GEMMA_TYPE>)
(model, weights, runtime_config, prompt, pos, prefix_end, kv_cache, env,
timing_info);
}
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
const ModelStore& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, const KVCaches& kv_caches,
MatMulEnv* env, TimingInfo& timing_info) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT<GEMMA_TYPE>)
(model, weights, runtime_config, queries_prompt, queries_pos,
queries_prefix_end, kv_caches, env, timing_info);
}
void GenerateImageTokens( // NOLINT(misc-definitions-in-headers)
const ModelStore& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
const RuntimeConfig& runtime_config, const Image& image,
ImageTokens& image_tokens, MatMulEnv* env) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImageTokensT<GEMMA_TYPE>)
(model, weights, runtime_config, image, image_tokens, env);
}
#endif // HWY_ONCE
} // namespace gcpp
HWY_AFTER_NAMESPACE();

View File

@ -13,21 +13,33 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Defines Gemma member functions; the actual implementations are in
// gemma-inl.h, included from instantiations/*.cc.
// Defines Gemma member functions which dynamic-dispatch into the SIMD
// implementations in gemma-inl.h.
#include "gemma/gemma.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "gemma/gemma-inl.h"
#ifndef GEMMA_CC_ONCE
#define GEMMA_CC_ONCE
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <memory>
#include <vector>
// Placeholder for internal header, do not modify.
#include "compression/types.h"
#include "gemma/configs.h"
#include "gemma/model_store.h"
#include "gemma/tokenizer.h"
@ -39,7 +51,13 @@
#include "util/threading_context.h"
#include "hwy/base.h"
#endif // GEMMA_CC_ONCE
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(GenerateSingleT);
HWY_EXPORT(GenerateBatchT);
HWY_EXPORT(GenerateImageTokensT);
// Internal init must run before I/O. This helper function takes care of that,
// plus calling `SetArgs`.
@ -52,12 +70,13 @@ MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) {
Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env)
: env_(env),
reader_(new BlobReader(loader.weights)),
model_(*reader_, loader.tokenizer, loader.wrapping),
weights_(model_.Config().weight),
reader_(loader.weights),
model_(reader_, loader.tokenizer, loader.wrapping),
weights_(model_.Config()),
chat_template_(model_.Tokenizer(), model_.Config().model) {
weights_.ReadFromBlobs(model_, *reader_, loader.map, env_.ctx.pools.Pool());
reader_.reset();
weights_.ReadFromBlobs(model_, reader_, loader.map, mat_owners_,
env.ctx.pools.Pool());
reader_.CloseFile();
}
Gemma::~Gemma() = default;
@ -70,42 +89,14 @@ void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const {
writer, env_.ctx.pools.Pool(), weights_path);
}
// There are >=3 types of the inference code. To reduce compile time,
// we shard them across multiple translation units in instantiations/*.cc.
// This declares the functions defined there. We use overloading because
// explicit instantiations are still too slow to compile.
// TODO: we want to move toward type-erasing, where we check the tensor type at
// each usage. Then we would have a single function, passing `WeightsOwner`
// instead of `WeightsPtrs<T>`.
#define GEMMA_DECLARE(WEIGHT_TYPE) \
extern void GenerateSingle( \
const ModelStore& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
const RuntimeConfig& runtime_config, const PromptTokens& prompt, \
size_t pos, size_t prefix_end, KVCache& kv_cache, MatMulEnv* env, \
TimingInfo& timing_info); \
extern void GenerateBatch( \
const ModelStore& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \
const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \
const KVCaches& kv_caches, MatMulEnv* env, TimingInfo& timing_info); \
extern void GenerateImageTokens( \
const ModelStore& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
const RuntimeConfig& runtime_config, const Image& image, \
ImageTokens& image_tokens, MatMulEnv* env);
GEMMA_DECLARE(float)
GEMMA_DECLARE(BF16)
GEMMA_DECLARE(NuqStream)
GEMMA_DECLARE(SfpStream)
void Gemma::Generate(const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, size_t prefix_end,
KVCache& kv_cache, TimingInfo& timing_info) const {
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
weights_.CallT([&](auto& weights) {
GenerateSingle(model_, *weights, runtime_config, prompt, pos, prefix_end,
kv_cache, &env_, timing_info);
});
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(model_.Config(), weights_,
runtime_config, prompt, pos, prefix_end,
kv_cache, &env_, timing_info);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}
@ -127,11 +118,9 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
weights_.CallT([&](auto& weights) {
gcpp::GenerateBatch(model_, *weights, runtime_config, queries_prompt,
queries_pos, mutable_queries_prefix_end, kv_caches,
&env_, timing_info);
});
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(
model_.Config(), weights_, runtime_config, queries_prompt, queries_pos,
mutable_queries_prefix_end, kv_caches, &env_, timing_info);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}
@ -141,12 +130,11 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
ImageTokens& image_tokens) const {
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
weights_.CallT([&](auto& weights) {
gcpp::GenerateImageTokens(model_, *weights, runtime_config, image,
image_tokens, &env_);
});
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(
model_.Config(), weights_, runtime_config, image, image_tokens, &env_);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -18,7 +18,7 @@
#include <stdio.h>
#include <memory>
#include <vector>
// IWYU pragma: begin_exports
#include "gemma/activations.h"
@ -113,11 +113,9 @@ class Gemma {
// TODO: rename to Config()
const ModelConfig& GetModelConfig() const { return model_.Config(); }
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
const WeightsOwner& Weights() const { return weights_; }
const ModelWeightsPtrs& Weights() const { return weights_; }
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
// For tests.
WeightsOwner& MutableWeights() { return weights_; }
void Save(const Path& weights_path, hwy::ThreadPool& pool) const;
// `pos` is the position in the KV cache. Users are responsible for
@ -154,9 +152,10 @@ class Gemma {
private:
MatMulEnv& env_;
std::unique_ptr<BlobReader> reader_; // null for second ctor
BlobReader reader_;
ModelStore model_;
WeightsOwner weights_;
std::vector<MatOwner> mat_owners_;
ModelWeightsPtrs weights_;
GemmaChatTemplate chat_template_;
};

View File

@ -1,21 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/instantiations/bf16.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_TYPE hwy::bfloat16_t
#include "gemma/gemma-inl.h"

View File

@ -1,20 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/instantiations/f32.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_TYPE float
#include "gemma/gemma-inl.h"

View File

@ -1,20 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/instantiations/nuq.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_TYPE NuqStream
#include "gemma/gemma-inl.h"

View File

@ -1,20 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/instantiations/sfp.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_TYPE SfpStream
#include "gemma/gemma-inl.h"

View File

@ -372,6 +372,11 @@ bool ModelStore::FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const {
// `Compress()` output is always packed because it assumes a 1D array.
HWY_ASSERT(mat.IsPacked());
// Update fields. Name already matched, otherwise we would not find it.
// For MatPtr tensors, the type will be `kUnknown`. If it was a `MatPtrT`,
// ensure the type set via code matches the file.
HWY_ASSERT_M(
mat.GetType() == Type::kUnknown || mat.GetType() == file_mat->GetType(),
mat.Name());
mat.SetType(file_mat->GetType());
if (scales_.empty()) {
// `file_mat->Scale()` is either read from file, or we have pre-2025 format

View File

@ -1,6 +1,7 @@
#include "gemma/tensor_info.h"
#include <stddef.h>
#include <stdint.h>
#include <string>

View File

@ -21,13 +21,14 @@ TEST(TensorInfoRegistryTest, Find) {
config.Specifier().c_str());
const TensorInfoRegistry tensors(config);
// Each tensor in the model should be known/found.
ModelWeightsPtrs<SfpStream> weights(config);
ModelWeightsPtrs weights(config);
weights.ForEachTensor(nullptr, nullptr, [&tensors](const TensorArgs& t) {
const TensorInfo* info = tensors.Find(t.mat.Name());
HWY_ASSERT_M(info, t.mat.Name());
// Test that the `MatPtr` can be constructed from the TensorInfo,
// and that the dimensions match.
MatPtrT<SfpStream> mat_ptr(t.mat.Name(), tensors);
const MatPtr mat_ptr(t.mat.Name(), Type::kUnknown,
ExtentsFromInfo(tensors.Find(t.mat.Name())));
EXPECT_STREQ(t.mat.Name(), mat_ptr.Name()) << t.mat.Name();
EXPECT_EQ(t.mat.Rows(), mat_ptr.Rows()) << t.mat.Name();
EXPECT_EQ(t.mat.Cols(), mat_ptr.Cols()) << t.mat.Name();

View File

@ -20,7 +20,7 @@
#include <stdio.h>
#include <stdlib.h>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <vector>
@ -42,15 +42,127 @@
namespace gcpp {
static void InitAttWeightsNUQ(const LayerConfig& layer_config,
MatPtrT<NuqStream>& attn_vec_einsum_w,
MatPtrT<NuqStream>& att_weights,
std::vector<MatOwner>& mat_owners) {
// Copies att_weights from `attn_vec_einsum_w`.
void LayerWeightsPtrs::InitAttWeights(std::vector<MatOwner>& mat_owners) {
// We only use this tensor for Gemma layers.
if (layer_config.type != LayerAttentionType::kGemma) return;
// Files must have one or the other.
HWY_ASSERT(attn_vec_einsum_w.HasPtr() ^ att_weights.HasPtr());
// Done if we already read the transposed tensor.
if (att_weights.HasPtr() && !attn_vec_einsum_w.HasPtr()) return;
// NUQ is handled by a specialization in weights.cc.
HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ);
const size_t model_dim = layer_config.model_dim;
const size_t heads = layer_config.heads;
const size_t qkv_dim = layer_config.qkv_dim;
// Reshape [heads, model_dim, qkv_dim] to [model_dim, heads * qkv_dim].
att_weights.SetType(attn_vec_einsum_w.GetType());
HWY_ASSERT(att_weights.Rows() == model_dim);
HWY_ASSERT(att_weights.Cols() == heads * qkv_dim);
HWY_ASSERT(attn_vec_einsum_w.Rows() == heads * model_dim);
HWY_ASSERT(attn_vec_einsum_w.Cols() == qkv_dim);
{
static std::mutex m;
std::lock_guard<std::mutex> lock(m);
mat_owners.push_back(MatOwner());
mat_owners.back().AllocateFor(att_weights, MatPadding::kOdd);
}
const size_t T_bytes = att_weights.ElementBytes();
for (size_t m = 0; m < model_dim; ++m) {
uint8_t* HWY_RESTRICT out_row = att_weights.RowBytes(m);
for (size_t h = 0; h < heads; ++h) {
hwy::CopyBytes(attn_vec_einsum_w.RowBytes(h * model_dim + m),
out_row + h * qkv_dim * T_bytes, qkv_dim * T_bytes);
}
}
att_weights.SetScale(attn_vec_einsum_w.Scale());
}
// For FFN. Fast, only updates pointers.
void LayerWeightsPtrs::SplitW1() {
// Used for Gemma and Griffin layers; FFWVit uses different tensors.
if (layer_config.type == LayerAttentionType::kVit) return;
// Files have both or neither of w1 and w2.
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
// w is mutually exclusive with w1 and w2 in the file.
HWY_ASSERT(gating_einsum_w.HasPtr() ^ gating_einsum_w1.HasPtr());
// Done if we already read split tensors. Note that they are not
// necessarily the same type.
if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return;
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
HWY_ASSERT(gating_einsum_w.Rows() == 2 * ff_hidden_dim);
HWY_ASSERT(gating_einsum_w1.Rows() == ff_hidden_dim);
HWY_ASSERT(gating_einsum_w2.Rows() == ff_hidden_dim);
// Cols are the model_dim but we don't have ModelConfig here.
HWY_ASSERT(gating_einsum_w1.Cols() == gating_einsum_w.Cols());
HWY_ASSERT(gating_einsum_w2.Cols() == gating_einsum_w.Cols());
const size_t stride = gating_einsum_w.Stride();
gating_einsum_w1.SetPtr(gating_einsum_w.RowBytes(0), stride);
gating_einsum_w2.SetPtr(gating_einsum_w.RowBytes(ff_hidden_dim), stride);
gating_einsum_w1.SetType(gating_einsum_w.GetType());
gating_einsum_w2.SetType(gating_einsum_w.GetType());
gating_einsum_w1.SetScale(gating_einsum_w.Scale());
gating_einsum_w2.SetScale(gating_einsum_w.Scale());
gating_einsum_w.SetPtr(nullptr, gating_einsum_w.Cols());
}
// For attention, which might not have a w2. Fast, only updates pointers.
void LayerWeightsPtrs::SplitAttW1() {
// We only use this tensor for Gemma layers.
if (layer_config.type != LayerAttentionType::kGemma) return;
// w is mutually exclusive with w1 in the file.
HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr());
// Done if we already read split tensors. Note that w2 does not exist for
// MHA, and otherwise might not be the same type.
if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return;
const size_t w1_rows = layer_config.heads * layer_config.qkv_dim;
const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim;
HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows);
HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows);
HWY_ASSERT(qkv_einsum_w2.Rows() == w2_rows);
// Cols are the model_dim but we don't have ModelConfig here.
HWY_ASSERT(qkv_einsum_w1.Cols() == qkv_einsum_w.Cols());
HWY_ASSERT(qkv_einsum_w2.Cols() == qkv_einsum_w.Cols());
const size_t stride = qkv_einsum_w.Stride();
qkv_einsum_w1.SetPtr(qkv_einsum_w.RowBytes(0), stride);
qkv_einsum_w2.SetPtr(qkv_einsum_w.RowBytes(w1_rows), stride);
qkv_einsum_w1.SetType(qkv_einsum_w.GetType());
qkv_einsum_w2.SetType(qkv_einsum_w.GetType());
qkv_einsum_w1.SetScale(qkv_einsum_w.Scale());
qkv_einsum_w2.SetScale(qkv_einsum_w.Scale());
qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols());
}
// Must be called after reading weights via `ForEachTensor`.
// TODO: exporters should bake this into the weights already.
// WARNING: called from multiple threads; `mat_owners` requires a lock.
void LayerWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners) {
// TODO(janwas): handle NUQ
InitAttWeights(mat_owners);
SplitW1();
SplitAttW1();
}
static void HWY_MAYBE_UNUSED InitAttWeightsNUQ(
const LayerConfig& layer_config, MatPtrT<NuqStream>& attn_vec_einsum_w,
MatPtrT<NuqStream>& att_weights, std::vector<MatOwner>& mat_owners) {
if (!attn_vec_einsum_w.HasPtr()) return;
HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ);
HWY_ASSERT(att_weights.HasPtr());
HWY_ASSERT(att_weights.GetType() == Type::kNUQ);
att_weights.SetType(Type::kNUQ);
const size_t model_dim = layer_config.model_dim;
const size_t heads = layer_config.heads;
@ -85,14 +197,53 @@ static void InitAttWeightsNUQ(const LayerConfig& layer_config,
att_weights.SetScale(attn_vec_einsum_w.Scale());
}
static void SplitW1NUQ(const LayerConfig& layer_config) {
static void HWY_MAYBE_UNUSED SplitW1NUQ(const LayerConfig& layer_config) {
// TODO(janwas): implement.
}
template <>
void LayerWeightsPtrs<NuqStream>::Fixup(std::vector<MatOwner>& mat_owners) {
InitAttWeightsNUQ(layer_config, attn_vec_einsum_w, att_weights, mat_owners);
SplitW1NUQ(layer_config);
// Zero-initializes only the allocated tensors in `*this`.
void ModelWeightsPtrs::ZeroInit() {
ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) {
if (!t.mat.HasPtr()) return;
gcpp::ZeroInit(t.mat);
});
}
// Copies only the allocated tensors in `*this` from tensors in `other`.
void ModelWeightsPtrs::CopyFrom(const ModelWeightsPtrs& other) {
ForEachTensor(const_cast<ModelWeightsPtrs*>(&other), nullptr,
[](const TensorArgs& t) {
if (!t.mat.HasPtr()) return;
HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr());
CopyMat(*t.other_mat1, t.mat);
});
}
// For reshaping file tensors to the shape expected by the code. This would
// ideally already happen in the importer. Called by WeightsOwner::Fixup.
void ModelWeightsPtrs::Fixup(std::vector<MatOwner>& mat_owners,
hwy::ThreadPool& pool) {
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
GetLayer(layer)->Fixup(mat_owners);
});
pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
VitLayer(layer)->Fixup(mat_owners);
});
}
std::vector<uint32_t> ModelWeightsPtrs::AddTensorDataToWriter(
BlobWriter& writer) const {
std::vector<uint32_t> serialized_mat_ptrs;
// ForEachTensor is non-const but the lambda does not modify *this.
const_cast<ModelWeightsPtrs*>(this)->ForEachTensor(
nullptr, nullptr, [&](const TensorArgs& t) {
if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return;
HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name());
writer.Add(t.mat.Name(), t.mat.Packed(), t.mat.PackedBytes());
t.mat.AppendTo(serialized_mat_ptrs);
});
return serialized_mat_ptrs;
}
struct TensorToRead {
@ -144,7 +295,7 @@ static Mode ChooseMode(uint64_t file_bytes, Tristate map) {
return Mode::kRead;
}
MapPtr MapFileOrNull(File& file, uint64_t file_bytes) {
static MapPtr MapFileOrNull(File& file, uint64_t file_bytes) {
const Allocator& allocator = ThreadingContext::Get().allocator;
if (file_bytes % allocator.BasePageBytes() == 0) {
MapPtr mapped = file.Map();
@ -175,8 +326,8 @@ static void MapAll(const std::vector<TensorToRead>& tensors,
}
}
std::vector<IOBatch> MakeBatches(const std::vector<TensorToRead>& tensors,
const uint64_t file_bytes) {
static std::vector<IOBatch> MakeBatches(
const std::vector<TensorToRead>& tensors, const uint64_t file_bytes) {
PROFILER_ZONE("Startup.Weights.MakeBatches");
// Batches must be contiguous but blobs are padded, hence at least one
// batch per tensor, and more when tensor rows exceed the batch size.
@ -254,72 +405,34 @@ static void MapOrReadAll(const std::vector<TensorToRead>& tensors,
ReadBatches(reader, batches, pool);
}
void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader,
Tristate map, hwy::ThreadPool& pool) {
void ModelWeightsPtrs::ReadFromBlobs(const ModelStore& model,
BlobReader& reader, Tristate map,
std::vector<MatOwner>& mat_owners,
hwy::ThreadPool& pool) {
// List of tensors to read/map, and where from.
std::vector<TensorToRead> tensors;
AllocatePointer(model.Config());
// Enumerate all weights (negligible cost).
CallT([&](const auto& weights) {
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
const MatPadding padding = (t.flags & TensorArgs::kPacked)
? MatPadding::kPacked
: MatPadding::kOdd;
size_t key_idx;
if (model.FindAndUpdateMatPtr(t.mat, key_idx)) {
tensors.push_back({.mat = &t.mat,
.range = reader.Range(key_idx),
.padding = padding});
return;
}
if (t.flags & TensorArgs::kMaybeRead) return; // optional and not found.
HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name());
});
ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
const MatPadding padding = (t.flags & TensorArgs::kPacked)
? MatPadding::kPacked
: MatPadding::kOdd;
size_t key_idx;
if (model.FindAndUpdateMatPtr(t.mat, key_idx)) {
tensors.push_back(
{.mat = &t.mat, .range = reader.Range(key_idx), .padding = padding});
return;
}
if (t.flags & TensorArgs::kMaybeRead) return; // optional and not found.
HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name());
});
MapOrReadAll(tensors, reader, map, mat_owners_, pool);
MapOrReadAll(tensors, reader, map, mat_owners, pool);
Fixup(pool);
}
// Allocates `*_weights_`, but not yet the tensors inside. This is split out
// of `CallT` because that is const, hence it would pass a const& of the
// `unique_ptr` to its lambda, but we want to reset the pointer.
void WeightsOwner::AllocatePointer(const ModelConfig& config) {
switch (weight_type_) {
case Type::kSFP:
sfp_weights_.reset(new ModelWeightsPtrs<SfpStream>(config));
break;
case Type::kNUQ:
nuq_weights_.reset(new ModelWeightsPtrs<NuqStream>(config));
break;
case Type::kBF16:
bf16_weights_.reset(new ModelWeightsPtrs<BF16>(config));
break;
default:
HWY_ABORT("Unsupported weight type %s.", TypeName(weight_type_));
{
PROFILER_ZONE("Startup.Fixup");
Fixup(mat_owners, pool);
}
}
void WeightsOwner::Fixup(hwy::ThreadPool& pool) {
PROFILER_ZONE("Startup.Fixup");
CallT([&](const auto& weights) { weights->Fixup(mat_owners_, pool); });
}
std::vector<uint32_t> WeightsOwner::AddTensorDataToWriter(
BlobWriter& writer) const {
std::vector<uint32_t> serialized_mat_ptrs;
CallT([&](const auto& weights) {
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return;
HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name());
writer.Add(t.mat.Name(), t.mat.Packed(), t.mat.PackedBytes());
t.mat.AppendTo(serialized_mat_ptrs);
});
});
return serialized_mat_ptrs;
}
} // namespace gcpp

View File

@ -19,18 +19,14 @@
#include <stddef.h>
#include <stdint.h>
#include <complex>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <vector>
#include "compression/types.h" // IsF32
#include "compression/types.h"
#include "gemma/configs.h" // ModelConfig
#include "gemma/model_store.h" // ModelStore
#include "gemma/tensor_info.h" // TensorInfoRegistry
#include "io/blob_store.h" // BlobWriter
#include "ops/matmul.h" // MatMulEnv
#include "util/mat.h" // MatPtr
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -73,142 +69,141 @@ struct TensorArgs {
TensorArgs(mat, other1 ? &other1->mat : nullptr, \
other2 ? &other2->mat : nullptr, TensorArgs::flag)
// Per-layer weight metadata and pointers. The tensor data is owned by
// `WeightsOwner`. Note that this class could be type-erased: member functions
// do not actually use the `Weight` template argument. See `WeightsPtrs`.
// `TensorInfoRegistry` (constructed from `ModelConfig`) is the source of truth
// for all tensor shapes.
template <class Weight>
struct LayerWeightsPtrs {
static inline std::string Concat(const char* base_name,
const std::string& suffix) {
return std::string(base_name) + suffix;
// Finds tensors by name in `TensorInfoRegistry` (constructed from
// `ModelConfig`) and constructs `MatPtr` metadata with those shapes.
class MatFinder {
public:
MatFinder(const std::string& suffix, const TensorInfoRegistry& tensors)
: suffix_(suffix), tensors_(tensors) {}
// Retrieves shape by name via `TensorInfo` from `TensorInfoRegistry`.
MatPtr operator()(const std::string& base_name) const {
const std::string name = std::string(base_name) + suffix_;
return MatPtr(name.c_str(), Type::kUnknown,
ExtentsFromInfo(tensors_.Find(name)));
}
private:
const std::string suffix_;
const TensorInfoRegistry& tensors_;
};
// Per-layer weight metadata and pointers. The tensor data is owned by
// `WeightsOwner`.
struct LayerWeightsPtrs {
// Initializes tensor metadata without allocating.
LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config,
const TensorInfoRegistry& tensors)
: suffix_(LayerSuffix(layer_idx)),
qkv_einsum_w(Concat("qkv_ein", suffix_), tensors),
qkv_einsum_w1(Concat("qkv1_w", suffix_), tensors),
qkv_einsum_w2(Concat("qkv2_w", suffix_), tensors),
attention_output_biases(Concat("attn_ob", suffix_), tensors),
griffin(
{.linear_x_w = {Concat("gr_lin_x_w", suffix_), tensors},
.linear_x_biases = {Concat("gr_lin_x_b", suffix_), tensors},
.linear_y_w = {Concat("gr_lin_y_w", suffix_), tensors},
.linear_y_biases = {Concat("gr_lin_y_b", suffix_), tensors},
.linear_out_w = {Concat("gr_lin_out_w", suffix_), tensors},
.linear_out_biases = {Concat("gr_lin_out_b", suffix_), tensors},
.conv_w = {Concat("gr_conv_w", suffix_), tensors},
.conv_biases = {Concat("gr_conv_b", suffix_), tensors},
.gate_w = {Concat("gr_gate_w", suffix_), tensors},
.gate_biases = {Concat("gr_gate_b", suffix_), tensors},
.a = {Concat("gr_a", suffix_), tensors}}),
: finder_(LayerSuffix(layer_idx), tensors),
qkv_einsum_w(finder_("qkv_ein")),
qkv_einsum_w1(finder_("qkv1_w")),
qkv_einsum_w2(finder_("qkv2_w")),
attention_output_biases(finder_("attn_ob")),
griffin({.linear_x_w = finder_("gr_lin_x_w"),
.linear_x_biases = finder_("gr_lin_x_b"),
.linear_y_w = finder_("gr_lin_y_w"),
.linear_y_biases = finder_("gr_lin_y_b"),
.linear_out_w = finder_("gr_lin_out_w"),
.linear_out_biases = finder_("gr_lin_out_b"),
.conv_w = finder_("gr_conv_w"),
.conv_biases = finder_("gr_conv_b"),
.gate_w = finder_("gr_gate_w"),
.gate_biases = finder_("gr_gate_b"),
.a = finder_("gr_a")}),
// MultiHeadDotProductAttention.
vit({.attn_out_w = {Concat("attn_out_w", suffix_), tensors},
.attn_out_b = {Concat("attn_out_b", suffix_), tensors},
.qkv_einsum_w = {Concat("qkv_ein_w", suffix_), tensors},
.qkv_einsum_b = {Concat("qkv_ein_b", suffix_), tensors},
.linear_0_w = {Concat("linear_0_w", suffix_), tensors},
.linear_0_b = {Concat("linear_0_b", suffix_), tensors},
.linear_1_w = {Concat("linear_1_w", suffix_), tensors},
.linear_1_b = {Concat("linear_1_b", suffix_), tensors},
.layer_norm_0_bias = {Concat("ln_0_bias", suffix_), tensors},
.layer_norm_0_scale = {Concat("ln_0_scale", suffix_), tensors},
.layer_norm_1_bias = {Concat("ln_1_bias", suffix_), tensors},
.layer_norm_1_scale = {Concat("ln_1_scale", suffix_), tensors}}),
gating_einsum_w(Concat("gating_ein", suffix_), tensors),
gating_einsum_w1(Concat("gating1_w", suffix_), tensors),
gating_einsum_w2(Concat("gating2_w", suffix_), tensors),
linear_w(Concat("linear_w", suffix_), tensors),
pre_attention_norm_scale(Concat("pre_att_ns", suffix_), tensors),
pre_ffw_norm_scale(Concat("pre_ff_ns", suffix_), tensors),
post_attention_norm_scale(Concat("post_att_ns", suffix_), tensors),
post_ffw_norm_scale(Concat("post_ff_ns", suffix_), tensors),
ffw_gating_biases(Concat("ffw_gat_b", suffix_), tensors),
ffw_output_biases(Concat("ffw_out_b", suffix_), tensors),
vit({.attn_out_w = finder_("attn_out_w"),
.attn_out_b = finder_("attn_out_b"),
.qkv_einsum_w = finder_("qkv_ein_w"),
.qkv_einsum_b = finder_("qkv_ein_b"),
.linear_0_w = finder_("linear_0_w"),
.linear_0_b = finder_("linear_0_b"),
.linear_1_w = finder_("linear_1_w"),
.linear_1_b = finder_("linear_1_b"),
.layer_norm_0_bias = finder_("ln_0_bias"),
.layer_norm_0_scale = finder_("ln_0_scale"),
.layer_norm_1_bias = finder_("ln_1_bias"),
.layer_norm_1_scale = finder_("ln_1_scale")}),
gating_einsum_w(finder_("gating_ein")),
gating_einsum_w1(finder_("gating1_w")),
gating_einsum_w2(finder_("gating2_w")),
linear_w(finder_("linear_w")),
pre_attention_norm_scale(finder_("pre_att_ns")),
pre_ffw_norm_scale(finder_("pre_ff_ns")),
post_attention_norm_scale(finder_("post_att_ns")),
post_ffw_norm_scale(finder_("post_ff_ns")),
ffw_gating_biases(finder_("ffw_gat_b")),
ffw_output_biases(finder_("ffw_out_b")),
attn_vec_einsum_w(Concat("att_ein", suffix_), tensors),
att_weights(Concat("att_w", suffix_), tensors),
attn_vec_einsum_w(finder_("att_ein")),
att_weights(finder_("att_w")),
key_norm_scale(Concat("key_norm", suffix_), tensors),
query_norm_scale(Concat("query_norm", suffix_), tensors),
key_norm_scale(finder_("key_norm")),
query_norm_scale(finder_("query_norm")),
layer_config(config) {
}
~LayerWeightsPtrs() = default;
const std::string suffix_;
// If weights are f32, also f32; otherwise at least bf16. Useful for ops that
// do not yet support smaller compressed types, or require at least bf16. When
// weights are f32, we also want such tensors to be f32.
// If weights are complex, this is also complex.
using WeightF32OrBF16 =
hwy::If<hwy::IsSame<Weight, std::complex<double>>(), std::complex<double>,
hwy::If<hwy::IsSame<Weight, double>(), double,
hwy::If<IsF32<Weight>(), float, BF16>>>;
const MatFinder finder_;
// Files either have qkv_einsum_w with 2 stacked matrices or separate
// w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h.
MatPtrT<Weight> qkv_einsum_w;
MatPtrT<Weight> qkv_einsum_w1;
MatPtrT<Weight> qkv_einsum_w2;
MatPtr qkv_einsum_w;
MatPtr qkv_einsum_w1;
MatPtr qkv_einsum_w2;
MatPtrT<float> attention_output_biases;
struct {
MatPtrT<Weight> linear_x_w;
MatPtr linear_x_w;
MatPtrT<float> linear_x_biases;
MatPtrT<Weight> linear_y_w;
MatPtr linear_y_w;
MatPtrT<float> linear_y_biases;
MatPtrT<Weight> linear_out_w;
MatPtr linear_out_w;
MatPtrT<float> linear_out_biases;
MatPtrT<float> conv_w;
MatPtrT<float> conv_biases;
MatPtrT<Weight> gate_w;
MatPtr gate_w;
MatPtrT<float> gate_biases;
MatPtrT<float> a;
} griffin;
struct {
// MultiHeadDotProductAttention.
MatPtrT<WeightF32OrBF16> attn_out_w;
MatPtr attn_out_w; // at least BF16.
MatPtrT<float> attn_out_b;
MatPtrT<WeightF32OrBF16> qkv_einsum_w;
MatPtr qkv_einsum_w; // at least BF16.
MatPtrT<float> qkv_einsum_b;
// MlpBlock.
MatPtrT<WeightF32OrBF16> linear_0_w;
MatPtr linear_0_w; // at least BF16.
MatPtrT<float> linear_0_b;
MatPtrT<WeightF32OrBF16> linear_1_w;
MatPtr linear_1_w; // at least BF16.
MatPtrT<float> linear_1_b;
// LayerNorm.
MatPtrT<WeightF32OrBF16> layer_norm_0_bias;
MatPtrT<WeightF32OrBF16> layer_norm_0_scale;
MatPtrT<WeightF32OrBF16> layer_norm_1_bias;
MatPtrT<WeightF32OrBF16> layer_norm_1_scale;
MatPtr layer_norm_0_bias; // at least BF16.
MatPtr layer_norm_0_scale; // at least BF16.
MatPtr layer_norm_1_bias; // at least BF16.
MatPtr layer_norm_1_scale; // at least BF16.
} vit;
// Files either have gating_einsum_w with 2 stacked matrices or separate
// w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h.
MatPtrT<Weight> gating_einsum_w;
MatPtrT<Weight> gating_einsum_w1;
MatPtrT<Weight> gating_einsum_w2;
MatPtrT<Weight> linear_w;
// > W8 is likely helpful.
MatPtrT<WeightF32OrBF16> pre_attention_norm_scale;
MatPtrT<WeightF32OrBF16> pre_ffw_norm_scale;
MatPtrT<WeightF32OrBF16> post_attention_norm_scale;
MatPtrT<WeightF32OrBF16> post_ffw_norm_scale;
// w1/w2 tensors. `Fixup` ensures w1/w2 are ready for use by gemma-inl.h.
MatPtr gating_einsum_w;
MatPtr gating_einsum_w1;
MatPtr gating_einsum_w2;
MatPtr linear_w;
MatPtr pre_attention_norm_scale; // at least BF16.
MatPtr pre_ffw_norm_scale; // at least BF16.
MatPtr post_attention_norm_scale; // at least BF16.
MatPtr post_ffw_norm_scale; // at least BF16.
MatPtrT<float> ffw_gating_biases;
MatPtrT<float> ffw_output_biases;
MatPtrT<Weight> attn_vec_einsum_w; // Use att_weights instead of this.
MatPtrT<Weight> att_weights; // Use this instead of attn_vec_einsum_w.
MatPtr attn_vec_einsum_w; // Use att_weights instead of this.
MatPtr att_weights; // Use this instead of attn_vec_einsum_w.
MatPtrT<WeightF32OrBF16> key_norm_scale;
MatPtrT<WeightF32OrBF16> query_norm_scale;
MatPtr key_norm_scale; // at least BF16.
MatPtr query_norm_scale; // at least BF16.
const LayerConfig& layer_config;
@ -217,8 +212,8 @@ struct LayerWeightsPtrs {
// can also iterate over pairs or triples of tensors for `AdamUpdateMV`.
// Public because also called by `WeightsPtrs`.
template <class Func>
void ForEachTensor(LayerWeightsPtrs<Weight>* other1,
LayerWeightsPtrs<Weight>* other2, Func func) {
void ForEachTensor(LayerWeightsPtrs* other1, LayerWeightsPtrs* other2,
Func func) {
if (layer_config.type == LayerAttentionType::kVit) {
// MHA.
func(TENSOR_ARGS(vit.attn_out_w, kMustRead));
@ -301,208 +296,98 @@ struct LayerWeightsPtrs {
// Must be called after reading weights via `ForEachTensor`.
// TODO: exporters should bake this into the weights already.
// WARNING: called from multiple threads; `mat_owners` requires a lock.
void Fixup(std::vector<MatOwner>& mat_owners) {
InitAttWeights(mat_owners);
SplitW1();
SplitAttW1();
}
void Fixup(std::vector<MatOwner>& mat_owners);
private:
// Copies att_weights from `attn_vec_einsum_w`.
void InitAttWeights(std::vector<MatOwner>& mat_owners) {
// We only use this tensor for Gemma layers.
if (layer_config.type != LayerAttentionType::kGemma) return;
// Files must have one or the other.
HWY_ASSERT(attn_vec_einsum_w.HasPtr() ^ att_weights.HasPtr());
// Done if we already read the transposed tensor.
if (att_weights.HasPtr() && !attn_vec_einsum_w.HasPtr()) return;
// NUQ is handled by a specialization in weights.cc.
HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ);
const size_t model_dim = layer_config.model_dim;
const size_t heads = layer_config.heads;
const size_t qkv_dim = layer_config.qkv_dim;
// Reshape [heads, model_dim, qkv_dim] to [model_dim, heads * qkv_dim].
HWY_ASSERT(att_weights.GetType() == attn_vec_einsum_w.GetType());
HWY_ASSERT(att_weights.Rows() == model_dim);
HWY_ASSERT(att_weights.Cols() == heads * qkv_dim);
HWY_ASSERT(attn_vec_einsum_w.Rows() == heads * model_dim);
HWY_ASSERT(attn_vec_einsum_w.Cols() == qkv_dim);
{
static std::mutex m;
std::lock_guard<std::mutex> lock(m);
mat_owners.push_back(MatOwner());
mat_owners.back().AllocateFor(att_weights, MatPadding::kOdd);
}
const size_t T_bytes = att_weights.ElementBytes();
for (size_t m = 0; m < model_dim; ++m) {
uint8_t* HWY_RESTRICT out_row =
reinterpret_cast<uint8_t*>(att_weights.Row(m));
for (size_t h = 0; h < heads; ++h) {
hwy::CopyBytes(attn_vec_einsum_w.Row(h * model_dim + m),
out_row + h * qkv_dim * T_bytes, qkv_dim * T_bytes);
}
}
att_weights.SetScale(attn_vec_einsum_w.Scale());
}
void InitAttWeights(std::vector<MatOwner>& mat_owners);
// For FFN. Fast, only updates pointers.
void SplitW1() {
// Used for Gemma and Griffin layers; FFWVit uses different tensors.
if (layer_config.type == LayerAttentionType::kVit) return;
// Files have both or neither of w1 and w2.
HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr());
// w is mutually exclusive with w1 and w2 in the file.
HWY_ASSERT(gating_einsum_w.HasPtr() ^ gating_einsum_w1.HasPtr());
// Done if we already read split tensors. Note that they are not
// necessarily the same type.
if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return;
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
HWY_ASSERT(gating_einsum_w.Rows() == 2 * ff_hidden_dim);
HWY_ASSERT(gating_einsum_w1.Rows() == ff_hidden_dim);
HWY_ASSERT(gating_einsum_w2.Rows() == ff_hidden_dim);
// Cols are the model_dim but we don't have ModelConfig here.
HWY_ASSERT(gating_einsum_w1.Cols() == gating_einsum_w.Cols());
HWY_ASSERT(gating_einsum_w2.Cols() == gating_einsum_w.Cols());
const size_t stride = gating_einsum_w.Stride();
gating_einsum_w1.SetPtr(gating_einsum_w.Row(0), stride);
gating_einsum_w2.SetPtr(gating_einsum_w.Row(ff_hidden_dim), stride);
gating_einsum_w1.SetType(gating_einsum_w.GetType());
gating_einsum_w2.SetType(gating_einsum_w.GetType());
gating_einsum_w1.SetScale(gating_einsum_w.Scale());
gating_einsum_w2.SetScale(gating_einsum_w.Scale());
gating_einsum_w.SetPtr(nullptr, gating_einsum_w.Cols());
}
void SplitW1();
// For attention, which might not have a w2. Fast, only updates pointers.
void SplitAttW1() {
// We only use this tensor for Gemma layers.
if (layer_config.type != LayerAttentionType::kGemma) return;
// w is mutually exclusive with w1 in the file.
HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr());
// Done if we already read split tensors. Note that w2 does not exist for
// MHA, and otherwise might not be the same type.
if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return;
const size_t w1_rows = layer_config.heads * layer_config.qkv_dim;
const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim;
HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows);
HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows);
HWY_ASSERT(qkv_einsum_w2.Rows() == w2_rows);
// Cols are the model_dim but we don't have ModelConfig here.
HWY_ASSERT(qkv_einsum_w1.Cols() == qkv_einsum_w.Cols());
HWY_ASSERT(qkv_einsum_w2.Cols() == qkv_einsum_w.Cols());
const size_t stride = qkv_einsum_w.Stride();
qkv_einsum_w1.SetPtr(qkv_einsum_w.Row(0), stride);
qkv_einsum_w2.SetPtr(qkv_einsum_w.Row(w1_rows), stride);
qkv_einsum_w1.SetType(qkv_einsum_w.GetType());
qkv_einsum_w2.SetType(qkv_einsum_w.GetType());
qkv_einsum_w1.SetScale(qkv_einsum_w.Scale());
qkv_einsum_w2.SetScale(qkv_einsum_w.Scale());
qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols());
}
void SplitAttW1();
};
// Holds layer-independent weight metadata and pointers plus per-layer
// `LayerWeightsPtrs`. The tensor data is owned by `WeightsOwner`. As with
// `LayerWeightsPtrs`, this class could be type-erased: member functions do not
// actually use the `Weight` template argument. The template does allow user
// code to dispatch only once. However, most tensors are large enough that
// dispatch at each usage would be feasible.
// `LayerWeightsPtrs`. The tensor data is owned by `WeightsOwner`.
// TODO: move `gemma-inl.h` toward dispatch at each usage.
// TODO: rename to WeightsPtrs.
template <class Weight>
struct ModelWeightsPtrs {
using WeightT = Weight;
explicit ModelWeightsPtrs(const ModelConfig& config)
: tensors_(config),
// No suffix, these are per-model.
embedder_input_embedding("c_embedding", tensors_),
final_norm_scale("c_final_norm", tensors_),
vit_encoder_norm_bias("enc_norm_bias", tensors_),
vit_encoder_norm_scale("enc_norm_scale", tensors_),
vit_img_embedding_bias("img_emb_bias", tensors_),
vit_img_embedding_kernel("img_emb_kernel", tensors_),
vit_img_pos_embedding("img_pos_emb", tensors_),
vit_img_head_bias("img_head_bias", tensors_),
vit_img_head_kernel("img_head_kernel", tensors_),
mm_embed_norm("mm_embed_norm", tensors_),
weights_config(config) {
c_layers.reserve(config.layer_configs.size());
for (size_t idx = 0; idx < config.layer_configs.size(); ++idx) {
const LayerConfig& layer_config = config.layer_configs[idx];
: config_(config),
tensors_(config_),
finder_("", tensors_), // no suffix because these are per-model.
embedder_input_embedding(finder_("c_embedding")),
final_norm_scale(finder_("c_final_norm")),
vit_encoder_norm_bias(finder_("enc_norm_bias")),
vit_encoder_norm_scale(finder_("enc_norm_scale")),
vit_img_embedding_bias(finder_("img_emb_bias")),
vit_img_embedding_kernel(finder_("img_emb_kernel")),
vit_img_pos_embedding(finder_("img_pos_emb")),
vit_img_head_bias(finder_("img_head_bias")),
vit_img_head_kernel(finder_("img_head_kernel")),
mm_embed_norm(finder_("mm_embed_norm")),
c_layers() {
c_layers.reserve(config_.layer_configs.size());
for (size_t idx = 0; idx < config_.layer_configs.size(); ++idx) {
const LayerConfig& layer_config = config_.layer_configs[idx];
c_layers.emplace_back(idx, layer_config, tensors_);
}
for (size_t idx = 0; idx < config.vit_config.layer_configs.size(); ++idx) {
const LayerConfig& layer_config = config.vit_config.layer_configs[idx];
for (size_t idx = 0; idx < config_.vit_config.layer_configs.size(); ++idx) {
const LayerConfig& layer_config = config_.vit_config.layer_configs[idx];
vit_layers.emplace_back(idx, layer_config, tensors_);
}
}
~ModelWeightsPtrs() = default;
// = F32 if weights are F32, else BF16.
using WeightF32OrBF16 = typename LayerWeightsPtrs<Weight>::WeightF32OrBF16;
// Passed to all `MatPtrT` initializers, hence must be initialized first.
const ModelConfig& config_;
// Passed to finder_, hence must be initialized first.
const TensorInfoRegistry tensors_;
const MatFinder finder_;
// TODO: switch to SFP?
MatPtrT<WeightF32OrBF16> embedder_input_embedding;
MatPtrT<WeightF32OrBF16> final_norm_scale;
MatPtr embedder_input_embedding;
MatPtr final_norm_scale; // at least BF16.
// Vit parts.
MatPtrT<WeightF32OrBF16> vit_encoder_norm_bias;
MatPtrT<WeightF32OrBF16> vit_encoder_norm_scale;
MatPtr vit_encoder_norm_bias; // at least BF16.
MatPtr vit_encoder_norm_scale; // at least BF16.
MatPtrT<float> vit_img_embedding_bias;
MatPtrT<WeightF32OrBF16> vit_img_embedding_kernel;
MatPtrT<float> vit_img_pos_embedding;
MatPtr vit_img_embedding_kernel; // at least BF16.
MatPtr vit_img_pos_embedding; // F32?
// The head maps from VitConfig::model_dim (Vit final layer) to
// model_dim (LLM input).
MatPtrT<float> vit_img_head_bias;
MatPtrT<WeightF32OrBF16> vit_img_head_kernel;
MatPtr vit_img_head_kernel; // at least BF16.
MatPtrT<WeightF32OrBF16> mm_embed_norm;
MatPtr mm_embed_norm; // at least BF16.
const ModelConfig& weights_config;
std::vector<LayerWeightsPtrs> c_layers;
std::vector<LayerWeightsPtrs> vit_layers;
std::vector<LayerWeightsPtrs<Weight>> c_layers;
std::vector<LayerWeightsPtrs<Weight>> vit_layers;
const LayerWeightsPtrs<Weight>* GetLayer(size_t layer) const {
const LayerWeightsPtrs* GetLayer(size_t layer) const {
return &c_layers[layer];
}
LayerWeightsPtrs<Weight>* GetLayer(size_t layer) { return &c_layers[layer]; }
const LayerWeightsPtrs<Weight>* VitLayer(size_t layer) const {
return &vit_layers[layer];
}
LayerWeightsPtrs<Weight>* VitLayer(size_t layer) {
LayerWeightsPtrs* GetLayer(size_t layer) { return &c_layers[layer]; }
const LayerWeightsPtrs* VitLayer(size_t layer) const {
return &vit_layers[layer];
}
LayerWeightsPtrs* VitLayer(size_t layer) { return &vit_layers[layer]; }
// Called via `CallT`. `other1` and `other2` are usually null, but can be
// used to copy from another set of weights. Public because called by tests
// and `WeightsOwner`.
template <class Func>
void ForEachTensor(ModelWeightsPtrs<Weight>* other1,
ModelWeightsPtrs<Weight>* other2, Func func) {
LayerWeightsPtrs<Weight>* other_layer1 = nullptr;
LayerWeightsPtrs<Weight>* other_layer2 = nullptr;
void ForEachTensor(ModelWeightsPtrs* other1, ModelWeightsPtrs* other2,
Func func) {
LayerWeightsPtrs* other_layer1 = nullptr;
LayerWeightsPtrs* other_layer2 = nullptr;
func(TENSOR_ARGS(embedder_input_embedding, kMustRead));
func(TENSOR_ARGS(final_norm_scale, kMustRead));
if (!weights_config.vit_config.layer_configs.empty()) { // Vit parts.
if (!config_.vit_config.layer_configs.empty()) { // Vit parts.
func(TENSOR_ARGS(vit_encoder_norm_bias, kMustRead));
func(TENSOR_ARGS(vit_encoder_norm_scale, kMustRead));
func(TENSOR_ARGS(vit_img_embedding_bias, kMustRead));
@ -511,7 +396,7 @@ struct ModelWeightsPtrs {
func(TENSOR_ARGS(vit_img_head_bias, kMustRead));
func(TENSOR_ARGS(vit_img_head_kernel, kMustRead));
if (weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
if (config_.wrapping == PromptWrapping::GEMMA_VLM) {
func(TENSOR_ARGS(mm_embed_norm, kMustRead));
}
}
@ -522,8 +407,7 @@ struct ModelWeightsPtrs {
GetLayer(layer_idx)->ForEachTensor(other_layer1, other_layer2, func);
}
HWY_ASSERT(weights_config.vit_config.layer_configs.empty() ==
vit_layers.empty());
HWY_ASSERT(config_.vit_config.layer_configs.empty() == vit_layers.empty());
for (size_t layer_idx = 0; layer_idx < vit_layers.size(); ++layer_idx) {
HWY_ASSERT(vit_layers[layer_idx].layer_config.type ==
LayerAttentionType::kVit);
@ -534,87 +418,24 @@ struct ModelWeightsPtrs {
} // `ForEachTensor`
// Zero-initializes only the allocated tensors in `*this`.
void ZeroInit() {
ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) {
if (!t.mat.HasPtr()) return;
gcpp::ZeroInit(t.mat);
});
}
void ZeroInit();
// Copies only the allocated tensors in `*this` from tensors in `other`.
void CopyFrom(const ModelWeightsPtrs<Weight>& other) {
ForEachTensor(const_cast<ModelWeightsPtrs<Weight>*>(&other), nullptr,
[](const TensorArgs& t) {
if (!t.mat.HasPtr()) return;
HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr());
CopyMat(*t.other_mat1, t.mat);
});
}
// For reshaping file tensors to the shape expected by the code. This would
// ideally already happen in the importer. Must be called after reading and
// updating the attention weights.
void Fixup(std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool) {
pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
GetLayer(layer)->Fixup(mat_owners);
});
pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) {
VitLayer(layer)->Fixup(mat_owners);
});
}
}; // `WeightsPtrs`
#undef TENSOR_ARGS
// Type-erased facade for `WeightsPtrs<T>`, stored inside the non-template
// `Gemma`. Also owns the underlying memory.
class WeightsOwner {
public:
// `weight_type` is obtained from `ModelConfig` in `ModelStore`.
WeightsOwner(Type weight_type) : weight_type_(weight_type) {}
void CopyFrom(const ModelWeightsPtrs& other);
// Reads tensor data from `BlobStore` or aborts on error. `map` is a user
// override for whether to map blobs or read them.
void ReadFromBlobs(const ModelStore& model, BlobReader& reader, Tristate map,
hwy::ThreadPool& pool);
// Calls `func(std::unique_ptr<WeightsPtrs<T>>&, args)`. `func` typically
// calls `ForEachTensor`.
template <class Func, typename... TArgs>
decltype(auto) CallT(const Func& func, TArgs&&... args) const {
if (HWY_LIKELY(weight_type_ == Type::kSFP)) {
return func(sfp_weights_, std::forward<TArgs>(args)...);
} else if (weight_type_ == Type::kNUQ) {
return func(nuq_weights_, std::forward<TArgs>(args)...);
} else if (weight_type_ == Type::kBF16) {
return func(bf16_weights_, std::forward<TArgs>(args)...);
}
return HWY_ABORT("Unsupported weight type %s.", TypeName(weight_type_));
}
// For writers:
std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool);
// Adds one blob for each tensor's data and returns all serialized MatPtr.
std::vector<uint32_t> AddTensorDataToWriter(BlobWriter& writer) const;
private:
Type weight_type_;
// Allocates `*_weights_`, but not yet the tensors inside. This is split out
// of `CallT` so that can be const.
void AllocatePointer(const ModelConfig& config);
// Called by `ReadFromBlobs`.
void Fixup(hwy::ThreadPool& pool);
// Only one is non-null, determined by `weight_type_`.
std::unique_ptr<ModelWeightsPtrs<BF16>> bf16_weights_;
std::unique_ptr<ModelWeightsPtrs<SfpStream>> sfp_weights_;
std::unique_ptr<ModelWeightsPtrs<NuqStream>> nuq_weights_;
// Owns the memory referenced by all `MatPtr`.
std::vector<MatOwner> mat_owners_;
};
// For reshaping file tensors to the shape expected by the code. This would
// ideally already happen in the importer. Called by ReadFromBlobs.
void Fixup(std::vector<MatOwner>& mat_owners, hwy::ThreadPool& pool);
}; // `ModelWeightsPtrs`
#undef TENSOR_ARGS
} // namespace gcpp

View File

@ -59,6 +59,8 @@ class BlobReader {
const File& file() const { return *file_; }
uint64_t file_bytes() const { return file_bytes_; }
void CloseFile() { file_.reset(); }
const std::vector<std::string>& Keys() const { return keys_; }
const BlobRange& Range(size_t key_idx) const {
@ -99,7 +101,7 @@ class BlobReader {
}
private:
const std::unique_ptr<File> file_;
std::unique_ptr<File> file_;
const uint64_t file_bytes_;
std::vector<std::string> keys_;

View File

@ -131,6 +131,13 @@ class MatPtr : public IFields {
Type GetType() const { return type_; }
void SetType(Type type) {
type_ = type;
if (type == Type::kUnknown) {
// Temporary invalid state. Happens during weights.h construction, before
// the ForEachTensor that loads them and sets the type.
element_bytes_ = 0;
num_elements_ = 0;
return;
}
element_bytes_ = static_cast<uint32_t>(hwy::DivCeil(TypeBits(type), 8));
num_elements_ = static_cast<uint32_t>(ComputeNumElements(type, Extents()));
HWY_DASSERT(0 != element_bytes_ && element_bytes_ <= 16);
@ -244,15 +251,17 @@ class MatPtrT : public MatPtr {
// Called by `MatStorageT`.
MatPtrT(const char* name, Extents2D extents)
: MatPtr(name, TypeEnum<MatT>(), extents) {}
// Retrieves shape by name via `TensorInfo` from `TensorInfoRegistry`. This is
// not a factory function because `weights.h` initializes members of type
// `MatPtrT<T>`, and `T` cannot be inferred at compile time from arguments.
MatPtrT(const std::string& name, const TensorInfoRegistry& info)
: MatPtrT<MatT>(name.c_str(), ExtentsFromInfo(info.Find(name))) {}
// Copying allowed because the metadata is small.
MatPtrT(const MatPtr& other) : MatPtr(other) {
HWY_ASSERT(other.GetType() == TypeEnum<MatT>());
// Happens in weights.h when constructing via MatFinder, which does not
// know the type. Setting the type here avoids having to keep the
// initializer list and member type in sync.
if (GetType() == Type::kUnknown) {
SetType(TypeEnum<MatT>());
} else {
HWY_ASSERT(other.GetType() == TypeEnum<MatT>());
}
}
MatPtrT& operator=(const MatPtr& other) {
MatPtr::operator=(other);
@ -289,27 +298,27 @@ class MatPtrT : public MatPtr {
}
};
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
// optional `args`. This supports all types used as weights.
// Calls `func` with `MatPtrT<T>*` plus the optional `args`. This supports all
// types used as weights.
template <class Func, typename... Args>
decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
Args&&... args) {
#if GEMMA_ENABLE_NUQ
if (base->GetType() == Type::kNUQ) {
return func(dynamic_cast<const MatPtrT<NuqStream>*>(base),
std::forward<Args>(args)...);
const MatPtrT<NuqStream> mat(*base);
return func(&mat, std::forward<Args>(args)...);
}
#endif // GEMMA_ENABLE_NUQ
if (base->GetType() == Type::kF32) {
return func(dynamic_cast<const MatPtrT<float>*>(base),
std::forward<Args>(args)...);
const MatPtrT<float> mat(*base);
return func(&mat, std::forward<Args>(args)...);
} else if (base->GetType() == Type::kBF16) {
return func(dynamic_cast<const MatPtrT<BF16>*>(base),
std::forward<Args>(args)...);
const MatPtrT<BF16> mat(*base);
return func(&mat, std::forward<Args>(args)...);
} else if (base->GetType() == Type::kSFP) {
return func(dynamic_cast<const MatPtrT<SfpStream>*>(base),
std::forward<Args>(args)...);
const MatPtrT<SfpStream> mat(*base);
return func(&mat, std::forward<Args>(args)...);
} else {
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
}
@ -323,24 +332,24 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
#if GEMMA_ENABLE_NUQ
if (base1->GetType() == Type::kNUQ) {
return func(dynamic_cast<const MatPtrT<NuqStream>*>(base1),
dynamic_cast<const MatPtrT<NuqStream>*>(base2),
std::forward<Args>(args)...);
const MatPtrT<NuqStream> mat1(*base1);
const MatPtrT<NuqStream> mat2(*base2);
return func(&mat1, &mat2, std::forward<Args>(args)...);
}
#endif // GEMMA_ENABLE_NUQ
if (base1->GetType() == Type::kF32) {
return func(dynamic_cast<const MatPtrT<float>*>(base1),
dynamic_cast<const MatPtrT<float>*>(base2),
std::forward<Args>(args)...);
const MatPtrT<float> mat1(*base1);
const MatPtrT<float> mat2(*base2);
return func(&mat1, &mat2, std::forward<Args>(args)...);
} else if (base1->GetType() == Type::kBF16) {
return func(dynamic_cast<const MatPtrT<BF16>*>(base1),
dynamic_cast<const MatPtrT<BF16>*>(base2),
std::forward<Args>(args)...);
const MatPtrT<BF16> mat1(*base1);
const MatPtrT<BF16> mat2(*base2);
return func(&mat1, &mat2, std::forward<Args>(args)...);
} else if (base1->GetType() == Type::kSFP) {
return func(dynamic_cast<const MatPtrT<SfpStream>*>(base1),
dynamic_cast<const MatPtrT<SfpStream>*>(base2),
std::forward<Args>(args)...);
const MatPtrT<SfpStream> mat1(*base1);
const MatPtrT<SfpStream> mat2(*base2);
return func(&mat1, &mat2, std::forward<Args>(args)...);
} else {
HWY_ABORT("Unhandled type %s.", TypeName(base1->GetType()));
}