diff --git a/BUILD.bazel b/BUILD.bazel index 8b8eb94..da42e58 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -469,10 +469,6 @@ cc_library( name = "gemma_lib", srcs = [ "gemma/gemma.cc", - "gemma/instantiations/bf16.cc", - "gemma/instantiations/f32.cc", - "gemma/instantiations/nuq.cc", - "gemma/instantiations/sfp.cc", ], hdrs = [ "gemma/activations.h", diff --git a/CMakeLists.txt b/CMakeLists.txt index 6f6a8d9..c425cf1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,10 +61,6 @@ set(SOURCES gemma/gemma-inl.h gemma/gemma.cc gemma/gemma.h - gemma/instantiations/bf16.cc - gemma/instantiations/f32.cc - gemma/instantiations/nuq.cc - gemma/instantiations/sfp.cc gemma/kv_cache.cc gemma/kv_cache.h gemma/model_store.cc diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 57f0979..c5d5129 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -556,7 +556,7 @@ HWY_INLINE void Decompress2(DRaw d, const PackedSpan& packed, // Decompresses from any type of `packed`, starting at (any) `packed_ofs`, to // (any) `num` elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as // required to round `num` up to one vector, if it is not already. The caller is -// responsible for scaling `raw` to the original range because `EmbedToken` +// responsible for scaling `raw` to the original range because `EmbedMMToken` // also wants to scale the decompressed elements. // `TRaw` can be `float/BF16`, or `double` if `Packed` is `float`. template > diff --git a/gemma/activations.h b/gemma/activations.h index f666eb0..983f8ce 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -86,6 +86,7 @@ struct Activations { // For MatMul outputs, precompute their row pointers. const auto init_row_ptrs = [&](MatPtrT& mat) { + if (!mat.HasPtr()) return; row_ptrs.push_back(hwy::AllocateAligned(mat.Rows())); uint8_t** ptrs = row_ptrs.back().get(); for (size_t r = 0; r < mat.Rows(); ++r) { diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 3f53134..d215596 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -53,31 +53,30 @@ #include "ops/matvec-inl.h" #include "ops/ops-inl.h" -#ifndef GEMMA_TYPE -#if HWY_IDE -// Provide a definition so the IDE does not complain. -#define GEMMA_TYPE float -#else -#error "Only include from instantiations/*.cc, which must define GEMMA_TYPE" -#endif // HWY_IDE -#endif // GEMMA_TYPE - HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { +template +MMPerKey* CallMatMul(const MatPtrT& A, const MatPtr& B, + const float* HWY_RESTRICT add, MatMulEnv& env, + MatPtrT& C) { + return CallUpcasted( + &B, [&](const auto* B_t) { return MatMulStatic(A, *B_t, add, env, C); }); +} + // Different functions use different naming conventions for the number of // tokens. Functions that are query-independent, such as RMSNorm*, call the // count `num_interleaved`. Functions that are query-dependent, such as // `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the // number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size. -template -HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos, - size_t num_tokens, size_t griffin_layer, - Activations& activations, - const LayerWeightsPtrs* layer_weights, - const KVCaches& kv_caches) { +static HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos, + size_t num_tokens, + size_t griffin_layer, + Activations& activations, + const LayerWeightsPtrs* layer_weights, + const KVCaches& kv_caches) { PROFILER_ZONE("Gen.Griffin"); hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0); namespace hn = hwy::HWY_NAMESPACE; @@ -101,17 +100,21 @@ HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos, // TODO: MatMul HWY_DASSERT(activations.griffin_y.Rows() == activations.griffin_x.Rows()); HWY_DASSERT(num_interleaved == activations.griffin_y.Rows()); - for (size_t r = 0; r < num_interleaved; ++r) { - float* HWY_RESTRICT y = activations.griffin_y.Row(r); - float* HWY_RESTRICT x = activations.griffin_x.Row(r); - TwoMatVecAdd(layer_weights->griffin.linear_x_w, - layer_weights->griffin.linear_y_w, 0, model_dim, model_dim, - activations.pre_att_rms_out.Row(r), - /*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(), - /*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(), - /*out0=*/x, /*out1=*/y, pool); - Gelu(y, model_dim); - } + CallUpcastedSame( + &layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w, + [&](const auto* wx, const auto* wy) { + for (size_t r = 0; r < num_interleaved; ++r) { + float* HWY_RESTRICT y = activations.griffin_y.Row(r); + float* HWY_RESTRICT x = activations.griffin_x.Row(r); + TwoMatVecAdd( + *wx, *wy, 0, model_dim, model_dim, + activations.pre_att_rms_out.Row(r), + /*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(), + /*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(), + /*out0=*/x, /*out1=*/y, pool); + Gelu(y, model_dim); + } + }); // Conv1D. for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; @@ -165,14 +168,16 @@ HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos, pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { size_t head_offset = head * kHeadDim; - TwoOfsMatVecAddLoop( - layer_weights->griffin.gate_w, kMatrixSize * head, - kMatrixSize * (heads + head), kHeadDim, kHeadDim, x + head_offset, - /*add0=*/layer_weights->griffin.gate_biases.PackedScale1() + - head_offset, - /*add1=*/layer_weights->griffin.gate_biases.PackedScale1() + - model_dim + head_offset, - /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); + CallUpcasted(&layer_weights->griffin.gate_w, [&](const auto* gate_w) { + TwoOfsMatVecAddLoop( + *gate_w, kMatrixSize * head, kMatrixSize * (heads + head), kHeadDim, + kHeadDim, x + head_offset, + /*add0=*/layer_weights->griffin.gate_biases.PackedScale1() + + head_offset, + /*add1=*/layer_weights->griffin.gate_biases.PackedScale1() + + model_dim + head_offset, + /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); + }); Sigmoid(gate_x + head_offset, kHeadDim); Sigmoid(a + head_offset, kHeadDim); const auto fn_mul = [](D d, hn::Vec x, hn::Vec gate_x) @@ -206,17 +211,19 @@ HWY_NOINLINE void GriffinRecurrent(const QueriesPos& queries_pos, // Final linear layer. // TODO: MatMul - for (size_t r = 0; r < num_interleaved; ++r) { - float* HWY_RESTRICT x = activations.griffin_x.Row(r); - float* out_ptr = activations.att_sums.Row(r); - MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x, - layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr, - pool); - } + CallUpcasted( + &layer_weights->griffin.linear_out_w, [&](const auto* weights_t) { + for (size_t r = 0; r < num_interleaved; ++r) { + float* HWY_RESTRICT x = activations.griffin_x.Row(r); + float* out_ptr = activations.att_sums.Row(r); + MatVecAdd(*weights_t, 0, model_dim, model_dim, x, + layer_weights->griffin.linear_out_biases.PackedScale1(), + out_ptr, pool); + } + }); } // GriffinRecurrent // Wrapper class; holds arguments in member variables to shorten call sites. -template class GemmaAttention { // The attention window usually starts at 0 unless `pos` is larger than // the attention window size, then it is `pos` - window_size + 1. @@ -257,8 +264,8 @@ class GemmaAttention { // The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim, // model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows. - MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1, - /*add=*/nullptr, *activations_.env, activations_.q); + CallMatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w1, + /*add=*/nullptr, *activations_.env, activations_.q); // Set up MatMul row pointers for writing to KV, which consists of // `kv_heads` pairs of (k, v) vectors. This safely handles wraparound @@ -279,8 +286,8 @@ class GemmaAttention { kv_offset); } kv_rows.AttachRowPtrs(&activations_.env->storage.OutRow(0)); - MatMulStatic(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2, - /*add=*/nullptr, *activations_.env, kv_rows); + CallMatMul(activations_.pre_att_rms_out, layer_weights_.qkv_einsum_w2, + /*add=*/nullptr, *activations_.env, kv_rows); // Apply positional encodings for K (and copy KV to cache if MHA). pool_.Run(0, kv_heads * num_interleaved, @@ -299,8 +306,11 @@ class GemmaAttention { // Apply further processing to K. if (layer_weights_.key_norm_scale.HasPtr()) { - RMSNormInplace(layer_weights_.key_norm_scale.PackedScale1(), - 0, kv, qkv_dim); + CallUpcasted(&layer_weights_.key_norm_scale, + [&](const auto* weights_t) { + RMSNormInplace(weights_t->PackedScale1(), 0, + kv, qkv_dim); + }); } PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f); }); @@ -367,8 +377,10 @@ class GemmaAttention { // Apply rope and scaling to Q. if (layer_weights_.query_norm_scale.HasPtr()) { - RMSNormInplace(layer_weights_.query_norm_scale.PackedScale1(), 0, q, - qkv_dim); + CallUpcasted(&layer_weights_.query_norm_scale, + [&](const auto* weights_t) { + RMSNormInplace(weights_t->PackedScale1(), 0, q, qkv_dim); + }); } PositionalEncodingQK(q, pos, layer_, query_scale); @@ -456,8 +468,8 @@ class GemmaAttention { layer_weights_.layer_config.softmax_attn_output_biases ? layer_weights_.attention_output_biases.PackedScale1() : nullptr; - MatMulStatic(activations_.att_out, layer_weights_.att_weights, add, - *activations_.env, activations_.att_sums); + CallMatMul(activations_.att_out, layer_weights_.att_weights, add, + *activations_.env, activations_.att_sums); } public: @@ -467,7 +479,7 @@ class GemmaAttention { GemmaAttention(const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, size_t num_tokens, size_t layer, Activations& activations, - const LayerWeightsPtrs* layer_weights, + const LayerWeightsPtrs* layer_weights, const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) : GemmaAttention(queries_pos, &queries_prefix_end, num_tokens, layer, activations, layer_weights, div_seq_len, kv_caches, @@ -475,7 +487,7 @@ class GemmaAttention { // Constructor with default initialization to 0 for queries_prefix_end. GemmaAttention(const QueriesPos& queries_pos, size_t num_tokens, size_t layer, Activations& activations, - const LayerWeightsPtrs* layer_weights, + const LayerWeightsPtrs* layer_weights, const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) : GemmaAttention(queries_pos, nullptr, num_tokens, layer, activations, layer_weights, div_seq_len, kv_caches, @@ -487,7 +499,7 @@ class GemmaAttention { GemmaAttention(const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, size_t num_tokens, size_t layer, Activations& activations, - const LayerWeightsPtrs* layer_weights, + const LayerWeightsPtrs* layer_weights, const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, ThreadingContext& ctx) : GemmaAttention(queries_pos, &queries_prefix_end, num_tokens, layer, @@ -507,7 +519,7 @@ class GemmaAttention { GemmaAttention(const QueriesPos& queries_pos, const QueriesPos* queries_prefix_end, size_t num_tokens, size_t layer, Activations& activations, - const LayerWeightsPtrs* layer_weights, + const LayerWeightsPtrs* layer_weights, const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, ThreadingContext& ctx) : queries_pos_(queries_pos), @@ -546,21 +558,20 @@ class GemmaAttention { const size_t cache_pos_size_ = 0; Activations& activations_; - const LayerWeightsPtrs& layer_weights_; + const LayerWeightsPtrs& layer_weights_; const hwy::Divisor& div_seq_len_; const KVCaches& kv_caches_; hwy::ThreadPool& pool_; }; -template -HWY_NOINLINE void Attention( +static HWY_NOINLINE void Attention( LayerAttentionType type, const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, size_t num_tokens, size_t layer, - Activations& activations, const LayerWeightsPtrs* layer_weights, + Activations& activations, const LayerWeightsPtrs* layer_weights, const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) { if (type == LayerAttentionType::kGemma) { - GemmaAttention(queries_pos, queries_prefix_end, num_tokens, layer, - activations, layer_weights, div_seq_len, kv_caches)(); + GemmaAttention(queries_pos, queries_prefix_end, num_tokens, layer, + activations, layer_weights, div_seq_len, kv_caches)(); } else { HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock); // KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer, @@ -581,7 +592,6 @@ HWY_NOINLINE void Attention( // This results in a much simpler implementation. However, to avoid duplicating // code, we should still consider merging the two classes. // TODO(keysers): Refactor to share code with GemmaAttention. -template class VitAttention { // Computes Q, K, V for all heads, stored in activations_.q. HWY_NOINLINE void ComputeQKV() { @@ -589,9 +599,9 @@ class VitAttention { auto& qkv = activations_.q; HWY_ASSERT(qkv.Rows() == num_tokens_); HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); - MatMulStatic(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w, - layer_weights_.vit.qkv_einsum_b.PackedScale1(), - *activations_.env, qkv); + CallMatMul(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w, + layer_weights_.vit.qkv_einsum_b.PackedScale1(), + *activations_.env, qkv); } // TODO(philculliton): transition fully to MatMul. @@ -631,7 +641,7 @@ class VitAttention { }); // this produces C, a (num_tokens_, seq_len) matrix of dot products - MatMulStatic(Q, K, nullptr, *activations_.env, C); + CallMatMul(Q, K, nullptr, *activations_.env, C); pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { float* HWY_RESTRICT c = C.Row(task); @@ -697,15 +707,14 @@ class VitAttention { // att_weights and att_out are concatenated heads, each of length // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] // matmul output is the sum over heads. - MatMulStatic(activations_.att_out, layer_weights_.vit.attn_out_w, bias, - *activations_.env, activations_.att_sums); + CallMatMul(activations_.att_out, layer_weights_.vit.attn_out_w, bias, + *activations_.env, activations_.att_sums); } public: VitAttention(size_t num_tokens, size_t layer, Activations& activations, - const LayerWeightsPtrs* layer_weights) + const LayerWeightsPtrs* layer_weights) : num_tokens_(num_tokens), - layer_(layer), activations_(activations), layer_weights_(*layer_weights), layer_config_(layer_weights->layer_config), @@ -723,9 +732,8 @@ class VitAttention { private: const size_t num_tokens_; - const size_t layer_; Activations& activations_; - const LayerWeightsPtrs& layer_weights_; + const LayerWeightsPtrs& layer_weights_; const LayerConfig& layer_config_; hwy::ThreadPool& pool_; }; @@ -775,9 +783,8 @@ void ActivationBatched(ActivationType activation, Mat& c1, const Mat* c2) { } } -template -HWY_NOINLINE void FFWNoVit(Activations& activations, - const LayerWeightsPtrs* layer_weights) { +static HWY_NOINLINE void FFWNoVit(Activations& activations, + const LayerWeightsPtrs* layer_weights) { PROFILER_ZONE("Gen.FFW"); const size_t ffh_hidden_dim = layer_weights->layer_config.ff_hidden_dim; @@ -789,25 +796,24 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr; // Compute the hidden layer activations. - MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1, - bias1, *activations.env, activations.C1); - MatMulStatic(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2, - bias2, *activations.env, activations.C2); + CallMatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w1, + bias1, *activations.env, activations.C1); + CallMatMul(activations.pre_ffw_rms_out, layer_weights->gating_einsum_w2, + bias2, *activations.env, activations.C2); // Activation (Gelu) and maybe multiply by gate. Store activations in act. ActivationBatched(layer_weights->layer_config.activation, activations.C1, &activations.C2); // Hidden layer -> output layer. - MatMulStatic(activations.C1, layer_weights->linear_w, output_bias, - *activations.env, activations.ffw_out); + CallMatMul(activations.C1, layer_weights->linear_w, output_bias, + *activations.env, activations.ffw_out); } // Same as FFWNoVit, but with different layer_weights members and no second // gating matrix. -template -HWY_NOINLINE void FFWVit(Activations& activations, - const LayerWeightsPtrs* layer_weights) { +static HWY_NOINLINE void FFWVit(Activations& activations, + const LayerWeightsPtrs* layer_weights) { PROFILER_ZONE("Gen.FFW.ViT"); const bool add_bias = layer_weights->layer_config.ff_biases; @@ -817,15 +823,15 @@ HWY_NOINLINE void FFWVit(Activations& activations, add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr; // Compute the hidden layer activations. - MatMulStatic(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, - bias1, *activations.env, activations.C1); + CallMatMul(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, bias1, + *activations.env, activations.C1); // Activation (Gelu), store in C1. ActivationBatched(layer_weights->layer_config.activation, activations.C1); // Hidden layer -> output layer. - MatMulStatic(activations.C1, layer_weights->vit.linear_1_w, output_bias, - *activations.env, activations.ffw_out); + CallMatMul(activations.C1, layer_weights->vit.linear_1_w, output_bias, + *activations.env, activations.ffw_out); } // `batch_idx` indicates which row of `x` to write to. @@ -836,67 +842,52 @@ HWY_NOINLINE void FFWVit(Activations& activations, // spec) until we run out of image tokens. This allows for a multi-image prompt // if -2 locations with appropriate begin/end image tokens are created by the // calling application. -template -HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos, - size_t pos_in_prompt, - const ModelWeightsPtrs& weights, - MatStorageT& x, - const ImageTokens* image_tokens, - size_t& image_token_position) { +// Returns new image_token_position. +static HWY_NOINLINE size_t +EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt, + const ModelConfig& model_config, const ModelWeightsPtrs& weights, + MatStorageT& x, const ImageTokens* image_tokens = nullptr, + size_t image_token_position = 0) { // Image tokens just need to be copied. - if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM && + if (model_config.wrapping == PromptWrapping::GEMMA_VLM && image_tokens != nullptr && token == -2 && image_token_position < image_tokens->Rows()) { hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(batch_idx), x.Cols() * x.ElementBytes()); - image_token_position++; - return; + return image_token_position + 1; } - if (weights.weights_config.wrapping == PromptWrapping::PALIGEMMA && + if (model_config.wrapping == PromptWrapping::PALIGEMMA && image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) { hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(batch_idx), x.Cols() * x.ElementBytes()); - return; + return image_token_position; } - const size_t model_dim = weights.weights_config.model_dim; + const size_t model_dim = model_config.model_dim; const float emb_scaling = EmbeddingScaling(model_dim); HWY_DASSERT(token >= 0); - HWY_DASSERT(token < static_cast(weights.weights_config.vocab_size)); + HWY_DASSERT(token < static_cast(model_config.vocab_size)); - const hn::ScalableTag df; - // Using `Stride` to compute the offset works for both NUQ (because we use an - // offset and NUQ is never padded) and padded, because non-NUQ types are - // seekable, hence the offset can also skip any padding. - const size_t embedding_ofs = - token * weights.embedder_input_embedding.Stride(); - HWY_ASSERT(weights.embedder_input_embedding.Cols() == model_dim); - const auto embedding_span = MakeSpan(weights.embedder_input_embedding.Row(0), - embedding_ofs + model_dim); - DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(batch_idx), - model_dim); - MulByConst(emb_scaling * weights.embedder_input_embedding.Scale(), - x.Row(batch_idx), model_dim); - if (weights.weights_config.absolute_pe) { + CallUpcasted(&weights.embedder_input_embedding, [&](const auto* weights_t) { + // Using `Stride` to compute the offset works for both NUQ (because we use + // an offset and NUQ is never padded) and padded, because non-NUQ types are + // seekable, hence the offset can also skip any padding. + const size_t embedding_ofs = token * weights_t->Stride(); + HWY_ASSERT(weights_t->Cols() == model_dim); + const auto embedding_span = + MakeSpan(weights_t->Row(0), embedding_ofs + model_dim); + const hn::ScalableTag df; + DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(batch_idx), + model_dim); + MulByConst(emb_scaling * weights_t->Scale(), x.Row(batch_idx), model_dim); + }); + + if (model_config.absolute_pe) { AddAbsolutePositionalEmbeddings(x.Row(batch_idx), model_dim, pos); } -} - -// `batch_idx` indicates which row of `x` to write to. -// `pos` is the *token*'s position, not the start of the batch, because this is -// called for batches of tokens in prefill, but batches of queries in decode. -// This version of the function doesn't track internal image token position. -template -HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos, - size_t pos_in_prompt, - const ModelWeightsPtrs& weights, - MatStorageT& x, - const ImageTokens* image_tokens) { - size_t image_token_position = 0; - EmbedMMToken(token, batch_idx, pos, pos_in_prompt, weights, x, - image_tokens, image_token_position); + return image_token_position; } template @@ -908,8 +899,8 @@ HWY_NOINLINE void ResidualConnection(const MatPtrT& other, AddFromBatched(other, x); } -template -void PostNorm(PostNormType post_norm, const MatPtrT& weights, +template +void PostNorm(PostNormType post_norm, const MatPtr& weights, MatPtrT& inout) { HWY_DASSERT(weights.Rows() == 1); if (post_norm == PostNormType::Scale) { @@ -917,14 +908,11 @@ void PostNorm(PostNormType post_norm, const MatPtrT& weights, } } -template -HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, - size_t num_tokens, size_t cache_layer_idx, - const LayerWeightsPtrs* layer_weights, - Activations& activations, - const hwy::Divisor& div_seq_len, - const KVCaches& kv_caches) { +static HWY_NOINLINE void TransformerLayer( + const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, + size_t num_tokens, size_t cache_layer_idx, + const LayerWeightsPtrs* layer_weights, Activations& activations, + const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) { auto type = layer_weights->layer_config.type; RMSNormBatched(activations.x, layer_weights->pre_attention_norm_scale, @@ -960,10 +948,9 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos, // github.com/google-research/big_vision/blob/main/big_vision/models/vit.py // TODO(keysers): consider adding a wrapper for both LayerNorm with RMSNorm and // try merging this with TransformerLayer. -template -HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer, - const LayerWeightsPtrs* layer_weights, - Activations& activations) { +static HWY_NOINLINE void VitTransformerLayer( + size_t num_tokens, size_t layer, const LayerWeightsPtrs* layer_weights, + Activations& activations) { const size_t model_dim = activations.weights_config.model_dim; auto type = layer_weights->layer_config.type; HWY_DASSERT(type == LayerAttentionType::kVit); @@ -982,7 +969,7 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer, // y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y) // y ~ att_sums - VitAttention(num_tokens, layer, activations, layer_weights)(); + VitAttention(num_tokens, layer, activations, layer_weights)(); // x = out["+sa"] = x + y AddFromBatched(activations.att_sums, x); @@ -1005,13 +992,13 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer, using QueriesMutablePos = hwy::Span; // Populates KV cache for batches of tokens from one query at a time. -template -HWY_NOINLINE void Prefill( +static HWY_NOINLINE void Prefill( const QueriesPromptTokens& queries_prompt, const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end, - const size_t query_idx_start, const ModelWeightsPtrs& weights, - Activations& activations, const RuntimeConfig& runtime_config, - const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) { + const size_t query_idx_start, const ModelConfig& config, + const ModelWeightsPtrs& weights, Activations& activations, + const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len, + const KVCaches& kv_caches) { PROFILER_ZONE("Gen.Prefill"); const size_t num_queries = queries_prompt.size(); HWY_DASSERT(queries_pos.size() == num_queries); @@ -1072,14 +1059,14 @@ HWY_NOINLINE void Prefill( const size_t pos = queries_pos[qi] + ti; const size_t pos_in_prompt = tbatch_start + ti; const int token = queries_prompt[qi][pos_in_prompt]; - EmbedMMToken(token, ti, pos, pos_in_prompt, weights, activations.x, - runtime_config.image_tokens, image_token_position); + image_token_position = EmbedMMToken( + token, ti, pos, pos_in_prompt, config, weights, activations.x, + runtime_config.image_tokens, image_token_position); } // Transformer with one batch of tokens from a single query. - for (size_t layer = 0; - layer < weights.weights_config.layer_configs.size(); ++layer) { - const LayerWeightsPtrs* layer_weights = weights.GetLayer(layer); + for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) { + const LayerWeightsPtrs* layer_weights = weights.GetLayer(layer); TransformerLayer(single_query_pos, single_query_prefix_end, tbatch_size, layer, layer_weights, activations, div_seq_len, single_kv_cache); @@ -1115,13 +1102,13 @@ HWY_NOINLINE void Prefill( // Gets the patches of the image and embeds them with the image embedding // kernel. The result is stored in activations.x. -template -HWY_NOINLINE void EmbedImagePatches(const Image& image, - const ModelWeightsPtrs& weights, - Activations& activations) { - const size_t model_dim = weights.weights_config.vit_config.model_dim; - const size_t patch_width = weights.weights_config.vit_config.patch_width; - const size_t seq_len = weights.weights_config.vit_config.seq_len; +static HWY_NOINLINE void EmbedImagePatches(const Image& image, + const ModelConfig& model_config, + const ModelWeightsPtrs& weights, + Activations& activations) { + const size_t model_dim = model_config.vit_config.model_dim; + const size_t patch_width = model_config.vit_config.patch_width; + const size_t seq_len = model_config.vit_config.seq_len; const size_t patch_size = patch_width * patch_width * 3; HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim); HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size); @@ -1138,7 +1125,7 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image, // MatStorageT image_patches("patches", Extents2D(kSeqLen, // kPatchSize), MatPadding::kPacked); // [Get patches] - // MatMulStatic( + // CallMatMul( // image_patches, // weights.vit_img_embedding_kernel, // weights.vit_img_embedding_bias.PackedScale1(), *activations.env, @@ -1147,61 +1134,64 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image, // A.cols % (2 * hn::Lanes(hn::ScalableTag())) == 0 // which is not the case here. We should relax that requirement on MatMul and // then use the above. For now, we rely on MatVecAdd instead. - for (size_t i = 0; i < seq_len; ++i) { - MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size, - image_patches[i].get(), - weights.vit_img_embedding_bias.PackedScale1(), - activations.x.Row(i), activations.env->ctx.pools.Pool(0)); - } + CallUpcasted(&weights.vit_img_embedding_kernel, [&](const auto* embedding_t) { + for (size_t i = 0; i < seq_len; ++i) { + MatVecAdd(*embedding_t, 0, model_dim, patch_size, image_patches[i].get(), + weights.vit_img_embedding_bias.PackedScale1(), + activations.x.Row(i), activations.env->ctx.pools.Pool(0)); + } + }); // Add position embeddings. AddFromBatched(weights.vit_img_pos_embedding, activations.x); } // Prefills the image tokens with the ViT encoder. -template -HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, - const Image& image, ImageTokens& image_tokens, - Activations& activations) { +static HWY_NOINLINE void PrefillVit(const ModelConfig& model_config, + const ModelWeightsPtrs& weights, + const RuntimeConfig& runtime_config, + const Image& image, + ImageTokens& image_tokens, + Activations& activations) { PROFILER_ZONE("Gen.PrefillVit"); - const size_t num_tokens = weights.weights_config.vit_config.seq_len; - const size_t vit_model_dim = weights.weights_config.vit_config.model_dim; + const size_t num_tokens = model_config.vit_config.seq_len; + const size_t vit_model_dim = model_config.vit_config.model_dim; HWY_ASSERT(num_tokens == activations.x.Rows()); // Embed the image patches. - EmbedImagePatches(image, weights, activations); + EmbedImagePatches(image, model_config, weights, activations); // Go through all layers. - for (size_t layer = 0; - layer < weights.weights_config.vit_config.layer_configs.size(); + for (size_t layer = 0; layer < model_config.vit_config.layer_configs.size(); ++layer) { - const LayerWeightsPtrs* layer_weights = weights.VitLayer(layer); + const LayerWeightsPtrs* layer_weights = weights.VitLayer(layer); VitTransformerLayer(num_tokens, layer, layer_weights, activations); } // Final Layernorm. LayerNormBatched(activations.x, weights.vit_encoder_norm_scale, weights.vit_encoder_norm_bias, activations.x); - if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) { + if (model_config.wrapping == PromptWrapping::GEMMA_VLM) { activations.x = AvgPool4x4(activations.x); // Apply soft embedding norm before input projection. - RMSNormInplace(weights.mm_embed_norm.PackedScale1(), 0, - activations.x.Row(0), vit_model_dim); + CallUpcasted(&weights.mm_embed_norm, [&](const auto* weights_t) { + RMSNormInplace(weights_t->PackedScale1(), 0, activations.x.Row(0), + vit_model_dim); + }); } // Apply head embedding into image_tokens of size of the LLM kModelDim. - MatMulStatic(activations.x, weights.vit_img_head_kernel, - weights.vit_img_head_bias.PackedScale1(), *activations.env, - image_tokens); + CallMatMul(activations.x, weights.vit_img_head_kernel, + weights.vit_img_head_bias.PackedScale1(), *activations.env, + image_tokens); } // Generates one token for each query. `queries_token` is the previous token // from each query, and `queries_pos` are their position in the sequence. -template -HWY_NOINLINE void Transformer( +static HWY_NOINLINE void Transformer( const QueriesToken& queries_token, const QueriesMutablePos& queries_pos, - const QueriesPos& queries_prefix_end, const ModelWeightsPtrs& weights, - Activations& activations, const hwy::Divisor& div_seq_len, - const KVCaches& kv_caches, const LayersOutputFunc& layers_output, + const QueriesPos& queries_prefix_end, const ModelConfig& config, + const ModelWeightsPtrs& weights, Activations& activations, + const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, + const LayersOutputFunc& layers_output, const ActivationsObserverFunc& activations_observer) { const size_t num_queries = queries_token.size(); HWY_DASSERT(queries_pos.size() == num_queries); @@ -1215,15 +1205,13 @@ HWY_NOINLINE void Transformer( } } - size_t image_token_position = 0; for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { EmbedMMToken(queries_token[query_idx], query_idx, queries_pos[query_idx], - /*pos_in_prompt=*/0, weights, activations.x, - /*image_tokens=*/nullptr, image_token_position); + /*pos_in_prompt=*/0, config, weights, activations.x); } for (size_t layer = 0; layer < weights.c_layers.size(); ++layer) { - const LayerWeightsPtrs* layer_weights = weights.GetLayer(layer); + const LayerWeightsPtrs* layer_weights = weights.GetLayer(layer); TransformerLayer(queries_pos, queries_prefix_end, /*num_tokens=*/1, layer, layer_weights, activations, div_seq_len, kv_caches); @@ -1302,25 +1290,25 @@ HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) { }; } -template // Runs one decode step for all the queries in the batch. Returns true if all // queries are at . -bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, - const size_t query_idx_start, const KVCaches& kv_caches, - const QueriesPos& queries_prefix_end, - const hwy::Divisor div_seq_len, const size_t vocab_size, - const SampleFunc& sample_token, Activations& activations, - TokenStreamer& token_streamer, std::vector& gen_tokens, - TimingInfo& timing_info, - const QueriesMutablePos& queries_mutable_pos) { +static bool DecodeStepT(const ModelConfig& config, + const ModelWeightsPtrs& weights, + const RuntimeConfig& runtime_config, + const QueriesPromptTokens& queries_prompt, + const size_t query_idx_start, const KVCaches& kv_caches, + const QueriesPos& queries_prefix_end, + const hwy::Divisor div_seq_len, const size_t vocab_size, + const SampleFunc& sample_token, + Activations& activations, TokenStreamer& token_streamer, + std::vector& gen_tokens, TimingInfo& timing_info, + const QueriesMutablePos& queries_mutable_pos) { const size_t num_queries = queries_prompt.size(); // Decode generates one token per query and increments // queries_mutable_pos. Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos, - queries_prefix_end, weights, activations, div_seq_len, kv_caches, - runtime_config.layers_output, + queries_prefix_end, config, weights, activations, div_seq_len, + kv_caches, runtime_config.layers_output, runtime_config.activations_observer); // queries_pos are incremented by Transformer. @@ -1329,13 +1317,13 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs& weights, { PROFILER_ZONE("Gen.EmbeddingMatmul"); // Compute logits from last layer activations. - MatMulStatic(activations.x, weights.embedder_input_embedding, - /*add=*/nullptr, *activations.env, activations.logits); + CallMatMul(activations.x, weights.embedder_input_embedding, + /*add=*/nullptr, *activations.env, activations.logits); } PROFILER_ZONE("Gen.Softcap+Sample+Stream"); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { float* HWY_RESTRICT logits = activations.logits.Row(query_idx); - MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size); + MaybeLogitsSoftCap(config.final_cap, logits, vocab_size); const TokenAndProb tp = sample_token(logits, vocab_size); timing_info.NotifyGenerated(); @@ -1359,15 +1347,16 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs& weights, // `StreamFunc` gets the global query index, not relative to the batch. // // `kv_caches` is for the batch, size must match `queries_prompt`. -template -void GenerateT(const ModelStore& model, const ModelWeightsPtrs& weights, - Activations& activations, const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, - const QueriesPos& queries_pos_in, - const QueriesPos& queries_prefix_end, - const size_t query_idx_start, const KVCaches& kv_caches, - TimingInfo& timing_info) { +static void GenerateT(const ModelConfig& config, + const ModelWeightsPtrs& weights, Activations& activations, + const RuntimeConfig& runtime_config, + const QueriesPromptTokens& queries_prompt, + const QueriesPos& queries_pos_in, + const QueriesPos& queries_prefix_end, + const size_t query_idx_start, const KVCaches& kv_caches, + TimingInfo& timing_info) { HWY_ASSERT(queries_pos_in.size() == kv_caches.size()); + // Griffin assumes that the recurrent block cache is zero-initialized. for (size_t i = 0; i < kv_caches.size(); ++i) { if (queries_pos_in[i] == 0) { @@ -1384,7 +1373,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs& weights, // Sanity check: prompts should not be empty, nor start with EOS. for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) { const PromptTokens& prompt = queries_prompt[query_idx]; - HWY_ASSERT(prompt.size() != 0 && prompt[0] != model.Config().eos_id); + HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id); } const size_t num_queries = queries_prompt.size(); @@ -1395,7 +1384,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs& weights, const hwy::Divisor div_seq_len(static_cast(kv_caches[0].seq_len)); size_t max_prompt_size = MaxQueryLength(queries_prompt); size_t max_generated_tokens = runtime_config.max_generated_tokens; - RangeChecks(weights.weights_config, max_generated_tokens, max_prompt_size); + RangeChecks(config, max_generated_tokens, max_prompt_size); const SampleFunc sample_token = ChooseSampleFunc(runtime_config); // Prefill stops before min_prompt_size - 1 because the last prompt @@ -1403,8 +1392,8 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs& weights, timing_info.prefill_start = hwy::platform::Now(); // Note that Prefill calls activations.SetBatchSize, so we reset it below. Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end, - query_idx_start, weights, activations, runtime_config, div_seq_len, - kv_caches); + query_idx_start, config, weights, activations, runtime_config, + div_seq_len, kv_caches); // Compute the number of tokens that were prefilled and notify timing_info. size_t prefilled_tokens = 0; for (size_t qi = 0; qi < num_queries; ++qi) { @@ -1419,7 +1408,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs& weights, std::vector gen_tokens(num_queries); // Stream the last prompt token from each query and fill gen_tokens. - TokenStreamer token_streamer(runtime_config, model.Config()); + TokenStreamer token_streamer(runtime_config, config); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { size_t last_token_pos_in_prompt = queries_mutable_pos[query_idx] - queries_pos_in[query_idx]; @@ -1430,53 +1419,49 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs& weights, } { - const size_t vocab_size = model.Config().vocab_size; + const size_t vocab_size = config.vocab_size; timing_info.generate_start = hwy::platform::Now(); for (size_t gen = 0; gen < max_generated_tokens; ++gen) { - bool all_queries_eos = DecodeStepT( - model.Config(), weights, runtime_config, queries_prompt, - query_idx_start, kv_caches, queries_prefix_end, div_seq_len, - vocab_size, sample_token, activations, token_streamer, gen_tokens, - timing_info, queries_mutable_pos); + bool all_queries_eos = DecodeStepT( + config, weights, runtime_config, queries_prompt, query_idx_start, + kv_caches, queries_prefix_end, div_seq_len, vocab_size, sample_token, + activations, token_streamer, gen_tokens, timing_info, + queries_mutable_pos); if (all_queries_eos) break; } // foreach token to generate timing_info.NotifyGenerateDone(); } } -template -void GenerateSingleT(const ModelStore& model, - const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, - const PromptTokens& prompt, size_t pos, size_t prefix_end, - KVCache& kv_cache, MatMulEnv* env, - TimingInfo& timing_info) { +static HWY_MAYBE_UNUSED void GenerateSingleT( + const ModelConfig& config, const ModelWeightsPtrs& weights, + const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, + size_t prefix_end, KVCache& kv_cache, MatMulEnv* env, + TimingInfo& timing_info) { constexpr size_t kNumQueries = 1; const size_t qbatch_start = 0; const size_t max_batch_size = HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size); // TODO: move into Gemma? - Activations activations(model.Config(), max_batch_size, env); + Activations activations(config, max_batch_size, env); const QueriesPromptTokens queries_prompt(&prompt, kNumQueries); QueriesPos queries_pos(&pos, kNumQueries); const QueriesPos queries_prefix_end(&prefix_end, kNumQueries); const KVCaches kv_caches{&kv_cache, kNumQueries}; - GenerateT(model, weights, activations, runtime_config, queries_prompt, - queries_pos, queries_prefix_end, qbatch_start, kv_caches, - timing_info); + GenerateT(config, weights, activations, runtime_config, queries_prompt, + queries_pos, queries_prefix_end, qbatch_start, kv_caches, + timing_info); } -template -void GenerateBatchT(const ModelStore& model, const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, - const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, - const KVCaches& kv_caches, MatMulEnv* env, - TimingInfo& timing_info) { +static HWY_MAYBE_UNUSED void GenerateBatchT( + const ModelConfig& config, const ModelWeightsPtrs& weights, + const RuntimeConfig& runtime_config, + const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos, + const QueriesPos& queries_prefix_end, const KVCaches& kv_caches, + MatMulEnv* env, TimingInfo& timing_info) { const size_t num_queries = queries_prompt.size(); HWY_ASSERT(queries_pos.size() == num_queries); HWY_ASSERT(kv_caches.size() >= num_queries); @@ -1484,7 +1469,7 @@ void GenerateBatchT(const ModelStore& model, const ModelWeightsPtrs& weights, const size_t max_batch_size = HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size); - Activations activations(model.Config(), max_batch_size, env); + Activations activations(config, max_batch_size, env); for (size_t qbatch_start = 0; qbatch_start < num_queries; qbatch_start += max_qbatch_size) { @@ -1497,68 +1482,31 @@ void GenerateBatchT(const ModelStore& model, const ModelWeightsPtrs& weights, const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start], qbatch_size); const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); - GenerateT(model, weights, activations, runtime_config, qbatch_prompts, - qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv, - timing_info); + GenerateT(config, weights, activations, runtime_config, qbatch_prompts, + qbatch_pos, qbatch_prefix_end, qbatch_start, qbatch_kv, + timing_info); } } -template -void GenerateImageTokensT(const ModelStore& model, - const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, - const Image& image, ImageTokens& image_tokens, - MatMulEnv* env) { - if (model.Config().vit_config.layer_configs.empty()) { +static HWY_MAYBE_UNUSED void GenerateImageTokensT( + const ModelConfig& config, const ModelWeightsPtrs& weights, + const RuntimeConfig& runtime_config, const Image& image, + ImageTokens& image_tokens, MatMulEnv* env) { + if (config.vit_config.layer_configs.empty()) { HWY_ABORT("Model does not support generating image tokens."); } RuntimeConfig prefill_runtime_config = runtime_config; - ModelConfig vit_config = GetVitConfig(model.Config()); + ModelConfig vit_config = GetVitConfig(config); prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim); Activations prefill_activations(vit_config, vit_config.seq_len, env); // Weights are for the full PaliGemma model, not just the ViT part. - PrefillVit(weights, prefill_runtime_config, image, image_tokens, + PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, prefill_activations); } +// NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE - -#if HWY_ONCE - -// These are extern functions defined by instantiations/*.cc, which include this -// 'header' after defining `GEMMA_TYPE`. -void GenerateSingle( // NOLINT(misc-definitions-in-headers) - const ModelStore& model, const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, - size_t prefix_end, KVCache& kv_cache, MatMulEnv* env, - TimingInfo& timing_info) { - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT) - (model, weights, runtime_config, prompt, pos, prefix_end, kv_cache, env, - timing_info); -} - -void GenerateBatch( // NOLINT(misc-definitions-in-headers) - const ModelStore& model, const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, const KVCaches& kv_caches, - MatMulEnv* env, TimingInfo& timing_info) { - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT) - (model, weights, runtime_config, queries_prompt, queries_pos, - queries_prefix_end, kv_caches, env, timing_info); -} - -void GenerateImageTokens( // NOLINT(misc-definitions-in-headers) - const ModelStore& model, const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, const Image& image, - ImageTokens& image_tokens, MatMulEnv* env) { - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImageTokensT) - (model, weights, runtime_config, image, image_tokens, env); -} - -#endif // HWY_ONCE - } // namespace gcpp HWY_AFTER_NAMESPACE(); diff --git a/gemma/gemma.cc b/gemma/gemma.cc index cea1559..36f97a9 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -13,21 +13,33 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Defines Gemma member functions; the actual implementations are in -// gemma-inl.h, included from instantiations/*.cc. +// Defines Gemma member functions which dynamic-dispatch into the SIMD +// implementations in gemma-inl.h. #include "gemma/gemma.h" +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "gemma/gemma-inl.h" + +#ifndef GEMMA_CC_ONCE +#define GEMMA_CC_ONCE + #include #include #include #include -#include #include // Placeholder for internal header, do not modify. -#include "compression/types.h" #include "gemma/configs.h" #include "gemma/model_store.h" #include "gemma/tokenizer.h" @@ -39,7 +51,13 @@ #include "util/threading_context.h" #include "hwy/base.h" +#endif // GEMMA_CC_ONCE + +#if HWY_ONCE namespace gcpp { +HWY_EXPORT(GenerateSingleT); +HWY_EXPORT(GenerateBatchT); +HWY_EXPORT(GenerateImageTokensT); // Internal init must run before I/O. This helper function takes care of that, // plus calling `SetArgs`. @@ -52,12 +70,13 @@ MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) { Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env) : env_(env), - reader_(new BlobReader(loader.weights)), - model_(*reader_, loader.tokenizer, loader.wrapping), - weights_(model_.Config().weight), + reader_(loader.weights), + model_(reader_, loader.tokenizer, loader.wrapping), + weights_(model_.Config()), chat_template_(model_.Tokenizer(), model_.Config().model) { - weights_.ReadFromBlobs(model_, *reader_, loader.map, env_.ctx.pools.Pool()); - reader_.reset(); + weights_.ReadFromBlobs(model_, reader_, loader.map, mat_owners_, + env.ctx.pools.Pool()); + reader_.CloseFile(); } Gemma::~Gemma() = default; @@ -70,42 +89,14 @@ void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const { writer, env_.ctx.pools.Pool(), weights_path); } -// There are >=3 types of the inference code. To reduce compile time, -// we shard them across multiple translation units in instantiations/*.cc. -// This declares the functions defined there. We use overloading because -// explicit instantiations are still too slow to compile. -// TODO: we want to move toward type-erasing, where we check the tensor type at -// each usage. Then we would have a single function, passing `WeightsOwner` -// instead of `WeightsPtrs`. -#define GEMMA_DECLARE(WEIGHT_TYPE) \ - extern void GenerateSingle( \ - const ModelStore& model, const ModelWeightsPtrs& weights, \ - const RuntimeConfig& runtime_config, const PromptTokens& prompt, \ - size_t pos, size_t prefix_end, KVCache& kv_cache, MatMulEnv* env, \ - TimingInfo& timing_info); \ - extern void GenerateBatch( \ - const ModelStore& model, const ModelWeightsPtrs& weights, \ - const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \ - const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \ - const KVCaches& kv_caches, MatMulEnv* env, TimingInfo& timing_info); \ - extern void GenerateImageTokens( \ - const ModelStore& model, const ModelWeightsPtrs& weights, \ - const RuntimeConfig& runtime_config, const Image& image, \ - ImageTokens& image_tokens, MatMulEnv* env); -GEMMA_DECLARE(float) -GEMMA_DECLARE(BF16) -GEMMA_DECLARE(NuqStream) -GEMMA_DECLARE(SfpStream) - void Gemma::Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, size_t prefix_end, KVCache& kv_cache, TimingInfo& timing_info) const { env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); - weights_.CallT([&](auto& weights) { - GenerateSingle(model_, *weights, runtime_config, prompt, pos, prefix_end, - kv_cache, &env_, timing_info); - }); + HWY_DYNAMIC_DISPATCH(GenerateSingleT)(model_.Config(), weights_, + runtime_config, prompt, pos, prefix_end, + kv_cache, &env_, timing_info); env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } @@ -127,11 +118,9 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); - weights_.CallT([&](auto& weights) { - gcpp::GenerateBatch(model_, *weights, runtime_config, queries_prompt, - queries_pos, mutable_queries_prefix_end, kv_caches, - &env_, timing_info); - }); + HWY_DYNAMIC_DISPATCH(GenerateBatchT)( + model_.Config(), weights_, runtime_config, queries_prompt, queries_pos, + mutable_queries_prefix_end, kv_caches, &env_, timing_info); env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } @@ -141,12 +130,11 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config, ImageTokens& image_tokens) const { env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); - weights_.CallT([&](auto& weights) { - gcpp::GenerateImageTokens(model_, *weights, runtime_config, image, - image_tokens, &env_); - }); + HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)( + model_.Config(), weights_, runtime_config, image, image_tokens, &env_); env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } } // namespace gcpp +#endif // HWY_ONCE diff --git a/gemma/gemma.h b/gemma/gemma.h index 908346f..66a8531 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -18,7 +18,7 @@ #include -#include +#include // IWYU pragma: begin_exports #include "gemma/activations.h" @@ -113,11 +113,9 @@ class Gemma { // TODO: rename to Config() const ModelConfig& GetModelConfig() const { return model_.Config(); } const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); } - const WeightsOwner& Weights() const { return weights_; } + const ModelWeightsPtrs& Weights() const { return weights_; } const GemmaChatTemplate& ChatTemplate() const { return chat_template_; } - // For tests. - WeightsOwner& MutableWeights() { return weights_; } void Save(const Path& weights_path, hwy::ThreadPool& pool) const; // `pos` is the position in the KV cache. Users are responsible for @@ -154,9 +152,10 @@ class Gemma { private: MatMulEnv& env_; - std::unique_ptr reader_; // null for second ctor + BlobReader reader_; ModelStore model_; - WeightsOwner weights_; + std::vector mat_owners_; + ModelWeightsPtrs weights_; GemmaChatTemplate chat_template_; }; diff --git a/gemma/instantiations/bf16.cc b/gemma/instantiations/bf16.cc deleted file mode 100644 index 2f001fb..0000000 --- a/gemma/instantiations/bf16.cc +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "gemma/instantiations/bf16.cc" - -#include "hwy/foreach_target.h" // IWYU pragma: keep -#define GEMMA_TYPE hwy::bfloat16_t -#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/f32.cc b/gemma/instantiations/f32.cc deleted file mode 100644 index 6b5496d..0000000 --- a/gemma/instantiations/f32.cc +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "gemma/instantiations/f32.cc" -#include "hwy/foreach_target.h" // IWYU pragma: keep -#define GEMMA_TYPE float -#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/nuq.cc b/gemma/instantiations/nuq.cc deleted file mode 100644 index 5e3ff4d..0000000 --- a/gemma/instantiations/nuq.cc +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "gemma/instantiations/nuq.cc" -#include "hwy/foreach_target.h" // IWYU pragma: keep -#define GEMMA_TYPE NuqStream -#include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/sfp.cc b/gemma/instantiations/sfp.cc deleted file mode 100644 index 563d034..0000000 --- a/gemma/instantiations/sfp.cc +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "gemma/instantiations/sfp.cc" -#include "hwy/foreach_target.h" // IWYU pragma: keep -#define GEMMA_TYPE SfpStream -#include "gemma/gemma-inl.h" diff --git a/gemma/model_store.cc b/gemma/model_store.cc index 27a3ef3..40d7c17 100644 --- a/gemma/model_store.cc +++ b/gemma/model_store.cc @@ -372,6 +372,11 @@ bool ModelStore::FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const { // `Compress()` output is always packed because it assumes a 1D array. HWY_ASSERT(mat.IsPacked()); // Update fields. Name already matched, otherwise we would not find it. + // For MatPtr tensors, the type will be `kUnknown`. If it was a `MatPtrT`, + // ensure the type set via code matches the file. + HWY_ASSERT_M( + mat.GetType() == Type::kUnknown || mat.GetType() == file_mat->GetType(), + mat.Name()); mat.SetType(file_mat->GetType()); if (scales_.empty()) { // `file_mat->Scale()` is either read from file, or we have pre-2025 format diff --git a/gemma/tensor_info.cc b/gemma/tensor_info.cc index 4006a29..de93cf9 100644 --- a/gemma/tensor_info.cc +++ b/gemma/tensor_info.cc @@ -1,6 +1,7 @@ #include "gemma/tensor_info.h" #include +#include #include diff --git a/gemma/tensor_info_test.cc b/gemma/tensor_info_test.cc index b336a94..427ecb4 100644 --- a/gemma/tensor_info_test.cc +++ b/gemma/tensor_info_test.cc @@ -21,13 +21,14 @@ TEST(TensorInfoRegistryTest, Find) { config.Specifier().c_str()); const TensorInfoRegistry tensors(config); // Each tensor in the model should be known/found. - ModelWeightsPtrs weights(config); + ModelWeightsPtrs weights(config); weights.ForEachTensor(nullptr, nullptr, [&tensors](const TensorArgs& t) { const TensorInfo* info = tensors.Find(t.mat.Name()); HWY_ASSERT_M(info, t.mat.Name()); // Test that the `MatPtr` can be constructed from the TensorInfo, // and that the dimensions match. - MatPtrT mat_ptr(t.mat.Name(), tensors); + const MatPtr mat_ptr(t.mat.Name(), Type::kUnknown, + ExtentsFromInfo(tensors.Find(t.mat.Name()))); EXPECT_STREQ(t.mat.Name(), mat_ptr.Name()) << t.mat.Name(); EXPECT_EQ(t.mat.Rows(), mat_ptr.Rows()) << t.mat.Name(); EXPECT_EQ(t.mat.Cols(), mat_ptr.Cols()) << t.mat.Name(); diff --git a/gemma/weights.cc b/gemma/weights.cc index 9e926a0..ab2c679 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -20,7 +20,7 @@ #include #include -#include +#include // NOLINT #include #include @@ -42,15 +42,127 @@ namespace gcpp { -static void InitAttWeightsNUQ(const LayerConfig& layer_config, - MatPtrT& attn_vec_einsum_w, - MatPtrT& att_weights, - std::vector& mat_owners) { +// Copies att_weights from `attn_vec_einsum_w`. +void LayerWeightsPtrs::InitAttWeights(std::vector& mat_owners) { + // We only use this tensor for Gemma layers. + if (layer_config.type != LayerAttentionType::kGemma) return; + + // Files must have one or the other. + HWY_ASSERT(attn_vec_einsum_w.HasPtr() ^ att_weights.HasPtr()); + // Done if we already read the transposed tensor. + if (att_weights.HasPtr() && !attn_vec_einsum_w.HasPtr()) return; + + // NUQ is handled by a specialization in weights.cc. + HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ); + + const size_t model_dim = layer_config.model_dim; + const size_t heads = layer_config.heads; + const size_t qkv_dim = layer_config.qkv_dim; + + // Reshape [heads, model_dim, qkv_dim] to [model_dim, heads * qkv_dim]. + att_weights.SetType(attn_vec_einsum_w.GetType()); + HWY_ASSERT(att_weights.Rows() == model_dim); + HWY_ASSERT(att_weights.Cols() == heads * qkv_dim); + HWY_ASSERT(attn_vec_einsum_w.Rows() == heads * model_dim); + HWY_ASSERT(attn_vec_einsum_w.Cols() == qkv_dim); + + { + static std::mutex m; + std::lock_guard lock(m); + mat_owners.push_back(MatOwner()); + mat_owners.back().AllocateFor(att_weights, MatPadding::kOdd); + } + + const size_t T_bytes = att_weights.ElementBytes(); + for (size_t m = 0; m < model_dim; ++m) { + uint8_t* HWY_RESTRICT out_row = att_weights.RowBytes(m); + for (size_t h = 0; h < heads; ++h) { + hwy::CopyBytes(attn_vec_einsum_w.RowBytes(h * model_dim + m), + out_row + h * qkv_dim * T_bytes, qkv_dim * T_bytes); + } + } + att_weights.SetScale(attn_vec_einsum_w.Scale()); +} + +// For FFN. Fast, only updates pointers. +void LayerWeightsPtrs::SplitW1() { + // Used for Gemma and Griffin layers; FFWVit uses different tensors. + if (layer_config.type == LayerAttentionType::kVit) return; + + // Files have both or neither of w1 and w2. + HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr()); + // w is mutually exclusive with w1 and w2 in the file. + HWY_ASSERT(gating_einsum_w.HasPtr() ^ gating_einsum_w1.HasPtr()); + // Done if we already read split tensors. Note that they are not + // necessarily the same type. + if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return; + + const size_t ff_hidden_dim = layer_config.ff_hidden_dim; + HWY_ASSERT(gating_einsum_w.Rows() == 2 * ff_hidden_dim); + HWY_ASSERT(gating_einsum_w1.Rows() == ff_hidden_dim); + HWY_ASSERT(gating_einsum_w2.Rows() == ff_hidden_dim); + // Cols are the model_dim but we don't have ModelConfig here. + HWY_ASSERT(gating_einsum_w1.Cols() == gating_einsum_w.Cols()); + HWY_ASSERT(gating_einsum_w2.Cols() == gating_einsum_w.Cols()); + + const size_t stride = gating_einsum_w.Stride(); + gating_einsum_w1.SetPtr(gating_einsum_w.RowBytes(0), stride); + gating_einsum_w2.SetPtr(gating_einsum_w.RowBytes(ff_hidden_dim), stride); + gating_einsum_w1.SetType(gating_einsum_w.GetType()); + gating_einsum_w2.SetType(gating_einsum_w.GetType()); + gating_einsum_w1.SetScale(gating_einsum_w.Scale()); + gating_einsum_w2.SetScale(gating_einsum_w.Scale()); + gating_einsum_w.SetPtr(nullptr, gating_einsum_w.Cols()); +} + +// For attention, which might not have a w2. Fast, only updates pointers. +void LayerWeightsPtrs::SplitAttW1() { + // We only use this tensor for Gemma layers. + if (layer_config.type != LayerAttentionType::kGemma) return; + + // w is mutually exclusive with w1 in the file. + HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr()); + // Done if we already read split tensors. Note that w2 does not exist for + // MHA, and otherwise might not be the same type. + if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return; + + const size_t w1_rows = layer_config.heads * layer_config.qkv_dim; + const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim; + HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows); + HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows); + HWY_ASSERT(qkv_einsum_w2.Rows() == w2_rows); + // Cols are the model_dim but we don't have ModelConfig here. + HWY_ASSERT(qkv_einsum_w1.Cols() == qkv_einsum_w.Cols()); + HWY_ASSERT(qkv_einsum_w2.Cols() == qkv_einsum_w.Cols()); + + const size_t stride = qkv_einsum_w.Stride(); + qkv_einsum_w1.SetPtr(qkv_einsum_w.RowBytes(0), stride); + qkv_einsum_w2.SetPtr(qkv_einsum_w.RowBytes(w1_rows), stride); + qkv_einsum_w1.SetType(qkv_einsum_w.GetType()); + qkv_einsum_w2.SetType(qkv_einsum_w.GetType()); + qkv_einsum_w1.SetScale(qkv_einsum_w.Scale()); + qkv_einsum_w2.SetScale(qkv_einsum_w.Scale()); + qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols()); +} + +// Must be called after reading weights via `ForEachTensor`. +// TODO: exporters should bake this into the weights already. +// WARNING: called from multiple threads; `mat_owners` requires a lock. +void LayerWeightsPtrs::Fixup(std::vector& mat_owners) { + // TODO(janwas): handle NUQ + InitAttWeights(mat_owners); + SplitW1(); + SplitAttW1(); +} + +static void HWY_MAYBE_UNUSED InitAttWeightsNUQ( + const LayerConfig& layer_config, MatPtrT& attn_vec_einsum_w, + MatPtrT& att_weights, std::vector& mat_owners) { if (!attn_vec_einsum_w.HasPtr()) return; HWY_ASSERT(attn_vec_einsum_w.GetType() == Type::kNUQ); HWY_ASSERT(att_weights.HasPtr()); - HWY_ASSERT(att_weights.GetType() == Type::kNUQ); + att_weights.SetType(Type::kNUQ); const size_t model_dim = layer_config.model_dim; const size_t heads = layer_config.heads; @@ -85,14 +197,53 @@ static void InitAttWeightsNUQ(const LayerConfig& layer_config, att_weights.SetScale(attn_vec_einsum_w.Scale()); } -static void SplitW1NUQ(const LayerConfig& layer_config) { +static void HWY_MAYBE_UNUSED SplitW1NUQ(const LayerConfig& layer_config) { // TODO(janwas): implement. } -template <> -void LayerWeightsPtrs::Fixup(std::vector& mat_owners) { - InitAttWeightsNUQ(layer_config, attn_vec_einsum_w, att_weights, mat_owners); - SplitW1NUQ(layer_config); +// Zero-initializes only the allocated tensors in `*this`. +void ModelWeightsPtrs::ZeroInit() { + ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) { + if (!t.mat.HasPtr()) return; + gcpp::ZeroInit(t.mat); + }); +} + +// Copies only the allocated tensors in `*this` from tensors in `other`. +void ModelWeightsPtrs::CopyFrom(const ModelWeightsPtrs& other) { + ForEachTensor(const_cast(&other), nullptr, + [](const TensorArgs& t) { + if (!t.mat.HasPtr()) return; + HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr()); + CopyMat(*t.other_mat1, t.mat); + }); +} + +// For reshaping file tensors to the shape expected by the code. This would +// ideally already happen in the importer. Called by WeightsOwner::Fixup. +void ModelWeightsPtrs::Fixup(std::vector& mat_owners, + hwy::ThreadPool& pool) { + pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) { + GetLayer(layer)->Fixup(mat_owners); + }); + + pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) { + VitLayer(layer)->Fixup(mat_owners); + }); +} + +std::vector ModelWeightsPtrs::AddTensorDataToWriter( + BlobWriter& writer) const { + std::vector serialized_mat_ptrs; + // ForEachTensor is non-const but the lambda does not modify *this. + const_cast(this)->ForEachTensor( + nullptr, nullptr, [&](const TensorArgs& t) { + if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return; + HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name()); + writer.Add(t.mat.Name(), t.mat.Packed(), t.mat.PackedBytes()); + t.mat.AppendTo(serialized_mat_ptrs); + }); + return serialized_mat_ptrs; } struct TensorToRead { @@ -144,7 +295,7 @@ static Mode ChooseMode(uint64_t file_bytes, Tristate map) { return Mode::kRead; } -MapPtr MapFileOrNull(File& file, uint64_t file_bytes) { +static MapPtr MapFileOrNull(File& file, uint64_t file_bytes) { const Allocator& allocator = ThreadingContext::Get().allocator; if (file_bytes % allocator.BasePageBytes() == 0) { MapPtr mapped = file.Map(); @@ -175,8 +326,8 @@ static void MapAll(const std::vector& tensors, } } -std::vector MakeBatches(const std::vector& tensors, - const uint64_t file_bytes) { +static std::vector MakeBatches( + const std::vector& tensors, const uint64_t file_bytes) { PROFILER_ZONE("Startup.Weights.MakeBatches"); // Batches must be contiguous but blobs are padded, hence at least one // batch per tensor, and more when tensor rows exceed the batch size. @@ -254,72 +405,34 @@ static void MapOrReadAll(const std::vector& tensors, ReadBatches(reader, batches, pool); } -void WeightsOwner::ReadFromBlobs(const ModelStore& model, BlobReader& reader, - Tristate map, hwy::ThreadPool& pool) { +void ModelWeightsPtrs::ReadFromBlobs(const ModelStore& model, + BlobReader& reader, Tristate map, + std::vector& mat_owners, + hwy::ThreadPool& pool) { // List of tensors to read/map, and where from. std::vector tensors; - AllocatePointer(model.Config()); - // Enumerate all weights (negligible cost). - CallT([&](const auto& weights) { - weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { - const MatPadding padding = (t.flags & TensorArgs::kPacked) - ? MatPadding::kPacked - : MatPadding::kOdd; - size_t key_idx; - if (model.FindAndUpdateMatPtr(t.mat, key_idx)) { - tensors.push_back({.mat = &t.mat, - .range = reader.Range(key_idx), - .padding = padding}); - return; - } - if (t.flags & TensorArgs::kMaybeRead) return; // optional and not found. - HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name()); - }); + ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { + const MatPadding padding = (t.flags & TensorArgs::kPacked) + ? MatPadding::kPacked + : MatPadding::kOdd; + size_t key_idx; + if (model.FindAndUpdateMatPtr(t.mat, key_idx)) { + tensors.push_back( + {.mat = &t.mat, .range = reader.Range(key_idx), .padding = padding}); + return; + } + if (t.flags & TensorArgs::kMaybeRead) return; // optional and not found. + HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name()); }); - MapOrReadAll(tensors, reader, map, mat_owners_, pool); + MapOrReadAll(tensors, reader, map, mat_owners, pool); - Fixup(pool); -} - -// Allocates `*_weights_`, but not yet the tensors inside. This is split out -// of `CallT` because that is const, hence it would pass a const& of the -// `unique_ptr` to its lambda, but we want to reset the pointer. -void WeightsOwner::AllocatePointer(const ModelConfig& config) { - switch (weight_type_) { - case Type::kSFP: - sfp_weights_.reset(new ModelWeightsPtrs(config)); - break; - case Type::kNUQ: - nuq_weights_.reset(new ModelWeightsPtrs(config)); - break; - case Type::kBF16: - bf16_weights_.reset(new ModelWeightsPtrs(config)); - break; - default: - HWY_ABORT("Unsupported weight type %s.", TypeName(weight_type_)); + { + PROFILER_ZONE("Startup.Fixup"); + Fixup(mat_owners, pool); } } -void WeightsOwner::Fixup(hwy::ThreadPool& pool) { - PROFILER_ZONE("Startup.Fixup"); - CallT([&](const auto& weights) { weights->Fixup(mat_owners_, pool); }); -} - -std::vector WeightsOwner::AddTensorDataToWriter( - BlobWriter& writer) const { - std::vector serialized_mat_ptrs; - CallT([&](const auto& weights) { - weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { - if (t.flags & TensorArgs::kMaybeRead && !t.mat.HasPtr()) return; - HWY_ASSERT_M(t.mat.HasPtr(), t.mat.Name()); - writer.Add(t.mat.Name(), t.mat.Packed(), t.mat.PackedBytes()); - t.mat.AppendTo(serialized_mat_ptrs); - }); - }); - return serialized_mat_ptrs; -} - } // namespace gcpp diff --git a/gemma/weights.h b/gemma/weights.h index 3173cb2..ba20627 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -19,18 +19,14 @@ #include #include -#include -#include -#include // NOLINT #include #include -#include "compression/types.h" // IsF32 +#include "compression/types.h" #include "gemma/configs.h" // ModelConfig #include "gemma/model_store.h" // ModelStore #include "gemma/tensor_info.h" // TensorInfoRegistry #include "io/blob_store.h" // BlobWriter -#include "ops/matmul.h" // MatMulEnv #include "util/mat.h" // MatPtr #include "hwy/contrib/thread_pool/thread_pool.h" @@ -73,142 +69,141 @@ struct TensorArgs { TensorArgs(mat, other1 ? &other1->mat : nullptr, \ other2 ? &other2->mat : nullptr, TensorArgs::flag) -// Per-layer weight metadata and pointers. The tensor data is owned by -// `WeightsOwner`. Note that this class could be type-erased: member functions -// do not actually use the `Weight` template argument. See `WeightsPtrs`. -// `TensorInfoRegistry` (constructed from `ModelConfig`) is the source of truth -// for all tensor shapes. -template -struct LayerWeightsPtrs { - static inline std::string Concat(const char* base_name, - const std::string& suffix) { - return std::string(base_name) + suffix; +// Finds tensors by name in `TensorInfoRegistry` (constructed from +// `ModelConfig`) and constructs `MatPtr` metadata with those shapes. +class MatFinder { + public: + MatFinder(const std::string& suffix, const TensorInfoRegistry& tensors) + : suffix_(suffix), tensors_(tensors) {} + + // Retrieves shape by name via `TensorInfo` from `TensorInfoRegistry`. + MatPtr operator()(const std::string& base_name) const { + const std::string name = std::string(base_name) + suffix_; + return MatPtr(name.c_str(), Type::kUnknown, + ExtentsFromInfo(tensors_.Find(name))); } + private: + const std::string suffix_; + const TensorInfoRegistry& tensors_; +}; + +// Per-layer weight metadata and pointers. The tensor data is owned by +// `WeightsOwner`. +struct LayerWeightsPtrs { // Initializes tensor metadata without allocating. LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config, const TensorInfoRegistry& tensors) - : suffix_(LayerSuffix(layer_idx)), - qkv_einsum_w(Concat("qkv_ein", suffix_), tensors), - qkv_einsum_w1(Concat("qkv1_w", suffix_), tensors), - qkv_einsum_w2(Concat("qkv2_w", suffix_), tensors), - attention_output_biases(Concat("attn_ob", suffix_), tensors), - griffin( - {.linear_x_w = {Concat("gr_lin_x_w", suffix_), tensors}, - .linear_x_biases = {Concat("gr_lin_x_b", suffix_), tensors}, - .linear_y_w = {Concat("gr_lin_y_w", suffix_), tensors}, - .linear_y_biases = {Concat("gr_lin_y_b", suffix_), tensors}, - .linear_out_w = {Concat("gr_lin_out_w", suffix_), tensors}, - .linear_out_biases = {Concat("gr_lin_out_b", suffix_), tensors}, - .conv_w = {Concat("gr_conv_w", suffix_), tensors}, - .conv_biases = {Concat("gr_conv_b", suffix_), tensors}, - .gate_w = {Concat("gr_gate_w", suffix_), tensors}, - .gate_biases = {Concat("gr_gate_b", suffix_), tensors}, - .a = {Concat("gr_a", suffix_), tensors}}), + : finder_(LayerSuffix(layer_idx), tensors), + qkv_einsum_w(finder_("qkv_ein")), + qkv_einsum_w1(finder_("qkv1_w")), + qkv_einsum_w2(finder_("qkv2_w")), + attention_output_biases(finder_("attn_ob")), + griffin({.linear_x_w = finder_("gr_lin_x_w"), + .linear_x_biases = finder_("gr_lin_x_b"), + .linear_y_w = finder_("gr_lin_y_w"), + .linear_y_biases = finder_("gr_lin_y_b"), + .linear_out_w = finder_("gr_lin_out_w"), + .linear_out_biases = finder_("gr_lin_out_b"), + .conv_w = finder_("gr_conv_w"), + .conv_biases = finder_("gr_conv_b"), + .gate_w = finder_("gr_gate_w"), + .gate_biases = finder_("gr_gate_b"), + .a = finder_("gr_a")}), // MultiHeadDotProductAttention. - vit({.attn_out_w = {Concat("attn_out_w", suffix_), tensors}, - .attn_out_b = {Concat("attn_out_b", suffix_), tensors}, - .qkv_einsum_w = {Concat("qkv_ein_w", suffix_), tensors}, - .qkv_einsum_b = {Concat("qkv_ein_b", suffix_), tensors}, - .linear_0_w = {Concat("linear_0_w", suffix_), tensors}, - .linear_0_b = {Concat("linear_0_b", suffix_), tensors}, - .linear_1_w = {Concat("linear_1_w", suffix_), tensors}, - .linear_1_b = {Concat("linear_1_b", suffix_), tensors}, - .layer_norm_0_bias = {Concat("ln_0_bias", suffix_), tensors}, - .layer_norm_0_scale = {Concat("ln_0_scale", suffix_), tensors}, - .layer_norm_1_bias = {Concat("ln_1_bias", suffix_), tensors}, - .layer_norm_1_scale = {Concat("ln_1_scale", suffix_), tensors}}), - gating_einsum_w(Concat("gating_ein", suffix_), tensors), - gating_einsum_w1(Concat("gating1_w", suffix_), tensors), - gating_einsum_w2(Concat("gating2_w", suffix_), tensors), - linear_w(Concat("linear_w", suffix_), tensors), - pre_attention_norm_scale(Concat("pre_att_ns", suffix_), tensors), - pre_ffw_norm_scale(Concat("pre_ff_ns", suffix_), tensors), - post_attention_norm_scale(Concat("post_att_ns", suffix_), tensors), - post_ffw_norm_scale(Concat("post_ff_ns", suffix_), tensors), - ffw_gating_biases(Concat("ffw_gat_b", suffix_), tensors), - ffw_output_biases(Concat("ffw_out_b", suffix_), tensors), + vit({.attn_out_w = finder_("attn_out_w"), + .attn_out_b = finder_("attn_out_b"), + .qkv_einsum_w = finder_("qkv_ein_w"), + .qkv_einsum_b = finder_("qkv_ein_b"), + .linear_0_w = finder_("linear_0_w"), + .linear_0_b = finder_("linear_0_b"), + .linear_1_w = finder_("linear_1_w"), + .linear_1_b = finder_("linear_1_b"), + .layer_norm_0_bias = finder_("ln_0_bias"), + .layer_norm_0_scale = finder_("ln_0_scale"), + .layer_norm_1_bias = finder_("ln_1_bias"), + .layer_norm_1_scale = finder_("ln_1_scale")}), + gating_einsum_w(finder_("gating_ein")), + gating_einsum_w1(finder_("gating1_w")), + gating_einsum_w2(finder_("gating2_w")), + linear_w(finder_("linear_w")), + pre_attention_norm_scale(finder_("pre_att_ns")), + pre_ffw_norm_scale(finder_("pre_ff_ns")), + post_attention_norm_scale(finder_("post_att_ns")), + post_ffw_norm_scale(finder_("post_ff_ns")), + ffw_gating_biases(finder_("ffw_gat_b")), + ffw_output_biases(finder_("ffw_out_b")), - attn_vec_einsum_w(Concat("att_ein", suffix_), tensors), - att_weights(Concat("att_w", suffix_), tensors), + attn_vec_einsum_w(finder_("att_ein")), + att_weights(finder_("att_w")), - key_norm_scale(Concat("key_norm", suffix_), tensors), - query_norm_scale(Concat("query_norm", suffix_), tensors), + key_norm_scale(finder_("key_norm")), + query_norm_scale(finder_("query_norm")), layer_config(config) { } ~LayerWeightsPtrs() = default; - const std::string suffix_; - - // If weights are f32, also f32; otherwise at least bf16. Useful for ops that - // do not yet support smaller compressed types, or require at least bf16. When - // weights are f32, we also want such tensors to be f32. - // If weights are complex, this is also complex. - using WeightF32OrBF16 = - hwy::If>(), std::complex, - hwy::If(), double, - hwy::If(), float, BF16>>>; + const MatFinder finder_; // Files either have qkv_einsum_w with 2 stacked matrices or separate // w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h. - MatPtrT qkv_einsum_w; - MatPtrT qkv_einsum_w1; - MatPtrT qkv_einsum_w2; + MatPtr qkv_einsum_w; + MatPtr qkv_einsum_w1; + MatPtr qkv_einsum_w2; MatPtrT attention_output_biases; struct { - MatPtrT linear_x_w; + MatPtr linear_x_w; MatPtrT linear_x_biases; - MatPtrT linear_y_w; + MatPtr linear_y_w; MatPtrT linear_y_biases; - MatPtrT linear_out_w; + MatPtr linear_out_w; MatPtrT linear_out_biases; MatPtrT conv_w; MatPtrT conv_biases; - MatPtrT gate_w; + MatPtr gate_w; MatPtrT gate_biases; MatPtrT a; } griffin; struct { // MultiHeadDotProductAttention. - MatPtrT attn_out_w; + MatPtr attn_out_w; // at least BF16. MatPtrT attn_out_b; - MatPtrT qkv_einsum_w; + MatPtr qkv_einsum_w; // at least BF16. MatPtrT qkv_einsum_b; // MlpBlock. - MatPtrT linear_0_w; + MatPtr linear_0_w; // at least BF16. MatPtrT linear_0_b; - MatPtrT linear_1_w; + MatPtr linear_1_w; // at least BF16. MatPtrT linear_1_b; // LayerNorm. - MatPtrT layer_norm_0_bias; - MatPtrT layer_norm_0_scale; - MatPtrT layer_norm_1_bias; - MatPtrT layer_norm_1_scale; + MatPtr layer_norm_0_bias; // at least BF16. + MatPtr layer_norm_0_scale; // at least BF16. + MatPtr layer_norm_1_bias; // at least BF16. + MatPtr layer_norm_1_scale; // at least BF16. } vit; // Files either have gating_einsum_w with 2 stacked matrices or separate - // w1/w2 tensors. Fixup ensures w1/w2 are ready for use by gemma-inl.h. - MatPtrT gating_einsum_w; - MatPtrT gating_einsum_w1; - MatPtrT gating_einsum_w2; - MatPtrT linear_w; - // > W8 is likely helpful. - MatPtrT pre_attention_norm_scale; - MatPtrT pre_ffw_norm_scale; - MatPtrT post_attention_norm_scale; - MatPtrT post_ffw_norm_scale; + // w1/w2 tensors. `Fixup` ensures w1/w2 are ready for use by gemma-inl.h. + MatPtr gating_einsum_w; + MatPtr gating_einsum_w1; + MatPtr gating_einsum_w2; + MatPtr linear_w; + MatPtr pre_attention_norm_scale; // at least BF16. + MatPtr pre_ffw_norm_scale; // at least BF16. + MatPtr post_attention_norm_scale; // at least BF16. + MatPtr post_ffw_norm_scale; // at least BF16. MatPtrT ffw_gating_biases; MatPtrT ffw_output_biases; - MatPtrT attn_vec_einsum_w; // Use att_weights instead of this. - MatPtrT att_weights; // Use this instead of attn_vec_einsum_w. + MatPtr attn_vec_einsum_w; // Use att_weights instead of this. + MatPtr att_weights; // Use this instead of attn_vec_einsum_w. - MatPtrT key_norm_scale; - MatPtrT query_norm_scale; + MatPtr key_norm_scale; // at least BF16. + MatPtr query_norm_scale; // at least BF16. const LayerConfig& layer_config; @@ -217,8 +212,8 @@ struct LayerWeightsPtrs { // can also iterate over pairs or triples of tensors for `AdamUpdateMV`. // Public because also called by `WeightsPtrs`. template - void ForEachTensor(LayerWeightsPtrs* other1, - LayerWeightsPtrs* other2, Func func) { + void ForEachTensor(LayerWeightsPtrs* other1, LayerWeightsPtrs* other2, + Func func) { if (layer_config.type == LayerAttentionType::kVit) { // MHA. func(TENSOR_ARGS(vit.attn_out_w, kMustRead)); @@ -301,208 +296,98 @@ struct LayerWeightsPtrs { // Must be called after reading weights via `ForEachTensor`. // TODO: exporters should bake this into the weights already. // WARNING: called from multiple threads; `mat_owners` requires a lock. - void Fixup(std::vector& mat_owners) { - InitAttWeights(mat_owners); - SplitW1(); - SplitAttW1(); - } + void Fixup(std::vector& mat_owners); private: // Copies att_weights from `attn_vec_einsum_w`. - void InitAttWeights(std::vector& mat_owners) { - // We only use this tensor for Gemma layers. - if (layer_config.type != LayerAttentionType::kGemma) return; - - // Files must have one or the other. - HWY_ASSERT(attn_vec_einsum_w.HasPtr() ^ att_weights.HasPtr()); - // Done if we already read the transposed tensor. - if (att_weights.HasPtr() && !attn_vec_einsum_w.HasPtr()) return; - - // NUQ is handled by a specialization in weights.cc. - HWY_ASSERT(attn_vec_einsum_w.GetType() != Type::kNUQ); - - const size_t model_dim = layer_config.model_dim; - const size_t heads = layer_config.heads; - const size_t qkv_dim = layer_config.qkv_dim; - - // Reshape [heads, model_dim, qkv_dim] to [model_dim, heads * qkv_dim]. - HWY_ASSERT(att_weights.GetType() == attn_vec_einsum_w.GetType()); - HWY_ASSERT(att_weights.Rows() == model_dim); - HWY_ASSERT(att_weights.Cols() == heads * qkv_dim); - HWY_ASSERT(attn_vec_einsum_w.Rows() == heads * model_dim); - HWY_ASSERT(attn_vec_einsum_w.Cols() == qkv_dim); - - { - static std::mutex m; - std::lock_guard lock(m); - mat_owners.push_back(MatOwner()); - mat_owners.back().AllocateFor(att_weights, MatPadding::kOdd); - } - - const size_t T_bytes = att_weights.ElementBytes(); - for (size_t m = 0; m < model_dim; ++m) { - uint8_t* HWY_RESTRICT out_row = - reinterpret_cast(att_weights.Row(m)); - for (size_t h = 0; h < heads; ++h) { - hwy::CopyBytes(attn_vec_einsum_w.Row(h * model_dim + m), - out_row + h * qkv_dim * T_bytes, qkv_dim * T_bytes); - } - } - att_weights.SetScale(attn_vec_einsum_w.Scale()); - } + void InitAttWeights(std::vector& mat_owners); // For FFN. Fast, only updates pointers. - void SplitW1() { - // Used for Gemma and Griffin layers; FFWVit uses different tensors. - if (layer_config.type == LayerAttentionType::kVit) return; - - // Files have both or neither of w1 and w2. - HWY_ASSERT(gating_einsum_w1.HasPtr() == gating_einsum_w2.HasPtr()); - // w is mutually exclusive with w1 and w2 in the file. - HWY_ASSERT(gating_einsum_w.HasPtr() ^ gating_einsum_w1.HasPtr()); - // Done if we already read split tensors. Note that they are not - // necessarily the same type. - if (gating_einsum_w1.HasPtr() && !gating_einsum_w.HasPtr()) return; - - const size_t ff_hidden_dim = layer_config.ff_hidden_dim; - HWY_ASSERT(gating_einsum_w.Rows() == 2 * ff_hidden_dim); - HWY_ASSERT(gating_einsum_w1.Rows() == ff_hidden_dim); - HWY_ASSERT(gating_einsum_w2.Rows() == ff_hidden_dim); - // Cols are the model_dim but we don't have ModelConfig here. - HWY_ASSERT(gating_einsum_w1.Cols() == gating_einsum_w.Cols()); - HWY_ASSERT(gating_einsum_w2.Cols() == gating_einsum_w.Cols()); - - const size_t stride = gating_einsum_w.Stride(); - gating_einsum_w1.SetPtr(gating_einsum_w.Row(0), stride); - gating_einsum_w2.SetPtr(gating_einsum_w.Row(ff_hidden_dim), stride); - gating_einsum_w1.SetType(gating_einsum_w.GetType()); - gating_einsum_w2.SetType(gating_einsum_w.GetType()); - gating_einsum_w1.SetScale(gating_einsum_w.Scale()); - gating_einsum_w2.SetScale(gating_einsum_w.Scale()); - gating_einsum_w.SetPtr(nullptr, gating_einsum_w.Cols()); - } + void SplitW1(); // For attention, which might not have a w2. Fast, only updates pointers. - void SplitAttW1() { - // We only use this tensor for Gemma layers. - if (layer_config.type != LayerAttentionType::kGemma) return; - - // w is mutually exclusive with w1 in the file. - HWY_ASSERT(qkv_einsum_w.HasPtr() ^ qkv_einsum_w1.HasPtr()); - // Done if we already read split tensors. Note that w2 does not exist for - // MHA, and otherwise might not be the same type. - if (qkv_einsum_w1.HasPtr() && !qkv_einsum_w.HasPtr()) return; - - const size_t w1_rows = layer_config.heads * layer_config.qkv_dim; - const size_t w2_rows = layer_config.kv_heads * 2 * layer_config.qkv_dim; - - HWY_ASSERT(qkv_einsum_w.Rows() == w1_rows + w2_rows); - HWY_ASSERT(qkv_einsum_w1.Rows() == w1_rows); - HWY_ASSERT(qkv_einsum_w2.Rows() == w2_rows); - // Cols are the model_dim but we don't have ModelConfig here. - HWY_ASSERT(qkv_einsum_w1.Cols() == qkv_einsum_w.Cols()); - HWY_ASSERT(qkv_einsum_w2.Cols() == qkv_einsum_w.Cols()); - - const size_t stride = qkv_einsum_w.Stride(); - qkv_einsum_w1.SetPtr(qkv_einsum_w.Row(0), stride); - qkv_einsum_w2.SetPtr(qkv_einsum_w.Row(w1_rows), stride); - qkv_einsum_w1.SetType(qkv_einsum_w.GetType()); - qkv_einsum_w2.SetType(qkv_einsum_w.GetType()); - qkv_einsum_w1.SetScale(qkv_einsum_w.Scale()); - qkv_einsum_w2.SetScale(qkv_einsum_w.Scale()); - qkv_einsum_w.SetPtr(nullptr, qkv_einsum_w.Cols()); - } + void SplitAttW1(); }; // Holds layer-independent weight metadata and pointers plus per-layer -// `LayerWeightsPtrs`. The tensor data is owned by `WeightsOwner`. As with -// `LayerWeightsPtrs`, this class could be type-erased: member functions do not -// actually use the `Weight` template argument. The template does allow user -// code to dispatch only once. However, most tensors are large enough that -// dispatch at each usage would be feasible. +// `LayerWeightsPtrs`. The tensor data is owned by `WeightsOwner`. // TODO: move `gemma-inl.h` toward dispatch at each usage. // TODO: rename to WeightsPtrs. -template struct ModelWeightsPtrs { - using WeightT = Weight; - explicit ModelWeightsPtrs(const ModelConfig& config) - : tensors_(config), - // No suffix, these are per-model. - embedder_input_embedding("c_embedding", tensors_), - final_norm_scale("c_final_norm", tensors_), - vit_encoder_norm_bias("enc_norm_bias", tensors_), - vit_encoder_norm_scale("enc_norm_scale", tensors_), - vit_img_embedding_bias("img_emb_bias", tensors_), - vit_img_embedding_kernel("img_emb_kernel", tensors_), - vit_img_pos_embedding("img_pos_emb", tensors_), - vit_img_head_bias("img_head_bias", tensors_), - vit_img_head_kernel("img_head_kernel", tensors_), - mm_embed_norm("mm_embed_norm", tensors_), - weights_config(config) { - c_layers.reserve(config.layer_configs.size()); - for (size_t idx = 0; idx < config.layer_configs.size(); ++idx) { - const LayerConfig& layer_config = config.layer_configs[idx]; + : config_(config), + tensors_(config_), + finder_("", tensors_), // no suffix because these are per-model. + embedder_input_embedding(finder_("c_embedding")), + final_norm_scale(finder_("c_final_norm")), + vit_encoder_norm_bias(finder_("enc_norm_bias")), + vit_encoder_norm_scale(finder_("enc_norm_scale")), + vit_img_embedding_bias(finder_("img_emb_bias")), + vit_img_embedding_kernel(finder_("img_emb_kernel")), + vit_img_pos_embedding(finder_("img_pos_emb")), + vit_img_head_bias(finder_("img_head_bias")), + vit_img_head_kernel(finder_("img_head_kernel")), + mm_embed_norm(finder_("mm_embed_norm")), + c_layers() { + c_layers.reserve(config_.layer_configs.size()); + for (size_t idx = 0; idx < config_.layer_configs.size(); ++idx) { + const LayerConfig& layer_config = config_.layer_configs[idx]; c_layers.emplace_back(idx, layer_config, tensors_); } - for (size_t idx = 0; idx < config.vit_config.layer_configs.size(); ++idx) { - const LayerConfig& layer_config = config.vit_config.layer_configs[idx]; + for (size_t idx = 0; idx < config_.vit_config.layer_configs.size(); ++idx) { + const LayerConfig& layer_config = config_.vit_config.layer_configs[idx]; vit_layers.emplace_back(idx, layer_config, tensors_); } } ~ModelWeightsPtrs() = default; - // = F32 if weights are F32, else BF16. - using WeightF32OrBF16 = typename LayerWeightsPtrs::WeightF32OrBF16; - // Passed to all `MatPtrT` initializers, hence must be initialized first. + const ModelConfig& config_; + // Passed to finder_, hence must be initialized first. const TensorInfoRegistry tensors_; + const MatFinder finder_; // TODO: switch to SFP? - MatPtrT embedder_input_embedding; - MatPtrT final_norm_scale; + MatPtr embedder_input_embedding; + MatPtr final_norm_scale; // at least BF16. // Vit parts. - MatPtrT vit_encoder_norm_bias; - MatPtrT vit_encoder_norm_scale; + MatPtr vit_encoder_norm_bias; // at least BF16. + MatPtr vit_encoder_norm_scale; // at least BF16. MatPtrT vit_img_embedding_bias; - MatPtrT vit_img_embedding_kernel; - MatPtrT vit_img_pos_embedding; + MatPtr vit_img_embedding_kernel; // at least BF16. + MatPtr vit_img_pos_embedding; // F32? // The head maps from VitConfig::model_dim (Vit final layer) to // model_dim (LLM input). MatPtrT vit_img_head_bias; - MatPtrT vit_img_head_kernel; + MatPtr vit_img_head_kernel; // at least BF16. - MatPtrT mm_embed_norm; + MatPtr mm_embed_norm; // at least BF16. - const ModelConfig& weights_config; + std::vector c_layers; + std::vector vit_layers; - std::vector> c_layers; - std::vector> vit_layers; - - const LayerWeightsPtrs* GetLayer(size_t layer) const { + const LayerWeightsPtrs* GetLayer(size_t layer) const { return &c_layers[layer]; } - LayerWeightsPtrs* GetLayer(size_t layer) { return &c_layers[layer]; } - const LayerWeightsPtrs* VitLayer(size_t layer) const { - return &vit_layers[layer]; - } - LayerWeightsPtrs* VitLayer(size_t layer) { + LayerWeightsPtrs* GetLayer(size_t layer) { return &c_layers[layer]; } + const LayerWeightsPtrs* VitLayer(size_t layer) const { return &vit_layers[layer]; } + LayerWeightsPtrs* VitLayer(size_t layer) { return &vit_layers[layer]; } // Called via `CallT`. `other1` and `other2` are usually null, but can be // used to copy from another set of weights. Public because called by tests // and `WeightsOwner`. template - void ForEachTensor(ModelWeightsPtrs* other1, - ModelWeightsPtrs* other2, Func func) { - LayerWeightsPtrs* other_layer1 = nullptr; - LayerWeightsPtrs* other_layer2 = nullptr; + void ForEachTensor(ModelWeightsPtrs* other1, ModelWeightsPtrs* other2, + Func func) { + LayerWeightsPtrs* other_layer1 = nullptr; + LayerWeightsPtrs* other_layer2 = nullptr; func(TENSOR_ARGS(embedder_input_embedding, kMustRead)); func(TENSOR_ARGS(final_norm_scale, kMustRead)); - if (!weights_config.vit_config.layer_configs.empty()) { // Vit parts. + if (!config_.vit_config.layer_configs.empty()) { // Vit parts. func(TENSOR_ARGS(vit_encoder_norm_bias, kMustRead)); func(TENSOR_ARGS(vit_encoder_norm_scale, kMustRead)); func(TENSOR_ARGS(vit_img_embedding_bias, kMustRead)); @@ -511,7 +396,7 @@ struct ModelWeightsPtrs { func(TENSOR_ARGS(vit_img_head_bias, kMustRead)); func(TENSOR_ARGS(vit_img_head_kernel, kMustRead)); - if (weights_config.wrapping == PromptWrapping::GEMMA_VLM) { + if (config_.wrapping == PromptWrapping::GEMMA_VLM) { func(TENSOR_ARGS(mm_embed_norm, kMustRead)); } } @@ -522,8 +407,7 @@ struct ModelWeightsPtrs { GetLayer(layer_idx)->ForEachTensor(other_layer1, other_layer2, func); } - HWY_ASSERT(weights_config.vit_config.layer_configs.empty() == - vit_layers.empty()); + HWY_ASSERT(config_.vit_config.layer_configs.empty() == vit_layers.empty()); for (size_t layer_idx = 0; layer_idx < vit_layers.size(); ++layer_idx) { HWY_ASSERT(vit_layers[layer_idx].layer_config.type == LayerAttentionType::kVit); @@ -534,87 +418,24 @@ struct ModelWeightsPtrs { } // `ForEachTensor` // Zero-initializes only the allocated tensors in `*this`. - void ZeroInit() { - ForEachTensor(nullptr, nullptr, [](const TensorArgs& t) { - if (!t.mat.HasPtr()) return; - gcpp::ZeroInit(t.mat); - }); - } - + void ZeroInit(); // Copies only the allocated tensors in `*this` from tensors in `other`. - void CopyFrom(const ModelWeightsPtrs& other) { - ForEachTensor(const_cast*>(&other), nullptr, - [](const TensorArgs& t) { - if (!t.mat.HasPtr()) return; - HWY_ASSERT(t.other_mat1 && t.other_mat1->HasPtr()); - CopyMat(*t.other_mat1, t.mat); - }); - } - - // For reshaping file tensors to the shape expected by the code. This would - // ideally already happen in the importer. Must be called after reading and - // updating the attention weights. - void Fixup(std::vector& mat_owners, hwy::ThreadPool& pool) { - pool.Run(0, c_layers.size(), [&](uint64_t layer, size_t /*thread*/) { - GetLayer(layer)->Fixup(mat_owners); - }); - - pool.Run(0, vit_layers.size(), [&](uint64_t layer, size_t /*thread*/) { - VitLayer(layer)->Fixup(mat_owners); - }); - } -}; // `WeightsPtrs` -#undef TENSOR_ARGS - -// Type-erased facade for `WeightsPtrs`, stored inside the non-template -// `Gemma`. Also owns the underlying memory. -class WeightsOwner { - public: - // `weight_type` is obtained from `ModelConfig` in `ModelStore`. - WeightsOwner(Type weight_type) : weight_type_(weight_type) {} + void CopyFrom(const ModelWeightsPtrs& other); // Reads tensor data from `BlobStore` or aborts on error. `map` is a user // override for whether to map blobs or read them. void ReadFromBlobs(const ModelStore& model, BlobReader& reader, Tristate map, - hwy::ThreadPool& pool); - - // Calls `func(std::unique_ptr>&, args)`. `func` typically - // calls `ForEachTensor`. - template - decltype(auto) CallT(const Func& func, TArgs&&... args) const { - if (HWY_LIKELY(weight_type_ == Type::kSFP)) { - return func(sfp_weights_, std::forward(args)...); - } else if (weight_type_ == Type::kNUQ) { - return func(nuq_weights_, std::forward(args)...); - } else if (weight_type_ == Type::kBF16) { - return func(bf16_weights_, std::forward(args)...); - } - return HWY_ABORT("Unsupported weight type %s.", TypeName(weight_type_)); - } - - // For writers: + std::vector& mat_owners, hwy::ThreadPool& pool); // Adds one blob for each tensor's data and returns all serialized MatPtr. std::vector AddTensorDataToWriter(BlobWriter& writer) const; private: - Type weight_type_; - - // Allocates `*_weights_`, but not yet the tensors inside. This is split out - // of `CallT` so that can be const. - void AllocatePointer(const ModelConfig& config); - - // Called by `ReadFromBlobs`. - void Fixup(hwy::ThreadPool& pool); - - // Only one is non-null, determined by `weight_type_`. - std::unique_ptr> bf16_weights_; - std::unique_ptr> sfp_weights_; - std::unique_ptr> nuq_weights_; - - // Owns the memory referenced by all `MatPtr`. - std::vector mat_owners_; -}; + // For reshaping file tensors to the shape expected by the code. This would + // ideally already happen in the importer. Called by ReadFromBlobs. + void Fixup(std::vector& mat_owners, hwy::ThreadPool& pool); +}; // `ModelWeightsPtrs` +#undef TENSOR_ARGS } // namespace gcpp diff --git a/io/blob_store.h b/io/blob_store.h index 19c6639..f7a4a6d 100644 --- a/io/blob_store.h +++ b/io/blob_store.h @@ -59,6 +59,8 @@ class BlobReader { const File& file() const { return *file_; } uint64_t file_bytes() const { return file_bytes_; } + void CloseFile() { file_.reset(); } + const std::vector& Keys() const { return keys_; } const BlobRange& Range(size_t key_idx) const { @@ -99,7 +101,7 @@ class BlobReader { } private: - const std::unique_ptr file_; + std::unique_ptr file_; const uint64_t file_bytes_; std::vector keys_; diff --git a/util/mat.h b/util/mat.h index 5b15df9..f350ce7 100644 --- a/util/mat.h +++ b/util/mat.h @@ -131,6 +131,13 @@ class MatPtr : public IFields { Type GetType() const { return type_; } void SetType(Type type) { type_ = type; + if (type == Type::kUnknown) { + // Temporary invalid state. Happens during weights.h construction, before + // the ForEachTensor that loads them and sets the type. + element_bytes_ = 0; + num_elements_ = 0; + return; + } element_bytes_ = static_cast(hwy::DivCeil(TypeBits(type), 8)); num_elements_ = static_cast(ComputeNumElements(type, Extents())); HWY_DASSERT(0 != element_bytes_ && element_bytes_ <= 16); @@ -244,15 +251,17 @@ class MatPtrT : public MatPtr { // Called by `MatStorageT`. MatPtrT(const char* name, Extents2D extents) : MatPtr(name, TypeEnum(), extents) {} - // Retrieves shape by name via `TensorInfo` from `TensorInfoRegistry`. This is - // not a factory function because `weights.h` initializes members of type - // `MatPtrT`, and `T` cannot be inferred at compile time from arguments. - MatPtrT(const std::string& name, const TensorInfoRegistry& info) - : MatPtrT(name.c_str(), ExtentsFromInfo(info.Find(name))) {} // Copying allowed because the metadata is small. MatPtrT(const MatPtr& other) : MatPtr(other) { - HWY_ASSERT(other.GetType() == TypeEnum()); + // Happens in weights.h when constructing via MatFinder, which does not + // know the type. Setting the type here avoids having to keep the + // initializer list and member type in sync. + if (GetType() == Type::kUnknown) { + SetType(TypeEnum()); + } else { + HWY_ASSERT(other.GetType() == TypeEnum()); + } } MatPtrT& operator=(const MatPtr& other) { MatPtr::operator=(other); @@ -289,27 +298,27 @@ class MatPtrT : public MatPtr { } }; -// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT`, plus the -// optional `args`. This supports all types used as weights. +// Calls `func` with `MatPtrT*` plus the optional `args`. This supports all +// types used as weights. template decltype(auto) CallUpcasted(const MatPtr* base, const Func& func, Args&&... args) { #if GEMMA_ENABLE_NUQ if (base->GetType() == Type::kNUQ) { - return func(dynamic_cast*>(base), - std::forward(args)...); + const MatPtrT mat(*base); + return func(&mat, std::forward(args)...); } #endif // GEMMA_ENABLE_NUQ if (base->GetType() == Type::kF32) { - return func(dynamic_cast*>(base), - std::forward(args)...); + const MatPtrT mat(*base); + return func(&mat, std::forward(args)...); } else if (base->GetType() == Type::kBF16) { - return func(dynamic_cast*>(base), - std::forward(args)...); + const MatPtrT mat(*base); + return func(&mat, std::forward(args)...); } else if (base->GetType() == Type::kSFP) { - return func(dynamic_cast*>(base), - std::forward(args)...); + const MatPtrT mat(*base); + return func(&mat, std::forward(args)...); } else { HWY_ABORT("Unhandled type %s.", TypeName(base->GetType())); } @@ -323,24 +332,24 @@ decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2, #if GEMMA_ENABLE_NUQ if (base1->GetType() == Type::kNUQ) { - return func(dynamic_cast*>(base1), - dynamic_cast*>(base2), - std::forward(args)...); + const MatPtrT mat1(*base1); + const MatPtrT mat2(*base2); + return func(&mat1, &mat2, std::forward(args)...); } #endif // GEMMA_ENABLE_NUQ if (base1->GetType() == Type::kF32) { - return func(dynamic_cast*>(base1), - dynamic_cast*>(base2), - std::forward(args)...); + const MatPtrT mat1(*base1); + const MatPtrT mat2(*base2); + return func(&mat1, &mat2, std::forward(args)...); } else if (base1->GetType() == Type::kBF16) { - return func(dynamic_cast*>(base1), - dynamic_cast*>(base2), - std::forward(args)...); + const MatPtrT mat1(*base1); + const MatPtrT mat2(*base2); + return func(&mat1, &mat2, std::forward(args)...); } else if (base1->GetType() == Type::kSFP) { - return func(dynamic_cast*>(base1), - dynamic_cast*>(base2), - std::forward(args)...); + const MatPtrT mat1(*base1); + const MatPtrT mat2(*base2); + return func(&mat1, &mat2, std::forward(args)...); } else { HWY_ABORT("Unhandled type %s.", TypeName(base1->GetType())); }