// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Defines Gemma member functions which dynamic-dispatch into the SIMD // implementations in gemma-inl.h. #include "gemma/gemma.h" #include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "util/zones.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS // Compiles this file for multiple architectures via "foreach_target.h", to // which we pass the filename via macro 'argument'. // clang-format off #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT // clang-format on #include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/highway.h" // After highway.h #include "gemma/attention.h" // includes highway.h #include "gemma/gemma-inl.h" #include "gemma/vit.h" // includes highway.h #ifndef GEMMA_CC_ONCE #define GEMMA_CC_ONCE #include // sqrtf #include #include #include #include #include #include "gemma/configs.h" #include "gemma/model_store.h" #include "gemma/weights.h" #include "io/blob_store.h" #include "io/io.h" // Path #include "ops/matmul.h" #include "paligemma/image.h" #include "util/basics.h" #include "util/threading_context.h" #include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" #include "hwy/timer.h" // Require opt-in to debug/introspection functions to eliminate their overhead. HWY_INLINE_VAR constexpr bool kObserver = false; #endif // GEMMA_CC_ONCE HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { void Attention(LayerAttentionType type, const size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, Activations& activations, QBatch& qbatch, MatMulEnv& env) { if (type == LayerAttentionType::kGemma) { // TODO: remove flag to enable FlashAttention. GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch, env, HWY_NATIVE_DOT_BF16 ? kAttentionUseOld : 0); } } static HWY_NOINLINE void TransformerLayer(const size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, Activations& activations, QBatch& qbatch, MatMulEnv& env) { const LayerConfig& layer_config = layer.layer_config; RMSNormBatched(activations.x, layer.pre_attention_norm_scale, activations.attention.pre_att_rms_out, env.ctx); Attention(layer_config.type, num_tokens, layer_idx, layer, activations, qbatch, env); PostNorm(layer_config.post_norm, layer.post_attention_norm_scale, activations.attention.att_sums, env.ctx); ResidualConnection(activations.attention.att_sums, activations.x, layer, /*is_attention=*/true, env.ctx); RMSNormBatched(activations.x, layer.pre_ffw_norm_scale, activations.pre_ffw_rms_out, env.ctx); if (layer_config.type == LayerAttentionType::kVit) { FFWVit(layer, activations, env); } else { FFWNoVit(layer, activations, env); } PostNorm(layer_config.post_norm, layer.post_ffw_norm_scale, activations.ffw_out, env.ctx); ResidualConnection(activations.ffw_out, activations.x, layer, /*is_attention=*/false, env.ctx); } // Returns the scale value to use for the embedding (basically sqrt model_dim). static float EmbeddingScaling(size_t model_dim) { // Round to bf16 to match Gemma's Embedder, which casts before mul. return hwy::ConvertScalarTo( hwy::ConvertScalarTo(sqrtf(static_cast(model_dim)))); } // `x_row` indicates which row of `x` to write to. // `pos` is the *token*'s position for `AddAbsolutePositionalEmbeddings`, not // the start of the batch, because this is called for batches of tokens in // prefill, but batches of queries in decode. // // For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3 // spec) until we run out of image tokens. This allows for a multi-image prompt // if -2 locations with appropriate begin/end image tokens are created by the // calling application. // Returns new image_token_position. static HWY_NOINLINE size_t EmbedMMToken(int token, size_t x_row, size_t pos, size_t pos_in_prompt, const ModelConfig& model_config, const WeightsPtrs& weights, MatStorageT& x, ThreadingContext& ctx, const ImageTokens* image_tokens = nullptr, size_t image_token_position = 0) { GCPP_ZONE(ctx, hwy::Profiler::GlobalIdx(), Zones::kGenEmbed); // Image tokens just need to be copied. if (model_config.wrapping == PromptWrapping::GEMMA_VLM && image_tokens != nullptr && token == -2 && image_token_position < image_tokens->Rows()) { hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(x_row), x.Cols() * x.ElementBytes()); return image_token_position + 1; } if (model_config.wrapping == PromptWrapping::PALIGEMMA && image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) { hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(x_row), x.Cols() * x.ElementBytes()); return image_token_position; } const size_t model_dim = model_config.model_dim; const float emb_scaling = EmbeddingScaling(model_dim); HWY_DASSERT(token >= 0); HWY_DASSERT(token < static_cast(model_config.vocab_size)); CallUpcasted(&weights.embedder_input_embedding, [&](const auto* weights_t) { // Using `Stride` to compute the offset works for both NUQ (because we use // an offset and NUQ is never padded) and padded, because non-NUQ types are // seekable, hence the offset can also skip any padding. const size_t embedding_ofs = token * weights_t->Stride(); HWY_ASSERT(weights_t->Cols() == model_dim); const auto embedding_span = MakeSpan(weights_t->Row(0), embedding_ofs + model_dim); const hn::ScalableTag df; DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(x_row), model_dim); MulByConst(emb_scaling * weights_t->Scale(), x.Row(x_row), model_dim); }); if (model_config.absolute_pe) { AddAbsolutePositionalEmbeddings(x.Row(x_row), model_dim, pos); } return image_token_position; } // Populates KV cache for batches of tokens from one query at a time. This is // called if prompts are longer than the query batch size, and also in // prefix-LM mode (end > 0), which must see all tokens in one batch. static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config, const RuntimeConfig& runtime_config, const WeightsPtrs& weights, Activations& activations, QBatch& qbatch, MatMulEnv& env, hwy::BitSet4096<>& non_eos) { PROFILER_ZONE("Gen.PrefillT"); // Batches are important for amortizing loading weights over multiple tokens. // This is possible in prefill because we know all tokens beforehand, whereas // decode depends on the previous output token. However, each prefill batch of // a query requires that preceding batches already wrote to the KV cache, // hence we sequentially loop over token batches. We can reduce the number of // iterations by increasing the batch size, but this also increases arithmetic // intensity, and so we are eventually compute-limited. TransformerLayer uses // all available threads, so we do not also parallelize over queries, but note // that PrefillQBatch uses queries as the batch dimension. const size_t max_tbatch_size = runtime_config.prefill_tbatch_size; // For each query. `qi` is within the batch, not the global query index. for (size_t qi = 0; qi < qbatch.Size(); ++qi) { non_eos.Set(qi); // One query at a time, batching will be the query's prompt tokens. QBatch qbatch_1 = qbatch.Single(qi); const size_t prompt_size = qbatch_1.Prompt(0).size(); // In autoregressive mode, we don't need to prefill the last token, so - 1. size_t prefill_this_query = prompt_size - 1; const size_t prefix_end_this_query = qbatch_1.PrefixEnd(0); // We can't attend beyond the prompt_size. HWY_ASSERT(prefix_end_this_query <= prompt_size); // Special case: if the prefix includes the last token, we need to prefill // the last token, too. However, we need to rewind this for the generation // of the first token. So we need to keep track of this. // TODO: consider implementing masking instead of this logic? const bool attend_to_last_token = (prefill_this_query < prefix_end_this_query); if (attend_to_last_token) { // The difference can be at most 1. prefill_this_query += 1; HWY_ASSERT(prefill_this_query == prefix_end_this_query); } // In prefix-LM mode, we need to look at all the tokens for the prefix in // one iteration through the layers, so we need a large enough batch size. HWY_ASSERT(prefix_end_this_query == 0 || max_tbatch_size >= prefill_this_query); // For each batch of tokens in the query: for (size_t tbatch_start = 0; tbatch_start < prefill_this_query; tbatch_start += max_tbatch_size) { const size_t tbatch_size = HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start); activations.SetBatchSize(tbatch_size); // Fill activations.x (much faster than TransformerLayer). size_t image_token_position = 0; for (size_t ti = 0; ti < tbatch_size; ++ti) { const size_t pos = qbatch_1.Pos(0) + ti; const size_t pos_in_prompt = tbatch_start + ti; HWY_DASSERT(pos_in_prompt < prompt_size); const int token = qbatch_1.Prompt(0)[pos_in_prompt]; image_token_position = EmbedMMToken( token, ti, pos, pos_in_prompt, config, weights, activations.x, env.ctx, runtime_config.image_tokens, image_token_position); // NOTE: we unconditionally call StreamToken, even if EOS. if (pos_in_prompt < prompt_size - 1) { runtime_config.StreamToken(qbatch_1.QueryIdx(0), pos, token, 0.0f); } else { // The last token will be streamed later and we should only get here // if we need to attend to the last token because it is in the prefix. HWY_ASSERT(attend_to_last_token); } } // Transformer with one batch of tokens from a single query. No need to // set `PrevToken` because we already did the embedding above. for (size_t layer_idx = 0; layer_idx < config.layer_configs.size(); ++layer_idx) { TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx), activations, qbatch_1, env); } qbatch_1.MutablePos(0) += tbatch_size; } // for tbatch_start if (attend_to_last_token) { // We need to rewind the position for the last token that we only // attended to to make sure the prefix LM sees everything. // This means we duplicate work on the last prompt token in autoregressive // decoding. Alternatives: (1) real masking; (2) always prefill the last // token and only generate the next one from the already prefilled // activations. qbatch_1.MutablePos(0) -= 1; } } } static void MaybeObserve(const RuntimeConfig& runtime_config, Activations& activations, QBatch& qbatch, int layer_idx) { if constexpr (kObserver) { if (HWY_UNLIKELY(runtime_config.activations_observer)) { runtime_config.activations_observer( QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx, activations); } } } // Embeds PrevToken (one from each query) and calls each TransformerLayer. // Called by query-batched `PrefillQBatch` and `GenerateT`, but not the // token-batched `PrefillTBatch`, which supports image embedding. static HWY_NOINLINE void Transformer(const ModelConfig& config, const RuntimeConfig& runtime_config, const WeightsPtrs& weights, Activations& activations, QBatch& qbatch, MatMulEnv& env) { if constexpr (kObserver) { if (HWY_UNLIKELY(runtime_config.layers_output)) { for (size_t qi = 0; qi < qbatch.Size(); ++qi) { const float token_f = qbatch.PrevToken(qi); runtime_config.layers_output(qbatch.QueryIdx(qi), qbatch.Pos(qi), "tokens", -1, &token_f, 1); } } } // TODO: parallelize? for (size_t qi = 0; qi < qbatch.Size(); ++qi) { EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi), /*pos_in_prompt=*/0, config, weights, activations.x, env.ctx); } for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) { TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx), activations, qbatch, env); MaybeObserve(runtime_config, activations, qbatch, layer_idx); } } // Populates KV cache for the batch queries, one token at a time. static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, const ModelConfig& config, const RuntimeConfig& runtime_config, const WeightsPtrs& weights, Activations& activations, QBatch& qbatch, MatMulEnv& env, hwy::BitSet4096<>& non_eos) { PROFILER_ZONE("Gen.PrefillQ"); for (size_t qi = 0; qi < qbatch.Size(); ++qi) { non_eos.Set(qi); // Should only be called for autoregressive (non-prefix-LM) prefill. HWY_DASSERT(qbatch.PrefixEnd(qi) == 0); } // In autoregressive mode, we don't prefill the last token, hence - 1. for (size_t pos_in_prompt = 0; pos_in_prompt < max_prompt_size - 1; ++pos_in_prompt) { for (size_t qi = 0; qi < qbatch.Size(); ++qi) { int token = config.eos_id; if (pos_in_prompt < qbatch.Prompt(qi).size() - 1) { token = qbatch.Prompt(qi)[pos_in_prompt]; // Ignore StreamToken return value because requesting to stop does not // make sense during prefill. (void)runtime_config.StreamToken(qbatch.QueryIdx(qi), pos_in_prompt, token, 0.0f); qbatch.MutablePos(qi) = pos_in_prompt; } qbatch.PrevToken(qi) = token; } // The input (PrevToken) is one token from each query in the batch. // Do not call `SampleAndStream` because it computes logits for token // probabilities, which are not required for the prompt tokens. Transformer(config, runtime_config, weights, activations, qbatch, env); } for (size_t qi = 0; qi < qbatch.Size(); ++qi) { qbatch.MutablePos(qi) = qbatch.Prompt(qi).size() - 1; } } // Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent // `Transformer`, and increments `MutablePos`. Also updates `non_eos` if the // query is at the end of its sequence. static void StreamAndUpdateEOS(const size_t qi, size_t pos, int token, const float prob, const ModelConfig& config, const RuntimeConfig& runtime_config, QBatch& qbatch, bool update_pos, hwy::BitSet4096<>& non_eos) { HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called. if (HWY_UNLIKELY( !runtime_config.StreamToken(qbatch.QueryIdx(qi), pos, token, prob))) { // User decided to stop: set token to primary EOS to trigger IsEOS below. token = config.eos_id; HWY_DASSERT(config.IsEOS(token)); } qbatch.PrevToken(qi) = token; qbatch.MutablePos(qi) += update_pos ? 1 : 0; // Primary or secondary EOS: mark query as EOS, but still increment (for // multi-turn, we should still keep the prior EOS). if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi); } // Must be called after Transformer: either after prefill, or during decode. // Computes logits, samples and streams the token. static void SampleAndStream(const ModelConfig& config, const RuntimeConfig& runtime_config, const WeightsPtrs& weights, const SampleFunc& sample_token, Activations& activations, QBatch& qbatch, MatMulEnv& env, hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) { HWY_DASSERT(qbatch.Size() == activations.x.Rows()); RMSNormBatched(activations.x, weights.final_norm_scale, activations.x_bf, env.ctx); MaybeObserve(runtime_config, activations, qbatch, -1); { GCPP_ZONE(env.ctx, /*worker=*/0, Zones::kGenEmbeddingMatmul); // Compute logits from last layer activations. CallMatMul(activations.x_bf, weights.embedder_input_embedding, /*add=*/nullptr, env, activations.logits); } PROFILER_ZONE("Gen.Softcap+Sample+Stream"); MaybeLogitsSoftCapBatched(config.final_cap, activations.logits, non_eos, env.ctx); timing_info.NotifyGenerated(non_eos.Count()); ParallelFor( ParallelismStrategy::kFlat, qbatch.Size(), env.ctx, /*cluster_idx=*/0, Callers::kSampleAndStream, [&](size_t qi, size_t worker) { if (!non_eos.Get(qi)) return; // We streamed all prefill tokens, but pos is still one behind // because we started generation at pos = prompt.size() - 1. // We want the pos argument to match the number of calls to // `StreamToken`, as expected by the caller. const size_t pos = qbatch.Pos(qi) + 1; const TokenAndProb tp = sample_token(qi, pos, activations.logits.RowSpan(qi), worker); // `sampled` is padded, which prevents false sharing. activations.sampled.Row(qi)[0] = static_cast(pos); activations.sampled.Row(qi)[1] = static_cast(tp.token); activations.sampled.Row(qi)[2] = hwy::BitCastScalar(tp.prob); }); // Sequentially, because `StreamToken` is not yet thread-safe. non_eos.Foreach([&](size_t qi) { const size_t pos = activations.sampled.Row(qi)[0]; const int token = static_cast(activations.sampled.Row(qi)[1]); const float prob = hwy::BitCastScalar(activations.sampled.Row(qi)[2]); StreamAndUpdateEOS(qi, pos, token, prob, config, runtime_config, qbatch, /*update_pos=*/true, non_eos); }); } static HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config, const AesCtrEngine& engine, ThreadingContext& ctx) { // If user provided a sample_func, use it. if (runtime_config.sample_func) return runtime_config.sample_func; // Fast path for top-1 with no accept_token. if (runtime_config.top_k == 1 && !runtime_config.accept_token) { return [&](size_t /*qi*/, size_t /*pos*/, Logits logits, size_t worker) HWY_ATTR -> TokenAndProb { GCPP_ZONE(ctx, worker, Zones::kGenSampleTop1); return Top1OfSoftmax(logits); }; } // General case: Softmax with top-k sampling. return [&](size_t qi, size_t pos, Logits logits, size_t worker) HWY_ATTR -> TokenAndProb { GCPP_ZONE(ctx, worker, Zones::kGenSampleTopK); // We want a different sequence for each batch element and position. const uint64_t stream = (static_cast(qi) << 32) | pos; RngStream gen(engine, stream); return FusedSoftmaxAndSampleTopK(logits, runtime_config.top_k, gen, runtime_config.temperature, runtime_config.accept_token, ctx, worker); }; } // Decode: generates one continuation token for each query in `qbatch`. static void GenerateT(const ModelConfig& config, const RuntimeConfig& runtime_config, const AesCtrEngine& engine, const WeightsPtrs& weights, Activations& activations, QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) { size_t max_prompt_size = 0; bool all_prefix_end_are_zero = true; size_t total_prefill_tokens = 0; // only for throughput stats. const size_t seq_len = qbatch.KV(0).SeqLen(); for (size_t qi = 0; qi < qbatch.Size(); ++qi) { const PromptTokens& prompt = qbatch.Prompt(qi); // Sanity check: prompts should not be empty. Note that multi-turn prompts // start with . HWY_ASSERT(prompt.size() != 0); max_prompt_size = HWY_MAX(max_prompt_size, prompt.size()); // Prefill stops before size - 1 because the last prompt token is the // first input token for generation. total_prefill_tokens += prompt.size() - 1; all_prefix_end_are_zero &= qbatch.PrefixEnd(qi) == 0; // We use a single divisor, so all sequence lengths must be the same. HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len); } if (max_prompt_size >= seq_len) { HWY_ABORT("max_prompt_size = %zu, increase --seq_len to at least that.", max_prompt_size); } HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len); // Lacks a constructor to bulk-set, hence initialized by Prefill* which have // qi loops anyway. hwy::BitSet4096<> non_eos; // indexed by qi timing_info.prefill_start = hwy::platform::Now(); // Batch over the larger of prompt length, or queries. if ((qbatch.Size() > max_prompt_size) && all_prefix_end_are_zero) { activations.SetBatchSize(qbatch.Size()); // required before PrefillQBatch PrefillQBatch(max_prompt_size, config, runtime_config, weights, activations, qbatch, env, non_eos); } else { PrefillTBatch(config, runtime_config, weights, activations, qbatch, env, non_eos); activations.SetBatchSize(qbatch.Size()); // Restore after PrefillTBatch. } HWY_DASSERT(non_eos.Count() == qbatch.Size()); timing_info.NotifyPrefill(total_prefill_tokens); // queries_pos have been incremented by Prefill. // Stream the last prompt token from each query, fill activations.gen_tokens. for (size_t qi = 0; qi < qbatch.Size(); ++qi) { const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct. // In autoregressive mode, we have not prefilled the last token, so do // not advance. const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi)); StreamAndUpdateEOS(qi, pos, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config, runtime_config, qbatch, update_pos, non_eos); } size_t max_gen_steps = runtime_config.max_generated_tokens; if (max_prompt_size + max_gen_steps > seq_len) { HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.", max_prompt_size, max_gen_steps, seq_len); max_gen_steps = seq_len - max_prompt_size; } const SampleFunc sample_token = ChooseSampleFunc(runtime_config, engine, env.ctx); timing_info.generate_start = hwy::platform::Now(); for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { Transformer(config, runtime_config, weights, activations, qbatch, env); SampleAndStream(config, runtime_config, weights, sample_token, activations, qbatch, env, non_eos, timing_info); } timing_info.NotifyGenerateDone(); } void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, const ModelConfig& config, const RuntimeConfig& runtime_config, const AesCtrEngine& engine, const WeightsPtrs& weights, KVCache& kv_cache, MatMulEnv& env, TimingInfo& timing_info) { Activations activations(config, runtime_config.prefill_tbatch_size, kv_cache.SeqLen(), env.ctx, env.row_ptrs); AllQueries all_queries(prompt, pos, prefix_end, hwy::Span(&kv_cache, 1)); QBatch qbatch(/*start=*/0, /*max_size=*/1, all_queries); GenerateT(config, runtime_config, engine, weights, activations, qbatch, env, timing_info); } // Splits the input into batches of at most `runtime_config.decode_qbatch_size` // queries, and calls `GenerateT` on each batch. void GenerateBatchT(const ModelConfig& config, const RuntimeConfig& runtime_config, const AesCtrEngine& engine, const WeightsPtrs& weights, AllQueries& all_queries, MatMulEnv& env, TimingInfo& timing_info) { const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size, runtime_config.prefill_tbatch_size); Activations activations(config, max_batch_size, all_queries[0].kv_cache.SeqLen(), env.ctx, env.row_ptrs); for (size_t start = 0; start < all_queries.NumQueries(); start += runtime_config.decode_qbatch_size) { QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries); // Generate a batch of one token for each of `qbatch.Size()` queries. GenerateT(config, runtime_config, engine, weights, activations, qbatch, env, timing_info); } } void GenerateImageTokensT(const ModelConfig& config, const RuntimeConfig& runtime_config, size_t seq_len, const WeightsPtrs& weights, const Image& image, ImageTokens& image_tokens, MatMulEnv& env) { if (config.vit_config.layer_configs.empty()) { HWY_ABORT("Model does not support generating image tokens."); } RuntimeConfig prefill_runtime_config = runtime_config; const ModelConfig vit_config = GetVitConfig(config); const size_t num_tokens = vit_config.max_seq_len; prefill_runtime_config.prefill_tbatch_size = num_tokens / (vit_config.pool_dim * vit_config.pool_dim); Activations prefill_activations(vit_config, num_tokens, num_tokens, env.ctx, env.row_ptrs); // Weights are for the full PaliGemma model, not just the ViT part. PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, prefill_activations, env); } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); #if HWY_ONCE namespace gcpp { HWY_EXPORT(GenerateSingleT); HWY_EXPORT(GenerateBatchT); HWY_EXPORT(GenerateImageTokensT); Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, ThreadingContext& ctx) : reader_(loader.weights), model_(reader_, loader.tokenizer, loader.wrapping), weights_(model_.Config()), chat_template_(model_.Tokenizer(), model_.Config().model), inference_(inference), aes_ctr_engine_(inference.deterministic) { // Negligible CPU time in the ctor body (except ReadFromBlobs). weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_, ctx); // Read everything into memory, or `weights_.mapped_` keeps the mapping alive. reader_.CloseFile(); } Gemma::~Gemma() = default; void Gemma::Save(const Path& weights_path, ThreadingContext& ctx) const { BlobWriter writer(weights_path, ctx); const std::vector serialized_mat_ptrs = weights_.AddTensorDataToWriter(writer); WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs, writer); } void Gemma::Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, size_t prefix_end, KVCache& kv_cache, MatMulEnv& env, TimingInfo& timing_info) const { env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); HWY_DYNAMIC_DISPATCH(GenerateSingleT)( prompt, pos, prefix_end, model_.Config(), runtime_config, aes_ctr_engine_, weights_, kv_cache, env, timing_info); env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, AllQueries& all_queries, MatMulEnv& env, TimingInfo& timing_info) const { env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config, aes_ctr_engine_, weights_, all_queries, env, timing_info); env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len, const Image& image, ImageTokens& image_tokens, MatMulEnv& env) const { env.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(model_.Config(), runtime_config, seq_len, weights_, image, image_tokens, env); env.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } } // namespace gcpp #endif // HWY_ONCE