mirror of https://github.com/google/gemma.cpp.git
Refactor for continous batching. This cl does not change the current behavior of the code. It only extract two functions that will later be called for adding continuous batching.
PiperOrigin-RevId: 829104661
This commit is contained in:
parent
35e9f9f05f
commit
f8131339a7
|
|
@ -486,12 +486,11 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode: generates one continuation token for each query in `qbatch`.
|
static size_t PrefillTBatchOrQBatch(const ModelConfig& config,
|
||||||
static void GenerateT(const ModelConfig& config,
|
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
const WeightsPtrs& weights,
|
||||||
Activations& activations, QBatch& qbatch, MatMulEnv& env,
|
Activations& activations, QBatch& qbatch,
|
||||||
TimingInfo& timing_info) {
|
MatMulEnv& env, TimingInfo& timing_info) {
|
||||||
size_t max_prompt_size = 0;
|
size_t max_prompt_size = 0;
|
||||||
bool all_prefix_end_are_zero = true;
|
bool all_prefix_end_are_zero = true;
|
||||||
size_t total_prefill_tokens = 0; // only for throughput stats.
|
size_t total_prefill_tokens = 0; // only for throughput stats.
|
||||||
|
|
@ -519,8 +518,6 @@ static void GenerateT(const ModelConfig& config,
|
||||||
}
|
}
|
||||||
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
|
HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len);
|
||||||
|
|
||||||
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
|
|
||||||
// qi loops anyway.
|
|
||||||
hwy::BitSet4096<> non_eos; // indexed by qi
|
hwy::BitSet4096<> non_eos; // indexed by qi
|
||||||
|
|
||||||
timing_info.prefill_start = hwy::platform::Now();
|
timing_info.prefill_start = hwy::platform::Now();
|
||||||
|
|
@ -538,8 +535,21 @@ static void GenerateT(const ModelConfig& config,
|
||||||
timing_info.NotifyPrefill(total_prefill_tokens);
|
timing_info.NotifyPrefill(total_prefill_tokens);
|
||||||
// queries_pos have been incremented by Prefill.
|
// queries_pos have been incremented by Prefill.
|
||||||
|
|
||||||
// Stream the last prompt token from each query, fill activations.gen_tokens.
|
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
||||||
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
if (max_prompt_size + max_gen_steps > seq_len) {
|
||||||
|
HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.",
|
||||||
|
max_prompt_size, max_gen_steps, seq_len);
|
||||||
|
max_gen_steps = seq_len - max_prompt_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
return max_gen_steps;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void StreamAndUpdateEOSAfterPrefill(const ModelConfig& config,
|
||||||
|
const RuntimeConfig& runtime_config,
|
||||||
|
QBatch& qbatch,
|
||||||
|
hwy::BitSet4096<>& non_eos,
|
||||||
|
size_t qi) {
|
||||||
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
|
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
|
||||||
|
|
||||||
const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct.
|
const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct.
|
||||||
|
|
@ -548,13 +558,23 @@ static void GenerateT(const ModelConfig& config,
|
||||||
const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
|
const bool update_pos = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
|
||||||
StreamAndUpdateEOS(qi, pos, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f,
|
StreamAndUpdateEOS(qi, pos, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f,
|
||||||
config, runtime_config, qbatch, update_pos, non_eos);
|
config, runtime_config, qbatch, update_pos, non_eos);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
// Decode: generates one continuation token for each query in `qbatch`.
|
||||||
if (max_prompt_size + max_gen_steps > seq_len) {
|
static void GenerateT(const ModelConfig& config,
|
||||||
HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.",
|
const RuntimeConfig& runtime_config,
|
||||||
max_prompt_size, max_gen_steps, seq_len);
|
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
||||||
max_gen_steps = seq_len - max_prompt_size;
|
Activations& activations, QBatch& qbatch, MatMulEnv& env,
|
||||||
|
TimingInfo& timing_info) {
|
||||||
|
const size_t max_gen_steps = PrefillTBatchOrQBatch(
|
||||||
|
config, runtime_config, weights, activations, qbatch, env, timing_info);
|
||||||
|
|
||||||
|
hwy::BitSet4096<> non_eos; // indexed by qi
|
||||||
|
|
||||||
|
// Stream the last prompt token from each query, fill activations.gen_tokens.
|
||||||
|
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
|
||||||
|
non_eos.Set(qi);
|
||||||
|
StreamAndUpdateEOSAfterPrefill(config, runtime_config, qbatch, non_eos, qi);
|
||||||
}
|
}
|
||||||
|
|
||||||
const SampleFunc sample_token =
|
const SampleFunc sample_token =
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue