Allow overriding hardcoded max_seq_len by cmdline argument seq_len.

Adds a SetMaxSeqLen method to ModelConfig to handle updating both max_seq_len and global attention window sizes. The Gemma constructor now checks if the provided inference seq_len exceeds the model's max_seq_len and, if so, emits a warning and updates the config.

This prevents clipping context to the hard-coded maximum.

PiperOrigin-RevId: 853676074
This commit is contained in:
Balazs Racz 2026-01-08 04:28:32 -08:00 committed by Copybara-Service
parent aeade052c6
commit 384c390181
4 changed files with 34 additions and 6 deletions

View File

@ -428,6 +428,18 @@ struct ModelConfig : public IFields {
// The third ctor also expects a string returned by this.
std::string Specifier() const;
// Overwrites `max_seq_len` with `new_max_seq_len` and updates all global
// layers' attention window sizes to `new_max_seq_len`. This function must be
// called before instantiating the KVCache object.
void SetMaxSeqLen(size_t new_max_seq_len) {
for (size_t i = 0; i < attention_window_sizes.size(); ++i) {
if (attention_window_sizes[i] == max_seq_len) {
attention_window_sizes[i] = new_max_seq_len;
}
}
max_seq_len = new_max_seq_len;
}
void AddLayerConfig(const LayerConfig& layer_config) {
layer_configs.push_back(layer_config);
HWY_ASSERT(layer_configs.size() <= num_layers);
@ -516,7 +528,7 @@ ModelConfig GetVitConfig(const ModelConfig& config);
enum DeducedLayerTypes {
kDeducedViT = 2,
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
kDeducedKqNorm = 8,
};

View File

@ -18,6 +18,9 @@
#include "gemma/gemma.h"
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <optional>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
@ -556,10 +559,10 @@ static size_t PrefillTBatchOrQBatch(const ModelConfig& config,
}
static void StreamAndUpdateEOSAfterPrefill(const ModelConfig& config,
const RuntimeConfig& runtime_config,
QBatch& qbatch,
hwy::BitSet4096<>& non_eos,
size_t qi) {
const RuntimeConfig& runtime_config,
QBatch& qbatch,
hwy::BitSet4096<>& non_eos,
size_t qi) {
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct.
@ -745,6 +748,12 @@ Gemma::Gemma(const GemmaArgs& args, ThreadingContext& ctx)
chat_template_(model_.Tokenizer(), model_.Config().model),
inference_(args.inference),
aes_ctr_engine_(args.inference.deterministic) {
if (args.inference.seq_len > model_.Config().max_seq_len) {
HWY_WARN(
"Overriding model's max_seq_len=%u with user provided seq_len=%zu.",
model_.Config().max_seq_len, args.inference.seq_len);
model_.MutableConfig().SetMaxSeqLen(args.inference.seq_len);
}
// Negligible CPU time in the ctor body (except ReadFromBlobs).
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, args.loader,
args.inference, mat_owners_, ctx);

View File

@ -387,7 +387,7 @@ void ModelStore::CreateMatPtrs(BlobReader& reader) {
}
ModelStore::ModelStore(BlobReader& reader, const Path& tokenizer_path,
Tristate wrapping)
Tristate wrapping)
: config_(ReadOrDeduceConfig(reader, wrapping)),
tokenizer_(ReadTokenizer(reader, tokenizer_path)) {
if (!ReadMatPtrs(reader)) { // Pre-2025 format.

View File

@ -39,6 +39,8 @@
namespace gcpp {
class Gemma;
// Reads and holds the model config, tokenizer and all `MatPtr`: everything
// except the tensor data, which are read/written by `weights.cc`.
//
@ -60,6 +62,11 @@ class ModelStore {
return config_;
}
ModelConfig& MutableConfig() {
HWY_ASSERT(config_.model != Model::UNKNOWN);
return config_;
}
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
// Returns nullptr if `name` is not available for loading, otherwise the