mirror of https://github.com/google/gemma.cpp.git
Fix prefill for batched queries.
This lets gemma_test/GeographyBatched pass now also for gemma2-27B. PiperOrigin-RevId: 664827485
This commit is contained in:
parent
c6eb3b6f0d
commit
18e6012872
|
|
@ -149,6 +149,45 @@ TEST_F(GemmaTest, Arithmetic) {
|
||||||
TestQuestions(kQA, kNum, /*batch=*/false);
|
TestQuestions(kQA, kNum, /*batch=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(GemmaTest, Multiturn) {
|
||||||
|
Gemma* model = s_env->GetModel();
|
||||||
|
ASSERT_NE(model, nullptr);
|
||||||
|
size_t abs_pos = 0;
|
||||||
|
std::string dialog;
|
||||||
|
auto stream_token = [&](int token, float) {
|
||||||
|
++abs_pos;
|
||||||
|
std::string token_text;
|
||||||
|
EXPECT_TRUE(
|
||||||
|
model->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||||
|
dialog += token_text;
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
RuntimeConfig runtime_config{
|
||||||
|
.max_tokens = 128,
|
||||||
|
.max_generated_tokens = 64,
|
||||||
|
.temperature = 0.0f,
|
||||||
|
.verbosity = 2,
|
||||||
|
.gen = &s_env->MutableGen(),
|
||||||
|
.stream_token = stream_token,
|
||||||
|
};
|
||||||
|
TimingInfo timing_info{.verbosity = 0};
|
||||||
|
// First "say" something slightly unusual.
|
||||||
|
std::string mutable_prompt = "The color of my car is turquoise.";
|
||||||
|
std::vector<int> tokens = WrapAndTokenize(model->Tokenizer(), model->Info(),
|
||||||
|
abs_pos, mutable_prompt);
|
||||||
|
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||||
|
timing_info);
|
||||||
|
mutable_prompt = "Can you repeat to me what I just said?";
|
||||||
|
tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos,
|
||||||
|
mutable_prompt);
|
||||||
|
// Reset the `dialog` string here, then check that the model actually has
|
||||||
|
// access to the previous turn by asking to reproduce.
|
||||||
|
dialog.clear();
|
||||||
|
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||||
|
timing_info);
|
||||||
|
EXPECT_TRUE(dialog.find("turquoise") != std::string::npos); // NOLINT
|
||||||
|
}
|
||||||
|
|
||||||
static const char kJingleBells[] = R"(
|
static const char kJingleBells[] = R"(
|
||||||
Dashing through the snow
|
Dashing through the snow
|
||||||
In a one-horse open sleigh
|
In a one-horse open sleigh
|
||||||
|
|
|
||||||
|
|
@ -637,7 +637,7 @@ using QueriesMutablePos = hwy::Span<size_t>;
|
||||||
// Populates KV cache for batches of tokens from one query at a time.
|
// Populates KV cache for batches of tokens from one query at a time.
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
HWY_NOINLINE void Prefill(
|
HWY_NOINLINE void Prefill(
|
||||||
const QueriesPromptTokens& queries_prompt, const size_t prefill_per_query,
|
const QueriesPromptTokens& queries_prompt,
|
||||||
const QueriesMutablePos& queries_pos, const size_t query_idx_start,
|
const QueriesMutablePos& queries_pos, const size_t query_idx_start,
|
||||||
const CompressedWeights<TConfig>& weights, Activations& activations,
|
const CompressedWeights<TConfig>& weights, Activations& activations,
|
||||||
const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len,
|
const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len,
|
||||||
|
|
@ -665,6 +665,7 @@ HWY_NOINLINE void Prefill(
|
||||||
QueriesPos single_query_pos(&queries_pos[qi], 1);
|
QueriesPos single_query_pos(&queries_pos[qi], 1);
|
||||||
KVCaches single_kv_cache(&kv_caches[qi], 1);
|
KVCaches single_kv_cache(&kv_caches[qi], 1);
|
||||||
|
|
||||||
|
const size_t prefill_per_query = queries_prompt[qi].size() - 1;
|
||||||
// For each batch of tokens in the query:
|
// For each batch of tokens in the query:
|
||||||
for (size_t tbatch_start = 0; tbatch_start < prefill_per_query;
|
for (size_t tbatch_start = 0; tbatch_start < prefill_per_query;
|
||||||
tbatch_start += max_tbatch_size) {
|
tbatch_start += max_tbatch_size) {
|
||||||
|
|
@ -688,7 +689,8 @@ HWY_NOINLINE void Prefill(
|
||||||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||||
const size_t pos = queries_pos[qi] + ti;
|
const size_t pos = queries_pos[qi] + ti;
|
||||||
const int token = queries_prompt[qi][pos];
|
const size_t pos_in_prompt = tbatch_start + ti;
|
||||||
|
const int token = queries_prompt[qi][pos_in_prompt];
|
||||||
runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f);
|
runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -780,15 +782,12 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
||||||
// Placeholder for internal test3, do not remove
|
// Placeholder for internal test3, do not remove
|
||||||
|
|
||||||
// Returns the min and max number of tokens for all queries.
|
// Returns the min and max number of tokens for all queries.
|
||||||
static void ScanQueryLengths(const QueriesPromptTokens& queries_prompt,
|
static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) {
|
||||||
size_t& min_prompt_size, size_t& max_prompt_size) {
|
size_t max_prompt_size = 0;
|
||||||
const size_t num_queries = queries_prompt.size();
|
for (size_t i = 0; i < queries_prompt.size(); ++i) {
|
||||||
min_prompt_size = hwy::LimitsMax<size_t>();
|
|
||||||
max_prompt_size = 0;
|
|
||||||
for (size_t i = 0; i < num_queries; ++i) {
|
|
||||||
min_prompt_size = std::min(min_prompt_size, queries_prompt[i].size());
|
|
||||||
max_prompt_size = std::max(max_prompt_size, queries_prompt[i].size());
|
max_prompt_size = std::max(max_prompt_size, queries_prompt[i].size());
|
||||||
}
|
}
|
||||||
|
return max_prompt_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Holds "is at end of stream" state for each query.
|
// Holds "is at end of stream" state for each query.
|
||||||
|
|
@ -851,9 +850,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
HWY_ASSERT(kv_caches.size() == num_queries);
|
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||||
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
||||||
|
|
||||||
size_t min_prompt_size, max_prompt_size;
|
size_t max_prompt_size = MaxQueryLength(queries_prompt);
|
||||||
ScanQueryLengths(queries_prompt, min_prompt_size, max_prompt_size);
|
|
||||||
|
|
||||||
size_t max_tokens = runtime_config.max_tokens;
|
size_t max_tokens = runtime_config.max_tokens;
|
||||||
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
||||||
RangeChecks<TConfig>(max_tokens, max_generated_tokens, max_prompt_size);
|
RangeChecks<TConfig>(max_tokens, max_generated_tokens, max_prompt_size);
|
||||||
|
|
@ -877,7 +874,6 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
|
|
||||||
// Prefill stops before min_prompt_size - 1 because the last prompt token is
|
// Prefill stops before min_prompt_size - 1 because the last prompt token is
|
||||||
// the first input token for generation.
|
// the first input token for generation.
|
||||||
const size_t prefill_per_query = min_prompt_size - 1;
|
|
||||||
const double prefill_start = hwy::platform::Now();
|
const double prefill_start = hwy::platform::Now();
|
||||||
// If tbatch is larger than the qbatch we already have in `activations`, then
|
// If tbatch is larger than the qbatch we already have in `activations`, then
|
||||||
// allocate prefill_activations, otherwise reuse.
|
// allocate prefill_activations, otherwise reuse.
|
||||||
|
|
@ -888,11 +884,16 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
prefill_activations.Allocate<TConfig>(runtime_config.prefill_tbatch_size,
|
prefill_activations.Allocate<TConfig>(runtime_config.prefill_tbatch_size,
|
||||||
activations.env.Pools());
|
activations.env.Pools());
|
||||||
}
|
}
|
||||||
Prefill<TConfig>(queries_prompt, prefill_per_query, queries_mutable_pos,
|
Prefill<TConfig>(queries_prompt, queries_mutable_pos, query_idx_start,
|
||||||
query_idx_start, weights,
|
weights,
|
||||||
use_prefill_activations ? prefill_activations : activations,
|
use_prefill_activations ? prefill_activations : activations,
|
||||||
runtime_config, div_seq_len, kv_caches);
|
runtime_config, div_seq_len, kv_caches);
|
||||||
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
|
// Compute the number of tokens that were prefilled and notify timing_info.
|
||||||
|
size_t prefilled_tokens = 0;
|
||||||
|
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||||
|
prefilled_tokens += queries_prompt[qi].size() - 1;
|
||||||
|
}
|
||||||
|
timing_info.NotifyPrefill(prefilled_tokens, prefill_start);
|
||||||
// queries_pos are incremented by Prefill.
|
// queries_pos are incremented by Prefill.
|
||||||
|
|
||||||
// Storage for the last generated token from each query, passed to the next
|
// Storage for the last generated token from each query, passed to the next
|
||||||
|
|
@ -902,7 +903,9 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
||||||
// Stream the last prompt token from each query and fill gen_tokens.
|
// Stream the last prompt token from each query and fill gen_tokens.
|
||||||
TokenStreamer token_streamer(runtime_config);
|
TokenStreamer token_streamer(runtime_config);
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||||
gen_tokens[query_idx] = queries_prompt[query_idx][prefill_per_query];
|
size_t last_token_pos_in_prompt =
|
||||||
|
queries_mutable_pos[query_idx] - queries_pos_in[query_idx];
|
||||||
|
gen_tokens[query_idx] = queries_prompt[query_idx][last_token_pos_in_prompt];
|
||||||
(void)token_streamer(query_idx_start + query_idx,
|
(void)token_streamer(query_idx_start + query_idx,
|
||||||
queries_mutable_pos[query_idx], gen_tokens[query_idx],
|
queries_mutable_pos[query_idx], gen_tokens[query_idx],
|
||||||
0.0f);
|
0.0f);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue