diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 01b3930..4f8c6e4 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -27,6 +27,7 @@ #include "gemma/common.h" #include "gemma/configs.h" #include "gemma/gemma.h" +#include "gemma/kv_cache.h" #include "gemma/weights.h" #include "paligemma/image.h" #include "util/allocator.h" @@ -1217,6 +1218,54 @@ HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) { }; } +template +// Runs one decode step for all the queries in the batch. Returns true if all +// queries are at . +bool DecodeStepT(const ModelWeightsPtrs& weights, + const RuntimeConfig& runtime_config, + const QueriesPromptTokens& queries_prompt, + const size_t query_idx_start, const KVCaches& kv_caches, + const QueriesPos& queries_prefix_end, + const hwy::Divisor div_seq_len, const size_t vocab_size, + const SampleFunc& sample_token, double prefill_start, + double gen_start, Activations& activations, + TokenStreamer& token_streamer, std::vector& gen_tokens, + TimingInfo& timing_info, + const QueriesMutablePos& queries_mutable_pos) { + const size_t num_queries = queries_prompt.size(); + // Decode generates one token per query and increments + // queries_mutable_pos. + Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos, + queries_prefix_end, weights, activations, div_seq_len, kv_caches, + runtime_config.layers_output, + runtime_config.activations_observer); + // queries_pos are incremented by Transformer. + + bool all_queries_eos = true; + { + PROFILER_ZONE("Gen.EmbeddingMatmul"); + // Compute logits from last layer activations. + MatMul(ConstMatFromBatch(num_queries, activations.x), + ConstMatFromWeights(weights.embedder_input_embedding), + /*add=*/nullptr, *activations.env, + RowPtrFromBatch(activations.logits)); + } + PROFILER_ZONE("Gen.Softcap+Sample+Stream"); + for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { + float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); + MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size); + const TokenAndProb tp = sample_token(logits, vocab_size); + timing_info.NotifyGenerated(prefill_start, gen_start); + + const bool is_eos = + token_streamer(query_idx_start + query_idx, + queries_mutable_pos[query_idx], tp.token, tp.prob); + all_queries_eos &= is_eos; + gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token; + } + return all_queries_eos; +} + // Generates one continuation for each query in `queries_prompt`, which is one // qbatch whose size is at most the `batch_size` passed to // `activations.Allocate`. @@ -1310,37 +1359,11 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, const size_t vocab_size = model.Config().vocab_size; const double gen_start = hwy::platform::Now(); for (size_t gen = 0; gen < max_generated_tokens; ++gen) { - // Decode generates one token per query and increments - // queries_mutable_pos. - Transformer(QueriesToken(gen_tokens.data(), num_queries), - queries_mutable_pos, queries_prefix_end, weights, activations, - div_seq_len, kv_caches, runtime_config.layers_output, - runtime_config.activations_observer); - // queries_pos are incremented by Transformer. - - bool all_queries_eos = true; - { - PROFILER_ZONE("Gen.EmbeddingMatmul"); - // Compute logits from last layer activations. - MatMul(ConstMatFromBatch(num_queries, activations.x), - ConstMatFromWeights(weights.embedder_input_embedding), - /*add=*/nullptr, *activations.env, - RowPtrFromBatch(activations.logits)); - } - PROFILER_ZONE("Gen.Softcap+Sample+Stream"); - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); - MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, - vocab_size); - const TokenAndProb tp = sample_token(logits, vocab_size); - timing_info.NotifyGenerated(prefill_start, gen_start); - - const bool is_eos = - token_streamer(query_idx_start + query_idx, - queries_mutable_pos[query_idx], tp.token, tp.prob); - all_queries_eos &= is_eos; - gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token; - } + bool all_queries_eos = DecodeStepT( + weights, runtime_config, queries_prompt, query_idx_start, kv_caches, + queries_prefix_end, div_seq_len, vocab_size, sample_token, + prefill_start, gen_start, activations, token_streamer, gen_tokens, + timing_info, queries_mutable_pos); if (all_queries_eos) break; } // foreach token to generate timing_info.NotifyGenerateDone(gen_start);