Fix an off-by-one error after StreamAndUpdateEOS() to remove the MSAN warning about reading an uninitialized variable in the kv_cache.

The logic for choosing whether or not to attend to the last token during prefill wasn't completely consistent with StreamAndUpdateEOS(), causing an off-by-one error that prevented the kv_cache from being fully populated.

PiperOrigin-RevId: 797614310
This commit is contained in:
Rhett Stucki 2025-08-20 22:59:24 -07:00 committed by Copybara-Service
parent 41a86d41a9
commit 73f1140dca
1 changed files with 6 additions and 0 deletions

View File

@ -520,6 +520,12 @@ static void GenerateT(const ModelConfig& config,
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config, StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config,
runtime_config, qbatch, non_eos); runtime_config, qbatch, non_eos);
// StreamAndUpdateEOS() sets the stream position one token too far in
// autoregressive mode.
const bool attend_to_last_token = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi));
if (!attend_to_last_token) {
qbatch.MutablePos(qi) -= 1;
}
} }
size_t max_gen_steps = runtime_config.max_generated_tokens; size_t max_gen_steps = runtime_config.max_generated_tokens;