From e5c81f64a1e7680d6fa2ef1f504249cf7a3b3b18 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 16 Jun 2025 02:41:30 -0700 Subject: [PATCH] Major refactor: clarify query_idx (global) vs qi. Refs #607 Fix missing pos increment for last prefill and check that in gemma_test. Thanks to @ufownl for pointing this out. Change argument lists to QBatch with accessors. Increase default seq_len to 8k. PiperOrigin-RevId: 771937385 --- evals/benchmark_helper.cc | 20 +- evals/benchmark_helper.h | 4 +- evals/gemma_test.cc | 8 +- gemma/activations.h | 19 +- gemma/attention.cc | 86 +++----- gemma/attention.h | 18 +- gemma/bindings/context.cc | 4 +- gemma/gemma.cc | 430 ++++++++++++++---------------------- gemma/gemma.h | 142 ++++++++++-- gemma/gemma_args.h | 13 +- gemma/griffin.cc | 43 ++-- gemma/griffin.h | 14 +- gemma/run.cc | 3 +- paligemma/paligemma_test.cc | 8 +- python/gemma_py.cc | 3 +- 15 files changed, 399 insertions(+), 416 deletions(-) diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 8f5968b..6184d97 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -109,16 +109,19 @@ void GemmaEnv::QueryModel( } std::vector GemmaEnv::BatchQueryModel( - const QueriesPromptTokens& queries_prompt) { + const QueriesPromptTokens& queries_prompt, + const hwy::Span& prefix_end) { const size_t num_queries = queries_prompt.size(); HWY_ASSERT(num_queries != 0); std::vector res(num_queries); - const BatchStreamFunc batch_stream_token = [&res, &queries_prompt, this]( - size_t query_index, size_t pos, - int token, float) { + const BatchStreamFunc batch_stream_token = [&, this](const size_t query_index, + const size_t pos, + const int token, float) { + HWY_ASSERT(query_index < num_queries); std::string token_text; HWY_ASSERT(gemma_.Tokenizer().Decode(std::vector{token}, &token_text)); res[query_index].response.append(token_text); + HWY_ASSERT(pos == res[query_index].tokens_generated); res[query_index].tokens_generated += 1; if (res[query_index].tokens_generated == queries_prompt[query_index].size()) { @@ -126,6 +129,7 @@ std::vector GemmaEnv::BatchQueryModel( } return true; }; + runtime_config_.batch_stream_token = batch_stream_token; if (runtime_config_.verbosity >= 2) { fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n", runtime_config_.max_generated_tokens, runtime_config_.temperature, @@ -137,13 +141,11 @@ std::vector GemmaEnv::BatchQueryModel( while (kv_caches_.size() < num_queries) { kv_caches_.push_back(KVCache(gemma_.GetModelConfig(), gemma_.Inference())); } + const hwy::Span kv_caches(&kv_caches_[0], num_queries); + gcpp::AllQueries all_queries(queries_prompt, kv_caches, prefix_end); gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity}; - runtime_config_.batch_stream_token = batch_stream_token; - std::vector queries_pos(num_queries, 0); - gemma_.GenerateBatch(runtime_config_, queries_prompt, - QueriesPos(queries_pos.data(), num_queries), - KVCaches(&kv_caches_[0], num_queries), timing_info); + gemma_.GenerateBatch(runtime_config_, all_queries, timing_info); return res; } diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 73d895b..870ad02 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -88,8 +88,10 @@ class GemmaEnv { // Runs inference on the given input and returns the top-1 result string and // the number of tokens that were generated. QueryResult QueryModel(const std::vector& tokens); + // The default prefix_end means "causal attention". std::vector BatchQueryModel( - const QueriesPromptTokens& queries_prompt); + const QueriesPromptTokens& queries_prompt, + const hwy::Span& prefix_end = hwy::Span()); // Adds turn structure to input, tokenizes and calls the above overload. QueryResult QueryModel(std::string& input); std::vector BatchQueryModel( diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index ff30338..a22ba32 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -101,9 +101,11 @@ TEST_F(GemmaTest, Multiturn) { const ModelConfig& config = model->GetModelConfig(); size_t abs_pos = 0; std::string response; - auto stream_token = [&](int token, float) { - if (config.IsEOS(token)) return true; + auto stream_token = [&](size_t query_idx, size_t pos, int token, float) { + HWY_ASSERT(query_idx == 0); + HWY_ASSERT(pos == abs_pos); ++abs_pos; + if (config.IsEOS(token)) return true; std::string token_text; EXPECT_TRUE( model->Tokenizer().Decode(std::vector{token}, &token_text)); @@ -115,7 +117,7 @@ TEST_F(GemmaTest, Multiturn) { .temperature = 0.0f, .gen = &s_env->MutableGen(), .verbosity = 2, - .stream_token = stream_token, + .batch_stream_token = stream_token, }; TimingInfo timing_info{.verbosity = 0}; // First "say" something slightly unusual. diff --git a/gemma/activations.h b/gemma/activations.h index e19926a..56c799b 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -42,13 +42,12 @@ static inline float ChooseQueryScale(const ModelConfig& config) { } struct Activations { - Activations(const ModelConfig& config, size_t batch_size, + Activations(const ModelConfig& config, size_t batch_size, size_t seq_len, std::vector>& row_ptrs) : weights_config(config), layer_config(config.layer_configs[0]), - div_seq_len(static_cast(config.max_seq_len)), + div_seq_len(static_cast(seq_len)), is_griffin(config.model == Model::GRIFFIN_2B), - query_scale(ChooseQueryScale(config)), x("x", Extents2D(batch_size, config.model_dim), pad_), // `vocab_size == 0` means it is for Vit part, VitAttention is still MHA @@ -63,10 +62,7 @@ struct Activations { pre_att_rms_out("pre_att_rms_out", Extents2D(batch_size, config.model_dim), pad_), - att("att", - Extents2D(batch_size, - layer_config.heads * div_seq_len.GetDivisor()), - pad_), + att("att", Extents2D(batch_size, layer_config.heads * seq_len), pad_), att_out( "att_out", Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim), @@ -99,7 +95,7 @@ struct Activations { layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope, 1000000.0)), - gen_tokens(batch_size) { + query_scale(ChooseQueryScale(config)) { HWY_ASSERT(batch_size != 0); // For MatMul outputs, precompute their row pointers. @@ -138,8 +134,6 @@ struct Activations { griffin_gate_x.OverrideRows(batch_size); griffin_multiplier.OverrideRows(batch_size); } - - gen_tokens.resize(batch_size); } bool IsGlobalLayer(size_t layer_idx) const { @@ -151,7 +145,6 @@ struct Activations { const LayerConfig& layer_config; hwy::Divisor div_seq_len; bool is_griffin; - float query_scale; const Extents2D none_ = Extents2D(); const MatPadding pad_ = MatPadding::kOdd; @@ -182,9 +175,7 @@ struct Activations { MatStorageT inv_timescale; MatStorageT inv_timescale_global; - // Storage for the last generated token from each query, passed to the next - // Transformer() call. - std::vector gen_tokens; // one per query in the batch + float query_scale; }; } // namespace gcpp diff --git a/gemma/attention.cc b/gemma/attention.cc index c0cce57..87793b8 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -153,15 +153,12 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config, return pos - HWY_MIN(att_window_size - 1, pos); } -void DotSoftmaxWeightedSum(const size_t num_tokens, - const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, - const size_t layer_idx, +void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, - Activations& activations, const KVCaches& kv_caches, + Activations& activations, QBatch& qbatch, NestedPools& pools) { PROFILER_ZONE("Gen.Attention.DotSoftmax"); - const hwy::Divisor div_queries(queries_pos.size()); + const hwy::Divisor div_qbatch(qbatch.Size()); const LayerConfig& layer_config = layer.layer_config; const size_t qkv_dim = layer_config.qkv_dim; @@ -176,7 +173,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, // For each head/token/query, compute Q.K, softmax, and weighted V. // Statically partition token/query across packages. - const size_t num_tq = num_tokens * div_queries.GetDivisor(); + const size_t num_tq = num_tokens * div_qbatch.GetDivisor(); const IndexRangePartition tq_ranges = StaticPartition(IndexRange(0, num_tq), pools.NumPackages(), 1); ParallelizeOneRange( @@ -185,17 +182,17 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, pools.AllClusters(pkg_idx).Run( tq_range.begin(), tq_range.end(), [&](const size_t tq_idx, const size_t cluster_idx) { - const size_t query_idx = div_queries.Remainder(tq_idx); - const size_t batch_idx = div_queries.Divide(tq_idx); - auto& kv_cache = kv_caches[query_idx].kv_cache; + const size_t qi = div_qbatch.Remainder(tq_idx); + const size_t batch_idx = div_qbatch.Divide(tq_idx); + auto& kv_cache = qbatch.KV(qi).kv_cache; // Find the token position in the query and calculate // the range of cache positions to attend to. - const size_t pos = queries_pos[query_idx] + batch_idx; + const size_t pos = qbatch.Pos(qi) + batch_idx; const size_t start_pos = StartPos(pos, activations.weights_config, layer_idx); size_t last_pos = pos; - const size_t prefix_end = queries_prefix_end[query_idx]; + const size_t prefix_end = qbatch.PrefixEnd(qi); if (prefix_end > 0 && prefix_end - 1 > last_pos) { // last_pos in QDotK and WeightedSumV is inclusive. last_pos = prefix_end - 1; @@ -235,14 +232,21 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, }); } +// Different functions use different naming conventions for the number of +// tokens. Functions that are query-independent, such as RMSNorm*, call the +// count `num_interleaved`. Functions that are query-dependent, such as +// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the +// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size. + // Fills activations.q and writes to KV cache. -static HWY_INLINE void ComputeQKV( - size_t num_tokens, const QueriesPos& queries_pos, const size_t layer_idx, - const LayerWeightsPtrs& layer, Activations& activations, - const KVCaches& kv_caches, const int flags, MatMulEnv& env) { +static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, + const LayerWeightsPtrs& layer, + Activations& activations, + const QBatch& qbatch, const int flags, + MatMulEnv& env) { PROFILER_ZONE("Gen.Attention.QKV"); - const hwy::Divisor div_queries(queries_pos.size()); - const size_t num_interleaved = num_tokens * div_queries.GetDivisor(); + const hwy::Divisor div_qbatch(qbatch.Size()); + const size_t num_interleaved = num_tokens * div_qbatch.GetDivisor(); const LayerConfig& layer_config = layer.layer_config; const size_t qkv_dim = layer_config.qkv_dim; const size_t kv_heads = layer_config.kv_heads; @@ -260,13 +264,12 @@ static HWY_INLINE void ComputeQKV( layer.qkv_einsum_w2.Rows())); for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; ++interleaved_idx) { - const size_t query_idx = div_queries.Remainder(interleaved_idx); - const size_t batch_idx = div_queries.Divide(interleaved_idx); + const size_t qi = div_qbatch.Remainder(interleaved_idx); + const size_t batch_idx = div_qbatch.Divide(interleaved_idx); const size_t cache_pos = - activations.div_seq_len.Remainder(queries_pos[query_idx] + batch_idx); + activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx); env.row_ptrs[0][interleaved_idx] = reinterpret_cast( - kv_caches[query_idx].kv_cache.Row(cache_pos) + - layer_idx * cache_layer_size); + qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size); } kv_rows.AttachRowPtrs(env.row_ptrs[0].get()); CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2, @@ -280,11 +283,11 @@ static HWY_INLINE void ComputeQKV( [&](uint64_t task, size_t /*thread*/) HWY_ATTR { const size_t head = task % kv_heads; const size_t interleaved_idx = task / kv_heads; - const size_t query_idx = div_queries.Remainder(interleaved_idx); - const size_t batch_idx = div_queries.Divide(interleaved_idx); - const size_t pos = queries_pos[query_idx] + batch_idx; + const size_t qi = div_qbatch.Remainder(interleaved_idx); + const size_t batch_idx = div_qbatch.Divide(interleaved_idx); + const size_t pos = qbatch.Pos(qi) + batch_idx; const size_t cache_pos = activations.div_seq_len.Remainder(pos); - auto& kv_cache = kv_caches[query_idx].kv_cache; + auto& kv_cache = qbatch.KV(qi).kv_cache; float* HWY_RESTRICT kv = kv_cache.Row(cache_pos) + layer_idx * cache_layer_size + head * qkv_dim * 2; @@ -320,35 +323,18 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, activations.att_sums); } -// `queries_prefix_end` can be null (interpreted as all-zero) for standard -// causal attention, and must be non-null for prefix-LM style attention. -void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, - const QueriesPos* queries_prefix_end, - const size_t layer_idx, const LayerWeightsPtrs& layer, - Activations& activations, const KVCaches& kv_caches, - MatMulEnv& env, int flags) { - const size_t num_queries = queries_pos.size(); - HWY_DASSERT(num_queries <= kv_caches.size()); - +void GemmaAttention(size_t num_tokens, const size_t layer_idx, + const LayerWeightsPtrs& layer, Activations& activations, + QBatch& qbatch, MatMulEnv& env, int flags) { const LayerConfig& layer_config = layer.layer_config; HWY_DASSERT(!layer_config.IsMHA()); // No longer supported. HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0, "query heads must be a multiple of key-value heads"); (void)layer_config; // only used in HWY_DASSERT - std::vector queries_prefix_end_vec; - QueriesPos queries_prefix_end_span; - if (queries_prefix_end == nullptr) { - queries_prefix_end_vec.assign(num_queries, 0); - queries_prefix_end_span = QueriesPos(queries_prefix_end_vec.data(), - queries_prefix_end_vec.size()); - queries_prefix_end = &queries_prefix_end_span; - } - - ComputeQKV(num_tokens, queries_pos, layer_idx, layer, activations, kv_caches, - flags, env); - DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end, layer_idx, - layer, activations, kv_caches, env.ctx.pools); + ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env); + DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, + env.ctx.pools); SumHeads(layer, activations, env); } diff --git a/gemma/attention.h b/gemma/attention.h index 3a01df2..12809e9 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -35,18 +35,14 @@ namespace gcpp { const Activations& activations, float* HWY_RESTRICT att, \ float* HWY_RESTRICT att_out); \ \ - void DotSoftmaxWeightedSum(const size_t num_tokens, \ - const QueriesPos& queries_pos, \ - const QueriesPos& queries_prefix_end, \ - size_t layer_idx, const LayerWeightsPtrs& layer, \ - Activations& activations, \ - const KVCaches& kv_caches, NestedPools& pools); \ + void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ + const LayerWeightsPtrs& layer, \ + Activations& activations, QBatch& qbatch, \ + NestedPools& pools); \ \ - void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, \ - const QueriesPos* queries_prefix_end, \ - const size_t layer_idx, const LayerWeightsPtrs& layer, \ - Activations& activations, const KVCaches& kv_caches, \ - MatMulEnv& env, int flags); \ + void GemmaAttention(size_t num_tokens, const size_t layer_idx, \ + const LayerWeightsPtrs& layer, Activations& activations, \ + QBatch& qbatch, MatMulEnv& env, int flags); \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE diff --git a/gemma/bindings/context.cc b/gemma/bindings/context.cc index 1fda34c..e7edeaf 100644 --- a/gemma/bindings/context.cc +++ b/gemma/bindings/context.cc @@ -205,7 +205,9 @@ int GemmaContext::GenerateInternal(const char* prompt_string, // RuntimeConfig runtime_config = { ... }; // This was already defined double image_tokens_start = hwy::platform::Now(); // Pass the populated image object to GenerateImageTokens - model.GenerateImageTokens(runtime_config, image, image_tokens); + model.GenerateImageTokens(runtime_config, + active_conversation->kv_cache->SeqLen(), image, + image_tokens); double image_tokens_duration = hwy::platform::Now() - image_tokens_start; ss.str(""); diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 2bcf1fa..87e241f 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -45,7 +45,6 @@ #include "gemma/configs.h" #include "gemma/model_store.h" -#include "gemma/tokenizer.h" #include "gemma/weights.h" #include "io/blob_store.h" #include "io/io.h" // Path @@ -62,14 +61,11 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -void Attention(LayerAttentionType type, size_t num_tokens, - const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, const size_t layer_idx, - const LayerWeightsPtrs& layer, Activations& activations, - const KVCaches& kv_caches, MatMulEnv& env) { +void Attention(LayerAttentionType type, const size_t num_tokens, + const size_t layer_idx, const LayerWeightsPtrs& layer, + Activations& activations, QBatch& qbatch, MatMulEnv& env) { if (type == LayerAttentionType::kGemma) { - GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, layer_idx, - layer, activations, kv_caches, env, + GemmaAttention(num_tokens, layer_idx, layer, activations, qbatch, env, /*flags=*/0); } else { HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock); @@ -77,23 +73,23 @@ void Attention(LayerAttentionType type, size_t num_tokens, // so map `layer` to the Griffin layer index. const size_t griffin_layer = activations.weights_config.NumLayersOfTypeBefore(type, layer_idx); - GriffinRecurrent(queries_pos, num_tokens, griffin_layer, activations, - &layer, kv_caches, env); + GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch, + env); } } -static HWY_NOINLINE void TransformerLayer( - const size_t num_tokens, const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, const size_t layer_idx, - const LayerWeightsPtrs& layer, Activations& activations, - const KVCaches& kv_caches, MatMulEnv& env) { +static HWY_NOINLINE void TransformerLayer(const size_t num_tokens, + const size_t layer_idx, + const LayerWeightsPtrs& layer, + Activations& activations, + QBatch& qbatch, MatMulEnv& env) { const LayerConfig& layer_config = layer.layer_config; RMSNormBatched(activations.x, layer.pre_attention_norm_scale, activations.pre_att_rms_out); - Attention(layer_config.type, num_tokens, queries_pos, queries_prefix_end, - layer_idx, layer, activations, kv_caches, env); + Attention(layer_config.type, num_tokens, layer_idx, layer, activations, + qbatch, env); PostNorm(layer_config.post_norm, layer.post_attention_norm_scale, activations.att_sums); @@ -134,7 +130,7 @@ static float EmbeddingScaling(size_t model_dim) { // calling application. // Returns new image_token_position. static HWY_NOINLINE size_t -EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt, +EmbedMMToken(int token, size_t qi, size_t pos, size_t pos_in_prompt, const ModelConfig& model_config, const ModelWeightsPtrs& weights, MatStorageT& x, const ImageTokens* image_tokens = nullptr, size_t image_token_position = 0) { @@ -142,14 +138,14 @@ EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt, if (model_config.wrapping == PromptWrapping::GEMMA_VLM && image_tokens != nullptr && token == -2 && image_token_position < image_tokens->Rows()) { - hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(batch_idx), + hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(qi), x.Cols() * x.ElementBytes()); return image_token_position + 1; } if (model_config.wrapping == PromptWrapping::PALIGEMMA && image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) { - hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(batch_idx), + hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(qi), x.Cols() * x.ElementBytes()); return image_token_position; } @@ -169,34 +165,27 @@ EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt, const auto embedding_span = MakeSpan(weights_t->Row(0), embedding_ofs + model_dim); const hn::ScalableTag df; - DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(batch_idx), + DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(qi), model_dim); - MulByConst(emb_scaling * weights_t->Scale(), x.Row(batch_idx), model_dim); + MulByConst(emb_scaling * weights_t->Scale(), x.Row(qi), model_dim); }); if (model_config.absolute_pe) { - AddAbsolutePositionalEmbeddings(x.Row(batch_idx), model_dim, pos); + AddAbsolutePositionalEmbeddings(x.Row(qi), model_dim, pos); } return image_token_position; } -// Incremented in-place by Prefill* and DecodeStepT. -using QueriesMutablePos = hwy::Span; - // Populates KV cache for batches of tokens from one query at a time. This is // called if prompts are longer than the query batch size, and also in // prefix-LM mode (end > 0), which must see all tokens in one batch. -static HWY_NOINLINE void PrefillTBatch( - const size_t query_idx_start, const QueriesPromptTokens& queries_prompt, - const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end, - const ModelConfig& config, const RuntimeConfig& runtime_config, - const ModelWeightsPtrs& weights, Activations& activations, - const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos) { +static HWY_NOINLINE void PrefillTBatch(const ModelConfig& config, + const RuntimeConfig& runtime_config, + const ModelWeightsPtrs& weights, + Activations& activations, QBatch& qbatch, + MatMulEnv& env, + hwy::BitSet4096<>& non_eos) { PROFILER_ZONE("Gen.PrefillT"); - const size_t num_queries = queries_prompt.size(); - HWY_DASSERT(num_queries == queries_pos.size()); - HWY_DASSERT(num_queries == queries_prefix_end.size()); - HWY_DASSERT(num_queries == kv_caches.size()); // Batches are important for amortizing loading weights over multiple tokens. // This is possible in prefill because we know all tokens beforehand, whereas @@ -210,19 +199,16 @@ static HWY_NOINLINE void PrefillTBatch( const size_t max_tbatch_size = runtime_config.prefill_tbatch_size; // For each query. `qi` is within the batch, not the global query index. - for (size_t qi = 0; qi < num_queries; ++qi) { + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { non_eos.Set(qi); - // Single query at a time, so pass slices of the spans because - // GemmaAttention will only access the first KV cache and position. - QueriesPos single_query_pos(&queries_pos[qi], 1); - QueriesPos single_query_prefix_end(&queries_prefix_end[qi], 1); - KVCaches single_kv_cache(&kv_caches[qi], 1); + // One query at a time, batching will be the query's prompt tokens. + QBatch qbatch_1 = qbatch.Single(qi); - const size_t prompt_size = queries_prompt[qi].size(); + const size_t prompt_size = qbatch_1.Prompt(0).size(); // In autoregressive mode, we don't need to prefill the last token, so - 1. size_t prefill_this_query = prompt_size - 1; - const size_t prefix_end_this_query = queries_prefix_end[qi]; + const size_t prefix_end_this_query = qbatch_1.PrefixEnd(0); // We can't attend beyond the prompt_size. HWY_ASSERT(prefix_end_this_query <= prompt_size); // Special case: if the prefix includes the last token, we need to prefill @@ -251,9 +237,9 @@ static HWY_NOINLINE void PrefillTBatch( // Fill activations.x (much faster than TransformerLayer). size_t image_token_position = 0; for (size_t ti = 0; ti < tbatch_size; ++ti) { - const size_t pos = queries_pos[qi] + ti; + const size_t pos = qbatch_1.Pos(0) + ti; const size_t pos_in_prompt = tbatch_start + ti; - const int token = queries_prompt[qi][pos_in_prompt]; + const int token = qbatch_1.Prompt(0)[pos_in_prompt]; image_token_position = EmbedMMToken( token, ti, pos, pos_in_prompt, config, weights, activations.x, runtime_config.image_tokens, image_token_position); @@ -262,18 +248,17 @@ static HWY_NOINLINE void PrefillTBatch( // Transformer with one batch of tokens from a single query. for (size_t layer_idx = 0; layer_idx < config.layer_configs.size(); ++layer_idx) { - TransformerLayer(tbatch_size, single_query_pos, single_query_prefix_end, - layer_idx, *weights.GetLayer(layer_idx), activations, - single_kv_cache, env); + TransformerLayer(tbatch_size, layer_idx, *weights.GetLayer(layer_idx), + activations, qbatch_1, env); } // NOTE: we unconditionally call StreamToken, even if EOS. for (size_t ti = 0; ti < tbatch_size; ++ti) { - const size_t pos = queries_pos[qi] + ti; + const size_t pos = qbatch_1.Pos(0) + ti; const size_t pos_in_prompt = tbatch_start + ti; - const int token = queries_prompt[qi][pos_in_prompt]; + const int token = qbatch_1.Prompt(0)[pos_in_prompt]; if (pos_in_prompt < prompt_size - 1) { - runtime_config.StreamToken(query_idx_start + qi, pos, token, 0.0f); + runtime_config.StreamToken(qbatch_1.QueryIdx(0), pos, token, 0.0f); } else { // The last token will be streamed later and we should only get here // if we need to attend to the last token because it is in the prefix. @@ -281,7 +266,7 @@ static HWY_NOINLINE void PrefillTBatch( } } - queries_pos[qi] += tbatch_size; + qbatch_1.MutablePos(0) += tbatch_size; } // for tbatch_start if (attend_to_last_token) { // We need to rewind the position for the last token that we only @@ -290,148 +275,125 @@ static HWY_NOINLINE void PrefillTBatch( // decoding. Alternatives: (1) real masking; (2) always prefill the last // token and only generate the next one from the already prefilled // activations. - queries_pos[qi] -= 1; + qbatch_1.MutablePos(0) -= 1; } } } -// Embeds token and calls each TransformerLayer. `queries_token` is the previous -// token from each query, and `queries_pos` are their position in the sequence. +// Embeds PrevToken (one from each query) and calls each TransformerLayer. // Called by query-batched `PrefillQBatch` and `DecodeStepT`, but not the // token-batched `PrefillTBatch`. -static HWY_NOINLINE void Transformer( - const QueriesToken& queries_token, const QueriesMutablePos& queries_pos, - const QueriesPos& queries_prefix_end, const ModelConfig& config, - const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights, - Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) { - const size_t num_queries = queries_token.size(); - HWY_DASSERT(num_queries == queries_pos.size()); - HWY_DASSERT(num_queries == queries_prefix_end.size()); - +static HWY_NOINLINE void Transformer(const ModelConfig& config, + const RuntimeConfig& runtime_config, + const ModelWeightsPtrs& weights, + Activations& activations, QBatch& qbatch, + MatMulEnv& env) { if (HWY_UNLIKELY(runtime_config.layers_output)) { - for (size_t qi = 0; qi < num_queries; ++qi) { - const float token_f = queries_token[qi]; - runtime_config.layers_output(qi, queries_pos[qi], "tokens", -1, &token_f, - 1); + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + const float token_f = qbatch.PrevToken(qi); + runtime_config.layers_output(qbatch.QueryIdx(qi), qbatch.Pos(qi), + "tokens", -1, &token_f, 1); } } - for (size_t qi = 0; qi < num_queries; ++qi) { - EmbedMMToken(queries_token[qi], qi, queries_pos[qi], + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + EmbedMMToken(qbatch.PrevToken(qi), qi, qbatch.Pos(qi), /*pos_in_prompt=*/0, config, weights, activations.x); } for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) { - TransformerLayer(/*num_tokens=*/1, queries_pos, queries_prefix_end, - layer_idx, *weights.GetLayer(layer_idx), activations, - kv_caches, env); + TransformerLayer(/*num_tokens=*/1, layer_idx, *weights.GetLayer(layer_idx), + activations, qbatch, env); if (HWY_UNLIKELY(runtime_config.activations_observer)) { - runtime_config.activations_observer(queries_pos, layer_idx, activations); + runtime_config.activations_observer( + QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), layer_idx, + activations); } } } // Populates KV cache for the batch queries, one token at a time. Only called // for autoregressive (non-prefix-LM) prefill, so `queries_prefix_end` == 0. -static HWY_NOINLINE void PrefillQBatch( - const size_t query_idx_start, const QueriesPromptTokens& queries_prompt, - const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end, - const size_t max_prompt_size, const ModelConfig& config, - const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights, - Activations& activations, const KVCaches& kv_caches, MatMulEnv& env, - hwy::BitSet4096<>& non_eos) { +static HWY_NOINLINE void PrefillQBatch(const size_t max_prompt_size, + const ModelConfig& config, + const RuntimeConfig& runtime_config, + const ModelWeightsPtrs& weights, + Activations& activations, QBatch& qbatch, + MatMulEnv& env, + hwy::BitSet4096<>& non_eos) { PROFILER_ZONE("Gen.Prefill"); - const size_t num_queries = queries_prompt.size(); - HWY_DASSERT(num_queries == queries_pos.size()); - HWY_DASSERT(num_queries == queries_prefix_end.size()); - HWY_DASSERT(num_queries == activations.x.Rows()); - HWY_DASSERT(num_queries == kv_caches.size()); - hwy::BitSet4096<> prefill_active; - for (size_t qi = 0; qi < num_queries; ++qi) { - prefill_active.Set(qi); - - HWY_DASSERT(queries_prefix_end[qi] == 0); - (void)queries_prefix_end; + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + non_eos.Set(qi); + HWY_DASSERT(qbatch.PrefixEnd(qi) == 0); } - non_eos = prefill_active; // In autoregressive mode, we don't prefill the last token, hence - 1. for (size_t pos_in_prompt = 0; pos_in_prompt < max_prompt_size - 1; ++pos_in_prompt) { - // Streams that have already finished prefill no longer interleave/stream. - for (size_t qi = 0; qi < num_queries; ++qi) { - if (pos_in_prompt >= queries_prompt[qi].size() - 1) { - prefill_active.Clear(qi); - activations.gen_tokens[qi] = config.eos_id; + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + int token = config.eos_id; + if (pos_in_prompt < qbatch.Prompt(qi).size() - 1) { + token = qbatch.Prompt(qi)[pos_in_prompt]; + // Ignore StreamToken return value because requesting to stop does not + // make sense during prefill. + (void)runtime_config.StreamToken(qbatch.QueryIdx(qi), qbatch.Pos(qi), + token, 0.0f); } + + qbatch.PrevToken(qi) = token; } - // Batch := interleaved tokens, one from each non-EOS query. - prefill_active.Foreach([&](size_t qi) { - activations.gen_tokens[qi] = queries_prompt[qi][pos_in_prompt]; - }); - - // One token from each query in the batch. Increments queries_pos. + // The input (PrevToken) is one token from each query in the batch. // Do not call DecodeStepT because it computes logits for token // probabilities, which are not required for the prompt tokens. - Transformer(QueriesToken(activations.gen_tokens.data(), num_queries), - queries_pos, queries_prefix_end, config, runtime_config, - weights, activations, kv_caches, env); - - prefill_active.Foreach([&](size_t qi) { - const int token = queries_prompt[qi][pos_in_prompt]; - // Ignore any user request to stop during prefill. - (void)runtime_config.StreamToken(query_idx_start + qi, queries_pos[qi], - token, 0.0f); - queries_pos[qi] += 1; - }); - } // pos_in_prompt + Transformer(config, runtime_config, weights, activations, qbatch, env); + } } -// Also writes the token to activations.gen_tokens for subsequent DecodeStepT, -// and updates `non_eos` if the query is at the end of its sequence. -static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token, - const float prob, const ModelConfig& config, +// Calls `StreamToken`, writes the token to `PrevToken` for use by subsequent +// `DecodeStepT`, and increments `MutablePos`. Also updates `non_eos` if the +// query is at the end of its sequence. +static void StreamAndUpdateEOS(const size_t qi, int token, const float prob, + const ModelConfig& config, const RuntimeConfig& runtime_config, - Activations& activations, - hwy::BitSet4096<>& non_eos) { - HWY_DASSERT(non_eos.Get(qi)); + QBatch& qbatch, hwy::BitSet4096<>& non_eos) { + HWY_DASSERT(non_eos.Get(qi)); // otherwise, should not be called. - // User decided to stop: set next token to primary EOS. - if (HWY_UNLIKELY(!runtime_config.StreamToken(qi, pos, token, prob))) { + if (HWY_UNLIKELY(!runtime_config.StreamToken(qbatch.QueryIdx(qi), + qbatch.Pos(qi), token, prob))) { + // User decided to stop: set token to primary EOS to trigger IsEOS below. token = config.eos_id; HWY_DASSERT(config.IsEOS(token)); } - // Primary or secondary EOS: mark query as EOS. - if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi); + qbatch.PrevToken(qi) = token; + qbatch.MutablePos(qi) += 1; - activations.gen_tokens[qi] = token; + // Primary or secondary EOS: mark query as EOS, but still increment (for + // multi-turn, we should still keep the prior EOS). + if (HWY_UNLIKELY(config.IsEOS(token))) non_eos.Clear(qi); } // For a batch of queries, runs Transformer, computes logits, samples and // streams the token. -static void DecodeStepT( - const size_t query_idx_start, const QueriesPromptTokens& queries_prompt, - const QueriesMutablePos& queries_mutable_pos, - const QueriesPos& queries_prefix_end, const ModelConfig& config, - const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights, - const SampleFunc& sample_token, Activations& activations, - const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos, - TimingInfo& timing_info) { - const size_t num_queries = queries_prompt.size(); - HWY_DASSERT(num_queries == activations.x.Rows()); +static void DecodeStepT(const ModelConfig& config, + const RuntimeConfig& runtime_config, + const ModelWeightsPtrs& weights, + const SampleFunc& sample_token, + Activations& activations, QBatch& qbatch, + MatMulEnv& env, hwy::BitSet4096<>& non_eos, + TimingInfo& timing_info) { + HWY_DASSERT(qbatch.Size() == activations.x.Rows()); - Transformer(QueriesToken(activations.gen_tokens.data(), num_queries), - queries_mutable_pos, queries_prefix_end, config, runtime_config, - weights, activations, kv_caches, env); + Transformer(config, runtime_config, weights, activations, qbatch, env); RMSNormInplaceBatched(weights.final_norm_scale, activations.x); if (HWY_UNLIKELY(runtime_config.activations_observer)) { - runtime_config.activations_observer(queries_mutable_pos, -1, activations); + runtime_config.activations_observer( + QueriesPos(&qbatch.MutablePos(0), qbatch.Size()), -1, activations); } { @@ -447,10 +409,8 @@ static void DecodeStepT( const TokenAndProb tp = sample_token(logits, config.vocab_size); timing_info.NotifyGenerated(); - StreamAndUpdateEOS(query_idx_start + qi, queries_mutable_pos[qi], tp.token, - tp.prob, config, runtime_config, activations, non_eos); - - if (non_eos.Get(qi)) queries_mutable_pos[qi] += 1; + StreamAndUpdateEOS(qi, tp.token, tp.prob, config, runtime_config, qbatch, + non_eos); }); } @@ -477,46 +437,24 @@ ChooseSampleFunc(const RuntimeConfig& runtime_config) { }; } -// Generates one continuation for each query in `queries_prompt`, which is one -// qbatch whose size is at most the `batch_size` passed to `activations` ctor. -// -// `queries_pos` stores the KV cache position for each query. In the first turn -// of a chat, pos = 0; we increment each query's position after each token. -// -// `query_idx_start` is the query_idx of the first query in the batch, so that -// `StreamFunc` gets the global query index, not relative to the batch. -static void GenerateT( - const size_t query_idx_start, const QueriesPromptTokens& queries_prompt, - const QueriesPos& queries_pos_in, const QueriesPos& queries_prefix_end, - const ModelConfig& config, const RuntimeConfig& runtime_config, - const ModelWeightsPtrs& weights, Activations& activations, - const KVCaches& kv_caches, MatMulEnv& env, TimingInfo& timing_info) { - const size_t num_queries = queries_prompt.size(); - HWY_ASSERT(num_queries <= 4096); // non_eos uses `BitSet4096`. - HWY_ASSERT(num_queries == queries_pos_in.size()); - HWY_ASSERT(num_queries == queries_prefix_end.size()); - HWY_ASSERT(num_queries <= activations.x.Rows()); - HWY_ASSERT(num_queries == kv_caches.size()); - +// Decode: generates one continuation token for each query in `qbatch`. +static void GenerateT(const ModelConfig& config, + const RuntimeConfig& runtime_config, + const ModelWeightsPtrs& weights, Activations& activations, + QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) { // Griffin assumes that the recurrent block cache is zero-initialized. - for (size_t i = 0; i < kv_caches.size(); ++i) { - if (queries_pos_in[i] == 0) { - kv_caches[i].ZeroGriffinCache(); // No-op for non-Griffin models. + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + if (qbatch.MutablePos(qi) == 0) { + qbatch.KV(qi).ZeroGriffinCache(); // No-op for non-Griffin models. } } - // Copy so we can increment without requiring users to pass in a mutable span. - std::vector queries_pos_copy(queries_pos_in.cbegin(), - queries_pos_in.cend()); - const QueriesMutablePos queries_mutable_pos(queries_pos_copy.data(), - queries_pos_copy.size()); - size_t max_prompt_size = 0; bool all_prefix_end_are_zero = true; - size_t prefill_tokens = 0; - const size_t seq_len = kv_caches[0].SeqLen(); - for (size_t qi = 0; qi < num_queries; ++qi) { - const PromptTokens& prompt = queries_prompt[qi]; + size_t prefill_tokens = 0; // only for timing. + const size_t seq_len = qbatch.KV(0).SeqLen(); + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + const PromptTokens& prompt = qbatch.Prompt(qi); max_prompt_size = HWY_MAX(max_prompt_size, prompt.size()); // Prefill stops before size - 1 because the last prompt token is the @@ -526,43 +464,38 @@ static void GenerateT( // Sanity check: prompts should not be empty, nor start with EOS. HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id); - all_prefix_end_are_zero &= queries_prefix_end[qi] == 0; + all_prefix_end_are_zero &= qbatch.PrefixEnd(qi) == 0; // We use a single divisor, so all sequence lengths must be the same. - HWY_ASSERT(kv_caches[qi].SeqLen() == seq_len); + HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len); } HWY_ASSERT(prefill_tokens < seq_len); activations.div_seq_len = hwy::Divisor(static_cast(seq_len)); // Lacks a constructor to bulk-set, hence initialized by Prefill* which have // qi loops anyway. - hwy::BitSet4096<> non_eos; + hwy::BitSet4096<> non_eos; // indexed by qi timing_info.prefill_start = hwy::platform::Now(); // Batch over the larger of prompt length, or queries. - if ((num_queries > max_prompt_size) && all_prefix_end_are_zero) { - activations.SetBatchSize(num_queries); // required before PrefillQBatch - PrefillQBatch(query_idx_start, queries_prompt, queries_mutable_pos, - queries_prefix_end, max_prompt_size, config, runtime_config, - weights, activations, kv_caches, env, non_eos); + if ((qbatch.Size() > max_prompt_size) && all_prefix_end_are_zero) { + activations.SetBatchSize(qbatch.Size()); // required before PrefillQBatch + PrefillQBatch(max_prompt_size, config, runtime_config, weights, activations, + qbatch, env, non_eos); } else { - PrefillTBatch(query_idx_start, queries_prompt, queries_mutable_pos, - queries_prefix_end, config, runtime_config, weights, - activations, kv_caches, env, non_eos); - activations.SetBatchSize(num_queries); // Restore after PrefillTBatch. + PrefillTBatch(config, runtime_config, weights, activations, qbatch, env, + non_eos); + activations.SetBatchSize(qbatch.Size()); // Restore after PrefillTBatch. } - HWY_DASSERT(num_queries == non_eos.Count()); + HWY_DASSERT(non_eos.Count() == qbatch.Size()); timing_info.NotifyPrefill(prefill_tokens); // queries_pos have been incremented by Prefill. // Stream the last prompt token from each query, fill activations.gen_tokens. - for (size_t qi = 0; qi < num_queries; ++qi) { - const size_t last_token_pos_in_prompt = - queries_mutable_pos[qi] - queries_pos_in[qi]; - StreamAndUpdateEOS(query_idx_start + qi, queries_mutable_pos[qi], - queries_prompt[qi][last_token_pos_in_prompt], 0.0f, - config, runtime_config, activations, non_eos); - // No incrementing queries_mutable_pos[qi]. + for (size_t qi = 0; qi < qbatch.Size(); ++qi) { + const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); + StreamAndUpdateEOS(qi, qbatch.Prompt(qi)[last_pos_in_prompt], 0.0f, config, + runtime_config, qbatch, non_eos); } size_t max_gen_steps = runtime_config.max_generated_tokens; @@ -577,10 +510,8 @@ static void GenerateT( { timing_info.generate_start = hwy::platform::Now(); for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { - DecodeStepT(query_idx_start, queries_prompt, queries_mutable_pos, - queries_prefix_end, config, runtime_config, weights, - sample_token, activations, kv_caches, env, non_eos, - timing_info); + DecodeStepT(config, runtime_config, weights, sample_token, activations, + qbatch, env, non_eos, timing_info); } timing_info.NotifyGenerateDone(); } @@ -591,61 +522,38 @@ void GenerateSingleT(const PromptTokens& prompt, size_t pos, size_t prefix_end, const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights, KVCache& kv_cache, MatMulEnv& env, TimingInfo& timing_info) { - constexpr size_t kNumQueries = 1; - const size_t qbatch_start = 0; + Activations activations(config, runtime_config.prefill_tbatch_size, + kv_cache.SeqLen(), env.row_ptrs); - const size_t max_batch_size = - HWY_MAX(kNumQueries, runtime_config.prefill_tbatch_size); - // TODO: move into Gemma? - Activations activations(config, max_batch_size, env.row_ptrs); - - const QueriesPromptTokens queries_prompt(&prompt, kNumQueries); - QueriesPos queries_pos(&pos, kNumQueries); - const QueriesPos queries_prefix_end(&prefix_end, kNumQueries); - const KVCaches kv_caches{&kv_cache, kNumQueries}; - - GenerateT(qbatch_start, queries_prompt, queries_pos, queries_prefix_end, - config, runtime_config, weights, activations, kv_caches, env, + AllQueries all_queries(prompt, pos, prefix_end, + hwy::Span(&kv_cache, 1)); + QBatch qbatch(/*start=*/0, /*max_size=*/1, all_queries); + GenerateT(config, runtime_config, weights, activations, qbatch, env, timing_info); } // Splits the input into batches of at most `runtime_config.decode_qbatch_size` // queries, and calls `GenerateT` on each batch. -void GenerateBatchT(const QueriesPromptTokens& queries_prompt, - const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, - const ModelConfig& config, +void GenerateBatchT(const ModelConfig& config, const RuntimeConfig& runtime_config, - const ModelWeightsPtrs& weights, const KVCaches& kv_caches, + const ModelWeightsPtrs& weights, AllQueries& all_queries, MatMulEnv& env, TimingInfo& timing_info) { - const size_t num_queries = queries_prompt.size(); - HWY_ASSERT(queries_pos.size() == num_queries); - HWY_ASSERT(kv_caches.size() >= num_queries); + const size_t max_batch_size = HWY_MAX(runtime_config.decode_qbatch_size, + runtime_config.prefill_tbatch_size); + Activations activations(config, max_batch_size, + all_queries[0].kv_cache.SeqLen(), env.row_ptrs); - const size_t max_qbatch_size = runtime_config.decode_qbatch_size; - const size_t max_batch_size = - HWY_MAX(max_qbatch_size, runtime_config.prefill_tbatch_size); - Activations activations(config, max_batch_size, env.row_ptrs); - - for (size_t qbatch_start = 0; qbatch_start < num_queries; - qbatch_start += max_qbatch_size) { - // Generate one batch of tokens from `qbatch_size` queries. - const size_t qbatch_size = - HWY_MIN(num_queries - qbatch_start, max_qbatch_size); - const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start], - qbatch_size); - QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size); - const QueriesPos qbatch_prefix_end(&queries_prefix_end[qbatch_start], - qbatch_size); - const KVCaches qbatch_kv(&kv_caches[qbatch_start], qbatch_size); - GenerateT(qbatch_start, qbatch_prompts, qbatch_pos, qbatch_prefix_end, - config, runtime_config, weights, activations, qbatch_kv, env, + for (size_t start = 0; start < all_queries.NumQueries(); + start += runtime_config.decode_qbatch_size) { + QBatch qbatch(start, runtime_config.decode_qbatch_size, all_queries); + // Generate a batch of one token for each of `qbatch.Size()` queries. + GenerateT(config, runtime_config, weights, activations, qbatch, env, timing_info); } } void GenerateImageTokensT(const ModelConfig& config, - const RuntimeConfig& runtime_config, + const RuntimeConfig& runtime_config, size_t seq_len, const ModelWeightsPtrs& weights, const Image& image, ImageTokens& image_tokens, MatMulEnv& env) { if (config.vit_config.layer_configs.empty()) { @@ -656,7 +564,8 @@ void GenerateImageTokensT(const ModelConfig& config, const size_t num_tokens = vit_config.max_seq_len; prefill_runtime_config.prefill_tbatch_size = num_tokens / (vit_config.pool_dim * vit_config.pool_dim); - Activations prefill_activations(vit_config, num_tokens, env.row_ptrs); + Activations prefill_activations(vit_config, num_tokens, num_tokens, + env.row_ptrs); // Weights are for the full PaliGemma model, not just the ViT part. PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, prefill_activations, env); @@ -714,36 +623,25 @@ void Gemma::Generate(const RuntimeConfig& runtime_config, } void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, - const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, - const KVCaches& kv_caches, + AllQueries& all_queries, TimingInfo& timing_info) const { - // If we did not get passed prefix ends (size 0), assume 0 and pass that on. - QueriesPos queries_prefix_end_or_zeros = queries_prefix_end; - std::vector prefix_end_vec; - if (queries_prefix_end.size() == 0) { // hwy::Span lacks empty() - prefix_end_vec.resize(queries_prompt.size(), 0); - queries_prefix_end_or_zeros = - QueriesPos(prefix_end_vec.data(), prefix_end_vec.size()); - } - env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); - HWY_DYNAMIC_DISPATCH(GenerateBatchT)( - queries_prompt, queries_pos, queries_prefix_end_or_zeros, model_.Config(), - runtime_config, weights_, kv_caches, env_, timing_info); + HWY_DYNAMIC_DISPATCH(GenerateBatchT)(model_.Config(), runtime_config, + weights_, all_queries, env_, + timing_info); env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config, - const Image& image, + size_t seq_len, const Image& image, ImageTokens& image_tokens) const { env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); - HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)( - model_.Config(), runtime_config, weights_, image, image_tokens, env_); + HWY_DYNAMIC_DISPATCH(GenerateImageTokensT)(model_.Config(), runtime_config, + seq_len, weights_, image, + image_tokens, env_); env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } diff --git a/gemma/gemma.h b/gemma/gemma.h index b866c41..6971802 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -38,6 +38,129 @@ namespace gcpp { +struct PerQuery { + PromptTokens prompt; + + // Position in the KV cache: initially zero for the first turn, or when + // multi-turn is NOT desired. Incremented by prefill and `StreamAndUpdateEOS`. + size_t mutable_pos; + // Allows computing the last prefill token as `mutable_pos - initial_pos`, + // which might differ from `prompt.size() - 1` for prefix-LM. + size_t initial_pos; + // Zero for causal attention, or the end of the prefix for prefix-LM style + // attention in Paligemma. + size_t prefix_end; + + KVCache& kv_cache; + + // Previous token generated for this query, or the last prompt token. Will be + // fed into the next Transformer() call. + int prev_token = 0; +}; + +// Array of `PerQuery`. Referenced by `QBatch` and passed to `GenerateBatch`. +struct AllQueries { + // For `GenerateSingleT`: same prompt/pos, replicated for each KV cache. + AllQueries(const PromptTokens& prompt, size_t pos, size_t prefix_end, + const hwy::Span& kv_caches) { + per_query_.reserve(kv_caches.size()); + for (size_t i = 0; i < kv_caches.size(); ++i) { + HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); + per_query_.push_back(PerQuery{ + .prompt = prompt, + .mutable_pos = pos, + .initial_pos = pos, + .prefix_end = prefix_end, + .kv_cache = kv_caches[i], + }); + } + } + + // Batch of queries with initial position set to zero. Causal attention + // is requested via empty or all-zero `prefix_end`. + AllQueries( + const hwy::Span& prompts, + const hwy::Span& kv_caches, + const hwy::Span& prefix_end = hwy::Span()) { + HWY_ASSERT(prompts.size() == kv_caches.size()); + HWY_ASSERT(prompts.size() == prefix_end.size() || prefix_end.size() == 0); + per_query_.reserve(kv_caches.size()); + for (size_t i = 0; i < kv_caches.size(); ++i) { + HWY_ASSERT(kv_caches[i].SeqLen() == kv_caches[0].SeqLen()); + per_query_.push_back(PerQuery{ + .prompt = prompts[i], + .mutable_pos = 0, + .initial_pos = 0, + .prefix_end = prefix_end.size() == 0 ? 0 : prefix_end[i], + .kv_cache = kv_caches[i], + }); + } + } + + size_t NumQueries() const { return per_query_.size(); } + + PerQuery& operator[](size_t query_idx) { + HWY_DASSERT(query_idx < NumQueries()); + return per_query_[query_idx]; + } + const PerQuery& operator[](size_t query_idx) const { + HWY_DASSERT(query_idx < NumQueries()); + return per_query_[query_idx]; + } + + private: + std::vector per_query_; +}; + +// View into AllQueries: either a batch of queries, or a single query for use +// in PrefillTBatch or GenerateSingleT. Cheap to create because it holds a +// reference to AllQueries. +class QBatch { + public: + QBatch(size_t start, size_t max_size, AllQueries& queries) + : start_(start), + max_size_(max_size), + queries_(queries), + size_(HWY_MIN(max_size_, queries_.NumQueries() - start_)) { + HWY_ASSERT(max_size_ <= 4096); // non_eos uses `BitSet4096`. + HWY_DASSERT(size_ != 0); + HWY_DASSERT(start_ + size_ <= queries_.NumQueries()); + } + + // Returns a single-query view starting at `qi` relative to this batch. + QBatch Single(size_t qi) const { return QBatch(start_ + qi, 1, queries_); } + + // How many queries in this batch, <= `queries_.NumQueries()` and `max_size_`. + size_t Size() const { return size_; } + + // Returns index for use with `AllQueries` and `BatchStreamToken`. + size_t QueryIdx(size_t qi) const { + HWY_DASSERT(qi < size_); + return start_ + qi; + } + + // Accessor functions to bridge the previous SoA and current AoS layout. + const PromptTokens& Prompt(size_t qi) const { + return queries_[QueryIdx(qi)].prompt; + } + size_t Pos(size_t qi) const { return queries_[QueryIdx(qi)].mutable_pos; } + size_t& MutablePos(size_t qi) { return queries_[QueryIdx(qi)].mutable_pos; } + size_t InitialPos(size_t qi) const { + return queries_[QueryIdx(qi)].initial_pos; + } + size_t PrefixEnd(size_t qi) const { + return queries_[QueryIdx(qi)].prefix_end; + } + KVCache& KV(size_t qi) const { return queries_[QueryIdx(qi)].kv_cache; } + int& PrevToken(size_t qi) { return queries_[QueryIdx(qi)].prev_token; } + + private: + size_t start_; + size_t max_size_; + AllQueries& queries_; + size_t size_; +}; + struct TimingInfo { // be sure to populate prefill_start before calling NotifyPrefill. void NotifyPrefill(size_t tokens) { @@ -100,8 +223,6 @@ struct TimingInfo { // Returns the `MatMulEnv` after calling `SetArgs`. MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args); -using KVCaches = hwy::Span; - class Gemma { public: // Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`. @@ -133,24 +254,11 @@ class Gemma { size_t pos, size_t prefix_end, KVCache& kv_cache, TimingInfo& timing_info) const; - // `queries_pos` are the positions in the KV cache. Users are responsible for - // incrementing them in `BatchStreamFunc`, or setting to zero for single-turn. void GenerateBatch(const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, - const QueriesPos& queries_pos, const KVCaches& kv_caches, - TimingInfo& timing_info) const { - GenerateBatch(runtime_config, queries_prompt, queries_pos, - /*queries_prefix_end=*/{}, kv_caches, timing_info); - } - // For prefix-LM style attention, we can pass the ends of the prefixes. - void GenerateBatch(const RuntimeConfig& runtime_config, - const QueriesPromptTokens& queries_prompt, - const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, - const KVCaches& kv_caches, TimingInfo& timing_info) const; + AllQueries& all_queries, TimingInfo& timing_info) const; // Generates the image tokens by running the image encoder ViT. - void GenerateImageTokens(const RuntimeConfig& runtime_config, + void GenerateImageTokens(const RuntimeConfig& runtime_config, size_t seq_len, const Image& image, ImageTokens& image_tokens) const; private: diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index e4c8a33..7f8b11b 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -82,8 +82,9 @@ using ImageTokens = MatStorageT; // true to continue generation. using StreamFunc = std::function; // BatchStreamFunc is called with (query_idx, pos, token, probability). -// For prompt tokens, probability is 0.0f. -// StreamFunc should return false to stop generation and true to continue. +// For prompt tokens, probability is 0.0f. Generation continues if this returns +// true and stops if it returns false. Note that query_idx is absolute, not +// relative to the batch. using BatchStreamFunc = std::function; // If not empty, AcceptFunc is called with token. It should return false for // tokens you don't want to generate and true for tokens you want to generate. @@ -112,8 +113,8 @@ using ActivationsObserverFunc = // RuntimeConfig holds configuration for a single generation run. // TODO: move into InferenceArgs, use that directly. struct RuntimeConfig { - // If not empty, batch_stream_token is called for each token in the batch, - // instead of stream_token. + // If non-null, `batch_stream_token` is called for each token in the batch, + // otherwise `stream_token`. `query_idx` is absolute, not batch-relative. bool StreamToken(size_t query_idx, size_t pos, int token, float prob) const { if (batch_stream_token) { return batch_stream_token(query_idx, pos, token, prob); @@ -189,9 +190,9 @@ struct InferenceArgs : public ArgsBase { "developer/debug info).\n Default = 1.", 1); // Changed verbosity level to 1 since it's user-facing - visitor(seq_len, "seq_len", size_t{2048}, + visitor(seq_len, "seq_len", size_t{8192}, "Sequence length, capped by ModelConfig.max_seq_len."); - visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, + visitor(max_generated_tokens, "max_generated_tokens", size_t{4096}, "Maximum number of tokens to generate."); visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256}, diff --git a/gemma/griffin.cc b/gemma/griffin.cc index 59faca4..f7b02e2 100644 --- a/gemma/griffin.cc +++ b/gemma/griffin.cc @@ -39,16 +39,10 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -// Different functions use different naming conventions for the number of -// tokens. Functions that are query-independent, such as RMSNorm*, call the -// count `num_interleaved`. Functions that are query-dependent, such as -// `Attention`, use separate `num_tokens` and `num_queries`. `num_tokens` is the -// number of tokens from one query: 1 for decode, otherwise prefill_tbatch_size. - -void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, - size_t griffin_layer, Activations& activations, +void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, const LayerWeightsPtrs* layer_weights, - const KVCaches& kv_caches, MatMulEnv& env) { + Activations& activations, QBatch& qbatch, + MatMulEnv& env) { PROFILER_ZONE("Gen.Griffin"); hwy::ThreadPool& pool = env.ctx.pools.Pool(0); namespace hn = hwy::HWY_NAMESPACE; @@ -64,9 +58,8 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, const size_t kHeadDim = model_dim / heads; const size_t kMatrixSize = kHeadDim * kHeadDim; - const size_t num_queries = queries_pos.size(); - const hwy::Divisor div_num_q(static_cast(num_queries)); - const size_t num_interleaved = num_tokens * num_queries; + const size_t num_interleaved = num_tokens * qbatch.Size(); + const hwy::Divisor div_qbatch(static_cast(qbatch.Size())); // X / Y linear layers. // TODO: MatMul @@ -91,17 +84,17 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, // Conv1D. for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; ++interleaved_idx) { - const size_t query_idx = div_num_q.Remainder(interleaved_idx); - const size_t batch_idx = div_num_q.Divide(interleaved_idx); - const size_t pos = queries_pos[query_idx] + batch_idx; - float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx); + const size_t qi = div_qbatch.Remainder(interleaved_idx); + const size_t batch_idx = div_qbatch.Divide(interleaved_idx); + const size_t pos = qbatch.Pos(qi) + batch_idx; + float* HWY_RESTRICT x = activations.griffin_x.Row(qi); // cache[i] = input at time t-i. float* HWY_RESTRICT cache[kMaxConv1DWidth]; cache[0] = x; for (size_t i = 1; i < conv_1d_width; i++) { cache[i] = - kv_caches[query_idx].conv1d_cache.Row(griffin_layer) + + qbatch.KV(qi).conv1d_cache.Row(griffin_layer) + ((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim; } for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) { @@ -127,16 +120,16 @@ void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, // RGLRU for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; ++interleaved_idx) { - const size_t query_idx = div_num_q.Remainder(interleaved_idx); - const size_t batch_idx = div_num_q.Divide(interleaved_idx); - const size_t pos = queries_pos[query_idx] + batch_idx; + const size_t qi = div_qbatch.Remainder(interleaved_idx); + const size_t batch_idx = div_qbatch.Divide(interleaved_idx); + const size_t pos = qbatch.Pos(qi) + batch_idx; - float* HWY_RESTRICT x = activations.griffin_x.Row(query_idx); - float* HWY_RESTRICT y = activations.griffin_y.Row(query_idx); - float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(query_idx); - float* HWY_RESTRICT a = activations.griffin_multiplier.Row(query_idx); + float* HWY_RESTRICT x = activations.griffin_x.Row(qi); + float* HWY_RESTRICT y = activations.griffin_y.Row(qi); + float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(qi); + float* HWY_RESTRICT a = activations.griffin_multiplier.Row(qi); float* HWY_RESTRICT rnn_state = - kv_caches[query_idx].rglru_cache.Row(griffin_layer); + qbatch.KV(qi).rglru_cache.Row(griffin_layer); pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { size_t head_offset = head * kHeadDim; diff --git a/gemma/griffin.h b/gemma/griffin.h index fea514c..0ba6a23 100644 --- a/gemma/griffin.h +++ b/gemma/griffin.h @@ -26,13 +26,13 @@ namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. -#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \ - namespace NAMESPACE { \ - void GriffinRecurrent(const QueriesPos& queries_pos, size_t num_tokens, \ - size_t griffin_layer, Activations& activations, \ - const LayerWeightsPtrs* layer_weights, \ - const KVCaches& kv_caches, MatMulEnv& env); \ - /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ +#define GEMMA_DECL_GRIFFIN(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, \ + const LayerWeightsPtrs* layer_weights, \ + Activations& activations, QBatch& qbatch, \ + MatMulEnv& env); \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE // Function declarations for each SIMD target. Allows direct call from the diff --git a/gemma/run.cc b/gemma/run.cc index dbdeb61..9d6df42 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -120,7 +120,8 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, .verbosity = inference.verbosity, .use_spinning = threading.spin}; double image_tokens_start = hwy::platform::Now(); - gemma.GenerateImageTokens(runtime_config, image, image_tokens); + gemma.GenerateImageTokens(runtime_config, kv_cache.SeqLen(), image, + image_tokens); if (inference.verbosity >= 1) { double image_tokens_duration = hwy::platform::Now() - image_tokens_start; fprintf(stderr, diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index 9491475..e883379 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -57,7 +57,8 @@ class PaliGemmaTest : public ::testing::Test { image.Resize(image_size, image_size); RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0}; - gemma.GenerateImageTokens(runtime_config, image, *image_tokens_); + gemma.GenerateImageTokens(runtime_config, s_env->MutableKVCache().SeqLen(), + image, *image_tokens_); } std::string GemmaReply(const std::string& prompt_text) const { @@ -107,12 +108,11 @@ class PaliGemmaTest : public ::testing::Test { TEST_F(PaliGemmaTest, QueryObjects) { ASSERT_NE(s_env->GetGemma(), nullptr); const char* question = "answer en What objects are in the image?"; - const char* expected_substring = "Building, Tower"; // 3B PT 224, 10B Mix 224 + // 3B PT/Mix 224, 10B Mix 224 + const char* expected_substring = "Building, Tower"; const Model model = s_env->GetGemma()->GetModelConfig().model; if (model == Model::PALIGEMMA2_3B_448) { expected_substring = "Lake."; - } else if (model == Model::PALIGEMMA2_3B_224) { - expected_substring = "Cloud, Water."; } else if (model == Model::PALIGEMMA2_10B_224) { expected_substring = "Building."; } diff --git a/python/gemma_py.cc b/python/gemma_py.cc index 23b9b99..c8f5192 100644 --- a/python/gemma_py.cc +++ b/python/gemma_py.cc @@ -190,7 +190,8 @@ class GemmaModel { gcpp::MatPadding::kOdd)); gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(), .verbosity = 0}; - gemma.GenerateImageTokens(runtime_config, c_image, *image_tokens_); + gemma.GenerateImageTokens(runtime_config, gemma_.MutableKVCache().SeqLen(), + c_image, *image_tokens_); } // Generates a response to the given prompt, using the last set image.