diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index 9d6f827..02a6a7c 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -33,7 +33,7 @@ namespace { // non-local static variables with dtors. GemmaEnv* s_env = nullptr; -class GemmaTest : public ::testing::Test { +class GemmaBatchBench : public ::testing::Test { protected: std::vector BatchGemmaReply( const std::vector& inputs) { @@ -48,7 +48,7 @@ class GemmaTest : public ::testing::Test { } }; -TEST_F(GemmaTest, RandomQuestionsBatched) { +TEST_F(GemmaBatchBench, RandomQuestionsBatched) { const std::vector questions = { {"Write me a poem about Australia?"}, {"What's the history of Denmark?"}, @@ -103,6 +103,7 @@ TEST_F(GemmaTest, RandomQuestionsBatched) { } // namespace gcpp int main(int argc, char** argv) { + fprintf(stderr, "GemmaEnv setup..\n"); gcpp::GemmaEnv env(argc, argv); gcpp::s_env = &env; diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 22958d9..ff30338 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -102,7 +102,7 @@ TEST_F(GemmaTest, Multiturn) { size_t abs_pos = 0; std::string response; auto stream_token = [&](int token, float) { - if (token == EOS_ID) return true; + if (config.IsEOS(token)) return true; ++abs_pos; std::string token_text; EXPECT_TRUE( diff --git a/gemma/activations.h b/gemma/activations.h index fa04c0b..3d07538 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -27,6 +27,7 @@ #include "util/allocator.h" // Allocator #include "util/basics.h" // BF16 #include "util/mat.h" // MatStorageT +#include "hwy/profiler.h" namespace gcpp { @@ -48,6 +49,7 @@ struct Activations { seq_len(config.seq_len), cache_pos_size(config.CachePosSize()), is_griffin(config.model == Model::GRIFFIN_2B), + query_scale(ChooseQueryScale(config)), x("x", Extents2D(batch_size, config.model_dim), pad_), // `vocab_size == 0` means it is for Vit part, VitAttention is still MHA @@ -96,7 +98,7 @@ struct Activations { layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope, 1000000.0)), - query_scale(ChooseQueryScale(config)) { + gen_tokens(batch_size) { HWY_ASSERT(batch_size != 0); // For MatMul outputs, precompute their row pointers. @@ -114,6 +116,7 @@ struct Activations { } void SetBatchSize(size_t batch_size) { + PROFILER_ZONE("SetBatchSize"); x.OverrideRows(batch_size); q.OverrideRows(batch_size); logits.OverrideRows(batch_size); @@ -134,13 +137,16 @@ struct Activations { griffin_gate_x.OverrideRows(batch_size); griffin_multiplier.OverrideRows(batch_size); } + + gen_tokens.resize(batch_size); } const ModelConfig& weights_config; const LayerConfig& layer_config; size_t seq_len; size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT. - bool is_griffin = false; + bool is_griffin; + float query_scale; const Extents2D none_ = Extents2D(); const MatPadding pad_ = MatPadding::kOdd; @@ -171,7 +177,9 @@ struct Activations { MatStorageT inv_timescale; MatStorageT inv_timescale_global; - float query_scale; + // Storage for the last generated token from each query, passed to the next + // Transformer() call. + std::vector gen_tokens; // one per query in the batch }; } // namespace gcpp diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 2bebb11..8f49717 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -181,21 +181,24 @@ EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt, return image_token_position; } -// Prefill() and Transformer() increment positions in-place. +// Incremented in-place by Prefill* and DecodeStepT. using QueriesMutablePos = hwy::Span; -// Populates KV cache for batches of tokens from one query at a time. -static HWY_NOINLINE void Prefill( +// Populates KV cache for batches of tokens from one query at a time. This is +// called if prompts are longer than the query batch size, and also in +// prefix-LM mode (end > 0), which must see all tokens in one batch. +static HWY_NOINLINE void PrefillTBatch( const size_t query_idx_start, const QueriesPromptTokens& queries_prompt, const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len, const ModelConfig& config, const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights, - Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) { - PROFILER_ZONE("Gen.Prefill"); + Activations& activations, const KVCaches& kv_caches, MatMulEnv& env, + hwy::BitSet4096<>& non_eos) { + PROFILER_ZONE("Gen.PrefillT"); const size_t num_queries = queries_prompt.size(); - HWY_DASSERT(queries_pos.size() == num_queries); - HWY_DASSERT(queries_prefix_end.size() == num_queries); - HWY_DASSERT(kv_caches.size() == num_queries); + HWY_DASSERT(num_queries == queries_pos.size()); + HWY_DASSERT(num_queries == queries_prefix_end.size()); + HWY_DASSERT(num_queries == kv_caches.size()); // Batches are important for amortizing loading weights over multiple tokens. // This is possible in prefill because we know all tokens beforehand, whereas @@ -203,13 +206,15 @@ static HWY_NOINLINE void Prefill( // a query requires that preceding batches already wrote to the KV cache, // hence we sequentially loop over token batches. We can reduce the number of // iterations by increasing the batch size, but this also increases arithmetic - // intensity, and so we are eventually compute-limited. We could devote some - // threads to parallelizing over queries, but for simplicity we assign them - // all to MatMul. + // intensity, and so we are eventually compute-limited. TransformerLayer uses + // all available threads, so we do not also parallelize over queries, but note + // that PrefillQBatch uses queries as the batch dimension. const size_t max_tbatch_size = runtime_config.prefill_tbatch_size; // For each query. `qi` is within the batch, not the global query index. for (size_t qi = 0; qi < num_queries; ++qi) { + non_eos.Set(qi); + // Single query at a time, so pass slices of the spans because // GemmaAttention will only access the first KV cache and position. QueriesPos single_query_pos(&queries_pos[qi], 1); @@ -292,30 +297,34 @@ static HWY_NOINLINE void Prefill( } } -// Generates one token for each query. `queries_token` is the previous token -// from each query, and `queries_pos` are their position in the sequence. +// Embeds token and calls each TransformerLayer. `queries_token` is the previous +// token from each query, and `queries_pos` are their position in the sequence. +// Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the +// token-batched `PrefillTBatch`. static HWY_NOINLINE void Transformer( const QueriesToken& queries_token, const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len, - const ModelConfig& config, const ModelWeightsPtrs& weights, - Activations& activations, const KVCaches& kv_caches, MatMulEnv& env, - const LayersOutputFunc& layers_output, - const ActivationsObserverFunc& activations_observer) { + const ModelConfig& config, const RuntimeConfig& runtime_config, + const ModelWeightsPtrs& weights, Activations& activations, + const KVCaches& kv_caches, MatMulEnv& env) { const size_t num_queries = queries_token.size(); - HWY_DASSERT(queries_pos.size() == num_queries); - HWY_DASSERT(queries_prefix_end.size() == num_queries); + HWY_DASSERT(num_queries == queries_pos.size()); + HWY_DASSERT(num_queries == queries_prefix_end.size()); - if (layers_output) { - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - const float token_f = queries_token[query_idx]; - layers_output(query_idx, queries_pos[query_idx], "tokens", -1, &token_f, - 1); + if (HWY_UNLIKELY(runtime_config.layers_output)) { + for (size_t qi = 0; qi < num_queries; ++qi) { + const float token_f = queries_token[qi]; + runtime_config.layers_output(qi, queries_pos[qi], "tokens", -1, &token_f, + 1); } } - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - EmbedMMToken(queries_token[query_idx], query_idx, queries_pos[query_idx], - /*pos_in_prompt=*/0, config, weights, activations.x); + size_t image_token_position = 0; + for (size_t qi = 0; qi < num_queries; ++qi) { + image_token_position = + EmbedMMToken(queries_token[qi], qi, queries_pos[qi], + /*pos_in_prompt=*/0, config, weights, activations.x, + runtime_config.image_tokens, image_token_position); } for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) { @@ -323,21 +332,71 @@ static HWY_NOINLINE void Transformer( div_seq_len, layer_idx, *weights.GetLayer(layer_idx), activations, kv_caches, env); - if (activations_observer) { - activations_observer(queries_pos, layer_idx, activations); + if (HWY_UNLIKELY(runtime_config.activations_observer)) { + runtime_config.activations_observer(queries_pos, layer_idx, activations); } } - - RMSNormInplaceBatched(weights.final_norm_scale, activations.x); - - if (activations_observer) { - activations_observer(queries_pos, -1, activations); - } - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - queries_pos[query_idx] += 1; - } } +// Populates KV cache for the batch queries, one token at a time. Only called +// for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0. +static HWY_NOINLINE void PrefillQBatch( + const size_t query_idx_start, const QueriesPromptTokens& queries_prompt, + const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end, + const size_t max_prompt_size, const hwy::Divisor& div_seq_len, + const ModelConfig& config, const RuntimeConfig& runtime_config, + const ModelWeightsPtrs& weights, Activations& activations, + const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos) { + PROFILER_ZONE("Gen.Prefill"); + const size_t num_queries = queries_prompt.size(); + HWY_DASSERT(num_queries == queries_pos.size()); + HWY_DASSERT(num_queries == queries_prefix_end.size()); + HWY_DASSERT(num_queries == activations.x.Rows()); + HWY_DASSERT(num_queries == kv_caches.size()); + + hwy::BitSet4096<> prefill_active; + for (size_t qi = 0; qi < num_queries; ++qi) { + prefill_active.Set(qi); + + HWY_DASSERT(queries_prefix_end[qi] == 0); + (void)queries_prefix_end; + } + non_eos = prefill_active; + + // In autoregressive mode, we don't prefill the last token, hence - 1. + for (size_t pos_in_prompt = 0; pos_in_prompt < max_prompt_size - 1; + ++pos_in_prompt) { + // Streams that have already finished prefill no longer interleave/stream. + for (size_t qi = 0; qi < num_queries; ++qi) { + if (pos_in_prompt >= queries_prompt[qi].size() - 1) { + prefill_active.Clear(qi); + activations.gen_tokens[qi] = config.eos_id; + } + } + + // Batch := interleaved tokens, one from each non-EOS query. + prefill_active.Foreach([&](size_t qi) { + activations.gen_tokens[qi] = queries_prompt[qi][pos_in_prompt]; + }); + + // One token from each query in the batch. Increments queries_pos. + // Do not call DecodeStepT because it computes logits for token + // probabilities, which are not required for the prompt tokens. + Transformer(QueriesToken(activations.gen_tokens.data(), num_queries), + queries_pos, queries_prefix_end, div_seq_len, config, + runtime_config, weights, activations, kv_caches, env); + + prefill_active.Foreach([&](size_t qi) { + const int token = queries_prompt[qi][pos_in_prompt]; + // Ignore any user request to stop during prefill. + (void)runtime_config.StreamToken(query_idx_start + qi, queries_pos[qi], + token, 0.0f); + queries_pos[qi] += 1; + }); + } // pos_in_prompt +} + +// TODO: inline. void RangeChecks(const ModelConfig& weights_config, size_t& max_generated_tokens, const size_t prompt_size) { if (!weights_config.use_local_attention) { @@ -350,56 +409,49 @@ void RangeChecks(const ModelConfig& weights_config, HWY_ASSERT(prompt_size > 0); } -// Holds "is at end of stream" state for each query. -class TokenStreamer { - public: - TokenStreamer(const RuntimeConfig& runtime_config, - const ModelConfig& model_config) - : runtime_config_(runtime_config), model_config_(model_config) {} +// Also writes the token to activations.gen_tokens for subsequent DecodeStepT, +// and updates `non_eos` if the query is at the end of its sequence. +static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token, + const float prob, const ModelConfig& config, + const RuntimeConfig& runtime_config, + Activations& activations, + hwy::BitSet4096<>& non_eos) { + HWY_DASSERT(non_eos.Get(qi)); - // Returns whether the query was already at, or has just reached, the end of - // the stream: either via token == eos_id, or StreamToken returning false. - bool operator()(size_t query_idx, size_t pos, int token, float prob) { - if (HWY_UNLIKELY(is_eos_.Get(query_idx))) return true; - - if (!runtime_config_.StreamToken(query_idx, pos, token, prob) || - model_config_.IsEOS(token)) { - is_eos_.Set(query_idx); - return true; - } - - return false; + // User decided to stop: set next token to primary EOS. + if (HWY_UNLIKELY(!runtime_config.StreamToken(qi, pos, token, prob))) { + token = config.eos_id; } - private: - const RuntimeConfig& runtime_config_; - const ModelConfig& model_config_; - hwy::BitSet4096<> is_eos_; -}; + // Primary or secondary EOS: mark query as EOS. + if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi); -// Runs one decode step for all the queries in the batch. Returns true if all -// queries are at . -static bool DecodeStepT( - const ModelConfig& config, const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, const size_t query_idx_start, - const QueriesPromptTokens& queries_prompt, + activations.gen_tokens[qi] = token; +} + +// For a batch of queries, runs Transformer, computes logits, samples and +// streams the token. +static void DecodeStepT( + const size_t query_idx_start, const QueriesPromptTokens& queries_prompt, const QueriesMutablePos& queries_mutable_pos, const QueriesPos& queries_prefix_end, const hwy::Divisor div_seq_len, - const size_t vocab_size, const SampleFunc& sample_token, - Activations& activations, const KVCaches& kv_caches, - TokenStreamer& token_streamer, std::vector& gen_tokens, - TimingInfo& timing_info, MatMulEnv& env) { + const ModelConfig& config, const RuntimeConfig& runtime_config, + const ModelWeightsPtrs& weights, const SampleFunc& sample_token, + Activations& activations, const KVCaches& kv_caches, MatMulEnv& env, + hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) { const size_t num_queries = queries_prompt.size(); - // Decode generates one token per query and increments - // queries_mutable_pos. - Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos, - queries_prefix_end, div_seq_len, config, weights, activations, - kv_caches, env, runtime_config.layers_output, - runtime_config.activations_observer); - // queries_pos are incremented by Transformer. - HWY_DASSERT(num_queries == activations.x.Rows()); - bool all_queries_eos = true; + + Transformer(QueriesToken(activations.gen_tokens.data(), num_queries), + queries_mutable_pos, queries_prefix_end, div_seq_len, config, + runtime_config, weights, activations, kv_caches, env); + + RMSNormInplaceBatched(weights.final_norm_scale, activations.x); + + if (HWY_UNLIKELY(runtime_config.activations_observer)) { + runtime_config.activations_observer(queries_mutable_pos, -1, activations); + } + { PROFILER_ZONE("Gen.EmbeddingMatmul"); // Compute logits from last layer activations. @@ -407,19 +459,17 @@ static bool DecodeStepT( /*add=*/nullptr, env, activations.logits); } PROFILER_ZONE("Gen.Softcap+Sample+Stream"); - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - float* HWY_RESTRICT logits = activations.logits.Row(query_idx); - MaybeLogitsSoftCap(config.final_cap, logits, vocab_size); - const TokenAndProb tp = sample_token(logits, vocab_size); + non_eos.Foreach([&](size_t qi) { + float* HWY_RESTRICT logits = activations.logits.Row(qi); + MaybeLogitsSoftCap(config.final_cap, logits, config.vocab_size); + const TokenAndProb tp = sample_token(logits, config.vocab_size); timing_info.NotifyGenerated(); - const bool is_eos = - token_streamer(query_idx_start + query_idx, - queries_mutable_pos[query_idx], tp.token, tp.prob); - all_queries_eos &= is_eos; - gen_tokens[query_idx] = is_eos ? config.eos_id : tp.token; - } - return all_queries_eos; + StreamAndUpdateEOS(query_idx_start + qi, queries_mutable_pos[qi], tp.token, + tp.prob, config, runtime_config, activations, non_eos); + + if (non_eos.Get(qi)) queries_mutable_pos[qi] += 1; + }); } static HWY_INLINE SampleFunc @@ -445,33 +495,26 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) { }; } -// Returns the min and max number of tokens for all queries. -static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) { - size_t max_prompt_size = 0; - for (size_t i = 0; i < queries_prompt.size(); ++i) { - max_prompt_size = HWY_MAX(max_prompt_size, queries_prompt[i].size()); - } - return max_prompt_size; -} - // Generates one continuation for each query in `queries_prompt`, which is one -// qbatch whose size is at most the `batch_size` passed to -// `activations.Allocate`. +// qbatch whose size is at most the `batch_size` passed to `activations` ctor. // // `queries_pos` stores the KV cache position for each query. In the first turn // of a chat, pos = 0; we increment each query's position after each token. // // `query_idx_start` is the query_idx of the first query in the batch, so that // `StreamFunc` gets the global query index, not relative to the batch. -// -// `kv_caches` is for the batch, size must match `queries_prompt`. static void GenerateT( - const ModelConfig& config, const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, const size_t query_idx_start, - const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos_in, - const QueriesPos& queries_prefix_end, Activations& activations, - const KVCaches& kv_caches, TimingInfo& timing_info, MatMulEnv& env) { - HWY_ASSERT(queries_pos_in.size() == kv_caches.size()); + const size_t query_idx_start, const QueriesPromptTokens& queries_prompt, + const QueriesPos& queries_pos_in, const QueriesPos& queries_prefix_end, + const ModelConfig& config, const RuntimeConfig& runtime_config, + const ModelWeightsPtrs& weights, Activations& activations, + const KVCaches& kv_caches, MatMulEnv& env, TimingInfo& timing_info) { + const size_t num_queries = queries_prompt.size(); + HWY_ASSERT(num_queries <= 4096); // non_eos uses `BitSet4096`. + HWY_ASSERT(num_queries == queries_pos_in.size()); + HWY_ASSERT(num_queries == queries_prefix_end.size()); + HWY_ASSERT(num_queries <= activations.x.Rows()); + HWY_ASSERT(num_queries == kv_caches.size()); // Griffin assumes that the recurrent block cache is zero-initialized. for (size_t i = 0; i < kv_caches.size(); ++i) { @@ -486,75 +529,79 @@ static void GenerateT( const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(), queries_pos_copy.size()); - // Sanity check: prompts should not be empty, nor start with EOS. - for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) { - const PromptTokens& prompt = queries_prompt[query_idx]; + size_t max_prompt_size = 0; + bool all_prefix_end_are_zero = true; + size_t prefill_tokens = 0; + for (size_t qi = 0; qi < num_queries; ++qi) { + const PromptTokens& prompt = queries_prompt[qi]; + max_prompt_size = HWY_MAX(max_prompt_size, prompt.size()); + + // Prefill stops before size - 1 because the last prompt token is the + // first input token for generation. + prefill_tokens += prompt.size() - 1; + + // Sanity check: prompts should not be empty, nor start with EOS. HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id); + + all_prefix_end_are_zero &= queries_prefix_end[qi] == 0; } - const size_t num_queries = queries_prompt.size(); - HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096. - HWY_ASSERT(num_queries <= activations.x.Rows()); - HWY_ASSERT(queries_pos_in.size() == num_queries); - HWY_ASSERT(kv_caches.size() == num_queries); const hwy::Divisor div_seq_len(static_cast(kv_caches[0].seq_len)); - size_t max_prompt_size = MaxQueryLength(queries_prompt); - size_t max_generated_tokens = runtime_config.max_generated_tokens; - RangeChecks(config, max_generated_tokens, max_prompt_size); + + // Lacks a constructor to bulk-set, hence initialized by Prefill* which have + // qi loops anyway. + hwy::BitSet4096<> non_eos; + + timing_info.prefill_start = hwy::platform::Now(); + // Batch over the larger of prompt length, or queries. + if ((num_queries > max_prompt_size) && all_prefix_end_are_zero) { + activations.SetBatchSize(num_queries); // required before PrefillQBatch + PrefillQBatch(query_idx_start, queries_prompt, queries_mutable_pos, + queries_prefix_end, max_prompt_size, div_seq_len, config, + runtime_config, weights, activations, kv_caches, env, + non_eos); + } else { + PrefillTBatch(query_idx_start, queries_prompt, queries_mutable_pos, + queries_prefix_end, div_seq_len, config, runtime_config, + weights, activations, kv_caches, env, non_eos); + activations.SetBatchSize(num_queries); // Restore after PrefillTBatch. + } + HWY_DASSERT(num_queries == non_eos.Count()); + timing_info.NotifyPrefill(prefill_tokens); + // queries_pos have been incremented by Prefill. + + // Stream the last prompt token from each query, fill activations.gen_tokens. + for (size_t qi = 0; qi < num_queries; ++qi) { + const size_t last_token_pos_in_prompt = + queries_mutable_pos[qi] - queries_pos_in[qi]; + StreamAndUpdateEOS(query_idx_start + qi, queries_mutable_pos[qi], + queries_prompt[qi][last_token_pos_in_prompt], 0.0f, + config, runtime_config, activations, non_eos); + // No incrementing queries_mutable_pos[qi]. + } + + size_t max_gen_steps = runtime_config.max_generated_tokens; + RangeChecks(config, max_gen_steps, max_prompt_size); + const SampleFunc sample_token = ChooseSampleFunc(runtime_config); - // Prefill stops before min_prompt_size - 1 because the last prompt - // token is the first input token for generation. - timing_info.prefill_start = hwy::platform::Now(); - // Note that Prefill calls activations.SetBatchSize, so we reset it below. - Prefill(query_idx_start, queries_prompt, queries_mutable_pos, - queries_prefix_end, div_seq_len, config, runtime_config, weights, - activations, kv_caches, env); - // Compute the number of tokens that were prefilled and notify timing_info. - size_t prefilled_tokens = 0; - for (size_t qi = 0; qi < num_queries; ++qi) { - prefilled_tokens += queries_prompt[qi].size() - 1; - } - timing_info.NotifyPrefill(prefilled_tokens); - // queries_pos are incremented by Prefill. - activations.SetBatchSize(num_queries); - - // Storage for the last generated token from each query, passed to the next - // Transformer() call. - std::vector gen_tokens(num_queries); - - // Stream the last prompt token from each query and fill gen_tokens. - TokenStreamer token_streamer(runtime_config, config); - for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - size_t last_token_pos_in_prompt = - queries_mutable_pos[query_idx] - queries_pos_in[query_idx]; - gen_tokens[query_idx] = queries_prompt[query_idx][last_token_pos_in_prompt]; - (void)token_streamer(query_idx_start + query_idx, - queries_mutable_pos[query_idx], gen_tokens[query_idx], - 0.0f); - } - { - const size_t vocab_size = config.vocab_size; timing_info.generate_start = hwy::platform::Now(); - for (size_t gen = 0; gen < max_generated_tokens; ++gen) { - bool all_queries_eos = - DecodeStepT(config, weights, runtime_config, query_idx_start, - queries_prompt, queries_mutable_pos, queries_prefix_end, - div_seq_len, vocab_size, sample_token, activations, - kv_caches, token_streamer, gen_tokens, timing_info, env); - - if (all_queries_eos) break; - } // foreach token to generate + for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { + DecodeStepT(query_idx_start, queries_prompt, queries_mutable_pos, + queries_prefix_end, div_seq_len, config, runtime_config, + weights, sample_token, activations, kv_caches, env, non_eos, + timing_info); + } timing_info.NotifyGenerateDone(); } } -void GenerateSingleT(const ModelConfig& config, const ModelWeightsPtrs& weights, +void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, + const ModelConfig& config, const RuntimeConfig& runtime_config, - const PromptTokens& prompt, size_t pos, size_t prefix_end, - KVCache& kv_cache, MatMulEnv& env, - TimingInfo& timing_info) { + const ModelWeightsPtrs& weights, KVCache& kv_cache, + MatMulEnv& env, TimingInfo& timing_info) { constexpr size_t kNumQueries = 1; const size_t qbatch_start = 0; @@ -568,25 +615,27 @@ void GenerateSingleT(const ModelConfig& config, const ModelWeightsPtrs& weights, const QueriesPos queries_prefix_end(&prefix_end, kNumQueries); const KVCaches kv_caches{&kv_cache, kNumQueries}; - GenerateT(config, weights, runtime_config, qbatch_start, queries_prompt, - queries_pos, queries_prefix_end, activations, kv_caches, - timing_info, env); + GenerateT(qbatch_start, queries_prompt, queries_pos, queries_prefix_end, + config, runtime_config, weights, activations, kv_caches, env, + timing_info); } -void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights, - const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, +// Splits the input into batches of at most `runtime_config.decode_qbatch_size` +// queries, and calls `GenerateT` on each batch. +void GenerateBatchT(const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, - const KVCaches& kv_caches, MatMulEnv& env, - TimingInfo& timing_info) { + const ModelConfig& config, + const RuntimeConfig& runtime_config, + const ModelWeightsPtrs& weights, const KVCaches& kv_caches, + MatMulEnv& env, TimingInfo& timing_info) { const size_t num_queries = queries_prompt.size(); HWY_ASSERT(queries_pos.size() == num_queries); HWY_ASSERT(kv_caches.size() >= num_queries); + const size_t max_qbatch_size = runtime_config.decode_qbatch_size; const size_t max_batch_size = HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size); - Activations activations(config, max_batch_size, env.row_ptrs); for (size_t qbatch_start = 0; qbatch_start < num_queries; @@ -600,17 +649,16 @@ void GenerateBatchT(const ModelConfig& config, const ModelWeightsPtrs& weights, const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start], qbatch_size); const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); - GenerateT(config, weights, runtime_config, qbatch_start, qbatch_prompts, - qbatch_pos, qbatch_prefix_end, activations, qbatch_kv, - timing_info, env); + GenerateT(qbatch_start, qbatch_prompts, qbatch_pos, qbatch_prefix_end, + config, runtime_config, weights, activations, qbatch_kv, env, + timing_info); } } void GenerateImageTokensT(const ModelConfig& config, - const ModelWeightsPtrs& weights, const RuntimeConfig& runtime_config, - const Image& image, ImageTokens& image_tokens, - MatMulEnv& env) { + const ModelWeightsPtrs& weights, const Image& image, + ImageTokens& image_tokens, MatMulEnv& env) { if (config.vit_config.layer_configs.empty()) { HWY_ABORT("Model does not support generating image tokens."); } @@ -667,9 +715,9 @@ void Gemma::Generate(const RuntimeConfig& runtime_config, KVCache& kv_cache, TimingInfo& timing_info) const { env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); - HWY_DYNAMIC_DISPATCH(GenerateSingleT)(model_.Config(), weights_, - runtime_config, prompt, pos, prefix_end, - kv_cache, env_, timing_info); + HWY_DYNAMIC_DISPATCH(GenerateSingleT)(prompt, pos, prefix_end, + model_.Config(), runtime_config, + weights_, kv_cache, env_, timing_info); env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } @@ -681,19 +729,19 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, const KVCaches& kv_caches, TimingInfo& timing_info) const { // If we did not get passed prefix ends (size 0), assume 0 and pass that on. - QueriesPos mutable_queries_prefix_end = queries_prefix_end; + QueriesPos queries_prefix_end_or_zeros = queries_prefix_end; std::vector prefix_end_vec; if (queries_prefix_end.size() == 0) { // hwy::Span lacks empty() prefix_end_vec.resize(queries_prompt.size(), 0); - mutable_queries_prefix_end = + queries_prefix_end_or_zeros = QueriesPos(prefix_end_vec.data(), prefix_end_vec.size()); } env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); HWY_DYNAMIC_DISPATCH(GenerateBatchT)( - model_.Config(), weights_, runtime_config, queries_prompt, queries_pos, - mutable_queries_prefix_end, kv_caches, env_, timing_info); + queries_prompt, queries_pos, queries_prefix_end_or_zeros, model_.Config(), + runtime_config, weights_, kv_caches, env_, timing_info); env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } @@ -704,7 +752,7 @@ void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config, env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)( - model_.Config(), weights_, runtime_config, image, image_tokens, env_); + model_.Config(), runtime_config, weights_, image, image_tokens, env_); env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } diff --git a/gemma/tokenizer.h b/gemma/tokenizer.h index a8c081b..aca01f9 100644 --- a/gemma/tokenizer.h +++ b/gemma/tokenizer.h @@ -26,10 +26,7 @@ namespace gcpp { -// The tokenizer's end of sentence and beginning of sentence token ids. -constexpr int EOS_ID = 1; -constexpr int SECONDARY_EOS_ID = 106; // for Gemma 3 -constexpr int BOS_ID = 2; +constexpr int BOS_ID = 2; // beginning of sequence // To avoid the complexity of storing the tokenizer into testdata/ or // downloading from gs://, while still always writing a blob for the tokenizer,