Fix prefill for batched queries.

This lets gemma_test/GeographyBatched pass now also for gemma2-27B.

PiperOrigin-RevId: 664827485
This commit is contained in:
Daniel Keysers 2024-08-19 08:50:12 -07:00 committed by Copybara-Service
parent c6eb3b6f0d
commit 18e6012872
2 changed files with 59 additions and 17 deletions

View File

@ -149,6 +149,45 @@ TEST_F(GemmaTest, Arithmetic) {
TestQuestions(kQA, kNum, /*batch=*/false);
}
TEST_F(GemmaTest, Multiturn) {
Gemma* model = s_env->GetModel();
ASSERT_NE(model, nullptr);
size_t abs_pos = 0;
std::string dialog;
auto stream_token = [&](int token, float) {
++abs_pos;
std::string token_text;
EXPECT_TRUE(
model->Tokenizer().Decode(std::vector<int>{token}, &token_text));
dialog += token_text;
return true;
};
RuntimeConfig runtime_config{
.max_tokens = 128,
.max_generated_tokens = 64,
.temperature = 0.0f,
.verbosity = 2,
.gen = &s_env->MutableGen(),
.stream_token = stream_token,
};
TimingInfo timing_info{.verbosity = 0};
// First "say" something slightly unusual.
std::string mutable_prompt = "The color of my car is turquoise.";
std::vector<int> tokens = WrapAndTokenize(model->Tokenizer(), model->Info(),
abs_pos, mutable_prompt);
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info);
mutable_prompt = "Can you repeat to me what I just said?";
tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos,
mutable_prompt);
// Reset the `dialog` string here, then check that the model actually has
// access to the previous turn by asking to reproduce.
dialog.clear();
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info);
EXPECT_TRUE(dialog.find("turquoise") != std::string::npos); // NOLINT
}
static const char kJingleBells[] = R"(
Dashing through the snow
In a one-horse open sleigh

View File

@ -637,7 +637,7 @@ using QueriesMutablePos = hwy::Span<size_t>;
// Populates KV cache for batches of tokens from one query at a time.
template <class TConfig>
HWY_NOINLINE void Prefill(
const QueriesPromptTokens& queries_prompt, const size_t prefill_per_query,
const QueriesPromptTokens& queries_prompt,
const QueriesMutablePos& queries_pos, const size_t query_idx_start,
const CompressedWeights<TConfig>& weights, Activations& activations,
const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len,
@ -665,6 +665,7 @@ HWY_NOINLINE void Prefill(
QueriesPos single_query_pos(&queries_pos[qi], 1);
KVCaches single_kv_cache(&kv_caches[qi], 1);
const size_t prefill_per_query = queries_prompt[qi].size() - 1;
// For each batch of tokens in the query:
for (size_t tbatch_start = 0; tbatch_start < prefill_per_query;
tbatch_start += max_tbatch_size) {
@ -688,7 +689,8 @@ HWY_NOINLINE void Prefill(
// NOTE: we unconditionally call StreamToken, even if EOS.
for (size_t ti = 0; ti < tbatch_size; ++ti) {
const size_t pos = queries_pos[qi] + ti;
const int token = queries_prompt[qi][pos];
const size_t pos_in_prompt = tbatch_start + ti;
const int token = queries_prompt[qi][pos_in_prompt];
runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f);
}
@ -780,15 +782,12 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
// Placeholder for internal test3, do not remove
// Returns the min and max number of tokens for all queries.
static void ScanQueryLengths(const QueriesPromptTokens& queries_prompt,
size_t& min_prompt_size, size_t& max_prompt_size) {
const size_t num_queries = queries_prompt.size();
min_prompt_size = hwy::LimitsMax<size_t>();
max_prompt_size = 0;
for (size_t i = 0; i < num_queries; ++i) {
min_prompt_size = std::min(min_prompt_size, queries_prompt[i].size());
static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) {
size_t max_prompt_size = 0;
for (size_t i = 0; i < queries_prompt.size(); ++i) {
max_prompt_size = std::max(max_prompt_size, queries_prompt[i].size());
}
return max_prompt_size;
}
// Holds "is at end of stream" state for each query.
@ -851,9 +850,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
HWY_ASSERT(kv_caches.size() == num_queries);
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
size_t min_prompt_size, max_prompt_size;
ScanQueryLengths(queries_prompt, min_prompt_size, max_prompt_size);
size_t max_prompt_size = MaxQueryLength(queries_prompt);
size_t max_tokens = runtime_config.max_tokens;
size_t max_generated_tokens = runtime_config.max_generated_tokens;
RangeChecks<TConfig>(max_tokens, max_generated_tokens, max_prompt_size);
@ -877,7 +874,6 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
// Prefill stops before min_prompt_size - 1 because the last prompt token is
// the first input token for generation.
const size_t prefill_per_query = min_prompt_size - 1;
const double prefill_start = hwy::platform::Now();
// If tbatch is larger than the qbatch we already have in `activations`, then
// allocate prefill_activations, otherwise reuse.
@ -888,11 +884,16 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
prefill_activations.Allocate<TConfig>(runtime_config.prefill_tbatch_size,
activations.env.Pools());
}
Prefill<TConfig>(queries_prompt, prefill_per_query, queries_mutable_pos,
query_idx_start, weights,
Prefill<TConfig>(queries_prompt, queries_mutable_pos, query_idx_start,
weights,
use_prefill_activations ? prefill_activations : activations,
runtime_config, div_seq_len, kv_caches);
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
// Compute the number of tokens that were prefilled and notify timing_info.
size_t prefilled_tokens = 0;
for (size_t qi = 0; qi < num_queries; ++qi) {
prefilled_tokens += queries_prompt[qi].size() - 1;
}
timing_info.NotifyPrefill(prefilled_tokens, prefill_start);
// queries_pos are incremented by Prefill.
// Storage for the last generated token from each query, passed to the next
@ -902,7 +903,9 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
// Stream the last prompt token from each query and fill gen_tokens.
TokenStreamer token_streamer(runtime_config);
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
gen_tokens[query_idx] = queries_prompt[query_idx][prefill_per_query];
size_t last_token_pos_in_prompt =
queries_mutable_pos[query_idx] - queries_pos_in[query_idx];
gen_tokens[query_idx] = queries_prompt[query_idx][last_token_pos_in_prompt];
(void)token_streamer(query_idx_start + query_idx,
queries_mutable_pos[query_idx], gen_tokens[query_idx],
0.0f);