6x large-batch, short-prompt prefill speedup

Parallelize over queries instead of tokens
introduce non_eos so we only iterate over not yet EOS queries; remove TokenStreamer.
move RMSNormInplaceBatched out of Transformer to call the latter from prefill
Consistent arg order.

Fix gemma_test EOS handling which (caught by msan), remove from tokenizer.h
Also add output to gemma_batch_bench, fix name

PiperOrigin-RevId: 769676106
This commit is contained in:
Jan Wassenberg 2025-06-10 09:55:51 -07:00 committed by Copybara-Service
parent d7b23d532a
commit ec02726cf7
5 changed files with 253 additions and 199 deletions

View File

@ -33,7 +33,7 @@ namespace {
// non-local static variables with dtors.
GemmaEnv* s_env = nullptr;
class GemmaTest : public ::testing::Test {
class GemmaBatchBench : public ::testing::Test {
protected:
std::vector<std::string> BatchGemmaReply(
const std::vector<std::string>& inputs) {
@ -48,7 +48,7 @@ class GemmaTest : public ::testing::Test {
}
};
TEST_F(GemmaTest, RandomQuestionsBatched) {
TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
const std::vector<std::string> questions = {
{"Write me a poem about Australia?"},
{"What's the history of Denmark?"},
@ -103,6 +103,7 @@ TEST_F(GemmaTest, RandomQuestionsBatched) {
} // namespace gcpp
int main(int argc, char** argv) {
fprintf(stderr, "GemmaEnv setup..\n");
gcpp::GemmaEnv env(argc, argv);
gcpp::s_env = &env;

View File

@ -102,7 +102,7 @@ TEST_F(GemmaTest, Multiturn) {
size_t abs_pos = 0;
std::string response;
auto stream_token = [&](int token, float) {
if (token == EOS_ID) return true;
if (config.IsEOS(token)) return true;
++abs_pos;
std::string token_text;
EXPECT_TRUE(

View File

@ -27,6 +27,7 @@
#include "util/allocator.h" // Allocator
#include "util/basics.h" // BF16
#include "util/mat.h" // MatStorageT
#include "hwy/profiler.h"
namespace gcpp {
@ -48,6 +49,7 @@ struct Activations {
seq_len(config.seq_len),
cache_pos_size(config.CachePosSize()),
is_griffin(config.model == Model::GRIFFIN_2B),
query_scale(ChooseQueryScale(config)),
x("x", Extents2D(batch_size, config.model_dim), pad_),
// `vocab_size == 0` means it is for Vit part, VitAttention is still MHA
@ -96,7 +98,7 @@ struct Activations {
layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope,
1000000.0)),
query_scale(ChooseQueryScale(config)) {
gen_tokens(batch_size) {
HWY_ASSERT(batch_size != 0);
// For MatMul outputs, precompute their row pointers.
@ -114,6 +116,7 @@ struct Activations {
}
void SetBatchSize(size_t batch_size) {
PROFILER_ZONE("SetBatchSize");
x.OverrideRows(batch_size);
q.OverrideRows(batch_size);
logits.OverrideRows(batch_size);
@ -134,13 +137,16 @@ struct Activations {
griffin_gate_x.OverrideRows(batch_size);
griffin_multiplier.OverrideRows(batch_size);
}
gen_tokens.resize(batch_size);
}
const ModelConfig& weights_config;
const LayerConfig& layer_config;
size_t seq_len;
size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT.
bool is_griffin = false;
bool is_griffin;
float query_scale;
const Extents2D none_ = Extents2D();
const MatPadding pad_ = MatPadding::kOdd;
@ -171,7 +177,9 @@ struct Activations {
MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global;
float query_scale;
// Storage for the last generated token from each query, passed to the next
// Transformer() call.
std::vector<int> gen_tokens; // one per query in the batch
};
} // namespace gcpp

View File

@ -181,21 +181,24 @@ EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt,
return image_token_position;
}
// Prefill() and Transformer() increment positions in-place.
// Incremented in-place by Prefill* and DecodeStepT.
using QueriesMutablePos = hwy::Span<size_t>;
// Populates KV cache for batches of tokens from one query at a time.
static HWY_NOINLINE void Prefill(
// Populates KV cache for batches of tokens from one query at a time. This is
// called if prompts are longer than the query batch size, and also in
// prefix-LM mode (end > 0), which must see all tokens in one batch.
static HWY_NOINLINE void PrefillTBatch(
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
const hwy::Divisor& div_seq_len, const ModelConfig& config,
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) {
PROFILER_ZONE("Gen.Prefill");
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
hwy::BitSet4096<>& non_eos) {
PROFILER_ZONE("Gen.PrefillT");
const size_t num_queries = queries_prompt.size();
HWY_DASSERT(queries_pos.size() == num_queries);
HWY_DASSERT(queries_prefix_end.size() == num_queries);
HWY_DASSERT(kv_caches.size() == num_queries);
HWY_DASSERT(num_queries == queries_pos.size());
HWY_DASSERT(num_queries == queries_prefix_end.size());
HWY_DASSERT(num_queries == kv_caches.size());
// Batches are important for amortizing loading weights over multiple tokens.
// This is possible in prefill because we know all tokens beforehand, whereas
@ -203,13 +206,15 @@ static HWY_NOINLINE void Prefill(
// a query requires that preceding batches already wrote to the KV cache,
// hence we sequentially loop over token batches. We can reduce the number of
// iterations by increasing the batch size, but this also increases arithmetic
// intensity, and so we are eventually compute-limited. We could devote some
// threads to parallelizing over queries, but for simplicity we assign them
// all to MatMul.
// intensity, and so we are eventually compute-limited. TransformerLayer uses
// all available threads, so we do not also parallelize over queries, but note
// that PrefillQBatch uses queries as the batch dimension.
const size_t max_tbatch_size = runtime_config.prefill_tbatch_size;
// For each query. `qi` is within the batch, not the global query index.
for (size_t qi = 0; qi < num_queries; ++qi) {
non_eos.Set(qi);
// Single query at a time, so pass slices of the spans because
// GemmaAttention will only access the first KV cache and position.
QueriesPos single_query_pos(&queries_pos[qi], 1);
@ -292,30 +297,34 @@ static HWY_NOINLINE void Prefill(
}
}
// Generates one token for each query. `queries_token` is the previous token
// from each query, and `queries_pos` are their position in the sequence.
// Embeds token and calls each TransformerLayer. `queries_token` is the previous
// token from each query, and `queries_pos` are their position in the sequence.
// Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the
// token-batched `PrefillTBatch`.
static HWY_NOINLINE void Transformer(
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
const ModelConfig& config, const ModelWeightsPtrs& weights,
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
const LayersOutputFunc& layers_output,
const ActivationsObserverFunc& activations_observer) {
const ModelConfig& config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env) {
const size_t num_queries = queries_token.size();
HWY_DASSERT(queries_pos.size() == num_queries);
HWY_DASSERT(queries_prefix_end.size() == num_queries);
HWY_DASSERT(num_queries == queries_pos.size());
HWY_DASSERT(num_queries == queries_prefix_end.size());
if (layers_output) {
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
const float token_f = queries_token[query_idx];
layers_output(query_idx, queries_pos[query_idx], "tokens", -1, &token_f,
if (HWY_UNLIKELY(runtime_config.layers_output)) {
for (size_t qi = 0; qi < num_queries; ++qi) {
const float token_f = queries_token[qi];
runtime_config.layers_output(qi, queries_pos[qi], "tokens", -1, &token_f,
1);
}
}
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
EmbedMMToken(queries_token[query_idx], query_idx, queries_pos[query_idx],
/*pos_in_prompt=*/0, config, weights, activations.x);
size_t image_token_position = 0;
for (size_t qi = 0; qi < num_queries; ++qi) {
image_token_position =
EmbedMMToken(queries_token[qi], qi, queries_pos[qi],
/*pos_in_prompt=*/0, config, weights, activations.x,
runtime_config.image_tokens, image_token_position);
}
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
@ -323,21 +332,71 @@ static HWY_NOINLINE void Transformer(
div_seq_len, layer_idx, *weights.GetLayer(layer_idx),
activations, kv_caches, env);
if (activations_observer) {
activations_observer(queries_pos, layer_idx, activations);
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(queries_pos, layer_idx, activations);
}
}
RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
if (activations_observer) {
activations_observer(queries_pos, -1, activations);
}
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
queries_pos[query_idx] += 1;
}
}
// Populates KV cache for the batch queries, one token at a time. Only called
// for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0.
static HWY_NOINLINE void PrefillQBatch(
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
const size_t max_prompt_size, const hwy::Divisor& div_seq_len,
const ModelConfig& config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos) {
PROFILER_ZONE("Gen.Prefill");
const size_t num_queries = queries_prompt.size();
HWY_DASSERT(num_queries == queries_pos.size());
HWY_DASSERT(num_queries == queries_prefix_end.size());
HWY_DASSERT(num_queries == activations.x.Rows());
HWY_DASSERT(num_queries == kv_caches.size());
hwy::BitSet4096<> prefill_active;
for (size_t qi = 0; qi < num_queries; ++qi) {
prefill_active.Set(qi);
HWY_DASSERT(queries_prefix_end[qi] == 0);
(void)queries_prefix_end;
}
non_eos = prefill_active;
// In autoregressive mode, we don't prefill the last token, hence - 1.
for (size_t pos_in_prompt = 0; pos_in_prompt < max_prompt_size - 1;
++pos_in_prompt) {
// Streams that have already finished prefill no longer interleave/stream.
for (size_t qi = 0; qi < num_queries; ++qi) {
if (pos_in_prompt >= queries_prompt[qi].size() - 1) {
prefill_active.Clear(qi);
activations.gen_tokens[qi] = config.eos_id;
}
}
// Batch := interleaved tokens, one from each non-EOS query.
prefill_active.Foreach([&](size_t qi) {
activations.gen_tokens[qi] = queries_prompt[qi][pos_in_prompt];
});
// One token from each query in the batch. Increments queries_pos.
// Do not call DecodeStepT because it computes logits for token
// probabilities, which are not required for the prompt tokens.
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
queries_pos, queries_prefix_end, div_seq_len, config,
runtime_config, weights, activations, kv_caches, env);
prefill_active.Foreach([&](size_t qi) {
const int token = queries_prompt[qi][pos_in_prompt];
// Ignore any user request to stop during prefill.
(void)runtime_config.StreamToken(query_idx_start + qi, queries_pos[qi],
token, 0.0f);
queries_pos[qi] += 1;
});
} // pos_in_prompt
}
// TODO: inline.
void RangeChecks(const ModelConfig& weights_config,
size_t& max_generated_tokens, const size_t prompt_size) {
if (!weights_config.use_local_attention) {
@ -350,56 +409,49 @@ void RangeChecks(const ModelConfig& weights_config,
HWY_ASSERT(prompt_size > 0);
}
// Holds "is at end of stream" state for each query.
class TokenStreamer {
public:
TokenStreamer(const RuntimeConfig& runtime_config,
const ModelConfig& model_config)
: runtime_config_(runtime_config), model_config_(model_config) {}
// Also writes the token to activations.gen_tokens for subsequent DecodeStepT,
// and updates `non_eos` if the query is at the end of its sequence.
static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token,
const float prob, const ModelConfig& config,
const RuntimeConfig& runtime_config,
Activations& activations,
hwy::BitSet4096<>& non_eos) {
HWY_DASSERT(non_eos.Get(qi));
// Returns whether the query was already at, or has just reached, the end of
// the stream: either via token == eos_id, or StreamToken returning false.
bool operator()(size_t query_idx, size_t pos, int token, float prob) {
if (HWY_UNLIKELY(is_eos_.Get(query_idx))) return true;
if (!runtime_config_.StreamToken(query_idx, pos, token, prob) ||
model_config_.IsEOS(token)) {
is_eos_.Set(query_idx);
return true;
// User decided to stop: set next token to primary EOS.
if (HWY_UNLIKELY(!runtime_config.StreamToken(qi, pos, token, prob))) {
token = config.eos_id;
}
return false;
}
// Primary or secondary EOS: mark query as EOS.
if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi);
private:
const RuntimeConfig& runtime_config_;
const ModelConfig& model_config_;
hwy::BitSet4096<> is_eos_;
};
activations.gen_tokens[qi] = token;
}
// Runs one decode step for all the queries in the batch. Returns true if all
// queries are at <end_of_sentence>.
static bool DecodeStepT(
const ModelConfig& config, const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config, const size_t query_idx_start,
const QueriesPromptTokens& queries_prompt,
// For a batch of queries, runs Transformer, computes logits, samples and
// streams the token.
static void DecodeStepT(
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
const QueriesMutablePos& queries_mutable_pos,
const QueriesPos& queries_prefix_end, const hwy::Divisor div_seq_len,
const size_t vocab_size, const SampleFunc& sample_token,
Activations& activations, const KVCaches& kv_caches,
TokenStreamer& token_streamer, std::vector<int>& gen_tokens,
TimingInfo& timing_info, MatMulEnv& env) {
const ModelConfig& config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, const SampleFunc& sample_token,
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) {
const size_t num_queries = queries_prompt.size();
// Decode generates one token per query and increments
// queries_mutable_pos.
Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos,
queries_prefix_end, div_seq_len, config, weights, activations,
kv_caches, env, runtime_config.layers_output,
runtime_config.activations_observer);
// queries_pos are incremented by Transformer.
HWY_DASSERT(num_queries == activations.x.Rows());
bool all_queries_eos = true;
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
queries_mutable_pos, queries_prefix_end, div_seq_len, config,
runtime_config, weights, activations, kv_caches, env);
RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(queries_mutable_pos, -1, activations);
}
{
PROFILER_ZONE("Gen.EmbeddingMatmul");
// Compute logits from last layer activations.
@ -407,19 +459,17 @@ static bool DecodeStepT(
/*add=*/nullptr, env, activations.logits);
}
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
float* HWY_RESTRICT logits = activations.logits.Row(query_idx);
MaybeLogitsSoftCap(config.final_cap, logits, vocab_size);
const TokenAndProb tp = sample_token(logits, vocab_size);
non_eos.Foreach([&](size_t qi) {
float* HWY_RESTRICT logits = activations.logits.Row(qi);
MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size);
const TokenAndProb tp = sample_token(logits, config.vocab_size);
timing_info.NotifyGenerated();
const bool is_eos =
token_streamer(query_idx_start + query_idx,
queries_mutable_pos[query_idx], tp.token, tp.prob);
all_queries_eos &= is_eos;
gen_tokens[query_idx] = is_eos ? config.eos_id : tp.token;
}
return all_queries_eos;
StreamAndUpdateEOS(query_idx_start + qi, queries_mutable_pos[qi], tp.token,
tp.prob, config, runtime_config, activations, non_eos);
if (non_eos.Get(qi)) queries_mutable_pos[qi] += 1;
});
}
static HWY_INLINE SampleFunc
@ -445,33 +495,26 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) {
};
}
// Returns the min and max number of tokens for all queries.
static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) {
size_t max_prompt_size = 0;
for (size_t i = 0; i < queries_prompt.size(); ++i) {
max_prompt_size = HWY_MAX(max_prompt_size, queries_prompt[i].size());
}
return max_prompt_size;
}
// Generates one continuation for each query in `queries_prompt`, which is one
// qbatch whose size is at most the `batch_size` passed to
// `activations.Allocate`.
// qbatch whose size is at most the `batch_size` passed to `activations` ctor.
//
// `queries_pos` stores the KV cache position for each query. In the first turn
// of a chat, pos = 0; we increment each query's position after each token.
//
// `query_idx_start` is the query_idx of the first query in the batch, so that
// `StreamFunc` gets the global query index, not relative to the batch.
//
// `kv_caches` is for the batch, size must match `queries_prompt`.
static void GenerateT(
const ModelConfig& config, const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config, const size_t query_idx_start,
const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos_in,
const QueriesPos& queries_prefix_end, Activations& activations,
const KVCaches& kv_caches, TimingInfo& timing_info, MatMulEnv& env) {
HWY_ASSERT(queries_pos_in.size() == kv_caches.size());
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos_in, const QueriesPos& queries_prefix_end,
const ModelConfig& config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env, TimingInfo& timing_info) {
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(num_queries <= 4096); // non_eos uses `BitSet4096`.
HWY_ASSERT(num_queries == queries_pos_in.size());
HWY_ASSERT(num_queries == queries_prefix_end.size());
HWY_ASSERT(num_queries <= activations.x.Rows());
HWY_ASSERT(num_queries == kv_caches.size());
// Griffin assumes that the recurrent block cache is zero-initialized.
for (size_t i = 0; i < kv_caches.size(); ++i) {
@ -486,75 +529,79 @@ static void GenerateT(
const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(),
queries_pos_copy.size());
size_t max_prompt_size = 0;
bool all_prefix_end_are_zero = true;
size_t prefill_tokens = 0;
for (size_t qi = 0; qi < num_queries; ++qi) {
const PromptTokens& prompt = queries_prompt[qi];
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
// Prefill stops before size - 1 because the last prompt token is the
// first input token for generation.
prefill_tokens += prompt.size() - 1;
// Sanity check: prompts should not be empty, nor start with EOS.
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
const PromptTokens& prompt = queries_prompt[query_idx];
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
all_prefix_end_are_zero &= queries_prefix_end[qi] == 0;
}
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.
HWY_ASSERT(num_queries <= activations.x.Rows());
HWY_ASSERT(queries_pos_in.size() == num_queries);
HWY_ASSERT(kv_caches.size() == num_queries);
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
size_t max_prompt_size = MaxQueryLength(queries_prompt);
size_t max_generated_tokens = runtime_config.max_generated_tokens;
RangeChecks(config, max_generated_tokens, max_prompt_size);
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
// qi loops anyway.
hwy::BitSet4096<> non_eos;
timing_info.prefill_start = hwy::platform::Now();
// Batch over the larger of prompt length, or queries.
if ((num_queries > max_prompt_size) && all_prefix_end_are_zero) {
activations.SetBatchSize(num_queries); // required before PrefillQBatch
PrefillQBatch(query_idx_start, queries_prompt, queries_mutable_pos,
queries_prefix_end, max_prompt_size, div_seq_len, config,
runtime_config, weights, activations, kv_caches, env,
non_eos);
} else {
PrefillTBatch(query_idx_start, queries_prompt, queries_mutable_pos,
queries_prefix_end, div_seq_len, config, runtime_config,
weights, activations, kv_caches, env, non_eos);
activations.SetBatchSize(num_queries); // Restore after PrefillTBatch.
}
HWY_DASSERT(num_queries == non_eos.Count());
timing_info.NotifyPrefill(prefill_tokens);
// queries_pos have been incremented by Prefill.
// Stream the last prompt token from each query, fill activations.gen_tokens.
for (size_t qi = 0; qi < num_queries; ++qi) {
const size_t last_token_pos_in_prompt =
queries_mutable_pos[qi] - queries_pos_in[qi];
StreamAndUpdateEOS(query_idx_start + qi, queries_mutable_pos[qi],
queries_prompt[qi][last_token_pos_in_prompt], 0.0f,
config, runtime_config, activations, non_eos);
// No incrementing queries_mutable_pos[qi].
}
size_t max_gen_steps = runtime_config.max_generated_tokens;
RangeChecks(config, max_gen_steps, max_prompt_size);
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
// Prefill stops before min_prompt_size - 1 because the last prompt
// token is the first input token for generation.
timing_info.prefill_start = hwy::platform::Now();
// Note that Prefill calls activations.SetBatchSize, so we reset it below.
Prefill(query_idx_start, queries_prompt, queries_mutable_pos,
queries_prefix_end, div_seq_len, config, runtime_config, weights,
activations, kv_caches, env);
// Compute the number of tokens that were prefilled and notify timing_info.
size_t prefilled_tokens = 0;
for (size_t qi = 0; qi < num_queries; ++qi) {
prefilled_tokens += queries_prompt[qi].size() - 1;
}
timing_info.NotifyPrefill(prefilled_tokens);
// queries_pos are incremented by Prefill.
activations.SetBatchSize(num_queries);
// Storage for the last generated token from each query, passed to the next
// Transformer() call.
std::vector<int> gen_tokens(num_queries);
// Stream the last prompt token from each query and fill gen_tokens.
TokenStreamer token_streamer(runtime_config, config);
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
size_t last_token_pos_in_prompt =
queries_mutable_pos[query_idx] - queries_pos_in[query_idx];
gen_tokens[query_idx] = queries_prompt[query_idx][last_token_pos_in_prompt];
(void)token_streamer(query_idx_start + query_idx,
queries_mutable_pos[query_idx], gen_tokens[query_idx],
0.0f);
}
{
const size_t vocab_size = config.vocab_size;
timing_info.generate_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
bool all_queries_eos =
DecodeStepT(config, weights, runtime_config, query_idx_start,
queries_prompt, queries_mutable_pos, queries_prefix_end,
div_seq_len, vocab_size, sample_token, activations,
kv_caches, token_streamer, gen_tokens, timing_info, env);
if (all_queries_eos) break;
} // foreach token to generate
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
DecodeStepT(query_idx_start, queries_prompt, queries_mutable_pos,
queries_prefix_end, div_seq_len, config, runtime_config,
weights, sample_token, activations, kv_caches, env, non_eos,
timing_info);
}
timing_info.NotifyGenerateDone();
}
}
void GenerateSingleT(const ModelConfig& config, const ModelWeightsPtrs& weights,
void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end,
const ModelConfig& config,
const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, size_t prefix_end,
KVCache& kv_cache, MatMulEnv& env,
TimingInfo& timing_info) {
const ModelWeightsPtrs& weights, KVCache& kv_cache,
MatMulEnv& env, TimingInfo& timing_info) {
constexpr size_t kNumQueries = 1;
const size_t qbatch_start = 0;
@ -568,25 +615,27 @@ void GenerateSingleT(const ModelConfig& config, const ModelWeightsPtrs& weights,
const QueriesPos queries_prefix_end(&prefix_end, kNumQueries);
const KVCaches kv_caches{&kv_cache, kNumQueries};
GenerateT(config, weights, runtime_config, qbatch_start, queries_prompt,
queries_pos, queries_prefix_end, activations, kv_caches,
timing_info, env);
GenerateT(qbatch_start, queries_prompt, queries_pos, queries_prefix_end,
config, runtime_config, weights, activations, kv_caches, env,
timing_info);
}
void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
// Splits the input into batches of at most `runtime_config.decode_qbatch_size`
// queries, and calls `GenerateT` on each batch.
void GenerateBatchT(const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const KVCaches& kv_caches, MatMulEnv& env,
TimingInfo& timing_info) {
const ModelConfig& config,
const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, const KVCaches& kv_caches,
MatMulEnv& env, TimingInfo& timing_info) {
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(queries_pos.size() == num_queries);
HWY_ASSERT(kv_caches.size() >= num_queries);
const size_t max_qbatch_size = runtime_config.decode_qbatch_size;
const size_t max_batch_size =
HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size);
Activations activations(config, max_batch_size, env.row_ptrs);
for (size_t qbatch_start = 0; qbatch_start < num_queries;
@ -600,17 +649,16 @@ void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights,
const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start],
qbatch_size);
const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size);
GenerateT(config, weights, runtime_config, qbatch_start, qbatch_prompts,
qbatch_pos, qbatch_prefix_end, activations, qbatch_kv,
timing_info, env);
GenerateT(qbatch_start, qbatch_prompts, qbatch_pos, qbatch_prefix_end,
config, runtime_config, weights, activations, qbatch_kv, env,
timing_info);
}
}
void GenerateImageTokensT(const ModelConfig& config,
const ModelWeightsPtrs& weights,
const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens,
MatMulEnv& env) {
const ModelWeightsPtrs& weights, const Image& image,
ImageTokens& image_tokens, MatMulEnv& env) {
if (config.vit_config.layer_configs.empty()) {
HWY_ABORT("Model does not support generating image tokens.");
}
@ -667,9 +715,9 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
KVCache& kv_cache, TimingInfo& timing_info) const {
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(model_.Config(), weights_,
runtime_config, prompt, pos, prefix_end,
kv_cache, env_, timing_info);
HWY_DYNAMIC_DISPATCH(GenerateSingleT)(prompt, pos, prefix_end,
model_.Config(), runtime_config,
weights_, kv_cache, env_, timing_info);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}
@ -681,19 +729,19 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
const KVCaches& kv_caches,
TimingInfo& timing_info) const {
// If we did not get passed prefix ends (size 0), assume 0 and pass that on.
QueriesPos mutable_queries_prefix_end = queries_prefix_end;
QueriesPos queries_prefix_end_or_zeros = queries_prefix_end;
std::vector<size_t> prefix_end_vec;
if (queries_prefix_end.size() == 0) { // hwy::Span lacks empty()
prefix_end_vec.resize(queries_prompt.size(), 0);
mutable_queries_prefix_end =
queries_prefix_end_or_zeros =
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
}
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
HWY_DYNAMIC_DISPATCH(GenerateBatchT)(
model_.Config(), weights_, runtime_config, queries_prompt, queries_pos,
mutable_queries_prefix_end, kv_caches, env_, timing_info);
queries_prompt, queries_pos, queries_prefix_end_or_zeros, model_.Config(),
runtime_config, weights_, kv_caches, env_, timing_info);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}
@ -704,7 +752,7 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(
model_.Config(), weights_, runtime_config, image, image_tokens, env_);
model_.Config(), runtime_config, weights_, image, image_tokens, env_);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}

View File

@ -26,10 +26,7 @@
namespace gcpp {
// The tokenizer's end of sentence and beginning of sentence token ids.
constexpr int EOS_ID = 1;
constexpr int SECONDARY_EOS_ID = 106; // for Gemma 3
constexpr int BOS_ID = 2;
constexpr int BOS_ID = 2; // beginning of sequence
// To avoid the complexity of storing the tokenizer into testdata/ or
// downloading from gs://, while still always writing a blob for the tokenizer,