mirror of https://github.com/google/gemma.cpp.git
Fix prefill for batched queries.
This lets gemma_test/GeographyBatched pass now also for gemma2-27B. PiperOrigin-RevId: 664827485
This commit is contained in:
parent
c6eb3b6f0d
commit
18e6012872
|
|
@ -149,6 +149,45 @@ TEST_F(GemmaTest, Arithmetic) {
|
|||
TestQuestions(kQA, kNum, /*batch=*/false);
|
||||
}
|
||||
|
||||
TEST_F(GemmaTest, Multiturn) {
|
||||
Gemma* model = s_env->GetModel();
|
||||
ASSERT_NE(model, nullptr);
|
||||
size_t abs_pos = 0;
|
||||
std::string dialog;
|
||||
auto stream_token = [&](int token, float) {
|
||||
++abs_pos;
|
||||
std::string token_text;
|
||||
EXPECT_TRUE(
|
||||
model->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
dialog += token_text;
|
||||
return true;
|
||||
};
|
||||
RuntimeConfig runtime_config{
|
||||
.max_tokens = 128,
|
||||
.max_generated_tokens = 64,
|
||||
.temperature = 0.0f,
|
||||
.verbosity = 2,
|
||||
.gen = &s_env->MutableGen(),
|
||||
.stream_token = stream_token,
|
||||
};
|
||||
TimingInfo timing_info{.verbosity = 0};
|
||||
// First "say" something slightly unusual.
|
||||
std::string mutable_prompt = "The color of my car is turquoise.";
|
||||
std::vector<int> tokens = WrapAndTokenize(model->Tokenizer(), model->Info(),
|
||||
abs_pos, mutable_prompt);
|
||||
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||
timing_info);
|
||||
mutable_prompt = "Can you repeat to me what I just said?";
|
||||
tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos,
|
||||
mutable_prompt);
|
||||
// Reset the `dialog` string here, then check that the model actually has
|
||||
// access to the previous turn by asking to reproduce.
|
||||
dialog.clear();
|
||||
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||
timing_info);
|
||||
EXPECT_TRUE(dialog.find("turquoise") != std::string::npos); // NOLINT
|
||||
}
|
||||
|
||||
static const char kJingleBells[] = R"(
|
||||
Dashing through the snow
|
||||
In a one-horse open sleigh
|
||||
|
|
|
|||
|
|
@ -637,7 +637,7 @@ using QueriesMutablePos = hwy::Span<size_t>;
|
|||
// Populates KV cache for batches of tokens from one query at a time.
|
||||
template <class TConfig>
|
||||
HWY_NOINLINE void Prefill(
|
||||
const QueriesPromptTokens& queries_prompt, const size_t prefill_per_query,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesMutablePos& queries_pos, const size_t query_idx_start,
|
||||
const CompressedWeights<TConfig>& weights, Activations& activations,
|
||||
const RuntimeConfig& runtime_config, const hwy::Divisor& div_seq_len,
|
||||
|
|
@ -665,6 +665,7 @@ HWY_NOINLINE void Prefill(
|
|||
QueriesPos single_query_pos(&queries_pos[qi], 1);
|
||||
KVCaches single_kv_cache(&kv_caches[qi], 1);
|
||||
|
||||
const size_t prefill_per_query = queries_prompt[qi].size() - 1;
|
||||
// For each batch of tokens in the query:
|
||||
for (size_t tbatch_start = 0; tbatch_start < prefill_per_query;
|
||||
tbatch_start += max_tbatch_size) {
|
||||
|
|
@ -688,7 +689,8 @@ HWY_NOINLINE void Prefill(
|
|||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||
for (size_t ti = 0; ti < tbatch_size; ++ti) {
|
||||
const size_t pos = queries_pos[qi] + ti;
|
||||
const int token = queries_prompt[qi][pos];
|
||||
const size_t pos_in_prompt = tbatch_start + ti;
|
||||
const int token = queries_prompt[qi][pos_in_prompt];
|
||||
runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f);
|
||||
}
|
||||
|
||||
|
|
@ -780,15 +782,12 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
|
|||
// Placeholder for internal test3, do not remove
|
||||
|
||||
// Returns the min and max number of tokens for all queries.
|
||||
static void ScanQueryLengths(const QueriesPromptTokens& queries_prompt,
|
||||
size_t& min_prompt_size, size_t& max_prompt_size) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
min_prompt_size = hwy::LimitsMax<size_t>();
|
||||
max_prompt_size = 0;
|
||||
for (size_t i = 0; i < num_queries; ++i) {
|
||||
min_prompt_size = std::min(min_prompt_size, queries_prompt[i].size());
|
||||
static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) {
|
||||
size_t max_prompt_size = 0;
|
||||
for (size_t i = 0; i < queries_prompt.size(); ++i) {
|
||||
max_prompt_size = std::max(max_prompt_size, queries_prompt[i].size());
|
||||
}
|
||||
return max_prompt_size;
|
||||
}
|
||||
|
||||
// Holds "is at end of stream" state for each query.
|
||||
|
|
@ -851,9 +850,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
HWY_ASSERT(kv_caches.size() == num_queries);
|
||||
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
||||
|
||||
size_t min_prompt_size, max_prompt_size;
|
||||
ScanQueryLengths(queries_prompt, min_prompt_size, max_prompt_size);
|
||||
|
||||
size_t max_prompt_size = MaxQueryLength(queries_prompt);
|
||||
size_t max_tokens = runtime_config.max_tokens;
|
||||
size_t max_generated_tokens = runtime_config.max_generated_tokens;
|
||||
RangeChecks<TConfig>(max_tokens, max_generated_tokens, max_prompt_size);
|
||||
|
|
@ -877,7 +874,6 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
|
||||
// Prefill stops before min_prompt_size - 1 because the last prompt token is
|
||||
// the first input token for generation.
|
||||
const size_t prefill_per_query = min_prompt_size - 1;
|
||||
const double prefill_start = hwy::platform::Now();
|
||||
// If tbatch is larger than the qbatch we already have in `activations`, then
|
||||
// allocate prefill_activations, otherwise reuse.
|
||||
|
|
@ -888,11 +884,16 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
prefill_activations.Allocate<TConfig>(runtime_config.prefill_tbatch_size,
|
||||
activations.env.Pools());
|
||||
}
|
||||
Prefill<TConfig>(queries_prompt, prefill_per_query, queries_mutable_pos,
|
||||
query_idx_start, weights,
|
||||
Prefill<TConfig>(queries_prompt, queries_mutable_pos, query_idx_start,
|
||||
weights,
|
||||
use_prefill_activations ? prefill_activations : activations,
|
||||
runtime_config, div_seq_len, kv_caches);
|
||||
timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start);
|
||||
// Compute the number of tokens that were prefilled and notify timing_info.
|
||||
size_t prefilled_tokens = 0;
|
||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||
prefilled_tokens += queries_prompt[qi].size() - 1;
|
||||
}
|
||||
timing_info.NotifyPrefill(prefilled_tokens, prefill_start);
|
||||
// queries_pos are incremented by Prefill.
|
||||
|
||||
// Storage for the last generated token from each query, passed to the next
|
||||
|
|
@ -902,7 +903,9 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
// Stream the last prompt token from each query and fill gen_tokens.
|
||||
TokenStreamer token_streamer(runtime_config);
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
gen_tokens[query_idx] = queries_prompt[query_idx][prefill_per_query];
|
||||
size_t last_token_pos_in_prompt =
|
||||
queries_mutable_pos[query_idx] - queries_pos_in[query_idx];
|
||||
gen_tokens[query_idx] = queries_prompt[query_idx][last_token_pos_in_prompt];
|
||||
(void)token_streamer(query_idx_start + query_idx,
|
||||
queries_mutable_pos[query_idx], gen_tokens[query_idx],
|
||||
0.0f);
|
||||
|
|
|
|||
Loading…
Reference in New Issue