mirror of https://github.com/google/gemma.cpp.git
Allow overriding hardcoded max_seq_len by cmdline argument seq_len.
Adds a SetMaxSeqLen method to ModelConfig to handle updating both max_seq_len and global attention window sizes. The Gemma constructor now checks if the provided inference seq_len exceeds the model's max_seq_len and, if so, emits a warning and updates the config. This prevents clipping context to the hard-coded maximum. PiperOrigin-RevId: 853676074
This commit is contained in:
parent
aeade052c6
commit
384c390181
|
|
@ -428,6 +428,18 @@ struct ModelConfig : public IFields {
|
||||||
// The third ctor also expects a string returned by this.
|
// The third ctor also expects a string returned by this.
|
||||||
std::string Specifier() const;
|
std::string Specifier() const;
|
||||||
|
|
||||||
|
// Overwrites `max_seq_len` with `new_max_seq_len` and updates all global
|
||||||
|
// layers' attention window sizes to `new_max_seq_len`. This function must be
|
||||||
|
// called before instantiating the KVCache object.
|
||||||
|
void SetMaxSeqLen(size_t new_max_seq_len) {
|
||||||
|
for (size_t i = 0; i < attention_window_sizes.size(); ++i) {
|
||||||
|
if (attention_window_sizes[i] == max_seq_len) {
|
||||||
|
attention_window_sizes[i] = new_max_seq_len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
max_seq_len = new_max_seq_len;
|
||||||
|
}
|
||||||
|
|
||||||
void AddLayerConfig(const LayerConfig& layer_config) {
|
void AddLayerConfig(const LayerConfig& layer_config) {
|
||||||
layer_configs.push_back(layer_config);
|
layer_configs.push_back(layer_config);
|
||||||
HWY_ASSERT(layer_configs.size() <= num_layers);
|
HWY_ASSERT(layer_configs.size() <= num_layers);
|
||||||
|
|
@ -516,7 +528,7 @@ ModelConfig GetVitConfig(const ModelConfig& config);
|
||||||
|
|
||||||
enum DeducedLayerTypes {
|
enum DeducedLayerTypes {
|
||||||
kDeducedViT = 2,
|
kDeducedViT = 2,
|
||||||
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
|
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
|
||||||
kDeducedKqNorm = 8,
|
kDeducedKqNorm = 8,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,9 @@
|
||||||
|
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
||||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||||
|
|
@ -556,10 +559,10 @@ static size_t PrefillTBatchOrQBatch(const ModelConfig& config,
|
||||||
}
|
}
|
||||||
|
|
||||||
static void StreamAndUpdateEOSAfterPrefill(const ModelConfig& config,
|
static void StreamAndUpdateEOSAfterPrefill(const ModelConfig& config,
|
||||||
const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config,
|
||||||
QBatch& qbatch,
|
QBatch& qbatch,
|
||||||
hwy::BitSet4096<>& non_eos,
|
hwy::BitSet4096<>& non_eos,
|
||||||
size_t qi) {
|
size_t qi) {
|
||||||
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
|
const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi);
|
||||||
|
|
||||||
const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct.
|
const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct.
|
||||||
|
|
@ -745,6 +748,12 @@ Gemma::Gemma(const GemmaArgs& args, ThreadingContext& ctx)
|
||||||
chat_template_(model_.Tokenizer(), model_.Config().model),
|
chat_template_(model_.Tokenizer(), model_.Config().model),
|
||||||
inference_(args.inference),
|
inference_(args.inference),
|
||||||
aes_ctr_engine_(args.inference.deterministic) {
|
aes_ctr_engine_(args.inference.deterministic) {
|
||||||
|
if (args.inference.seq_len > model_.Config().max_seq_len) {
|
||||||
|
HWY_WARN(
|
||||||
|
"Overriding model's max_seq_len=%u with user provided seq_len=%zu.",
|
||||||
|
model_.Config().max_seq_len, args.inference.seq_len);
|
||||||
|
model_.MutableConfig().SetMaxSeqLen(args.inference.seq_len);
|
||||||
|
}
|
||||||
// Negligible CPU time in the ctor body (except ReadFromBlobs).
|
// Negligible CPU time in the ctor body (except ReadFromBlobs).
|
||||||
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, args.loader,
|
weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, args.loader,
|
||||||
args.inference, mat_owners_, ctx);
|
args.inference, mat_owners_, ctx);
|
||||||
|
|
|
||||||
|
|
@ -387,7 +387,7 @@ void ModelStore::CreateMatPtrs(BlobReader& reader) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ModelStore::ModelStore(BlobReader& reader, const Path& tokenizer_path,
|
ModelStore::ModelStore(BlobReader& reader, const Path& tokenizer_path,
|
||||||
Tristate wrapping)
|
Tristate wrapping)
|
||||||
: config_(ReadOrDeduceConfig(reader, wrapping)),
|
: config_(ReadOrDeduceConfig(reader, wrapping)),
|
||||||
tokenizer_(ReadTokenizer(reader, tokenizer_path)) {
|
tokenizer_(ReadTokenizer(reader, tokenizer_path)) {
|
||||||
if (!ReadMatPtrs(reader)) { // Pre-2025 format.
|
if (!ReadMatPtrs(reader)) { // Pre-2025 format.
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,8 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
class Gemma;
|
||||||
|
|
||||||
// Reads and holds the model config, tokenizer and all `MatPtr`: everything
|
// Reads and holds the model config, tokenizer and all `MatPtr`: everything
|
||||||
// except the tensor data, which are read/written by `weights.cc`.
|
// except the tensor data, which are read/written by `weights.cc`.
|
||||||
//
|
//
|
||||||
|
|
@ -60,6 +62,11 @@ class ModelStore {
|
||||||
return config_;
|
return config_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ModelConfig& MutableConfig() {
|
||||||
|
HWY_ASSERT(config_.model != Model::UNKNOWN);
|
||||||
|
return config_;
|
||||||
|
}
|
||||||
|
|
||||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||||
|
|
||||||
// Returns nullptr if `name` is not available for loading, otherwise the
|
// Returns nullptr if `name` is not available for loading, otherwise the
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue