diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index a0405b3..c2b7793 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -149,6 +149,45 @@ TEST_F(GemmaTest, Arithmetic) { TestQuestions(kQA, kNum, /*batch=*/false); } +TEST_F(GemmaTest, Multiturn) { + Gemma* model = s_env->GetModel(); + ASSERT_NE(model, nullptr); + size_t abs_pos = 0; + std::string dialog; + auto stream_token = [&](int token, float) { + ++abs_pos; + std::string token_text; + EXPECT_TRUE( + model->Tokenizer().Decode(std::vector{token}, &token_text)); + dialog += token_text; + return true; + }; + RuntimeConfig runtime_config{ + .max_tokens = 128, + .max_generated_tokens = 64, + .temperature = 0.0f, + .verbosity = 2, + .gen = &s_env->MutableGen(), + .stream_token = stream_token, + }; + TimingInfo timing_info{.verbosity = 0}; + // First "say" something slightly unusual. + std::string mutable_prompt = "The color of my car is turquoise."; + std::vector tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), + abs_pos, mutable_prompt); + model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), + timing_info); + mutable_prompt = "Can you repeat to me what I just said?"; + tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos, + mutable_prompt); + // Reset the `dialog` string here, then check that the model actually has + // access to the previous turn by asking to reproduce. + dialog.clear(); + model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), + timing_info); + EXPECT_TRUE(dialog.find("turquoise") != std::string::npos); // NOLINT +} + static const char kJingleBells[] = R"( Dashing through the snow In a one-horse open sleigh diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 0f14790..24ac462 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -637,7 +637,7 @@ using QueriesMutablePos = hwy::Span; // Populates KV cache for batches of tokens from one query at a time. template HWY_NOINLINE void Prefill( - const QueriesPromptTokens& queries_prompt, const size_t prefill_per_query, + const QueriesPromptTokens& queries_prompt, const QueriesMutablePos& queries_pos, const size_t query_idx_start, const CompressedWeights& weights, Activations& activations, const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len, @@ -665,6 +665,7 @@ HWY_NOINLINE void Prefill( QueriesPos single_query_pos(&queries_pos[qi], 1); KVCaches single_kv_cache(&kv_caches[qi], 1); + const size_t prefill_per_query = queries_prompt[qi].size() - 1; // For each batch of tokens in the query: for (size_t tbatch_start = 0; tbatch_start < prefill_per_query; tbatch_start += max_tbatch_size) { @@ -688,7 +689,8 @@ HWY_NOINLINE void Prefill( // NOTE: we unconditionally call StreamToken, even if EOS. for (size_t ti = 0; ti < tbatch_size; ++ti) { const size_t pos = queries_pos[qi] + ti; - const int token = queries_prompt[qi][pos]; + const size_t pos_in_prompt = tbatch_start + ti; + const int token = queries_prompt[qi][pos_in_prompt]; runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f); } @@ -780,15 +782,12 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, // Placeholder for internal test3, do not remove // Returns the min and max number of tokens for all queries. -static void ScanQueryLengths(const QueriesPromptTokens& queries_prompt, - size_t& min_prompt_size, size_t& max_prompt_size) { - const size_t num_queries = queries_prompt.size(); - min_prompt_size = hwy::LimitsMax(); - max_prompt_size = 0; - for (size_t i = 0; i < num_queries; ++i) { - min_prompt_size = std::min(min_prompt_size, queries_prompt[i].size()); +static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) { + size_t max_prompt_size = 0; + for (size_t i = 0; i < queries_prompt.size(); ++i) { max_prompt_size = std::max(max_prompt_size, queries_prompt[i].size()); } + return max_prompt_size; } // Holds "is at end of stream" state for each query. @@ -851,9 +850,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, HWY_ASSERT(kv_caches.size() == num_queries); const hwy::Divisor div_seq_len(static_cast(kv_caches[0].seq_len)); - size_t min_prompt_size, max_prompt_size; - ScanQueryLengths(queries_prompt, min_prompt_size, max_prompt_size); - + size_t max_prompt_size = MaxQueryLength(queries_prompt); size_t max_tokens = runtime_config.max_tokens; size_t max_generated_tokens = runtime_config.max_generated_tokens; RangeChecks(max_tokens, max_generated_tokens, max_prompt_size); @@ -877,7 +874,6 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, // Prefill stops before min_prompt_size - 1 because the last prompt token is // the first input token for generation. - const size_t prefill_per_query = min_prompt_size - 1; const double prefill_start = hwy::platform::Now(); // If tbatch is larger than the qbatch we already have in `activations`, then // allocate prefill_activations, otherwise reuse. @@ -888,11 +884,16 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, prefill_activations.Allocate(runtime_config.prefill_tbatch_size, activations.env.Pools()); } - Prefill(queries_prompt, prefill_per_query, queries_mutable_pos, - query_idx_start, weights, + Prefill(queries_prompt, queries_mutable_pos, query_idx_start, + weights, use_prefill_activations ? prefill_activations : activations, runtime_config, div_seq_len, kv_caches); - timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start); + // Compute the number of tokens that were prefilled and notify timing_info. + size_t prefilled_tokens = 0; + for (size_t qi = 0; qi < num_queries; ++qi) { + prefilled_tokens += queries_prompt[qi].size() - 1; + } + timing_info.NotifyPrefill(prefilled_tokens, prefill_start); // queries_pos are incremented by Prefill. // Storage for the last generated token from each query, passed to the next @@ -902,7 +903,9 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, // Stream the last prompt token from each query and fill gen_tokens. TokenStreamer token_streamer(runtime_config); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - gen_tokens[query_idx] = queries_prompt[query_idx][prefill_per_query]; + size_t last_token_pos_in_prompt = + queries_mutable_pos[query_idx] - queries_pos_in[query_idx]; + gen_tokens[query_idx] = queries_prompt[query_idx][last_token_pos_in_prompt]; (void)token_streamer(query_idx_start + query_idx, queries_mutable_pos[query_idx], gen_tokens[query_idx], 0.0f);