From 384c3901818ccbd0877a62d7b1bc4922a70aa668 Mon Sep 17 00:00:00 2001 From: Balazs Racz Date: Thu, 8 Jan 2026 04:28:32 -0800 Subject: [PATCH] Allow overriding hardcoded max_seq_len by cmdline argument seq_len. Adds a SetMaxSeqLen method to ModelConfig to handle updating both max_seq_len and global attention window sizes. The Gemma constructor now checks if the provided inference seq_len exceeds the model's max_seq_len and, if so, emits a warning and updates the config. This prevents clipping context to the hard-coded maximum. PiperOrigin-RevId: 853676074 --- gemma/configs.h | 14 +++++++++++++- gemma/gemma.cc | 17 +++++++++++++---- gemma/model_store.cc | 2 +- gemma/model_store.h | 7 +++++++ 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/gemma/configs.h b/gemma/configs.h index b727480..f1bd0c5 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -428,6 +428,18 @@ struct ModelConfig : public IFields { // The third ctor also expects a string returned by this. std::string Specifier() const; + // Overwrites `max_seq_len` with `new_max_seq_len` and updates all global + // layers' attention window sizes to `new_max_seq_len`. This function must be + // called before instantiating the KVCache object. + void SetMaxSeqLen(size_t new_max_seq_len) { + for (size_t i = 0; i < attention_window_sizes.size(); ++i) { + if (attention_window_sizes[i] == max_seq_len) { + attention_window_sizes[i] = new_max_seq_len; + } + } + max_seq_len = new_max_seq_len; + } + void AddLayerConfig(const LayerConfig& layer_config) { layer_configs.push_back(layer_config); HWY_ASSERT(layer_configs.size() <= num_layers); @@ -516,7 +528,7 @@ ModelConfig GetVitConfig(const ModelConfig& config); enum DeducedLayerTypes { kDeducedViT = 2, - kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224. + kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224. kDeducedKqNorm = 8, }; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 0ce6ab3..5a48d00 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -18,6 +18,9 @@ #include "gemma/gemma.h" +#include +#include +#include #include #include "compression/types.h" // GEMMA_DISABLED_TARGETS @@ -556,10 +559,10 @@ static size_t PrefillTBatchOrQBatch(const ModelConfig& config, } static void StreamAndUpdateEOSAfterPrefill(const ModelConfig& config, - const RuntimeConfig& runtime_config, - QBatch& qbatch, - hwy::BitSet4096<>& non_eos, - size_t qi) { + const RuntimeConfig& runtime_config, + QBatch& qbatch, + hwy::BitSet4096<>& non_eos, + size_t qi) { const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct. @@ -745,6 +748,12 @@ Gemma::Gemma(const GemmaArgs& args, ThreadingContext& ctx) chat_template_(model_.Tokenizer(), model_.Config().model), inference_(args.inference), aes_ctr_engine_(args.inference.deterministic) { + if (args.inference.seq_len > model_.Config().max_seq_len) { + HWY_WARN( + "Overriding model's max_seq_len=%u with user provided seq_len=%zu.", + model_.Config().max_seq_len, args.inference.seq_len); + model_.MutableConfig().SetMaxSeqLen(args.inference.seq_len); + } // Negligible CPU time in the ctor body (except ReadFromBlobs). weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, args.loader, args.inference, mat_owners_, ctx); diff --git a/gemma/model_store.cc b/gemma/model_store.cc index 204dee9..76f0c75 100644 --- a/gemma/model_store.cc +++ b/gemma/model_store.cc @@ -387,7 +387,7 @@ void ModelStore::CreateMatPtrs(BlobReader& reader) { } ModelStore::ModelStore(BlobReader& reader, const Path& tokenizer_path, - Tristate wrapping) + Tristate wrapping) : config_(ReadOrDeduceConfig(reader, wrapping)), tokenizer_(ReadTokenizer(reader, tokenizer_path)) { if (!ReadMatPtrs(reader)) { // Pre-2025 format. diff --git a/gemma/model_store.h b/gemma/model_store.h index b4d63ad..506fb77 100644 --- a/gemma/model_store.h +++ b/gemma/model_store.h @@ -39,6 +39,8 @@ namespace gcpp { +class Gemma; + // Reads and holds the model config, tokenizer and all `MatPtr`: everything // except the tensor data, which are read/written by `weights.cc`. // @@ -60,6 +62,11 @@ class ModelStore { return config_; } + ModelConfig& MutableConfig() { + HWY_ASSERT(config_.model != Model::UNKNOWN); + return config_; + } + const GemmaTokenizer& Tokenizer() const { return tokenizer_; } // Returns nullptr if `name` is not available for loading, otherwise the