From 71ead04afb6a7a8983761fdcf7ba8fae09525539 Mon Sep 17 00:00:00 2001 From: Zoltan Szabadka Date: Thu, 4 Apr 2024 07:50:03 +0000 Subject: [PATCH] Fix off-by-one errors in generation code and token streaming callback. In the generation code we were feeding the last token of the prompt twice through the transformer. The new version fixes that and also works in the case where Prefill is completely disabled. --- gemma.cc | 17 +++++++++++++---- run.cc | 3 ++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/gemma.cc b/gemma.cc index edc5dfd..1751ad5 100644 --- a/gemma.cc +++ b/gemma.cc @@ -666,12 +666,16 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, size_t pos_gen_start = pos_offset; int token = prompt.at(pos_offset); + stream_token(token, 0); for (size_t generate_pos = 0; pos < max_tokens && generate_pos < max_generated_tokens; ++pos, ++pos_offset, ++generate_pos) { Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool); float* final_activation = activations.x.data(); - if (pos_offset >= prompt_size) { + // The condition below is always true if we are doing Prefill above. + // We keep it here for clarity so that the code is correct even if Prefill + // is disabled. + if (pos_offset >= prompt_size - 1) { PROFILER_ZONE("Gen.Embedding"); // Generation phase MatVec( @@ -681,9 +685,14 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, Softmax(activations.logits.data(), kVocabSize); token = SampleTopK(activations.logits.data(), kVocabSize, gen, temperature, accept_token); - } - if (!stream_token(token, activations.logits[token])) { - token = EOS_ID; + if (!stream_token(token, activations.logits[token])) { + token = EOS_ID; + } + } else { + // We would take this branch if we were not doing Prefill but would + // process the tokens of the prompt one at a time. + token = prompt.at(pos_offset + 1); + stream_token(token, 0); } if (token == EOS_ID) { if (verbosity >= 2) { diff --git a/run.cc b/run.cc index 3f38031..46ac1ba 100644 --- a/run.cc +++ b/run.cc @@ -116,7 +116,8 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, verbosity](int token, float) { ++abs_pos; ++current_pos; - if (current_pos < prompt_size) { + // <= since position is incremented before + if (current_pos <= prompt_size) { std::cerr << "." << std::flush; } else if (token == gcpp::EOS_ID) { if (!args.multiturn) {