diff --git a/gemma.cc b/gemma.cc index edc5dfd..1751ad5 100644 --- a/gemma.cc +++ b/gemma.cc @@ -666,12 +666,16 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, size_t pos_gen_start = pos_offset; int token = prompt.at(pos_offset); + stream_token(token, 0); for (size_t generate_pos = 0; pos < max_tokens && generate_pos < max_generated_tokens; ++pos, ++pos_offset, ++generate_pos) { Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool); float* final_activation = activations.x.data(); - if (pos_offset >= prompt_size) { + // The condition below is always true if we are doing Prefill above. + // We keep it here for clarity so that the code is correct even if Prefill + // is disabled. + if (pos_offset >= prompt_size - 1) { PROFILER_ZONE("Gen.Embedding"); // Generation phase MatVec( @@ -681,9 +685,14 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, Softmax(activations.logits.data(), kVocabSize); token = SampleTopK(activations.logits.data(), kVocabSize, gen, temperature, accept_token); - } - if (!stream_token(token, activations.logits[token])) { - token = EOS_ID; + if (!stream_token(token, activations.logits[token])) { + token = EOS_ID; + } + } else { + // We would take this branch if we were not doing Prefill but would + // process the tokens of the prompt one at a time. + token = prompt.at(pos_offset + 1); + stream_token(token, 0); } if (token == EOS_ID) { if (verbosity >= 2) { diff --git a/run.cc b/run.cc index 3f38031..46ac1ba 100644 --- a/run.cc +++ b/run.cc @@ -116,7 +116,8 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, verbosity](int token, float) { ++abs_pos; ++current_pos; - if (current_pos < prompt_size) { + // <= since position is incremented before + if (current_pos <= prompt_size) { std::cerr << "." << std::flush; } else if (token == gcpp::EOS_ID) { if (!args.multiturn) {