Make prompt wrapping more consistent and fix duplicated tokens for multi-turn.

Do not echo <end_of_turn> tokens to the user.
Have verbosity=0 only show the dialog.

PiperOrigin-RevId: 705021391
This commit is contained in:
Daniel Keysers 2024-12-11 01:51:29 -08:00 committed by Copybara-Service
parent e69bc3bc1c
commit aed17396be
3 changed files with 78 additions and 45 deletions

View File

@ -157,13 +157,14 @@ TEST_F(GemmaTest, Multiturn) {
Gemma* model = s_env->GetModel(); Gemma* model = s_env->GetModel();
ASSERT_NE(model, nullptr); ASSERT_NE(model, nullptr);
size_t abs_pos = 0; size_t abs_pos = 0;
std::string dialog; std::string response;
auto stream_token = [&](int token, float) { auto stream_token = [&](int token, float) {
if (token == EOS_ID) return true;
++abs_pos; ++abs_pos;
std::string token_text; std::string token_text;
EXPECT_TRUE( EXPECT_TRUE(
model->Tokenizer().Decode(std::vector<int>{token}, &token_text)); model->Tokenizer().Decode(std::vector<int>{token}, &token_text));
dialog += token_text; response += token_text;
return true; return true;
}; };
RuntimeConfig runtime_config{ RuntimeConfig runtime_config{
@ -180,18 +181,21 @@ TEST_F(GemmaTest, Multiturn) {
abs_pos, mutable_prompt); abs_pos, mutable_prompt);
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info); timing_info);
// Note: we do not rewind any <end_of_turn> tokens here. If the model
// produced one and WrapAndTokenize() inserts another one, it will just be
// duplicated.
mutable_prompt = "Please repeat all prior statements."; mutable_prompt = "Please repeat all prior statements.";
tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos, tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos,
mutable_prompt); mutable_prompt);
// Reset the `dialog` string here, then check that the model actually has // Reset the `response` string here, then check that the model actually has
// access to the previous turn by asking to reproduce. // access to the previous turn by asking to reproduce.
dialog.clear(); response.clear();
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info); timing_info);
fprintf(stderr, "decoded: %s\n", dialog.c_str()); fprintf(stderr, "decoded: %s\n", response.c_str());
bool remembered_turquoise = bool remembered_turquoise =
dialog.find("turquoise") != std::string::npos; // NOLINT response.find("turquoise") != std::string::npos; // NOLINT
bool remembered_car = dialog.find("car") != std::string::npos; // NOLINT bool remembered_car = response.find("car") != std::string::npos; // NOLINT
EXPECT_TRUE(remembered_turquoise || remembered_car); EXPECT_TRUE(remembered_turquoise || remembered_car);
} }

View File

