From 2bad79f11015e639b12352e370a44e2d7897a969 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Fri, 21 Mar 2025 19:26:59 +0800 Subject: [PATCH] Fix the EOS checking The secondary eos is usually ``, which can appear in the prompt, so we can only check it not in the prompt. --- gemma/gemma-inl.h | 4 ++-- gemma/run.cc | 11 +++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index ab25d53..ccb34f0 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -1427,7 +1427,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, // Sanity check: prompts should not be empty, nor start with EOS. for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) { const PromptTokens& prompt = queries_prompt[query_idx]; - HWY_ASSERT(prompt.size() != 0 && !model.Config().IsEOS(prompt[0])); + HWY_ASSERT(prompt.size() != 0 && prompt[0] != runtime_config.eos_id); } const size_t num_queries = queries_prompt.size(); @@ -1615,4 +1615,4 @@ void GenerateImageTokens( // NOLINT(misc-definitions-in-headers) } // namespace gcpp HWY_AFTER_NAMESPACE(); -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ \ No newline at end of file +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ diff --git a/gemma/run.cc b/gemma/run.cc index 8c23c15..dab48a1 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -118,12 +118,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, // callback function invoked for each generated token. auto stream_token = [&](int token, float) { ++abs_pos; - if (model.GetModelConfig().IsEOS(token)) { - if (app.verbosity >= 2) { - std::cout << "\n[ End ]\n"; - } - return true; - } const bool in_prompt = tokens_generated_this_turn < prompt_size; const bool first_response_token = tokens_generated_this_turn == prompt_size; ++tokens_generated_this_turn; @@ -132,6 +126,11 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, std::cerr << "." << std::flush; } return true; + } else if (model.GetModelConfig().IsEOS(token)) { + if (app.verbosity >= 2) { + std::cout << "\n[ End ]\n"; + } + return true; } std::string token_text; HWY_ASSERT(model.Tokenizer().Decode(std::vector{token}, &token_text));