mirror of https://github.com/google/gemma.cpp.git
Merge pull request #527 from ufownl:feature/gemma2_secondary_eos
PiperOrigin-RevId: 740327973
This commit is contained in:
commit
4a924f1794
|
|
@ -35,6 +35,8 @@ static ModelConfig ConfigBaseGemmaV2() {
|
||||||
ModelConfig config = ConfigNoSSM();
|
ModelConfig config = ConfigNoSSM();
|
||||||
config.att_cap = 50.0f;
|
config.att_cap = 50.0f;
|
||||||
config.final_cap = 30.0f;
|
config.final_cap = 30.0f;
|
||||||
|
config.eos_id = 1;
|
||||||
|
config.secondary_eos_id = 107;
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1427,7 +1427,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
||||||
// Sanity check: prompts should not be empty, nor start with EOS.
|
// Sanity check: prompts should not be empty, nor start with EOS.
|
||||||
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
|
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
|
||||||
const PromptTokens& prompt = queries_prompt[query_idx];
|
const PromptTokens& prompt = queries_prompt[query_idx];
|
||||||
HWY_ASSERT(prompt.size() != 0 && !model.Config().IsEOS(prompt[0]));
|
HWY_ASSERT(prompt.size() != 0 && prompt[0] != runtime_config.eos_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t num_queries = queries_prompt.size();
|
const size_t num_queries = queries_prompt.size();
|
||||||
|
|
|
||||||
26
gemma/run.cc
26
gemma/run.cc
|
|
@ -85,7 +85,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
size_t abs_pos = 0; // across turns
|
size_t abs_pos = 0; // across turns
|
||||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||||
size_t prompt_size = 0;
|
size_t prompt_size = 0;
|
||||||
bool end_of_turn_seen = false;
|
|
||||||
|
|
||||||
std::mt19937 gen;
|
std::mt19937 gen;
|
||||||
InitGenerator(args, gen);
|
InitGenerator(args, gen);
|
||||||
|
|
@ -118,12 +117,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
// callback function invoked for each generated token.
|
// callback function invoked for each generated token.
|
||||||
auto stream_token = [&](int token, float) {
|
auto stream_token = [&](int token, float) {
|
||||||
++abs_pos;
|
++abs_pos;
|
||||||
if (model.GetModelConfig().IsEOS(token)) {
|
|
||||||
if (app.verbosity >= 2) {
|
|
||||||
std::cout << "\n[ End ]\n";
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
const bool in_prompt = tokens_generated_this_turn < prompt_size;
|
const bool in_prompt = tokens_generated_this_turn < prompt_size;
|
||||||
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
||||||
++tokens_generated_this_turn;
|
++tokens_generated_this_turn;
|
||||||
|
|
@ -132,6 +125,11 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
std::cerr << "." << std::flush;
|
std::cerr << "." << std::flush;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
} else if (model.GetModelConfig().IsEOS(token)) {
|
||||||
|
if (app.verbosity >= 2) {
|
||||||
|
std::cout << "\n[ End ]\n";
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||||
|
|
@ -141,13 +139,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
std::cout << "\n\n";
|
std::cout << "\n\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (token_text == "<end_of_turn>") {
|
|
||||||
// We don't want to show the <end_of_turn> token to the user.
|
|
||||||
// We also need to remember that we've seen it, so that we can rewind
|
|
||||||
// abs_pos appropriately. We expect EOS as the next token.
|
|
||||||
end_of_turn_seen = true;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
std::cout << token_text << std::flush;
|
std::cout << token_text << std::flush;
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
@ -233,13 +224,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
HWY_ASSERT(abs_pos > 0);
|
HWY_ASSERT(abs_pos > 0);
|
||||||
abs_pos--;
|
abs_pos--;
|
||||||
}
|
}
|
||||||
if (end_of_turn_seen && abs_pos > 0) {
|
|
||||||
// If we have seen an end_of_turn token, we need to rewind abs_pos by one
|
|
||||||
// more, because we will prepend it again to the prompt in
|
|
||||||
// WrapAndTokenize.
|
|
||||||
abs_pos--;
|
|
||||||
}
|
|
||||||
end_of_turn_seen = false;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue