mirror of https://github.com/google/gemma.cpp.git
Make prompt wrapping more consistent and fix duplicated tokens for multi-turn.
Do not echo <end_of_turn> tokens to the user. Have verbosity=0 only show the dialog. PiperOrigin-RevId: 705021391
This commit is contained in:
parent
e69bc3bc1c
commit
aed17396be
|
|
@ -157,13 +157,14 @@ TEST_F(GemmaTest, Multiturn) {
|
||||||
Gemma* model = s_env->GetModel();
|
Gemma* model = s_env->GetModel();
|
||||||
ASSERT_NE(model, nullptr);
|
ASSERT_NE(model, nullptr);
|
||||||
size_t abs_pos = 0;
|
size_t abs_pos = 0;
|
||||||
std::string dialog;
|
std::string response;
|
||||||
auto stream_token = [&](int token, float) {
|
auto stream_token = [&](int token, float) {
|
||||||
|
if (token == EOS_ID) return true;
|
||||||
++abs_pos;
|
++abs_pos;
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
model->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
model->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||||
dialog += token_text;
|
response += token_text;
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
RuntimeConfig runtime_config{
|
RuntimeConfig runtime_config{
|
||||||
|
|
@ -180,18 +181,21 @@ TEST_F(GemmaTest, Multiturn) {
|
||||||
abs_pos, mutable_prompt);
|
abs_pos, mutable_prompt);
|
||||||
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||||
timing_info);
|
timing_info);
|
||||||
|
// Note: we do not rewind any <end_of_turn> tokens here. If the model
|
||||||
|
// produced one and WrapAndTokenize() inserts another one, it will just be
|
||||||
|
// duplicated.
|
||||||
mutable_prompt = "Please repeat all prior statements.";
|
mutable_prompt = "Please repeat all prior statements.";
|
||||||
tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos,
|
tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos,
|
||||||
mutable_prompt);
|
mutable_prompt);
|
||||||
// Reset the `dialog` string here, then check that the model actually has
|
// Reset the `response` string here, then check that the model actually has
|
||||||
// access to the previous turn by asking to reproduce.
|
// access to the previous turn by asking to reproduce.
|
||||||
dialog.clear();
|
response.clear();
|
||||||
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||||
timing_info);
|
timing_info);
|
||||||
fprintf(stderr, "decoded: %s\n", dialog.c_str());
|
fprintf(stderr, "decoded: %s\n", response.c_str());
|
||||||
bool remembered_turquoise =
|
bool remembered_turquoise =
|
||||||
dialog.find("turquoise") != std::string::npos; // NOLINT
|
response.find("turquoise") != std::string::npos; // NOLINT
|
||||||
bool remembered_car = dialog.find("car") != std::string::npos; // NOLINT
|
bool remembered_car = response.find("car") != std::string::npos; // NOLINT
|
||||||
EXPECT_TRUE(remembered_turquoise || remembered_car);
|
EXPECT_TRUE(remembered_turquoise || remembered_car);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1249,16 +1249,9 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
||||||
// Copy so we can increment without requiring users to pass in a mutable span.
|
// Copy so we can increment without requiring users to pass in a mutable span.
|
||||||
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
|
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
|
||||||
queries_pos_in.cend());
|
queries_pos_in.cend());
|
||||||
QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
|
const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
|
||||||
queries_pos_copy.size());
|
queries_pos_copy.size());
|
||||||
// For the first turn, qpos remains 0. Otherwise, rewind the previous EOS.
|
|
||||||
// Background: for multiturn, Gemma 2 expects only <end_of_turn>, not EOS. The
|
|
||||||
// previous `Generate` called `StreamToken` for the last token (EOS), hence
|
|
||||||
// our caller's qpos is 1 too high. This must be corrected because we didn't
|
|
||||||
// write to the KV cache at that position, so MSAN would complain.
|
|
||||||
for (size_t& qpos : queries_mutable_pos) {
|
|
||||||
qpos = qpos == 0 ? 0 : qpos - 1;
|
|
||||||
}
|
|
||||||
// Sanity check: prompts should not be empty, nor start with EOS.
|
// Sanity check: prompts should not be empty, nor start with EOS.
|
||||||
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
|
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
|
||||||
const PromptTokens& prompt = queries_prompt[query_idx];
|
const PromptTokens& prompt = queries_prompt[query_idx];
|
||||||
|
|
|
||||||
80
gemma/run.cc
80
gemma/run.cc
|
|
@ -85,6 +85,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
size_t abs_pos = 0; // across turns
|
size_t abs_pos = 0; // across turns
|
||||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||||
size_t prompt_size = 0;
|
size_t prompt_size = 0;
|
||||||
|
bool end_of_turn_seen = false;
|
||||||
|
|
||||||
std::mt19937 gen;
|
std::mt19937 gen;
|
||||||
InitGenerator(args, gen);
|
InitGenerator(args, gen);
|
||||||
|
|
@ -114,37 +115,44 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
// callback function invoked for each generated token.
|
// callback function invoked for each generated token.
|
||||||
auto stream_token = [&](int token, float) {
|
auto stream_token = [&](int token, float) {
|
||||||
++abs_pos;
|
++abs_pos;
|
||||||
++tokens_generated_this_turn;
|
if (token == EOS_ID) {
|
||||||
// <= since position is incremented before
|
|
||||||
if (tokens_generated_this_turn <= prompt_size) {
|
|
||||||
std::cerr << "." << std::flush;
|
|
||||||
} else if (token == EOS_ID) {
|
|
||||||
if (!args.multiturn) {
|
|
||||||
abs_pos = 0;
|
|
||||||
InitGenerator(args, gen);
|
|
||||||
}
|
|
||||||
if (app.verbosity >= 2) {
|
if (app.verbosity >= 2) {
|
||||||
std::cout << "\n[ End ]\n";
|
std::cout << "\n[ End ]\n";
|
||||||
}
|
}
|
||||||
} else {
|
return true;
|
||||||
|
}
|
||||||
|
const bool in_prompt = tokens_generated_this_turn < prompt_size;
|
||||||
|
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
||||||
|
++tokens_generated_this_turn;
|
||||||
|
if (in_prompt) {
|
||||||
|
if (app.verbosity >= 1) {
|
||||||
|
std::cerr << "." << std::flush;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
HWY_ASSERT(
|
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||||
model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
if (first_response_token) {
|
||||||
// +1 since position is incremented above
|
|
||||||
if (tokens_generated_this_turn == prompt_size + 1) {
|
|
||||||
// first token of response
|
|
||||||
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||||
if (app.verbosity >= 1) {
|
if (app.verbosity >= 1) {
|
||||||
std::cout << "\n\n";
|
std::cout << "\n\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::cout << token_text << std::flush;
|
if (token_text == "<end_of_turn>") {
|
||||||
|
// We don't want to show the <end_of_turn> token to the user.
|
||||||
|
// We also need to remember that we've seen it, so that we can rewind
|
||||||
|
// abs_pos appropriately. We expect EOS as the next token.
|
||||||
|
end_of_turn_seen = true;
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
std::cout << token_text << std::flush;
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
while (true) { // Loop until user quits.
|
while (true) { // Loop until user quits.
|
||||||
tokens_generated_this_turn = 0;
|
tokens_generated_this_turn = 0;
|
||||||
|
|
||||||
|
// Read prompt and handle special commands.
|
||||||
std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line);
|
std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line);
|
||||||
if (!std::cin) return;
|
if (!std::cin) return;
|
||||||
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
||||||
|
|
@ -155,23 +163,22 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (prompt_string.empty()) {
|
||||||
if (have_image && abs_pos != 0) {
|
std::cout << "Use '%q' to quit.\n";
|
||||||
// This occurs when we have hit max_generated.
|
continue;
|
||||||
abs_pos = 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wrap, tokenize and maybe log prompt tokens.
|
||||||
std::vector<int> prompt = WrapAndTokenize(
|
std::vector<int> prompt = WrapAndTokenize(
|
||||||
model.Tokenizer(), model.Info(), abs_pos, prompt_string);
|
model.Tokenizer(), model.Info(), abs_pos, prompt_string);
|
||||||
prompt_size = prompt.size();
|
prompt_size = prompt.size();
|
||||||
std::cerr << "\n"
|
|
||||||
<< "[ Reading prompt ] " << std::flush;
|
|
||||||
if constexpr (kVerboseLogTokens) {
|
if constexpr (kVerboseLogTokens) {
|
||||||
for (int i = 0; i < prompt_size; ++i) {
|
for (int i = 0; i < prompt_size; ++i) {
|
||||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set up runtime config.
|
||||||
TimingInfo timing_info = {.verbosity = app.verbosity};
|
TimingInfo timing_info = {.verbosity = app.verbosity};
|
||||||
RuntimeConfig runtime_config = {.gen = &gen,
|
RuntimeConfig runtime_config = {.gen = &gen,
|
||||||
.verbosity = app.verbosity,
|
.verbosity = app.verbosity,
|
||||||
|
|
@ -190,9 +197,38 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
// We need to look at all the tokens for the prefix.
|
// We need to look at all the tokens for the prefix.
|
||||||
runtime_config.prefill_tbatch_size = prompt_size;
|
runtime_config.prefill_tbatch_size = prompt_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generate until EOS or max_generated_tokens.
|
||||||
|
if (app.verbosity >= 1) {
|
||||||
|
std::cerr << "\n[ Reading prompt ] " << std::flush;
|
||||||
|
}
|
||||||
model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache,
|
model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache,
|
||||||
timing_info);
|
timing_info);
|
||||||
std::cout << "\n\n";
|
std::cout << "\n\n";
|
||||||
|
|
||||||
|
// Prepare for the next turn.
|
||||||
|
if (!args.multiturn || model.Info().training == ModelTraining::PALIGEMMA) {
|
||||||
|
abs_pos = 0; // Start a new turn at position 0.
|
||||||
|
InitGenerator(args, gen);
|
||||||
|
} else {
|
||||||
|
// The last token was either EOS, then it should be ignored because it is
|
||||||
|
// never part of the dialog, see Table 5 in the Gemma-2 paper:
|
||||||
|
// https://arxiv.org/pdf/2408.00118
|
||||||
|
// Or we have hit max_generated_tokens, then the last token will be lost.
|
||||||
|
// (We could store it in stream_token, and then prepend to the next turn,
|
||||||
|
// but it's not worth the complexity, as multi-turn with max_generated is
|
||||||
|
// not a common use case.)
|
||||||
|
// In either case, we need to rewind abs_pos by one.
|
||||||
|
HWY_ASSERT(abs_pos > 0);
|
||||||
|
abs_pos--;
|
||||||
|
}
|
||||||
|
if (end_of_turn_seen && abs_pos > 0) {
|
||||||
|
// If we have seen an end_of_turn token, we need to rewind abs_pos by one
|
||||||
|
// more, because we will pre-pend it again to the prompt in
|
||||||
|
// WrapAndTokenize.
|
||||||
|
abs_pos--;
|
||||||
|
}
|
||||||
|
end_of_turn_seen = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue