Refactor for continous batching. This cl does not change the current behavior of the code. It only extract two functions that will later be called for adding continuous batching.

PiperOrigin-RevId: 829104661
This commit is contained in:
Charles Zhao 2025-11-06 14:19:38 -08:00 committed by Copybara-Service
parent 35e9f9f05f
commit f8131339a7
1 changed files with 40 additions and 20 deletions

View File

@ -486,12 +486,11 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
};
}
// Decode: generates one continuation token for each query in `qbatch`.
static void GenerateT(const ModelConfig& config,
static size_t PrefillTBatchOrQBatch(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const AesCtrEngine& engine, const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch, MatMulEnv& env,
TimingInfo& timing_info) {
const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch,
MatMulEnv& env, TimingInfo& timing_info) {
size_t max_prompt_size = 0;
bool all_prefix_end_are_zero = true;
size_t total_prefill_tokens = 0; // only for throughput stats.
@ -519,8 +518,6 @@ static void GenerateT(const ModelConfig& config,
}
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
// qi loops anyway.
hwy::BitSet4096<> non_eos; // indexed by qi
timing_info.prefill_start = hwy::platform::Now();
@ -538,8 +535,21 @@ static void GenerateT(const ModelConfig& config,
timing_info.NotifyPrefill(total_prefill_tokens);
// queries_pos have been incremented by Prefill.
// Stream the last prompt token from each query, fill activations.gen_tokens.
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
size_t max_gen_steps = runtime_config.max_generated_tokens;
if (max_prompt_size + max_gen_steps > seq_len) {
HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.",
max_prompt_size, max_gen_steps, seq_len);
max_gen_steps = seq_len - max_prompt_size;
}
return max_gen_steps;
}
static void StreamAndUpdateEOSAfterPrefill(const ModelConfig& config,
const RuntimeConfig& runtime_config,
QBatch& qbatch,
hwy::BitSet4096<>& non_eos,
size_t qi) {
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct.
@ -548,13 +558,23 @@ static void GenerateT(const ModelConfig& config,
const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
StreamAndUpdateEOS(qi, pos, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f,
config, runtime_config, qbatch, update_pos, non_eos);
}
}
size_t max_gen_steps = runtime_config.max_generated_tokens;
if (max_prompt_size + max_gen_steps > seq_len) {
HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.",
max_prompt_size, max_gen_steps, seq_len);
max_gen_steps = seq_len - max_prompt_size;
// Decode: generates one continuation token for each query in `qbatch`.
static void GenerateT(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const AesCtrEngine& engine, const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch, MatMulEnv& env,
TimingInfo& timing_info) {
const size_t max_gen_steps = PrefillTBatchOrQBatch(
config, runtime_config, weights, activations, qbatch, env, timing_info);
hwy::BitSet4096<> non_eos; // indexed by qi
// Stream the last prompt token from each query, fill activations.gen_tokens.
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
non_eos.Set(qi);
StreamAndUpdateEOSAfterPrefill(config, runtime_config, qbatch, non_eos, qi);
}
const SampleFunc sample_token =