@ -1249,16 +1249,9 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
// Copy so we can increment without requiring users to pass in a mutable span. // Copy so we can increment without requiring users to pass in a mutable span.
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(), std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
queries_pos_in.cend()); queries_pos_in.cend());
QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(), const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
queries_pos_copy.size()); queries_pos_copy.size());
// For the first turn, qpos remains 0. Otherwise, rewind the previous EOS.
// Background: for multiturn, Gemma 2 expects only <end_of_turn>, not EOS. The
// previous `Generate` called `StreamToken` for the last token (EOS), hence
// our caller's qpos is 1 too high. This must be corrected because we didn't
// write to the KV cache at that position, so MSAN would complain.
for (size_t& qpos : queries_mutable_pos) {
qpos = qpos == 0 ? 0 : qpos - 1;
}
// Sanity check: prompts should not be empty, nor start with EOS. // Sanity check: prompts should not be empty, nor start with EOS.
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) { for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
const PromptTokens& prompt = queries_prompt[query_idx]; const PromptTokens& prompt = queries_prompt[query_idx];

View File

@ -85,6 +85,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
size_t abs_pos = 0; // across turns size_t abs_pos = 0; // across turns
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
size_t prompt_size = 0; size_t prompt_size = 0;
bool end_of_turn_seen = false;
std::mt19937 gen; std::mt19937 gen;
InitGenerator(args, gen); InitGenerator(args, gen);
@ -114,37 +115,44 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
// callback function invoked for each generated token. // callback function invoked for each generated token.
auto stream_token = [&](int token, float) { auto stream_token = [&](int token, float) {
++abs_pos; ++abs_pos;
++tokens_generated_this_turn; if (token == EOS_ID) {
// <= since position is incremented before
if (tokens_generated_this_turn <= prompt_size) {
std::cerr << "." << std::flush;
} else if (token == EOS_ID) {
if (!args.multiturn) {
abs_pos = 0;
InitGenerator(args, gen);
}
if (app.verbosity >= 2) { if (app.verbosity >= 2) {
std::cout << "\n[ End ]\n"; std::cout << "\n[ End ]\n";
} }
} else { return true;
std::string token_text;
HWY_ASSERT(
model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
// +1 since position is incremented above
if (tokens_generated_this_turn == prompt_size + 1) {
// first token of response
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
if (app.verbosity >= 1) {
std::cout << "\n\n";
}
}
std::cout << token_text << std::flush;
} }
const bool in_prompt = tokens_generated_this_turn < prompt_size;
const bool first_response_token = tokens_generated_this_turn == prompt_size;
++tokens_generated_this_turn;
if (in_prompt) {
if (app.verbosity >= 1) {
std::cerr << "." << std::flush;
}
return true;
}
std::string token_text;
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
if (first_response_token) {
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
if (app.verbosity >= 1) {
std::cout << "\n\n";
}
}
if (token_text == "<end_of_turn>") {
// We don't want to show the <end_of_turn> token to the user.
// We also need to remember that we've seen it, so that we can rewind
// abs_pos appropriately. We expect EOS as the next token.
end_of_turn_seen = true;
return true;
}
std::cout << token_text << std::flush;
return true; return true;
}; };
while (true) { // Loop until user quits. while (true) { // Loop until user quits.
tokens_generated_this_turn = 0; tokens_generated_this_turn = 0;
// Read prompt and handle special commands.
std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line); std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line);
if (!std::cin) return; if (!std::cin) return;
// If !eot_line.empty(), we append \n, so only look at the first 2 chars. // If !eot_line.empty(), we append \n, so only look at the first 2 chars.
@ -155,23 +163,22 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
continue; continue;
} }
} }
if (prompt_string.empty()) {
if (have_image && abs_pos != 0) { std::cout << "Use '%q' to quit.\n";
// This occurs when we have hit max_generated. continue;
abs_pos = 0;
} }
// Wrap, tokenize and maybe log prompt tokens.
std::vector<int> prompt = WrapAndTokenize( std::vector<int> prompt = WrapAndTokenize(
model.Tokenizer(), model.Info(), abs_pos, prompt_string); model.Tokenizer(), model.Info(), abs_pos, prompt_string);
prompt_size = prompt.size(); prompt_size = prompt.size();
std::cerr << "\n"
<< "[ Reading prompt ] " << std::flush;
if constexpr (kVerboseLogTokens) { if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) { for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
} }
} }
// Set up runtime config.
TimingInfo timing_info = {.verbosity = app.verbosity}; TimingInfo timing_info = {.verbosity = app.verbosity};
RuntimeConfig runtime_config = {.gen = &gen, RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = app.verbosity, .verbosity = app.verbosity,
@ -190,9 +197,38 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
// We need to look at all the tokens for the prefix. // We need to look at all the tokens for the prefix.
runtime_config.prefill_tbatch_size = prompt_size; runtime_config.prefill_tbatch_size = prompt_size;
} }
// Generate until EOS or max_generated_tokens.
if (app.verbosity >= 1) {
std::cerr << "\n[ Reading prompt ] " << std::flush;
}
model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache,
timing_info); timing_info);
std::cout << "\n\n"; std::cout << "\n\n";
// Prepare for the next turn.
if (!args.multiturn || model.Info().training == ModelTraining::PALIGEMMA) {
abs_pos = 0; // Start a new turn at position 0.
InitGenerator(args, gen);
} else {
// The last token was either EOS, then it should be ignored because it is
// never part of the dialog, see Table 5 in the Gemma-2 paper:
// https://arxiv.org/pdf/2408.00118
// Or we have hit max_generated_tokens, then the last token will be lost.
// (We could store it in stream_token, and then prepend to the next turn,
// but it's not worth the complexity, as multi-turn with max_generated is
// not a common use case.)
// In either case, we need to rewind abs_pos by one.
HWY_ASSERT(abs_pos > 0);
abs_pos--;
}
if (end_of_turn_seen && abs_pos > 0) {
// If we have seen an end_of_turn token, we need to rewind abs_pos by one
// more, because we will pre-pend it again to the prompt in
// WrapAndTokenize.
abs_pos--;
}
end_of_turn_seen = false;
} }
} }