mirror of https://github.com/google/gemma.cpp.git
Fix the EOS checking
The secondary eos is usually `<end_of_turn>`, which can appear in the prompt, so we can only check it not in the prompt.
This commit is contained in:
parent
6300c123ee
commit
2bad79f110
|
|
@ -1427,7 +1427,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
|||
// Sanity check: prompts should not be empty, nor start with EOS.
|
||||
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
|
||||
const PromptTokens& prompt = queries_prompt[query_idx];
|
||||
HWY_ASSERT(prompt.size() != 0 && !model.Config().IsEOS(prompt[0]));
|
||||
HWY_ASSERT(prompt.size() != 0 && prompt[0] != runtime_config.eos_id);
|
||||
}
|
||||
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
|
|
@ -1615,4 +1615,4 @@ void GenerateImageTokens( // NOLINT(misc-definitions-in-headers)
|
|||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
|
||||
|
|
|
|||
11
gemma/run.cc
11
gemma/run.cc
|
|
@ -118,12 +118,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
// callback function invoked for each generated token.
|
||||
auto stream_token = [&](int token, float) {
|
||||
++abs_pos;
|
||||
if (model.GetModelConfig().IsEOS(token)) {
|
||||
if (app.verbosity >= 2) {
|
||||
std::cout << "\n[ End ]\n";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
const bool in_prompt = tokens_generated_this_turn < prompt_size;
|
||||
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
||||
++tokens_generated_this_turn;
|
||||
|
|
@ -132,6 +126,11 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
std::cerr << "." << std::flush;
|
||||
}
|
||||
return true;
|
||||
} else if (model.GetModelConfig().IsEOS(token)) {
|
||||
if (app.verbosity >= 2) {
|
||||
std::cout << "\n[ End ]\n";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
std::string token_text;
|
||||
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
|
|
|
|||
Loading…
Reference in New Issue