diff --git a/gemma/gemma.cc b/gemma/gemma.cc index a7b8423..95f52b8 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -520,6 +520,12 @@ static void GenerateT(const ModelConfig& config, const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config, runtime_config, qbatch, non_eos); + // StreamAndUpdateEOS() sets the stream position one token too far in + // autoregressive mode. + const bool attend_to_last_token = (qbatch.Pos(qi) < qbatch.PrefixEnd(qi)); + if (!attend_to_last_token) { + qbatch.MutablePos(qi) -= 1; + } } size_t max_gen_steps = runtime_config.max_generated_tokens;