mirror of https://github.com/google/gemma.cpp.git
Make prompt wrapping more consistent and fix duplicated tokens for multi-turn.
Do not echo <end_of_turn> tokens to the user. Have verbosity=0 only show the dialog. PiperOrigin-RevId: 705021391
This commit is contained in:
parent
e69bc3bc1c
commit
aed17396be
|
|
@ -157,13 +157,14 @@ TEST_F(GemmaTest, Multiturn) {
|
|||
Gemma* model = s_env->GetModel();
|
||||
ASSERT_NE(model, nullptr);
|
||||
size_t abs_pos = 0;
|
||||
std::string dialog;
|
||||
std::string response;
|
||||
auto stream_token = [&](int token, float) {
|
||||
if (token == EOS_ID) return true;
|
||||
++abs_pos;
|
||||
std::string token_text;
|
||||
EXPECT_TRUE(
|
||||
model->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
dialog += token_text;
|
||||
response += token_text;
|
||||
return true;
|
||||
};
|
||||
RuntimeConfig runtime_config{
|
||||
|
|
@ -180,18 +181,21 @@ TEST_F(GemmaTest, Multiturn) {
|
|||
abs_pos, mutable_prompt);
|
||||
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||
timing_info);
|
||||
// Note: we do not rewind any <end_of_turn> tokens here. If the model
|
||||
// produced one and WrapAndTokenize() inserts another one, it will just be
|
||||
// duplicated.
|
||||
mutable_prompt = "Please repeat all prior statements.";
|
||||
tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos,
|
||||
mutable_prompt);
|
||||
// Reset the `dialog` string here, then check that the model actually has
|
||||
// Reset the `response` string here, then check that the model actually has
|
||||
// access to the previous turn by asking to reproduce.
|
||||
dialog.clear();
|
||||
response.clear();
|
||||
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||
timing_info);
|
||||
fprintf(stderr, "decoded: %s\n", dialog.c_str());
|
||||
fprintf(stderr, "decoded: %s\n", response.c_str());
|
||||
bool remembered_turquoise =
|
||||
dialog.find("turquoise") != std::string::npos; // NOLINT
|
||||
bool remembered_car = dialog.find("car") != std::string::npos; // NOLINT
|
||||
response.find("turquoise") != std::string::npos; // NOLINT
|
||||
bool remembered_car = response.find("car") != std::string::npos; // NOLINT
|
||||
EXPECT_TRUE(remembered_turquoise || remembered_car);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1249,16 +1249,9 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
|||
// Copy so we can increment without requiring users to pass in a mutable span.
|
||||
std::vector<size_t> queries_pos_copy(queries_pos_in.cbegin(),
|
||||
queries_pos_in.cend());
|
||||
QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
|
||||
const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
|
||||
queries_pos_copy.size());
|
||||
// For the first turn, qpos remains 0. Otherwise, rewind the previous EOS.
|
||||
// Background: for multiturn, Gemma 2 expects only <end_of_turn>, not EOS. The
|
||||
// previous `Generate` called `StreamToken` for the last token (EOS), hence
|
||||
// our caller's qpos is 1 too high. This must be corrected because we didn't
|
||||
// write to the KV cache at that position, so MSAN would complain.
|
||||
for (size_t& qpos : queries_mutable_pos) {
|
||||
qpos = qpos == 0 ? 0 : qpos - 1;
|
||||
}
|
||||
|
||||
// Sanity check: prompts should not be empty, nor start with EOS.
|
||||
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
|
||||
const PromptTokens& prompt = queries_prompt[query_idx];
|
||||
|
|
|
|||
80
gemma/run.cc
80
gemma/run.cc
|
|
@ -85,6 +85,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
size_t abs_pos = 0; // across turns
|
||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||
size_t prompt_size = 0;
|
||||
bool end_of_turn_seen = false;
|
||||
|
||||
std::mt19937 gen;
|
||||
InitGenerator(args, gen);
|
||||
|
|
@ -114,37 +115,44 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
// callback function invoked for each generated token.
|
||||
auto stream_token = [&](int token, float) {
|
||||
++abs_pos;
|
||||
++tokens_generated_this_turn;
|
||||
// <= since position is incremented before
|
||||
if (tokens_generated_this_turn <= prompt_size) {
|
||||
std::cerr << "." << std::flush;
|
||||
} else if (token == EOS_ID) {
|
||||
if (!args.multiturn) {
|
||||
abs_pos = 0;
|
||||
InitGenerator(args, gen);
|
||||
}
|
||||
if (token == EOS_ID) {
|
||||
if (app.verbosity >= 2) {
|
||||
std::cout << "\n[ End ]\n";
|
||||
}
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
const bool in_prompt = tokens_generated_this_turn < prompt_size;
|
||||
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
||||
++tokens_generated_this_turn;
|
||||
if (in_prompt) {
|
||||
if (app.verbosity >= 1) {
|
||||
std::cerr << "." << std::flush;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
std::string token_text;
|
||||
HWY_ASSERT(
|
||||
model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
// +1 since position is incremented above
|
||||
if (tokens_generated_this_turn == prompt_size + 1) {
|
||||
// first token of response
|
||||
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
if (first_response_token) {
|
||||
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||
if (app.verbosity >= 1) {
|
||||
std::cout << "\n\n";
|
||||
}
|
||||
}
|
||||
std::cout << token_text << std::flush;
|
||||
if (token_text == "<end_of_turn>") {
|
||||
// We don't want to show the <end_of_turn> token to the user.
|
||||
// We also need to remember that we've seen it, so that we can rewind
|
||||
// abs_pos appropriately. We expect EOS as the next token.
|
||||
end_of_turn_seen = true;
|
||||
return true;
|
||||
}
|
||||
std::cout << token_text << std::flush;
|
||||
return true;
|
||||
};
|
||||
|
||||
while (true) { // Loop until user quits.
|
||||
tokens_generated_this_turn = 0;
|
||||
|
||||
// Read prompt and handle special commands.
|
||||
std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line);
|
||||
if (!std::cin) return;
|
||||
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
||||
|
|
@ -155,23 +163,22 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (have_image && abs_pos != 0) {
|
||||
// This occurs when we have hit max_generated.
|
||||
abs_pos = 0;
|
||||
if (prompt_string.empty()) {
|
||||
std::cout << "Use '%q' to quit.\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Wrap, tokenize and maybe log prompt tokens.
|
||||
std::vector<int> prompt = WrapAndTokenize(
|
||||
model.Tokenizer(), model.Info(), abs_pos, prompt_string);
|
||||
prompt_size = prompt.size();
|
||||
std::cerr << "\n"
|
||||
<< "[ Reading prompt ] " << std::flush;
|
||||
if constexpr (kVerboseLogTokens) {
|
||||
for (int i = 0; i < prompt_size; ++i) {
|
||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Set up runtime config.
|
||||
TimingInfo timing_info = {.verbosity = app.verbosity};
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.verbosity = app.verbosity,
|
||||
|
|
@ -190,9 +197,38 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
// We need to look at all the tokens for the prefix.
|
||||
runtime_config.prefill_tbatch_size = prompt_size;
|
||||
}
|
||||
|
||||
// Generate until EOS or max_generated_tokens.
|
||||
if (app.verbosity >= 1) {
|
||||
std::cerr << "\n[ Reading prompt ] " << std::flush;
|
||||
}
|
||||
model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache,
|
||||
timing_info);
|
||||
std::cout << "\n\n";
|
||||
|
||||
// Prepare for the next turn.
|
||||
if (!args.multiturn || model.Info().training == ModelTraining::PALIGEMMA) {
|
||||
abs_pos = 0; // Start a new turn at position 0.
|
||||
InitGenerator(args, gen);
|
||||
} else {
|
||||
// The last token was either EOS, then it should be ignored because it is
|
||||
// never part of the dialog, see Table 5 in the Gemma-2 paper:
|
||||
// https://arxiv.org/pdf/2408.00118
|
||||
// Or we have hit max_generated_tokens, then the last token will be lost.
|
||||
// (We could store it in stream_token, and then prepend to the next turn,
|
||||
// but it's not worth the complexity, as multi-turn with max_generated is
|
||||
// not a common use case.)
|
||||
// In either case, we need to rewind abs_pos by one.
|
||||
HWY_ASSERT(abs_pos > 0);
|
||||
abs_pos--;
|
||||
}
|
||||
if (end_of_turn_seen && abs_pos > 0) {
|
||||
// If we have seen an end_of_turn token, we need to rewind abs_pos by one
|
||||
// more, because we will pre-pend it again to the prompt in
|
||||
// WrapAndTokenize.
|
||||
abs_pos--;
|
||||
}
|
||||
end_of_turn_seen = false;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue