diff --git a/gemma/configs.cc b/gemma/configs.cc index d980b3b..276c8f9 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -35,6 +35,8 @@ static ModelConfig ConfigBaseGemmaV2() { ModelConfig config = ConfigNoSSM(); config.att_cap = 50.0f; config.final_cap = 30.0f; + config.eos_id = 1; + config.secondary_eos_id = 107; return config; } diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index ab25d53..ccb34f0 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -1427,7 +1427,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, // Sanity check: prompts should not be empty, nor start with EOS. for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) { const PromptTokens& prompt = queries_prompt[query_idx]; - HWY_ASSERT(prompt.size() != 0 && !model.Config().IsEOS(prompt[0])); + HWY_ASSERT(prompt.size() != 0 && prompt[0] != runtime_config.eos_id); } const size_t num_queries = queries_prompt.size(); @@ -1615,4 +1615,4 @@ void GenerateImageTokens( // NOLINT(misc-definitions-in-headers) } // namespace gcpp HWY_AFTER_NAMESPACE(); -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ \ No newline at end of file +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ diff --git a/gemma/run.cc b/gemma/run.cc index 8c23c15..254d13f 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -85,7 +85,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, size_t abs_pos = 0; // across turns size_t tokens_generated_this_turn = 0; // differentiates prefill from reply size_t prompt_size = 0; - bool end_of_turn_seen = false; std::mt19937 gen; InitGenerator(args, gen); @@ -118,12 +117,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, // callback function invoked for each generated token. auto stream_token = [&](int token, float) { ++abs_pos; - if (model.GetModelConfig().IsEOS(token)) { - if (app.verbosity >= 2) { - std::cout << "\n[ End ]\n"; - } - return true; - } const bool in_prompt = tokens_generated_this_turn < prompt_size; const bool first_response_token = tokens_generated_this_turn == prompt_size; ++tokens_generated_this_turn; @@ -132,6 +125,11 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, std::cerr << "." << std::flush; } return true; + } else if (model.GetModelConfig().IsEOS(token)) { + if (app.verbosity >= 2) { + std::cout << "\n[ End ]\n"; + } + return true; } std::string token_text; HWY_ASSERT(model.Tokenizer().Decode(std::vector{token}, &token_text)); @@ -141,13 +139,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, std::cout << "\n\n"; } } - if (token_text == "") { - // We don't want to show the token to the user. - // We also need to remember that we've seen it, so that we can rewind - // abs_pos appropriately. We expect EOS as the next token. - end_of_turn_seen = true; - return true; - } std::cout << token_text << std::flush; return true; }; @@ -233,13 +224,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, HWY_ASSERT(abs_pos > 0); abs_pos--; } - if (end_of_turn_seen && abs_pos > 0) { - // If we have seen an end_of_turn token, we need to rewind abs_pos by one - // more, because we will prepend it again to the prompt in - // WrapAndTokenize. - abs_pos--; - } - end_of_turn_seen = false; } }