mirror of https://github.com/google/gemma.cpp.git
Allow overriding hardcoded max_seq_len by cmdline argument seq_len.
Adds a SetMaxSeqLen method to ModelConfig to handle updating both max_seq_len and global attention window sizes. The Gemma constructor now checks if the provided inference seq_len exceeds the model's max_seq_len and, if so, emits a warning and updates the config. This prevents clipping context to the hard-coded maximum. PiperOrigin-RevId: 853676074
This commit is contained in:
parent
aeade052c6
commit
384c390181
|
|
@ -428,6 +428,18 @@ struct ModelConfig : public IFields {
|
|||
// The third ctor also expects a string returned by this.
|
||||
std::string Specifier() const;
|
||||
|
||||
// Overwrites `max_seq_len` with `new_max_seq_len` and updates all global
|
||||
// layers' attention window sizes to `new_max_seq_len`. This function must be
|
||||
// called before instantiating the KVCache object.
|
||||
void SetMaxSeqLen(size_t new_max_seq_len) {
|
||||
for (size_t i = 0; i < attention_window_sizes.size(); ++i) {
|
||||
if (attention_window_sizes[i] == max_seq_len) {
|
||||
attention_window_sizes[i] = new_max_seq_len;
|
||||
}
|
||||
}
|
||||
max_seq_len = new_max_seq_len;
|
||||
}
|
||||
|
||||
void AddLayerConfig(const LayerConfig& layer_config) {
|
||||
layer_configs.push_back(layer_config);
|
||||
HWY_ASSERT(layer_configs.size() <= num_layers);
|
||||
|
|
@ -516,7 +528,7 @@ ModelConfig GetVitConfig(const ModelConfig& config);
|
|||
|
||||
enum DeducedLayerTypes {
|
||||
kDeducedViT = 2,
|
||||
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
|
||||
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
|
||||
kDeducedKqNorm = 8,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@
|
|||
|
||||
#include "gemma/gemma.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
|
|
@ -556,10 +559,10 @@ static size_t PrefillTBatchOrQBatch(const ModelConfig& config,
|
|||
}
|
||||
|
||||
static void StreamAndUpdateEOSAfterPrefill(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
QBatch& qbatch,
|
||||
hwy::BitSet4096<>& non_eos,
|
||||
size_t qi) {
|
||||
const RuntimeConfig& runtime_config,
|
||||
QBatch& qbatch,
|
||||
hwy::BitSet4096<>& non_eos,
|
||||
size_t qi) {
|
||||
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
|
||||
|
||||
const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct.
|
||||
|
|
@ -745,6 +748,12 @@ Gemma::Gemma(const GemmaArgs& args, ThreadingContext& ctx)
|
|||
chat_template_(model_.Tokenizer(), model_.Config().model),
|
||||
inference_(args.inference),
|
||||
aes_ctr_engine_(args.inference.deterministic) {
|
||||
if (args.inference.seq_len > model_.Config().max_seq_len) {
|
||||
HWY_WARN(
|
||||
"Overriding model's max_seq_len=%u with user provided seq_len=%zu.",
|
||||
model_.Config().max_seq_len, args.inference.seq_len);
|
||||
model_.MutableConfig().SetMaxSeqLen(args.inference.seq_len);
|
||||
}
|
||||
// Negligible CPU time in the ctor body (except ReadFromBlobs).
|
||||
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, args.loader,
|
||||
args.inference, mat_owners_, ctx);
|
||||
|
|
|
|||
|
|
@ -387,7 +387,7 @@ void ModelStore::CreateMatPtrs(BlobReader& reader) {
|
|||
}
|
||||
|
||||
ModelStore::ModelStore(BlobReader& reader, const Path& tokenizer_path,
|
||||
Tristate wrapping)
|
||||
Tristate wrapping)
|
||||
: config_(ReadOrDeduceConfig(reader, wrapping)),
|
||||
tokenizer_(ReadTokenizer(reader, tokenizer_path)) {
|
||||
if (!ReadMatPtrs(reader)) { // Pre-2025 format.
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
class Gemma;
|
||||
|
||||
// Reads and holds the model config, tokenizer and all `MatPtr`: everything
|
||||
// except the tensor data, which are read/written by `weights.cc`.
|
||||
//
|
||||
|
|
@ -60,6 +62,11 @@ class ModelStore {
|
|||
return config_;
|
||||
}
|
||||
|
||||
ModelConfig& MutableConfig() {
|
||||
HWY_ASSERT(config_.model != Model::UNKNOWN);
|
||||
return config_;
|
||||
}
|
||||
|
||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||
|
||||
// Returns nullptr if `name` is not available for loading, otherwise the
|
||||
|
|
|
|||
Loading…
Reference in New Issue