From c027a45a2e862e97e34eb464e7bcb429369e5929 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 11 Jun 2025 09:48:48 -0700 Subject: [PATCH] MatPtr-ify KV, shared div_seq_len, --seq_len flag PiperOrigin-RevId: 770194455 --- BUILD.bazel | 1 + DEVELOPERS.md | 14 +--- README.md | 7 +- evals/benchmark.cc | 4 +- evals/benchmark_helper.cc | 5 +- examples/hello_world/run.cc | 2 +- examples/simplified_gemma/gemma.hpp | 12 +-- gemma/activations.h | 15 ++-- gemma/attention.cc | 87 ++++++++++----------- gemma/attention.h | 15 ++-- gemma/bindings/context.cc | 16 ++-- gemma/bindings/context.h | 25 ++---- gemma/configs.cc | 39 +++++----- gemma/configs.h | 6 +- gemma/gemma.cc | 113 +++++++++++++--------------- gemma/gemma.h | 2 + gemma/gemma_args.h | 11 ++- gemma/kv_cache.cc | 66 ++++++++-------- gemma/kv_cache.h | 27 ++++--- gemma/run.cc | 2 +- gemma/vit.cc | 14 ++-- python/configs.cc | 2 +- 22 files changed, 226 insertions(+), 259 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 82e8878..0e4df32 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -447,6 +447,7 @@ cc_library( hdrs = ["gemma/kv_cache.h"], deps = [ ":configs", + ":gemma_args", ":mat", "@highway//:hwy", ], diff --git a/DEVELOPERS.md b/DEVELOPERS.md index a19bf9c..5d70fdb 100644 --- a/DEVELOPERS.md +++ b/DEVELOPERS.md @@ -101,18 +101,6 @@ directly. For other models, `gemma_export_main.py` is not yet open sourced. -## Compile-Time Flags (Advanced) - -There are several compile-time flags to be aware of (note these may or may not -be exposed to the build system): - -- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV - Cache. The default is 4096 tokens but can be overridden. This is not exposed - through `CMakeLists.txt` yet. - -In the medium term this will likely be deprecated in favor of handling options -at runtime - dynamically resizing the KV cache as needed. - ## Using gemma.cpp as a Library (Advanced) Unless you are doing lower level implementations or research, from an @@ -165,7 +153,7 @@ constrained decoding type of use cases where you want to force the generation to fit a grammar. If you're not doing this, you can send an empty lambda or `std::function` as a no-op which is what `run.cc` does. -### `Transformer()` implements the inference (i.e. `forward()` method in PyTorch or Jax) computation of the neural network +### `Transformer()` implements inference (i.e. `forward()` in PyTorch or Jax) For high-level applications, you might only call `model.Generate()` and never interact directly with the neural network, but if you're doing something a bit diff --git a/README.md b/README.md index 39ba907..363c173 100644 --- a/README.md +++ b/README.md @@ -322,9 +322,10 @@ model (any model with a `-pt` suffix). **What sequence lengths are supported?** -See `seq_len` in `configs.cc`. For the Gemma 3 models larger than 1B, this is -typically 32K but 128K would also work given enough RAM. Note that long -sequences will be slow due to the quadratic cost of attention. +See `max_seq_len` in `configs.cc` and `InferenceArgs.seq_len`. For the Gemma 3 +models larger than 1B, this is typically 32K but 128K would also work given +enough RAM. Note that long sequences will be slow due to the quadratic cost of +attention. **How do I convert my fine-tune to a `.sbs` compressed model file?** diff --git a/evals/benchmark.cc b/evals/benchmark.cc index c899642..aceee59 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -62,6 +62,7 @@ int BenchmarkSummary(GemmaEnv& env, const Path& text) { int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text, size_t batch_tokens) { + const Gemma& gemma = *env.GetGemma(); std::string input = ReadFileToString(text); std::vector prompt = env.Tokenize(input); std::cout << "Number of input tokens: " << prompt.size() << "\n"; @@ -73,8 +74,7 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text, size_t num_tokens = std::min(prompt.size() - pos, batch_tokens); std::vector prompt_slice(prompt.begin() + pos, prompt.begin() + pos + num_tokens); - KVCache kv_cache(env.GetGemma()->GetModelConfig(), - env.MutableConfig().prefill_tbatch_size); + KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference()); float entropy = ComputeCrossEntropy( *env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity()); total_entropy += entropy; diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 6ebf930..8f5968b 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -52,7 +52,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, : env_(MakeMatMulEnv(threading)), gemma_(loader, inference, env_) { const ModelConfig& config = gemma_.GetModelConfig(); // Only allocate one for starters because GenerateBatch might not be called. - kv_caches_.push_back(KVCache(config, inference.prefill_tbatch_size)); + kv_caches_.push_back(KVCache(config, inference)); if (inference.verbosity >= 2) { ShowConfig(loader, threading, inference, config); @@ -135,8 +135,7 @@ std::vector GemmaEnv::BatchQueryModel( // Ensure we have at least one KVCache per query. while (kv_caches_.size() < num_queries) { - kv_caches_.push_back( - KVCache(gemma_.GetModelConfig(), runtime_config_.prefill_tbatch_size)); + kv_caches_.push_back(KVCache(gemma_.GetModelConfig(), gemma_.Inference())); } gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity}; diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index faf3f42..931de8a 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -53,7 +53,7 @@ int main(int argc, char** argv) { // Instantiate model and KV Cache gcpp::MatMulEnv env(MakeMatMulEnv(threading)); gcpp::Gemma gemma(loader, inference, env); - gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size); + gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference); size_t generated = 0; // Initialize random number generator diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp index 2f6f5be..551cab4 100644 --- a/examples/simplified_gemma/gemma.hpp +++ b/examples/simplified_gemma/gemma.hpp @@ -35,12 +35,9 @@ class SimplifiedGemma { SimplifiedGemma(const gcpp::LoaderArgs& loader, const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(), const gcpp::InferenceArgs& inference = gcpp::InferenceArgs()) - : loader_(loader), - threading_(threading), - inference_(inference), - env_(MakeMatMulEnv(threading_)), - gemma_(loader_, inference_, env_), - kv_cache_(gemma_.GetModelConfig(), inference_.prefill_tbatch_size) { + : env_(MakeMatMulEnv(threading)), + gemma_(loader, inference, env_), + kv_cache_(gemma_.GetModelConfig(), inference) { // Initialize random number generator std::random_device rd; gen_.seed(rd()); @@ -91,9 +88,6 @@ class SimplifiedGemma { ~SimplifiedGemma() = default; private: - gcpp::LoaderArgs loader_; - gcpp::ThreadingArgs threading_; - gcpp::InferenceArgs inference_; gcpp::MatMulEnv env_; gcpp::Gemma gemma_; gcpp::KVCache kv_cache_; diff --git a/gemma/activations.h b/gemma/activations.h index 3d07538..e19926a 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -46,8 +46,7 @@ struct Activations { std::vector>& row_ptrs) : weights_config(config), layer_config(config.layer_configs[0]), - seq_len(config.seq_len), - cache_pos_size(config.CachePosSize()), + div_seq_len(static_cast(config.max_seq_len)), is_griffin(config.model == Model::GRIFFIN_2B), query_scale(ChooseQueryScale(config)), @@ -64,7 +63,9 @@ struct Activations { pre_att_rms_out("pre_att_rms_out", Extents2D(batch_size, config.model_dim), pad_), - att("att", Extents2D(batch_size, layer_config.heads * config.seq_len), + att("att", + Extents2D(batch_size, + layer_config.heads * div_seq_len.GetDivisor()), pad_), att_out( "att_out", @@ -141,10 +142,14 @@ struct Activations { gen_tokens.resize(batch_size); } + bool IsGlobalLayer(size_t layer_idx) const { + return weights_config.attention_window_sizes[layer_idx] == + div_seq_len.GetDivisor(); + } + const ModelConfig& weights_config; const LayerConfig& layer_config; - size_t seq_len; - size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT. + hwy::Divisor div_seq_len; bool is_griffin; float query_scale; const Extents2D none_ = Extents2D(); diff --git a/gemma/attention.cc b/gemma/attention.cc index f7f6ec1..55ac12b 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -70,9 +70,7 @@ static void PositionalEncodingQK(U* qk, const size_t qkv_dim, const PostQKType& post_qk = layer.layer_config.post_qk; // qk is either q or k, so qkv_dim is the length we operate on. const float* inv_timescale = activations.inv_timescale.PackedScale1(); - bool is_global_layer = - activations.weights_config.attention_window_sizes[layer_idx] == - activations.seq_len; + bool is_global_layer = activations.IsGlobalLayer(layer_idx); // TODO: add a config flag instead of hardcoding the model. if (is_global_layer && IsVLM(activations.weights_config.model)) { inv_timescale = activations.inv_timescale_global.PackedScale1(); @@ -116,13 +114,15 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos, // Calculates the attention outputs for a single q. void SingleDotSoftmaxWeightedSum( const size_t pos, const size_t start_pos, const size_t last_pos, - const hwy::Divisor& div_seq_len, float* HWY_RESTRICT q, - const MatPtrT& k, const MatPtrT& v, const size_t layer_idx, - const LayerWeightsPtrs& layer, const Activations& activations, - float* HWY_RESTRICT att, float* HWY_RESTRICT att_out) { + float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, + const size_t layer_idx, const LayerWeightsPtrs& layer, + const Activations& activations, float* HWY_RESTRICT att, + float* HWY_RESTRICT att_out) { const size_t qkv_dim = layer.layer_config.qkv_dim; const float att_cap = activations.weights_config.att_cap; const float query_scale = activations.query_scale; + const size_t seq_len = + static_cast(activations.div_seq_len.GetDivisor()); // Apply rope and scaling to Q. if (layer.query_norm_scale.HasPtr()) { @@ -133,15 +133,14 @@ void SingleDotSoftmaxWeightedSum( PositionalEncodingQK(q, qkv_dim, layer_idx, layer, activations, pos, query_scale); - QDotK(start_pos, last_pos, div_seq_len, q, k, att); + QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att); // SoftMax with optional SoftCap yields "probabilities" in att. - const size_t att_len = - HWY_MIN(last_pos + 1, static_cast(div_seq_len.GetDivisor())); + const size_t att_len = HWY_MIN(last_pos + 1, seq_len); MaybeLogitsSoftCap(att_cap, att, att_len); Softmax(att, att_len); - WeightedSumV(start_pos, last_pos, div_seq_len, att, v, att_out); + WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out); } // The attention window usually starts at 0 unless `pos` is larger than @@ -152,11 +151,13 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config, return pos - HWY_MIN(att_window_size - 1, pos); } -void DotSoftmaxWeightedSum( - const size_t num_tokens, const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len, - const size_t layer_idx, const LayerWeightsPtrs& layer, - Activations& activations, const KVCaches& kv_caches, NestedPools& pools) { +void DotSoftmaxWeightedSum(const size_t num_tokens, + const QueriesPos& queries_pos, + const QueriesPos& queries_prefix_end, + const size_t layer_idx, + const LayerWeightsPtrs& layer, + Activations& activations, const KVCaches& kv_caches, + NestedPools& pools) { const size_t num_queries = queries_pos.size(); const LayerConfig& layer_config = layer.layer_config; PROFILER_ZONE("Gen.Attention.DotSoftmax"); @@ -166,7 +167,8 @@ void DotSoftmaxWeightedSum( const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; const size_t cache_layer_size = layer_config.CacheLayerSize(); - const size_t cache_pos_size = activations.cache_pos_size; + const size_t seq_len = + static_cast(activations.div_seq_len.GetDivisor()); // For each head (token, query), compute Q.K, softmax, and weighted V. // TODO: nested parallelism to use more threads. @@ -183,21 +185,19 @@ void DotSoftmaxWeightedSum( float* HWY_RESTRICT q = activations.q.Row(interleaved_idx) + head * qkv_dim; float* HWY_RESTRICT att = - activations.att.Row(interleaved_idx) + head * activations.seq_len; + activations.att.Row(interleaved_idx) + head * seq_len; float* HWY_RESTRICT att_out = activations.att_out.Row(interleaved_idx) + head * qkv_dim; // Make strided views into the kv cache entries for the current // query and head. - KVCache& kv_cache = kv_caches[query_idx]; + auto& kv_cache = kv_caches[query_idx].kv_cache; const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset; - MatPtrT k("k_view", Extents2D(kv_cache.seq_len, qkv_dim)); - k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset, - /*stride=*/cache_pos_size); - MatPtrT v("v_view", Extents2D(kv_cache.seq_len, qkv_dim)); - v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim, - /*stride=*/cache_pos_size); + MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); + k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride()); + MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); + v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride()); // Find the token position in the query and calculate the range // of cache positions to attend to. @@ -211,16 +211,15 @@ void DotSoftmaxWeightedSum( last_pos = prefix_end - 1; } - SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, div_seq_len, q, k, - v, layer_idx, layer, activations, att, + SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v, + layer_idx, layer, activations, att, att_out); }); } // Fills activations.q and writes to KV cache. static HWY_INLINE void ComputeQKV( - size_t num_tokens, const QueriesPos& queries_pos, - const hwy::Divisor& div_seq_len, const size_t layer_idx, + size_t num_tokens, const QueriesPos& queries_pos, const size_t layer_idx, const LayerWeightsPtrs& layer, Activations& activations, const KVCaches& kv_caches, const int flags, MatMulEnv& env) { PROFILER_ZONE("Gen.Attention.QKV"); @@ -230,7 +229,6 @@ static HWY_INLINE void ComputeQKV( const size_t qkv_dim = layer_config.qkv_dim; const size_t kv_heads = layer_config.kv_heads; const size_t cache_layer_size = layer_config.CacheLayerSize(); - const size_t cache_pos_size = activations.cache_pos_size; // The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim, // model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows. @@ -247,11 +245,10 @@ static HWY_INLINE void ComputeQKV( const size_t query_idx = interleaved_idx % num_queries; const size_t batch_idx = interleaved_idx / num_queries; const size_t cache_pos = - div_seq_len.Remainder(queries_pos[query_idx] + batch_idx); - const size_t kv_offset = - cache_pos * cache_pos_size + layer_idx * cache_layer_size; + activations.div_seq_len.Remainder(queries_pos[query_idx] + batch_idx); env.row_ptrs[0][interleaved_idx] = reinterpret_cast( - kv_caches[query_idx].kv_cache.get() + kv_offset); + kv_caches[query_idx].kv_cache.Row(cache_pos) + + layer_idx * cache_layer_size); } kv_rows.AttachRowPtrs(env.row_ptrs[0].get()); CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2, @@ -267,12 +264,11 @@ static HWY_INLINE void ComputeQKV( const size_t query_idx = interleaved_idx % num_queries; const size_t batch_idx = interleaved_idx / num_queries; const size_t pos = queries_pos[query_idx] + batch_idx; - const size_t cache_pos = div_seq_len.Remainder(pos); - const size_t kv_offset = cache_pos * cache_pos_size + + const size_t cache_pos = activations.div_seq_len.Remainder(pos); + auto& kv_cache = kv_caches[query_idx].kv_cache; + float* HWY_RESTRICT kv = kv_cache.Row(cache_pos) + layer_idx * cache_layer_size + head * qkv_dim * 2; - KVCache& kv_cache = kv_caches[query_idx]; - float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; // Apply further processing to K. if (layer.key_norm_scale.HasPtr()) { @@ -309,9 +305,9 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, // causal attention, and must be non-null for prefix-LM style attention. void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, const QueriesPos* queries_prefix_end, - const hwy::Divisor& div_seq_len, const size_t layer_idx, - const LayerWeightsPtrs& layer, Activations& activations, - const KVCaches& kv_caches, MatMulEnv& env, int flags) { + const size_t layer_idx, const LayerWeightsPtrs& layer, + Activations& activations, const KVCaches& kv_caches, + MatMulEnv& env, int flags) { const size_t num_queries = queries_pos.size(); HWY_DASSERT(num_queries <= kv_caches.size()); @@ -330,11 +326,10 @@ void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, queries_prefix_end = &queries_prefix_end_span; } - ComputeQKV(num_tokens, queries_pos, div_seq_len, layer_idx, layer, - activations, kv_caches, flags, env); - DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end, - div_seq_len, layer_idx, layer, activations, kv_caches, - env.ctx.pools); + ComputeQKV(num_tokens, queries_pos, layer_idx, layer, activations, kv_caches, + flags, env); + DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end, layer_idx, + layer, activations, kv_caches, env.ctx.pools); SumHeads(layer, activations, env); } diff --git a/gemma/attention.h b/gemma/attention.h index b43aeae..3a01df2 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -30,24 +30,23 @@ namespace gcpp { namespace NAMESPACE { \ void SingleDotSoftmaxWeightedSum( \ const size_t pos, const size_t start_pos, const size_t last_pos, \ - const hwy::Divisor& div_seq_len, float* HWY_RESTRICT q, \ - const MatPtrT& k, const MatPtrT& v, size_t layer_idx, \ - const LayerWeightsPtrs& layer, const Activations& activations, \ - float* HWY_RESTRICT att, float* HWY_RESTRICT att_out); \ + float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ + size_t layer_idx, const LayerWeightsPtrs& layer, \ + const Activations& activations, float* HWY_RESTRICT att, \ + float* HWY_RESTRICT att_out); \ \ void DotSoftmaxWeightedSum(const size_t num_tokens, \ const QueriesPos& queries_pos, \ const QueriesPos& queries_prefix_end, \ - const hwy::Divisor& div_seq_len, \ size_t layer_idx, const LayerWeightsPtrs& layer, \ Activations& activations, \ const KVCaches& kv_caches, NestedPools& pools); \ \ void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, \ const QueriesPos* queries_prefix_end, \ - const hwy::Divisor& div_seq_len, const size_t layer_idx, \ - const LayerWeightsPtrs& layer, Activations& activations, \ - const KVCaches& kv_caches, MatMulEnv& env, int flags); \ + const size_t layer_idx, const LayerWeightsPtrs& layer, \ + Activations& activations, const KVCaches& kv_caches, \ + MatMulEnv& env, int flags); \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE diff --git a/gemma/bindings/context.cc b/gemma/bindings/context.cc index 47f4ad7..1fda34c 100644 --- a/gemma/bindings/context.cc +++ b/gemma/bindings/context.cc @@ -43,21 +43,15 @@ namespace gcpp { // ConversationData constructor implementation ConversationData::ConversationData(const ModelConfig& model_config, - size_t prefill_tbatch_size) - : model_config_ref_(model_config), - prefill_tbatch_size_(prefill_tbatch_size), - kv_cache(std::make_unique(model_config, prefill_tbatch_size)), + const InferenceArgs& inference_args) + : kv_cache(std::make_unique(model_config, inference_args)), abs_pos(0) {} // ConversationData copy constructor implementation ConversationData::ConversationData(const ConversationData& other) - : model_config_ref_(other.model_config_ref_), - prefill_tbatch_size_(other.prefill_tbatch_size_), - kv_cache(nullptr), - abs_pos(other.abs_pos) { + : kv_cache(nullptr), abs_pos(other.abs_pos) { if (other.kv_cache) { - kv_cache = std::make_unique(other.kv_cache->Copy( - other.model_config_ref_, other.prefill_tbatch_size_)); + kv_cache = std::make_unique(other.kv_cache->Copy()); } } @@ -115,7 +109,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader, LogDebug("Creating initial ConversationData"); // Create the initial ConversationData object using make_shared active_conversation = std::make_shared( - model.GetModelConfig(), inference_args.prefill_tbatch_size); + model.GetModelConfig(), inference_args); LogDebug( "Storing initial ConversationData in conversation_cache[\"default\"]"); diff --git a/gemma/bindings/context.h b/gemma/bindings/context.h index 9b6fe94..c954da3 100644 --- a/gemma/bindings/context.h +++ b/gemma/bindings/context.h @@ -31,26 +31,19 @@ #include "gemma/gemma.h" #include "gemma/gemma_args.h" +#include "gemma/kv_cache.h" #include "ops/matmul.h" // MatMulEnv #include "hwy/base.h" #include "hwy/highway.h" namespace gcpp { -// Forward declaration - use 'struct' to match definition tag -struct KVCache; - // Struct to hold data for a single conversation thread struct ConversationData { - public: - ConversationData(const ModelConfig& model_config, size_t prefill_tbatch_size); + ConversationData(const ModelConfig& model_config, + const InferenceArgs& inference_args); ConversationData(const ConversationData& other); - private: - const ModelConfig& model_config_ref_; - size_t prefill_tbatch_size_; - - public: std::unique_ptr kv_cache; size_t abs_pos = 0; }; @@ -142,8 +135,7 @@ class GemmaContext { log_msg += "' to prewarmed_cache."; LogDebug(log_msg.c_str()); - // Create a deep copy of the active_conversation. - // The ConversationData copy constructor handles the deep copy of KVCache. + // Create a deep copy of the active_conversation via copy ctor. auto conversation_copy = std::make_shared(*active_conversation); @@ -176,8 +168,7 @@ class GemmaContext { active_conversation->abs_pos = it->second->abs_pos; // Perform a deep copy of the KVCache from the prewarmed version. active_conversation->kv_cache = - std::make_unique(it->second->kv_cache->Copy( - model.GetModelConfig(), inference_args.prefill_tbatch_size)); + std::make_unique(it->second->kv_cache->Copy()); LogDebug((log_prefix + "Successfully restored from prewarmed_cache.") .c_str()); return; @@ -187,8 +178,8 @@ class GemmaContext { // rewind to initial state. active_conversation->abs_pos = 0; // Replace the cache within the current ConversationData object - active_conversation->kv_cache = std::make_unique( - model.GetModelConfig(), inference_args.prefill_tbatch_size); + active_conversation->kv_cache = + std::make_unique(model.GetModelConfig(), inference_args); LogDebug((log_prefix + "Successfully rewound to initial state.").c_str()); } else { @@ -206,7 +197,7 @@ class GemmaContext { LogDebug("Creating new conversation"); // Create a new ConversationData object using make_shared conversation_cache[name] = std::make_shared( - model.GetModelConfig(), inference_args.prefill_tbatch_size); + model.GetModelConfig(), inference_args); return true; } diff --git a/gemma/configs.cc b/gemma/configs.cc index 7e06e8c..0ab67b2 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -27,12 +27,8 @@ namespace gcpp { -// Allow changing pre-allocated kv cache size as a compiler flag -#ifndef GEMMA_MAX_SEQLEN -#define GEMMA_MAX_SEQLEN 4096 -#endif // !GEMMA_MAX_SEQLEN - static constexpr size_t kVocabSize = 256000; +static constexpr size_t kMaxSeqLen = 4096; static ModelConfig ConfigNoSSM() { ModelConfig config; @@ -69,7 +65,7 @@ static ModelConfig ConfigGemma2_27B() { config.model = Model::GEMMA2_27B; config.model_dim = 4608; config.vocab_size = kVocabSize; - config.seq_len = 8192; + config.max_seq_len = 8192; LayerConfig layer_config = LayerConfigGemma2_27B(config.model_dim); config.num_layers = 46; config.layer_configs = {config.num_layers, layer_config}; @@ -97,7 +93,7 @@ static ModelConfig ConfigGemma2_9B() { config.model = Model::GEMMA2_9B; config.model_dim = 3584; config.vocab_size = kVocabSize; - config.seq_len = 8192; + config.max_seq_len = 8192; LayerConfig layer_config = LayerConfigGemma2_9B(config.model_dim); config.num_layers = 42; config.layer_configs = {config.num_layers, layer_config}; @@ -125,7 +121,7 @@ static ModelConfig ConfigGemma2_2B() { config.model = Model::GEMMA2_2B; config.model_dim = 2304; config.vocab_size = kVocabSize; - config.seq_len = 8192; + config.max_seq_len = 8192; LayerConfig layer_config = LayerConfigGemma2_2B(config.model_dim); config.num_layers = 26; config.layer_configs = {config.num_layers, layer_config}; @@ -152,7 +148,7 @@ static ModelConfig ConfigGemmaTiny() { config.wrapping = PromptWrapping::GEMMA_IT; config.model_dim = 32; config.vocab_size = 32; // at least two f32 vectors - config.seq_len = 32; + config.max_seq_len = 32; LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim); config.num_layers = 2; config.layer_configs = {config.num_layers, layer_config}; @@ -188,11 +184,11 @@ static ModelConfig ConfigGriffin2B() { ModelConfig config = ConfigNoSSM(); config.display_name = "Griffin2B"; config.model = Model::GRIFFIN_2B; - // Griffin uses local attention, so GEMMA_MAX_SEQLEN is actually the local + // Griffin uses local attention, so max_seq_len is actually the local // attention window. config.model_dim = 2560; config.vocab_size = kVocabSize; - config.seq_len = 2048; + config.max_seq_len = 2048; LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim); config.num_layers = 26; config.layer_configs = {config.num_layers, layer_config}; @@ -200,7 +196,8 @@ static ModelConfig ConfigGriffin2B() { config.layer_configs[i].type = LayerAttentionType::kGemma; config.layer_configs[i].griffin_dim = 0; } - config.attention_window_sizes = FixedAttentionWindowSizes<26>(config.seq_len); + config.attention_window_sizes = + FixedAttentionWindowSizes<26>(config.max_seq_len); config.use_local_attention = true; config.final_cap = 0.0f; return config; @@ -238,7 +235,7 @@ static void AddVitConfig(ModelConfig& config, size_t image_size = 224) { ModelConfig GetVitConfig(const ModelConfig& config) { ModelConfig vit_config = ConfigNoSSM(); vit_config.model_dim = config.vit_config.model_dim; - vit_config.seq_len = config.vit_config.seq_len; + vit_config.max_seq_len = config.vit_config.seq_len; vit_config.layer_configs = config.vit_config.layer_configs; vit_config.pool_dim = config.vit_config.pool_dim; vit_config.wrapping = config.wrapping; @@ -313,14 +310,14 @@ static ModelConfig ConfigGemma3_1B() { config.wrapping = PromptWrapping::GEMMA_VLM; config.model_dim = 1152; config.vocab_size = 262144; // new vocab size / tokenizer - config.seq_len = 32 * 1024; + config.max_seq_len = 32 * 1024; LayerConfig layer_config = LayerConfigGemma3_1B_LM(config.model_dim); config.num_layers = 26; config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; // interleaved local / global attention config.attention_window_sizes = RepeatedAttentionWindowSizes<26, 6>( - {512, 512, 512, 512, 512, config.seq_len}); + {512, 512, 512, 512, 512, config.max_seq_len}); return config; } @@ -345,14 +342,14 @@ static ModelConfig ConfigGemma3_4B_LM() { config.wrapping = PromptWrapping::GEMMA_VLM; config.model_dim = 2560; config.vocab_size = 262144; // new vocab size / tokenizer - config.seq_len = 32 * 1024; + config.max_seq_len = 32 * 1024; LayerConfig layer_config = LayerConfigGemma3_4B_LM(config.model_dim); config.num_layers = 34; config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; // interleaved local / global attention config.attention_window_sizes = RepeatedAttentionWindowSizes<34, 6>( - {1024, 1024, 1024, 1024, 1024, config.seq_len}); + {1024, 1024, 1024, 1024, 1024, config.max_seq_len}); return config; } @@ -394,14 +391,14 @@ static ModelConfig ConfigGemma3_12B_LM() { config.wrapping = PromptWrapping::GEMMA_VLM; config.model_dim = 3840; config.vocab_size = 262144; // new vocab size / tokenizer - config.seq_len = 32 * 1024; + config.max_seq_len = 32 * 1024; LayerConfig layer_config = LayerConfigGemma3_12B_LM(config.model_dim); config.num_layers = 48; config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; // interleaved local / global attention config.attention_window_sizes = RepeatedAttentionWindowSizes<48, 6>( - {1024, 1024, 1024, 1024, 1024, config.seq_len}); + {1024, 1024, 1024, 1024, 1024, config.max_seq_len}); return config; } @@ -443,14 +440,14 @@ static ModelConfig ConfigGemma3_27B_LM() { config.wrapping = PromptWrapping::GEMMA_VLM; config.model_dim = 5376; config.vocab_size = 262144; // new vocab size / tokenizer - config.seq_len = 32 * 1024; + config.max_seq_len = 32 * 1024; LayerConfig layer_config = LayerConfigGemma3_27B_LM(config.model_dim); config.num_layers = 62; config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; // interleaved local / global attention config.attention_window_sizes = RepeatedAttentionWindowSizes<62, 6>( - {1024, 1024, 1024, 1024, 1024, config.seq_len}); + {1024, 1024, 1024, 1024, 1024, config.max_seq_len}); return config; } diff --git a/gemma/configs.h b/gemma/configs.h index e4521d5..00b39cb 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -347,7 +347,7 @@ struct ModelConfig : public IFields { visitor(num_layers); visitor(model_dim); visitor(vocab_size); - visitor(seq_len); + visitor(max_seq_len); visitor(unused_num_tensor_scales); @@ -413,7 +413,7 @@ struct ModelConfig : public IFields { return num_heads; } - size_t CachePosSize() const { + size_t KVCacheCols() const { size_t num_layers = layer_configs.size(); return num_layers * layer_configs[0].CacheLayerSize(); } @@ -435,7 +435,7 @@ struct ModelConfig : public IFields { uint32_t num_layers = 0; uint32_t model_dim = 0; uint32_t vocab_size = 0; - uint32_t seq_len = 0; + uint32_t max_seq_len = 0; // We no longer set nor use this: config_converter is not able to set this, // and only pre-2025 format stores scales, and we do not require advance diff --git a/gemma/gemma.cc b/gemma/gemma.cc index b72e31f..2bcf1fa 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -64,13 +64,12 @@ namespace HWY_NAMESPACE { void Attention(LayerAttentionType type, size_t num_tokens, const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, - const hwy::Divisor& div_seq_len, const size_t layer_idx, + const QueriesPos& queries_prefix_end, const size_t layer_idx, const LayerWeightsPtrs& layer, Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) { if (type == LayerAttentionType::kGemma) { - GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, div_seq_len, - layer_idx, layer, activations, kv_caches, env, + GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, layer_idx, + layer, activations, kv_caches, env, /*flags=*/0); } else { HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock); @@ -85,16 +84,16 @@ void Attention(LayerAttentionType type, size_t num_tokens, static HWY_NOINLINE void TransformerLayer( const size_t num_tokens, const QueriesPos& queries_pos, - const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len, - const size_t layer_idx, const LayerWeightsPtrs& layer, - Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) { + const QueriesPos& queries_prefix_end, const size_t layer_idx, + const LayerWeightsPtrs& layer, Activations& activations, + const KVCaches& kv_caches, MatMulEnv& env) { const LayerConfig& layer_config = layer.layer_config; RMSNormBatched(activations.x, layer.pre_attention_norm_scale, activations.pre_att_rms_out); Attention(layer_config.type, num_tokens, queries_pos, queries_prefix_end, - div_seq_len, layer_idx, layer, activations, kv_caches, env); + layer_idx, layer, activations, kv_caches, env); PostNorm(layer_config.post_norm, layer.post_attention_norm_scale, activations.att_sums); @@ -190,10 +189,9 @@ using QueriesMutablePos = hwy::Span; static HWY_NOINLINE void PrefillTBatch( const size_t query_idx_start, const QueriesPromptTokens& queries_prompt, const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end, - const hwy::Divisor& div_seq_len, const ModelConfig& config, - const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights, - Activations& activations, const KVCaches& kv_caches, MatMulEnv& env, - hwy::BitSet4096<>& non_eos) { + const ModelConfig& config, const RuntimeConfig& runtime_config, + const ModelWeightsPtrs& weights, Activations& activations, + const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos) { PROFILER_ZONE("Gen.PrefillT"); const size_t num_queries = queries_prompt.size(); HWY_DASSERT(num_queries == queries_pos.size()); @@ -265,8 +263,8 @@ static HWY_NOINLINE void PrefillTBatch( for (size_t layer_idx = 0; layer_idx < config.layer_configs.size(); ++layer_idx) { TransformerLayer(tbatch_size, single_query_pos, single_query_prefix_end, - div_seq_len, layer_idx, *weights.GetLayer(layer_idx), - activations, single_kv_cache, env); + layer_idx, *weights.GetLayer(layer_idx), activations, + single_kv_cache, env); } // NOTE: we unconditionally call StreamToken, even if EOS. @@ -303,10 +301,9 @@ static HWY_NOINLINE void PrefillTBatch( // token-batched `PrefillTBatch`. static HWY_NOINLINE void Transformer( const QueriesToken& queries_token, const QueriesMutablePos& queries_pos, - const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len, - const ModelConfig& config, const RuntimeConfig& runtime_config, - const ModelWeightsPtrs& weights, Activations& activations, - const KVCaches& kv_caches, MatMulEnv& env) { + const QueriesPos& queries_prefix_end, const ModelConfig& config, + const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights, + Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) { const size_t num_queries = queries_token.size(); HWY_DASSERT(num_queries == queries_pos.size()); HWY_DASSERT(num_queries == queries_prefix_end.size()); @@ -326,8 +323,8 @@ static HWY_NOINLINE void Transformer( for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) { TransformerLayer(/*num_tokens=*/1, queries_pos, queries_prefix_end, - div_seq_len, layer_idx, *weights.GetLayer(layer_idx), - activations, kv_caches, env); + layer_idx, *weights.GetLayer(layer_idx), activations, + kv_caches, env); if (HWY_UNLIKELY(runtime_config.activations_observer)) { runtime_config.activations_observer(queries_pos, layer_idx, activations); @@ -340,10 +337,10 @@ static HWY_NOINLINE void Transformer( static HWY_NOINLINE void PrefillQBatch( const size_t query_idx_start, const QueriesPromptTokens& queries_prompt, const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end, - const size_t max_prompt_size, const hwy::Divisor& div_seq_len, - const ModelConfig& config, const RuntimeConfig& runtime_config, - const ModelWeightsPtrs& weights, Activations& activations, - const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos) { + const size_t max_prompt_size, const ModelConfig& config, + const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights, + Activations& activations, const KVCaches& kv_caches, MatMulEnv& env, + hwy::BitSet4096<>& non_eos) { PROFILER_ZONE("Gen.Prefill"); const size_t num_queries = queries_prompt.size(); HWY_DASSERT(num_queries == queries_pos.size()); @@ -380,8 +377,8 @@ static HWY_NOINLINE void PrefillQBatch( // Do not call DecodeStepT because it computes logits for token // probabilities, which are not required for the prompt tokens. Transformer(QueriesToken(activations.gen_tokens.data(), num_queries), - queries_pos, queries_prefix_end, div_seq_len, config, - runtime_config, weights, activations, kv_caches, env); + queries_pos, queries_prefix_end, config, runtime_config, + weights, activations, kv_caches, env); prefill_active.Foreach([&](size_t qi) { const int token = queries_prompt[qi][pos_in_prompt]; @@ -393,19 +390,6 @@ static HWY_NOINLINE void PrefillQBatch( } // pos_in_prompt } -// TODO: inline. -void RangeChecks(const ModelConfig& weights_config, - size_t& max_generated_tokens, const size_t prompt_size) { - if (!weights_config.use_local_attention) { - if (max_generated_tokens > weights_config.seq_len) { - HWY_WARN("max_generated_tokens %zu > kSeqLen %u, truncating.", - max_generated_tokens, weights_config.seq_len); - max_generated_tokens = weights_config.seq_len; - } - } - HWY_ASSERT(prompt_size > 0); -} - // Also writes the token to activations.gen_tokens for subsequent DecodeStepT, // and updates `non_eos` if the query is at the end of its sequence. static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token, @@ -432,17 +416,17 @@ static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token, static void DecodeStepT( const size_t query_idx_start, const QueriesPromptTokens& queries_prompt, const QueriesMutablePos& queries_mutable_pos, - const QueriesPos& queries_prefix_end, const hwy::Divisor div_seq_len, - const ModelConfig& config, const RuntimeConfig& runtime_config, - const ModelWeightsPtrs& weights, const SampleFunc& sample_token, - Activations& activations, const KVCaches& kv_caches, MatMulEnv& env, - hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) { + const QueriesPos& queries_prefix_end, const ModelConfig& config, + const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights, + const SampleFunc& sample_token, Activations& activations, + const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos, + TimingInfo& timing_info) { const size_t num_queries = queries_prompt.size(); HWY_DASSERT(num_queries == activations.x.Rows()); Transformer(QueriesToken(activations.gen_tokens.data(), num_queries), - queries_mutable_pos, queries_prefix_end, div_seq_len, config, - runtime_config, weights, activations, kv_caches, env); + queries_mutable_pos, queries_prefix_end, config, runtime_config, + weights, activations, kv_caches, env); RMSNormInplaceBatched(weights.final_norm_scale, activations.x); @@ -530,6 +514,7 @@ static void GenerateT( size_t max_prompt_size = 0; bool all_prefix_end_are_zero = true; size_t prefill_tokens = 0; + const size_t seq_len = kv_caches[0].SeqLen(); for (size_t qi = 0; qi < num_queries; ++qi) { const PromptTokens& prompt = queries_prompt[qi]; max_prompt_size = HWY_MAX(max_prompt_size, prompt.size()); @@ -542,9 +527,12 @@ static void GenerateT( HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id); all_prefix_end_are_zero &= queries_prefix_end[qi] == 0; - } - const hwy::Divisor div_seq_len(static_cast(kv_caches[0].seq_len)); + // We use a single divisor, so all sequence lengths must be the same. + HWY_ASSERT(kv_caches[qi].SeqLen() == seq_len); + } + HWY_ASSERT(prefill_tokens < seq_len); + activations.div_seq_len = hwy::Divisor(static_cast(seq_len)); // Lacks a constructor to bulk-set, hence initialized by Prefill* which have // qi loops anyway. @@ -555,13 +543,12 @@ static void GenerateT( if ((num_queries > max_prompt_size) && all_prefix_end_are_zero) { activations.SetBatchSize(num_queries); // required before PrefillQBatch PrefillQBatch(query_idx_start, queries_prompt, queries_mutable_pos, - queries_prefix_end, max_prompt_size, div_seq_len, config, - runtime_config, weights, activations, kv_caches, env, - non_eos); + queries_prefix_end, max_prompt_size, config, runtime_config, + weights, activations, kv_caches, env, non_eos); } else { PrefillTBatch(query_idx_start, queries_prompt, queries_mutable_pos, - queries_prefix_end, div_seq_len, config, runtime_config, - weights, activations, kv_caches, env, non_eos); + queries_prefix_end, config, runtime_config, weights, + activations, kv_caches, env, non_eos); activations.SetBatchSize(num_queries); // Restore after PrefillTBatch. } HWY_DASSERT(num_queries == non_eos.Count()); @@ -579,7 +566,11 @@ static void GenerateT( } size_t max_gen_steps = runtime_config.max_generated_tokens; - RangeChecks(config, max_gen_steps, max_prompt_size); + if (prefill_tokens + max_gen_steps > seq_len) { + HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.", + prefill_tokens, max_gen_steps, seq_len); + max_gen_steps = seq_len - prefill_tokens; + } const SampleFunc sample_token = ChooseSampleFunc(runtime_config); @@ -587,8 +578,8 @@ static void GenerateT( timing_info.generate_start = hwy::platform::Now(); for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) { DecodeStepT(query_idx_start, queries_prompt, queries_mutable_pos, - queries_prefix_end, div_seq_len, config, runtime_config, - weights, sample_token, activations, kv_caches, env, non_eos, + queries_prefix_end, config, runtime_config, weights, + sample_token, activations, kv_caches, env, non_eos, timing_info); } timing_info.NotifyGenerateDone(); @@ -661,10 +652,11 @@ void GenerateImageTokensT(const ModelConfig& config, HWY_ABORT("Model does not support generating image tokens."); } RuntimeConfig prefill_runtime_config = runtime_config; - ModelConfig vit_config = GetVitConfig(config); + const ModelConfig vit_config = GetVitConfig(config); + const size_t num_tokens = vit_config.max_seq_len; prefill_runtime_config.prefill_tbatch_size = - vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim); - Activations prefill_activations(vit_config, vit_config.seq_len, env.row_ptrs); + num_tokens / (vit_config.pool_dim * vit_config.pool_dim); + Activations prefill_activations(vit_config, num_tokens, env.row_ptrs); // Weights are for the full PaliGemma model, not just the ViT part. PrefillVit(config, weights, prefill_runtime_config, image, image_tokens, prefill_activations, env); @@ -692,7 +684,8 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, reader_(loader.weights), model_(reader_, loader.tokenizer, loader.wrapping), weights_(model_.Config()), - chat_template_(model_.Tokenizer(), model_.Config().model) { + chat_template_(model_.Tokenizer(), model_.Config().model), + inference_(inference) { weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_, env.ctx.pools.Pool()); reader_.CloseFile(); diff --git a/gemma/gemma.h b/gemma/gemma.h index 21e9619..b866c41 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -117,6 +117,7 @@ class Gemma { const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); } const ModelWeightsPtrs& Weights() const { return weights_; } const GemmaChatTemplate& ChatTemplate() const { return chat_template_; } + const InferenceArgs& Inference() const { return inference_; } void Save(const Path& weights_path, hwy::ThreadPool& pool) const; @@ -159,6 +160,7 @@ class Gemma { std::vector mat_owners_; ModelWeightsPtrs weights_; GemmaChatTemplate chat_template_; + InferenceArgs inference_; }; } // namespace gcpp diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index b842d44..e4c8a33 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -35,11 +35,6 @@ namespace gcpp { -// Allow changing k parameter of `SampleTopK` as a compiler flag -#ifndef GEMMA_TOPK -#define GEMMA_TOPK 1 -#endif // !GEMMA_TOPK - struct LoaderArgs : public ArgsBase { LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } LoaderArgs(const std::string& tokenizer_path, @@ -115,6 +110,7 @@ using ActivationsObserverFunc = std::function; // RuntimeConfig holds configuration for a single generation run. +// TODO: move into InferenceArgs, use that directly. struct RuntimeConfig { // If not empty, batch_stream_token is called for each token in the batch, // instead of stream_token. @@ -137,7 +133,7 @@ struct RuntimeConfig { // Sampling-related parameters. float temperature; // Temperature for sampling. - size_t top_k = GEMMA_TOPK; // Top-k for sampling. + size_t top_k = 1; // Top-k for sampling. std::mt19937* gen; // Random number generator used for sampling. int verbosity; // Controls verbosity of printed messages. @@ -170,6 +166,7 @@ struct InferenceArgs : public ArgsBase { int verbosity; + size_t seq_len; size_t max_generated_tokens; size_t prefill_tbatch_size; @@ -192,6 +189,8 @@ struct InferenceArgs : public ArgsBase { "developer/debug info).\n Default = 1.", 1); // Changed verbosity level to 1 since it's user-facing + visitor(seq_len, "seq_len", size_t{2048}, + "Sequence length, capped by ModelConfig.max_seq_len."); visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, "Maximum number of tokens to generate."); diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index 49dc31c..a2a577f 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -15,21 +15,25 @@ #include "gemma/kv_cache.h" -#include // std::copy +#include #include "gemma/configs.h" +#include "gemma/gemma_args.h" #include "util/mat.h" // ZeroInit -#include "hwy/aligned_allocator.h" -#include "hwy/base.h" // ZeroBytes +#include "hwy/base.h" // HWY_MAX namespace gcpp { void KVCache::ZeroGriffinCache() { - if (griffin_layers == 0) return; + if (conv1d_cache.Rows() == 0) return; ZeroInit(conv1d_cache); ZeroInit(rglru_cache); } +static size_t GriffinLayers(const ModelConfig& config) { + return config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock); +} + static size_t GriffinConv1dCols(const ModelConfig& config) { size_t conv1d_width = 0; for (const auto& layer_config : config.layer_configs) { @@ -40,43 +44,41 @@ static size_t GriffinConv1dCols(const ModelConfig& config) { return conv1d_width * config.model_dim; } -// prefill_tbatch_size is the maximum number of tokens from one query to -// prefill at a time. -KVCache::KVCache(const ModelConfig& config, size_t prefill_tbatch_size) - : griffin_layers( - config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock)), - conv1d_cache("conv1d_cache", - Extents2D(griffin_layers, GriffinConv1dCols(config)), - MatPadding::kOdd), - rglru_cache("rglru_cache", Extents2D(griffin_layers, config.model_dim), - MatPadding::kOdd) { - // TODO: move to MatStorageT. - const size_t size_cache_pos = config.CachePosSize(); - if (size_cache_pos != 0) { - // Allocate more so that prefill can always access one batch, even if - // near the end of the sequence. - seq_len = config.seq_len + prefill_tbatch_size; - kv_cache = hwy::AllocateAligned(seq_len * size_cache_pos); +// Number of rows for KV cache. Note that both rows and cols are u32, and +// the total number of elements can exceed 2^32. +static size_t CappedSeqLen(const ModelConfig& config, + const InferenceArgs& inference_args) { + if (inference_args.seq_len > config.max_seq_len) { + HWY_WARN("Capping seq_len %zu to config.max_seq_len %u.", + inference_args.seq_len, config.max_seq_len); + return config.max_seq_len; } + return inference_args.seq_len; } -KVCache KVCache::Copy(const ModelConfig& weights_config, - size_t prefill_tbatch_size) { - KVCache copy(weights_config, prefill_tbatch_size); +KVCache::KVCache(const Extents2D& conv1d_extents, + const Extents2D& rglru_extents, const Extents2D& kv_extents) + : conv1d_cache("conv1d_cache", conv1d_extents, MatPadding::kOdd), + rglru_cache("rglru_cache", rglru_extents, MatPadding::kOdd), + kv_cache("kv", kv_extents, MatPadding::kOdd) {} - const size_t size_cache_pos = weights_config.CachePosSize(); - if (size_cache_pos != 0) { - std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len, - copy.kv_cache.get()); - } +KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args) + : KVCache(Extents2D(GriffinLayers(config), GriffinConv1dCols(config)), + Extents2D(GriffinLayers(config), config.model_dim), + Extents2D(CappedSeqLen(config, inference_args), + config.KVCacheCols())) {} - if (conv1d_cache.HasPtr()) { +KVCache KVCache::Copy() { + KVCache copy(conv1d_cache.Extents(), rglru_cache.Extents(), + kv_cache.Extents()); + + if (conv1d_cache.Rows() != 0) { CopyMat(conv1d_cache, copy.conv1d_cache); - } - if (rglru_cache.HasPtr()) { CopyMat(rglru_cache, copy.rglru_cache); } + CopyMat(kv_cache, copy.kv_cache); + return copy; } diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 014e75d..8c8d762 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -19,29 +19,34 @@ #include #include "gemma/configs.h" // ModelConfig +#include "gemma/gemma_args.h" #include "util/mat.h" -#include "hwy/aligned_allocator.h" namespace gcpp { struct KVCache { - KVCache(const ModelConfig& weights_config, size_t prefill_tbatch_size); + KVCache(const ModelConfig& config, const InferenceArgs& inference_args); - // Returns a deep copy of the KVCache. - KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size); + // Returns a deep copy of the KVCache. Use explicit function instead of + // copy ctor to make the cost explicit. + KVCache Copy(); - size_t griffin_layers = 0; - // griffin_layers, griffin_conv1d_cols * config.model_dim - MatStorageT conv1d_cache; - MatStorageT rglru_cache; // griffin_layers, config.model_dim // Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache // and rglru_cache. void ZeroGriffinCache(); - size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size + size_t SeqLen() const { return kv_cache.Rows(); } - // seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2 - hwy::AlignedFreeUniquePtr kv_cache; + // [griffin_layers, griffin_conv1d_cols * model_dim] + MatStorageT conv1d_cache; + MatStorageT rglru_cache; // [griffin_layers, model_dim] + + MatStorageT kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2] + + private: + // For use by other ctor and Copy() + KVCache(const Extents2D& conv1d_extents, const Extents2D& rglru_extents, + const Extents2D& kv_extents); }; } // namespace gcpp diff --git a/gemma/run.cc b/gemma/run.cc index 49432dd..dbdeb61 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -256,7 +256,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, MatMulEnv env(MakeMatMulEnv(threading)); if (inference.verbosity >= 2) env.print_best = true; const Gemma gemma(loader, inference, env); - KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size); + KVCache kv_cache(gemma.GetModelConfig(), inference); if (inference.verbosity >= 1) { std::string instructions = diff --git a/gemma/vit.cc b/gemma/vit.cc index 14d9fd9..fe17c3f 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -68,7 +68,8 @@ class VitAttention { const size_t qkv_dim = layer_config_.qkv_dim; const size_t heads = layer_config_.heads; HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); - const size_t seq_len = activations_.seq_len; + const size_t seq_len = + static_cast(activations_.div_seq_len.GetDivisor()); const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); @@ -124,7 +125,8 @@ class VitAttention { const size_t qkv_dim = layer_config_.qkv_dim; const size_t heads = layer_config_.heads; HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); - const size_t seq_len = activations_.seq_len; + const size_t seq_len = + static_cast(activations_.div_seq_len.GetDivisor()); const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); @@ -138,7 +140,7 @@ class VitAttention { activations_.q.Row(token) + head * 3 * qkv_dim; MulByConst(query_scale, q, qkv_dim); float* HWY_RESTRICT head_att = - activations_.att.Row(token) + head * activations_.seq_len; + activations_.att.Row(token) + head * seq_len; for (size_t i = 0; i < seq_len; ++i) { float* HWY_RESTRICT k = activations_.q.Row(i) + head * 3 * qkv_dim + qkv_dim; @@ -275,7 +277,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image, MatMulEnv& env) { const size_t model_dim = model_config.vit_config.model_dim; const size_t patch_width = model_config.vit_config.patch_width; - const size_t seq_len = model_config.vit_config.seq_len; + const size_t num_tokens = model_config.vit_config.seq_len; const size_t patch_size = patch_width * patch_width * 3; HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim); HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size); @@ -285,9 +287,9 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image, // H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3) // image_patches is (256, 14 * 14 * 3) // Must be padded, see `DoDecompressA`. - MatStorageT image_patches("patches", Extents2D(seq_len, patch_size), + MatStorageT image_patches("patches", Extents2D(num_tokens, patch_size), MatPadding::kOdd); - for (size_t i = 0; i < seq_len; ++i) { + for (size_t i = 0; i < num_tokens; ++i) { image.GetPatch(i, image_patches.Row(i)); } CallMatMul(image_patches, weights.vit_img_embedding_kernel, diff --git a/python/configs.cc b/python/configs.cc index 2fa6252..b9a4bf6 100644 --- a/python/configs.cc +++ b/python/configs.cc @@ -161,7 +161,7 @@ PYBIND11_MODULE(configs, py_module) { .def_readwrite("num_layers", &ModelConfig::num_layers) .def_readwrite("model_dim", &ModelConfig::model_dim) .def_readwrite("vocab_size", &ModelConfig::vocab_size) - .def_readwrite("seq_len", &ModelConfig::seq_len) + .def_readwrite("max_seq_len", &ModelConfig::max_seq_len) // Skip `unused_num_tensor_scales`. .def_readwrite("att_cap", &ModelConfig::att_cap) .def_readwrite("final_cap", &ModelConfig::final_cap)