From f8131339a7fc2e58c546bc5c7562899160de052b Mon Sep 17 00:00:00 2001 From: Charles Zhao Date: Thu, 6 Nov 2025 14:19:38 -0800 Subject: [PATCH] Refactor for continous batching. This cl does not change the current behavior of the code. It only extract two functions that will later be called for adding continuous batching. PiperOrigin-RevId: 829104661 --- gemma/gemma.cc | 60 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/gemma/gemma.cc b/gemma/gemma.cc index ffe7c47..7ed7f50 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -486,12 +486,11 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config, }; } -// Decode: generates one continuation token for each query in `qbatch`. -static void GenerateT(const ModelConfig& config, - const RuntimeConfig& runtime_config, - const AesCtrEngine& engine, const WeightsPtrs& weights, - Activations& activations, QBatch& qbatch, MatMulEnv& env, - TimingInfo& timing_info) { +static size_t PrefillTBatchOrQBatch(const ModelConfig& config, + const RuntimeConfig& runtime_config, + const WeightsPtrs& weights, + Activations& activations, QBatch& qbatch, + MatMulEnv& env, TimingInfo& timing_info) { size_t max_prompt_size = 0; bool all_prefix_end_are_zero = true; size_t total_prefill_tokens = 0; // only for throughput stats. @@ -519,8 +518,6 @@ static void GenerateT(const ModelConfig& config, } HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len); - // Lacks a constructor to bulk-set, hence initialized by Prefill* which have - // qi loops anyway. hwy::BitSet4096<> non_eos; // indexed by qi timing_info.prefill_start = hwy::platform::Now(); @@ -538,18 +535,6 @@ static void GenerateT(const ModelConfig& config, timing_info.NotifyPrefill(total_prefill_tokens); // queries_pos have been incremented by Prefill. - // Stream the last prompt token from each query, fill activations.gen_tokens. - for (size_t qi = 0; qi < qbatch.Size(); ++qi) { - const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); - - const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct. - // In autoregressive mode, we have not prefilled the last token, so do - // not advance. - const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi)); - StreamAndUpdateEOS(qi, pos, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, - config, runtime_config, qbatch, update_pos, non_eos); - } - size_t max_gen_steps = runtime_config.max_generated_tokens; if (max_prompt_size + max_gen_steps > seq_len) { HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.", @@ -557,6 +542,41 @@ static void GenerateT(const ModelConfig& config, max_gen_steps = seq_len - max_prompt_size; } + return max_gen_steps; +} + +static void StreamAndUpdateEOSAfterPrefill(const ModelConfig& config, + const RuntimeConfig& runtime_config, + QBatch& qbatch, + hwy::BitSet4096<>& non_eos, + size_t qi) { + const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); + + const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct. + // In autoregressive mode, we have not prefilled the last token, so do + // not advance. + const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi)); + StreamAndUpdateEOS(qi, pos, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, + config, runtime_config, qbatch, update_pos, non_eos); +} + +// Decode: generates one continuation token for each query in `qbatch`. +static void GenerateT(const ModelConfig& config, + const RuntimeConfig& runtime_config, + const AesCtrEngine& engine, const WeightsPtrs& weights, + Activations& activations, QBatch& qbatch, MatMulEnv& env, + TimingInfo& timing_info) { + const size_t max_gen_steps = PrefillTBatchOrQBatch( + config, runtime_config, weights, activations, qbatch, env, timing_info); + + hwy::BitSet4096<> non_eos; // indexed by qi + + // Stream the last prompt token from each query, fill activations.gen_tokens. + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + non_eos.Set(qi); + StreamAndUpdateEOSAfterPrefill(config, runtime_config, qbatch, non_eos, qi); + } + const SampleFunc sample_token = ChooseSampleFunc(runtime_config, engine, env.ctx);