MatPtr-ify KV, shared div_seq_len, --seq_len flag

PiperOrigin-RevId: 770194455
This commit is contained in:
Jan Wassenberg 2025-06-11 09:48:48 -07:00 committed by Copybara-Service
parent bd98b43cea
commit c027a45a2e
22 changed files with 226 additions and 259 deletions

View File

@ -447,6 +447,7 @@ cc_library(
hdrs = ["gemma/kv_cache.h"],
deps = [
":configs",
":gemma_args",
":mat",
"@highway//:hwy",
],

View File

@ -101,18 +101,6 @@ directly.
For other models, `gemma_export_main.py` is not yet open sourced.
## Compile-Time Flags (Advanced)
There are several compile-time flags to be aware of (note these may or may not
be exposed to the build system):
- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV
Cache. The default is 4096 tokens but can be overridden. This is not exposed
through `CMakeLists.txt` yet.
In the medium term this will likely be deprecated in favor of handling options
at runtime - dynamically resizing the KV cache as needed.
## Using gemma.cpp as a Library (Advanced)
Unless you are doing lower level implementations or research, from an
@ -165,7 +153,7 @@ constrained decoding type of use cases where you want to force the generation to
fit a grammar. If you're not doing this, you can send an empty lambda or
`std::function` as a no-op which is what `run.cc` does.
### `Transformer()` implements the inference (i.e. `forward()` method in PyTorch or Jax) computation of the neural network
### `Transformer()` implements inference (i.e. `forward()` in PyTorch or Jax)
For high-level applications, you might only call `model.Generate()` and never
interact directly with the neural network, but if you're doing something a bit

View File

@ -322,9 +322,10 @@ model (any model with a `-pt` suffix).
**What sequence lengths are supported?**
See `seq_len` in `configs.cc`. For the Gemma 3 models larger than 1B, this is
typically 32K but 128K would also work given enough RAM. Note that long
sequences will be slow due to the quadratic cost of attention.
See `max_seq_len` in `configs.cc` and `InferenceArgs.seq_len`. For the Gemma 3
models larger than 1B, this is typically 32K but 128K would also work given
enough RAM. Note that long sequences will be slow due to the quadratic cost of
attention.
**How do I convert my fine-tune to a `.sbs` compressed model file?**

View File

@ -62,6 +62,7 @@ int BenchmarkSummary(GemmaEnv& env, const Path& text) {
int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t batch_tokens) {
const Gemma& gemma = *env.GetGemma();
std::string input = ReadFileToString(text);
std::vector<int> prompt = env.Tokenize(input);
std::cout << "Number of input tokens: " << prompt.size() << "\n";
@ -73,8 +74,7 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens);
KVCache kv_cache(env.GetGemma()->GetModelConfig(),
env.MutableConfig().prefill_tbatch_size);
KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference());
float entropy = ComputeCrossEntropy(
*env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
total_entropy += entropy;

View File

@ -52,7 +52,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
: env_(MakeMatMulEnv(threading)), gemma_(loader, inference, env_) {
const ModelConfig& config = gemma_.GetModelConfig();
// Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back(KVCache(config, inference.prefill_tbatch_size));
kv_caches_.push_back(KVCache(config, inference));
if (inference.verbosity >= 2) {
ShowConfig(loader, threading, inference, config);
@ -135,8 +135,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
// Ensure we have at least one KVCache per query.
while (kv_caches_.size() < num_queries) {
kv_caches_.push_back(
KVCache(gemma_.GetModelConfig(), runtime_config_.prefill_tbatch_size));
kv_caches_.push_back(KVCache(gemma_.GetModelConfig(), gemma_.Inference()));
}
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};

View File

@ -53,7 +53,7 @@ int main(int argc, char** argv) {
// Instantiate model and KV Cache
gcpp::MatMulEnv env(MakeMatMulEnv(threading));
gcpp::Gemma gemma(loader, inference, env);
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference);
size_t generated = 0;
// Initialize random number generator

View File

@ -35,12 +35,9 @@ class SimplifiedGemma {
SimplifiedGemma(const gcpp::LoaderArgs& loader,
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
: loader_(loader),
threading_(threading),
inference_(inference),
env_(MakeMatMulEnv(threading_)),
gemma_(loader_, inference_, env_),
kv_cache_(gemma_.GetModelConfig(), inference_.prefill_tbatch_size) {
: env_(MakeMatMulEnv(threading)),
gemma_(loader, inference, env_),
kv_cache_(gemma_.GetModelConfig(), inference) {
// Initialize random number generator
std::random_device rd;
gen_.seed(rd());
@ -91,9 +88,6 @@ class SimplifiedGemma {
~SimplifiedGemma() = default;
private:
gcpp::LoaderArgs loader_;
gcpp::ThreadingArgs threading_;
gcpp::InferenceArgs inference_;
gcpp::MatMulEnv env_;
gcpp::Gemma gemma_;
gcpp::KVCache kv_cache_;

View File

@ -46,8 +46,7 @@ struct Activations {
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: weights_config(config),
layer_config(config.layer_configs[0]),
seq_len(config.seq_len),
cache_pos_size(config.CachePosSize()),
div_seq_len(static_cast<uint32_t>(config.max_seq_len)),
is_griffin(config.model == Model::GRIFFIN_2B),
query_scale(ChooseQueryScale(config)),
@ -64,7 +63,9 @@ struct Activations {
pre_att_rms_out("pre_att_rms_out",
Extents2D(batch_size, config.model_dim), pad_),
att("att", Extents2D(batch_size, layer_config.heads * config.seq_len),
att("att",
Extents2D(batch_size,
layer_config.heads * div_seq_len.GetDivisor()),
pad_),
att_out(
"att_out",
@ -141,10 +142,14 @@ struct Activations {
gen_tokens.resize(batch_size);
}
bool IsGlobalLayer(size_t layer_idx) const {
return weights_config.attention_window_sizes[layer_idx] ==
div_seq_len.GetDivisor();
}
const ModelConfig& weights_config;
const LayerConfig& layer_config;
size_t seq_len;
size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT.
hwy::Divisor div_seq_len;
bool is_griffin;
float query_scale;
const Extents2D none_ = Extents2D();

View File

@ -70,9 +70,7 @@ static void PositionalEncodingQK(U* qk, const size_t qkv_dim,
const PostQKType& post_qk = layer.layer_config.post_qk;
// qk is either q or k, so qkv_dim is the length we operate on.
const float* inv_timescale = activations.inv_timescale.PackedScale1();
bool is_global_layer =
activations.weights_config.attention_window_sizes[layer_idx] ==
activations.seq_len;
bool is_global_layer = activations.IsGlobalLayer(layer_idx);
// TODO: add a config flag instead of hardcoding the model.
if (is_global_layer && IsVLM(activations.weights_config.model)) {
inv_timescale = activations.inv_timescale_global.PackedScale1();
@ -116,13 +114,15 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
// Calculates the attention outputs for a single q.
void SingleDotSoftmaxWeightedSum(
const size_t pos, const size_t start_pos, const size_t last_pos,
const hwy::Divisor& div_seq_len, float* HWY_RESTRICT q,
const MatPtrT<float>& k, const MatPtrT<float>& v, const size_t layer_idx,
const LayerWeightsPtrs& layer, const Activations& activations,
float* HWY_RESTRICT att, float* HWY_RESTRICT att_out) {
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v,
const size_t layer_idx, const LayerWeightsPtrs& layer,
const Activations& activations, float* HWY_RESTRICT att,
float* HWY_RESTRICT att_out) {
const size_t qkv_dim = layer.layer_config.qkv_dim;
const float att_cap = activations.weights_config.att_cap;
const float query_scale = activations.query_scale;
const size_t seq_len =
static_cast<size_t>(activations.div_seq_len.GetDivisor());
// Apply rope and scaling to Q.
if (layer.query_norm_scale.HasPtr()) {
@ -133,15 +133,14 @@ void SingleDotSoftmaxWeightedSum(
PositionalEncodingQK(q, qkv_dim, layer_idx, layer, activations, pos,
query_scale);
QDotK(start_pos, last_pos, div_seq_len, q, k, att);
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att);
// SoftMax with optional SoftCap yields "probabilities" in att.
const size_t att_len =
HWY_MIN(last_pos + 1, static_cast<size_t>(div_seq_len.GetDivisor()));
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
MaybeLogitsSoftCap(att_cap, att, att_len);
Softmax(att, att_len);
WeightedSumV(start_pos, last_pos, div_seq_len, att, v, att_out);
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out);
}
// The attention window usually starts at 0 unless `pos` is larger than
@ -152,11 +151,13 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
return pos - HWY_MIN(att_window_size - 1, pos);
}
void DotSoftmaxWeightedSum(
const size_t num_tokens, const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, const KVCaches& kv_caches, NestedPools& pools) {
void DotSoftmaxWeightedSum(const size_t num_tokens,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const size_t layer_idx,
const LayerWeightsPtrs& layer,
Activations& activations, const KVCaches& kv_caches,
NestedPools& pools) {
const size_t num_queries = queries_pos.size();
const LayerConfig& layer_config = layer.layer_config;
PROFILER_ZONE("Gen.Attention.DotSoftmax");
@ -166,7 +167,8 @@ void DotSoftmaxWeightedSum(
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
const size_t cache_layer_size = layer_config.CacheLayerSize();
const size_t cache_pos_size = activations.cache_pos_size;
const size_t seq_len =
static_cast<size_t>(activations.div_seq_len.GetDivisor());
// For each head (token, query), compute Q.K, softmax, and weighted V.
// TODO: nested parallelism to use more threads.
@ -183,21 +185,19 @@ void DotSoftmaxWeightedSum(
float* HWY_RESTRICT q =
activations.q.Row(interleaved_idx) + head * qkv_dim;
float* HWY_RESTRICT att =
activations.att.Row(interleaved_idx) + head * activations.seq_len;
activations.att.Row(interleaved_idx) + head * seq_len;
float* HWY_RESTRICT att_out =
activations.att_out.Row(interleaved_idx) + head * qkv_dim;
// Make strided views into the kv cache entries for the current
// query and head.
KVCache& kv_cache = kv_caches[query_idx];
auto& kv_cache = kv_caches[query_idx].kv_cache;
const size_t kv_head_offset =
layer_idx * cache_layer_size + head_offset;
MatPtrT<float> k("k_view", Extents2D(kv_cache.seq_len, qkv_dim));
k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset,
/*stride=*/cache_pos_size);
MatPtrT<float> v("v_view", Extents2D(kv_cache.seq_len, qkv_dim));
v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
/*stride=*/cache_pos_size);
MatPtrT<float> k("k_view", Extents2D(seq_len, qkv_dim));
k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride());
MatPtrT<float> v("v_view", Extents2D(seq_len, qkv_dim));
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
// Find the token position in the query and calculate the range
// of cache positions to attend to.
@ -211,16 +211,15 @@ void DotSoftmaxWeightedSum(
last_pos = prefix_end - 1;
}
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, div_seq_len, q, k,
v, layer_idx, layer, activations, att,
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
layer_idx, layer, activations, att,
att_out);
});
}
// Fills activations.q and writes to KV cache.
static HWY_INLINE void ComputeQKV(
size_t num_tokens, const QueriesPos& queries_pos,
const hwy::Divisor& div_seq_len, const size_t layer_idx,
size_t num_tokens, const QueriesPos& queries_pos, const size_t layer_idx,
const LayerWeightsPtrs& layer, Activations& activations,
const KVCaches& kv_caches, const int flags, MatMulEnv& env) {
PROFILER_ZONE("Gen.Attention.QKV");
@ -230,7 +229,6 @@ static HWY_INLINE void ComputeQKV(
const size_t qkv_dim = layer_config.qkv_dim;
const size_t kv_heads = layer_config.kv_heads;
const size_t cache_layer_size = layer_config.CacheLayerSize();
const size_t cache_pos_size = activations.cache_pos_size;
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
@ -247,11 +245,10 @@ static HWY_INLINE void ComputeQKV(
const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = interleaved_idx / num_queries;
const size_t cache_pos =
div_seq_len.Remainder(queries_pos[query_idx] + batch_idx);
const size_t kv_offset =
cache_pos * cache_pos_size + layer_idx * cache_layer_size;
activations.div_seq_len.Remainder(queries_pos[query_idx] + batch_idx);
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
kv_caches[query_idx].kv_cache.get() + kv_offset);
kv_caches[query_idx].kv_cache.Row(cache_pos) +
layer_idx * cache_layer_size);
}
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
@ -267,12 +264,11 @@ static HWY_INLINE void ComputeQKV(
const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = interleaved_idx / num_queries;
const size_t pos = queries_pos[query_idx] + batch_idx;
const size_t cache_pos = div_seq_len.Remainder(pos);
const size_t kv_offset = cache_pos * cache_pos_size +
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
auto& kv_cache = kv_caches[query_idx].kv_cache;
float* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
layer_idx * cache_layer_size +
head * qkv_dim * 2;
KVCache& kv_cache = kv_caches[query_idx];
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
// Apply further processing to K.
if (layer.key_norm_scale.HasPtr()) {
@ -309,9 +305,9 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
// causal attention, and must be non-null for prefix-LM style attention.
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos,
const QueriesPos* queries_prefix_end,
const hwy::Divisor& div_seq_len, const size_t layer_idx,
const LayerWeightsPtrs& layer, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env, int flags) {
const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, const KVCaches& kv_caches,
MatMulEnv& env, int flags) {
const size_t num_queries = queries_pos.size();
HWY_DASSERT(num_queries <= kv_caches.size());
@ -330,11 +326,10 @@ void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos,
queries_prefix_end = &queries_prefix_end_span;
}
ComputeQKV(num_tokens, queries_pos, div_seq_len, layer_idx, layer,
activations, kv_caches, flags, env);
DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end,
div_seq_len, layer_idx, layer, activations, kv_caches,
env.ctx.pools);
ComputeQKV(num_tokens, queries_pos, layer_idx, layer, activations, kv_caches,
flags, env);
DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end, layer_idx,
layer, activations, kv_caches, env.ctx.pools);
SumHeads(layer, activations, env);
}

View File

@ -30,24 +30,23 @@ namespace gcpp {
namespace NAMESPACE { \
void SingleDotSoftmaxWeightedSum( \
const size_t pos, const size_t start_pos, const size_t last_pos, \
const hwy::Divisor& div_seq_len, float* HWY_RESTRICT q, \
const MatPtrT<float>& k, const MatPtrT<float>& v, size_t layer_idx, \
const LayerWeightsPtrs& layer, const Activations& activations, \
float* HWY_RESTRICT att, float* HWY_RESTRICT att_out); \
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
const Activations& activations, float* HWY_RESTRICT att, \
float* HWY_RESTRICT att_out); \
\
void DotSoftmaxWeightedSum(const size_t num_tokens, \
const QueriesPos& queries_pos, \
const QueriesPos& queries_prefix_end, \
const hwy::Divisor& div_seq_len, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
Activations& activations, \
const KVCaches& kv_caches, NestedPools& pools); \
\
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, \
const QueriesPos* queries_prefix_end, \
const hwy::Divisor& div_seq_len, const size_t layer_idx, \
const LayerWeightsPtrs& layer, Activations& activations, \
const KVCaches& kv_caches, MatMulEnv& env, int flags); \
const size_t layer_idx, const LayerWeightsPtrs& layer, \
Activations& activations, const KVCaches& kv_caches, \
MatMulEnv& env, int flags); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

View File

@ -43,21 +43,15 @@ namespace gcpp {
// ConversationData constructor implementation
ConversationData::ConversationData(const ModelConfig& model_config,
size_t prefill_tbatch_size)
: model_config_ref_(model_config),
prefill_tbatch_size_(prefill_tbatch_size),
kv_cache(std::make_unique<KVCache>(model_config, prefill_tbatch_size)),
const InferenceArgs& inference_args)
: kv_cache(std::make_unique<KVCache>(model_config, inference_args)),
abs_pos(0) {}
// ConversationData copy constructor implementation
ConversationData::ConversationData(const ConversationData& other)
: model_config_ref_(other.model_config_ref_),
prefill_tbatch_size_(other.prefill_tbatch_size_),
kv_cache(nullptr),
abs_pos(other.abs_pos) {
: kv_cache(nullptr), abs_pos(other.abs_pos) {
if (other.kv_cache) {
kv_cache = std::make_unique<KVCache>(other.kv_cache->Copy(
other.model_config_ref_, other.prefill_tbatch_size_));
kv_cache = std::make_unique<KVCache>(other.kv_cache->Copy());
}
}
@ -115,7 +109,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
LogDebug("Creating initial ConversationData");
// Create the initial ConversationData object using make_shared
active_conversation = std::make_shared<ConversationData>(
model.GetModelConfig(), inference_args.prefill_tbatch_size);
model.GetModelConfig(), inference_args);
LogDebug(
"Storing initial ConversationData in conversation_cache[\"default\"]");

View File

@ -31,26 +31,19 @@
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/kv_cache.h"
#include "ops/matmul.h" // MatMulEnv
#include "hwy/base.h"
#include "hwy/highway.h"
namespace gcpp {
// Forward declaration - use 'struct' to match definition tag
struct KVCache;
// Struct to hold data for a single conversation thread
struct ConversationData {
public:
ConversationData(const ModelConfig& model_config, size_t prefill_tbatch_size);
ConversationData(const ModelConfig& model_config,
const InferenceArgs& inference_args);
ConversationData(const ConversationData& other);
private:
const ModelConfig& model_config_ref_;
size_t prefill_tbatch_size_;
public:
std::unique_ptr<KVCache> kv_cache;
size_t abs_pos = 0;
};
@ -142,8 +135,7 @@ class GemmaContext {
log_msg += "' to prewarmed_cache.";
LogDebug(log_msg.c_str());
// Create a deep copy of the active_conversation.
// The ConversationData copy constructor handles the deep copy of KVCache.
// Create a deep copy of the active_conversation via copy ctor.
auto conversation_copy =
std::make_shared<ConversationData>(*active_conversation);
@ -176,8 +168,7 @@ class GemmaContext {
active_conversation->abs_pos = it->second->abs_pos;
// Perform a deep copy of the KVCache from the prewarmed version.
active_conversation->kv_cache =
std::make_unique<KVCache>(it->second->kv_cache->Copy(
model.GetModelConfig(), inference_args.prefill_tbatch_size));
std::make_unique<KVCache>(it->second->kv_cache->Copy());
LogDebug((log_prefix + "Successfully restored from prewarmed_cache.")
.c_str());
return;
@ -187,8 +178,8 @@ class GemmaContext {
// rewind to initial state.
active_conversation->abs_pos = 0;
// Replace the cache within the current ConversationData object
active_conversation->kv_cache = std::make_unique<KVCache>(
model.GetModelConfig(), inference_args.prefill_tbatch_size);
active_conversation->kv_cache =
std::make_unique<KVCache>(model.GetModelConfig(), inference_args);
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
} else {
@ -206,7 +197,7 @@ class GemmaContext {
LogDebug("Creating new conversation");
// Create a new ConversationData object using make_shared
conversation_cache[name] = std::make_shared<ConversationData>(
model.GetModelConfig(), inference_args.prefill_tbatch_size);
model.GetModelConfig(), inference_args);
return true;
}

View File

@ -27,12 +27,8 @@
namespace gcpp {
// Allow changing pre-allocated kv cache size as a compiler flag
#ifndef GEMMA_MAX_SEQLEN
#define GEMMA_MAX_SEQLEN 4096
#endif // !GEMMA_MAX_SEQLEN
static constexpr size_t kVocabSize = 256000;
static constexpr size_t kMaxSeqLen = 4096;
static ModelConfig ConfigNoSSM() {
ModelConfig config;
@ -69,7 +65,7 @@ static ModelConfig ConfigGemma2_27B() {
config.model = Model::GEMMA2_27B;
config.model_dim = 4608;
config.vocab_size = kVocabSize;
config.seq_len = 8192;
config.max_seq_len = 8192;
LayerConfig layer_config = LayerConfigGemma2_27B(config.model_dim);
config.num_layers = 46;
config.layer_configs = {config.num_layers, layer_config};
@ -97,7 +93,7 @@ static ModelConfig ConfigGemma2_9B() {
config.model = Model::GEMMA2_9B;
config.model_dim = 3584;
config.vocab_size = kVocabSize;
config.seq_len = 8192;
config.max_seq_len = 8192;
LayerConfig layer_config = LayerConfigGemma2_9B(config.model_dim);
config.num_layers = 42;
config.layer_configs = {config.num_layers, layer_config};
@ -125,7 +121,7 @@ static ModelConfig ConfigGemma2_2B() {
config.model = Model::GEMMA2_2B;
config.model_dim = 2304;
config.vocab_size = kVocabSize;
config.seq_len = 8192;
config.max_seq_len = 8192;
LayerConfig layer_config = LayerConfigGemma2_2B(config.model_dim);
config.num_layers = 26;
config.layer_configs = {config.num_layers, layer_config};
@ -152,7 +148,7 @@ static ModelConfig ConfigGemmaTiny() {
config.wrapping = PromptWrapping::GEMMA_IT;
config.model_dim = 32;
config.vocab_size = 32; // at least two f32 vectors
config.seq_len = 32;
config.max_seq_len = 32;
LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim);
config.num_layers = 2;
config.layer_configs = {config.num_layers, layer_config};
@ -188,11 +184,11 @@ static ModelConfig ConfigGriffin2B() {
ModelConfig config = ConfigNoSSM();
config.display_name = "Griffin2B";
config.model = Model::GRIFFIN_2B;
// Griffin uses local attention, so GEMMA_MAX_SEQLEN is actually the local
// Griffin uses local attention, so max_seq_len is actually the local
// attention window.
config.model_dim = 2560;
config.vocab_size = kVocabSize;
config.seq_len = 2048;
config.max_seq_len = 2048;
LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim);
config.num_layers = 26;
config.layer_configs = {config.num_layers, layer_config};
@ -200,7 +196,8 @@ static ModelConfig ConfigGriffin2B() {
config.layer_configs[i].type = LayerAttentionType::kGemma;
config.layer_configs[i].griffin_dim = 0;
}
config.attention_window_sizes = FixedAttentionWindowSizes<26>(config.seq_len);
config.attention_window_sizes =
FixedAttentionWindowSizes<26>(config.max_seq_len);
config.use_local_attention = true;
config.final_cap = 0.0f;
return config;
@ -238,7 +235,7 @@ static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
ModelConfig GetVitConfig(const ModelConfig& config) {
ModelConfig vit_config = ConfigNoSSM();
vit_config.model_dim = config.vit_config.model_dim;
vit_config.seq_len = config.vit_config.seq_len;
vit_config.max_seq_len = config.vit_config.seq_len;
vit_config.layer_configs = config.vit_config.layer_configs;
vit_config.pool_dim = config.vit_config.pool_dim;
vit_config.wrapping = config.wrapping;
@ -313,14 +310,14 @@ static ModelConfig ConfigGemma3_1B() {
config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 1152;
config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024;
config.max_seq_len = 32 * 1024;
LayerConfig layer_config = LayerConfigGemma3_1B_LM(config.model_dim);
config.num_layers = 26;
config.layer_configs = {config.num_layers, layer_config};
config.query_scale = QueryScaleType::SqrtKeySize;
// interleaved local / global attention
config.attention_window_sizes = RepeatedAttentionWindowSizes<26, 6>(
{512, 512, 512, 512, 512, config.seq_len});
{512, 512, 512, 512, 512, config.max_seq_len});
return config;
}
@ -345,14 +342,14 @@ static ModelConfig ConfigGemma3_4B_LM() {
config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 2560;
config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024;
config.max_seq_len = 32 * 1024;
LayerConfig layer_config = LayerConfigGemma3_4B_LM(config.model_dim);
config.num_layers = 34;
config.layer_configs = {config.num_layers, layer_config};
config.query_scale = QueryScaleType::SqrtKeySize;
// interleaved local / global attention
config.attention_window_sizes = RepeatedAttentionWindowSizes<34, 6>(
{1024, 1024, 1024, 1024, 1024, config.seq_len});
{1024, 1024, 1024, 1024, 1024, config.max_seq_len});
return config;
}
@ -394,14 +391,14 @@ static ModelConfig ConfigGemma3_12B_LM() {
config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 3840;
config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024;
config.max_seq_len = 32 * 1024;
LayerConfig layer_config = LayerConfigGemma3_12B_LM(config.model_dim);
config.num_layers = 48;
config.layer_configs = {config.num_layers, layer_config};
config.query_scale = QueryScaleType::SqrtKeySize;
// interleaved local / global attention
config.attention_window_sizes = RepeatedAttentionWindowSizes<48, 6>(
{1024, 1024, 1024, 1024, 1024, config.seq_len});
{1024, 1024, 1024, 1024, 1024, config.max_seq_len});
return config;
}
@ -443,14 +440,14 @@ static ModelConfig ConfigGemma3_27B_LM() {
config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 5376;
config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024;
config.max_seq_len = 32 * 1024;
LayerConfig layer_config = LayerConfigGemma3_27B_LM(config.model_dim);
config.num_layers = 62;
config.layer_configs = {config.num_layers, layer_config};
config.query_scale = QueryScaleType::SqrtKeySize;
// interleaved local / global attention
config.attention_window_sizes = RepeatedAttentionWindowSizes<62, 6>(
{1024, 1024, 1024, 1024, 1024, config.seq_len});
{1024, 1024, 1024, 1024, 1024, config.max_seq_len});
return config;
}

View File

@ -347,7 +347,7 @@ struct ModelConfig : public IFields {
visitor(num_layers);
visitor(model_dim);
visitor(vocab_size);
visitor(seq_len);
visitor(max_seq_len);
visitor(unused_num_tensor_scales);
@ -413,7 +413,7 @@ struct ModelConfig : public IFields {
return num_heads;
}
size_t CachePosSize() const {
size_t KVCacheCols() const {
size_t num_layers = layer_configs.size();
return num_layers * layer_configs[0].CacheLayerSize();
}
@ -435,7 +435,7 @@ struct ModelConfig : public IFields {
uint32_t num_layers = 0;
uint32_t model_dim = 0;
uint32_t vocab_size = 0;
uint32_t seq_len = 0;
uint32_t max_seq_len = 0;
// We no longer set nor use this: config_converter is not able to set this,
// and only pre-2025 format stores scales, and we do not require advance

View File

@ -64,13 +64,12 @@ namespace HWY_NAMESPACE {
void Attention(LayerAttentionType type, size_t num_tokens,
const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end,
const hwy::Divisor& div_seq_len, const size_t layer_idx,
const QueriesPos& queries_prefix_end, const size_t layer_idx,
const LayerWeightsPtrs& layer, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env) {
if (type == LayerAttentionType::kGemma) {
GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, div_seq_len,
layer_idx, layer, activations, kv_caches, env,
GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, layer_idx,
layer, activations, kv_caches, env,
/*flags=*/0);
} else {
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
@ -85,16 +84,16 @@ void Attention(LayerAttentionType type, size_t num_tokens,
static HWY_NOINLINE void TransformerLayer(
const size_t num_tokens, const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
const size_t layer_idx, const LayerWeightsPtrs& layer,
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) {
const QueriesPos& queries_prefix_end, const size_t layer_idx,
const LayerWeightsPtrs& layer, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env) {
const LayerConfig& layer_config = layer.layer_config;
RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
activations.pre_att_rms_out);
Attention(layer_config.type, num_tokens, queries_pos, queries_prefix_end,
div_seq_len, layer_idx, layer, activations, kv_caches, env);
layer_idx, layer, activations, kv_caches, env);
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
activations.att_sums);
@ -190,10 +189,9 @@ using QueriesMutablePos = hwy::Span<size_t>;
static HWY_NOINLINE void PrefillTBatch(
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
const hwy::Divisor& div_seq_len, const ModelConfig& config,
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
hwy::BitSet4096<>& non_eos) {
const ModelConfig& config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos) {
PROFILER_ZONE("Gen.PrefillT");
const size_t num_queries = queries_prompt.size();
HWY_DASSERT(num_queries == queries_pos.size());
@ -265,8 +263,8 @@ static HWY_NOINLINE void PrefillTBatch(
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
++layer_idx) {
TransformerLayer(tbatch_size, single_query_pos, single_query_prefix_end,
div_seq_len, layer_idx, *weights.GetLayer(layer_idx),
activations, single_kv_cache, env);
layer_idx, *weights.GetLayer(layer_idx), activations,
single_kv_cache, env);
}
// NOTE: we unconditionally call StreamToken, even if EOS.
@ -303,10 +301,9 @@ static HWY_NOINLINE void PrefillTBatch(
// token-batched `PrefillTBatch`.
static HWY_NOINLINE void Transformer(
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
const ModelConfig& config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env) {
const QueriesPos& queries_prefix_end, const ModelConfig& config,
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) {
const size_t num_queries = queries_token.size();
HWY_DASSERT(num_queries == queries_pos.size());
HWY_DASSERT(num_queries == queries_prefix_end.size());
@ -326,8 +323,8 @@ static HWY_NOINLINE void Transformer(
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
TransformerLayer(/*num_tokens=*/1, queries_pos, queries_prefix_end,
div_seq_len, layer_idx, *weights.GetLayer(layer_idx),
activations, kv_caches, env);
layer_idx, *weights.GetLayer(layer_idx), activations,
kv_caches, env);
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
runtime_config.activations_observer(queries_pos, layer_idx, activations);
@ -340,10 +337,10 @@ static HWY_NOINLINE void Transformer(
static HWY_NOINLINE void PrefillQBatch(
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
const size_t max_prompt_size, const hwy::Divisor& div_seq_len,
const ModelConfig& config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos) {
const size_t max_prompt_size, const ModelConfig& config,
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
hwy::BitSet4096<>& non_eos) {
PROFILER_ZONE("Gen.Prefill");
const size_t num_queries = queries_prompt.size();
HWY_DASSERT(num_queries == queries_pos.size());
@ -380,8 +377,8 @@ static HWY_NOINLINE void PrefillQBatch(
// Do not call DecodeStepT because it computes logits for token
// probabilities, which are not required for the prompt tokens.
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
queries_pos, queries_prefix_end, div_seq_len, config,
runtime_config, weights, activations, kv_caches, env);
queries_pos, queries_prefix_end, config, runtime_config,
weights, activations, kv_caches, env);
prefill_active.Foreach([&](size_t qi) {
const int token = queries_prompt[qi][pos_in_prompt];
@ -393,19 +390,6 @@ static HWY_NOINLINE void PrefillQBatch(
} // pos_in_prompt
}
// TODO: inline.
void RangeChecks(const ModelConfig& weights_config,
size_t& max_generated_tokens, const size_t prompt_size) {
if (!weights_config.use_local_attention) {
if (max_generated_tokens > weights_config.seq_len) {
HWY_WARN("max_generated_tokens %zu > kSeqLen %u, truncating.",
max_generated_tokens, weights_config.seq_len);
max_generated_tokens = weights_config.seq_len;
}
}
HWY_ASSERT(prompt_size > 0);
}
// Also writes the token to activations.gen_tokens for subsequent DecodeStepT,
// and updates `non_eos` if the query is at the end of its sequence.
static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token,
@ -432,17 +416,17 @@ static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token,
static void DecodeStepT(
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
const QueriesMutablePos& queries_mutable_pos,
const QueriesPos& queries_prefix_end, const hwy::Divisor div_seq_len,
const ModelConfig& config, const RuntimeConfig& runtime_config,
const ModelWeightsPtrs& weights, const SampleFunc& sample_token,
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) {
const QueriesPos& queries_prefix_end, const ModelConfig& config,
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
const SampleFunc& sample_token, Activations& activations,
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos,
TimingInfo& timing_info) {
const size_t num_queries = queries_prompt.size();
HWY_DASSERT(num_queries == activations.x.Rows());
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
queries_mutable_pos, queries_prefix_end, div_seq_len, config,
runtime_config, weights, activations, kv_caches, env);
queries_mutable_pos, queries_prefix_end, config, runtime_config,
weights, activations, kv_caches, env);
RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
@ -530,6 +514,7 @@ static void GenerateT(
size_t max_prompt_size = 0;
bool all_prefix_end_are_zero = true;
size_t prefill_tokens = 0;
const size_t seq_len = kv_caches[0].SeqLen();
for (size_t qi = 0; qi < num_queries; ++qi) {
const PromptTokens& prompt = queries_prompt[qi];
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
@ -542,9 +527,12 @@ static void GenerateT(
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
all_prefix_end_are_zero &= queries_prefix_end[qi] == 0;
}
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
// We use a single divisor, so all sequence lengths must be the same.
HWY_ASSERT(kv_caches[qi].SeqLen() == seq_len);
}
HWY_ASSERT(prefill_tokens < seq_len);
activations.div_seq_len = hwy::Divisor(static_cast<uint32_t>(seq_len));
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
// qi loops anyway.
@ -555,13 +543,12 @@ static void GenerateT(
if ((num_queries > max_prompt_size) && all_prefix_end_are_zero) {
activations.SetBatchSize(num_queries); // required before PrefillQBatch
PrefillQBatch(query_idx_start, queries_prompt, queries_mutable_pos,
queries_prefix_end, max_prompt_size, div_seq_len, config,
runtime_config, weights, activations, kv_caches, env,
non_eos);
queries_prefix_end, max_prompt_size, config, runtime_config,
weights, activations, kv_caches, env, non_eos);
} else {
PrefillTBatch(query_idx_start, queries_prompt, queries_mutable_pos,
queries_prefix_end, div_seq_len, config, runtime_config,
weights, activations, kv_caches, env, non_eos);
queries_prefix_end, config, runtime_config, weights,
activations, kv_caches, env, non_eos);
activations.SetBatchSize(num_queries); // Restore after PrefillTBatch.
}
HWY_DASSERT(num_queries == non_eos.Count());
@ -579,7 +566,11 @@ static void GenerateT(
}
size_t max_gen_steps = runtime_config.max_generated_tokens;
RangeChecks(config, max_gen_steps, max_prompt_size);
if (prefill_tokens + max_gen_steps > seq_len) {
HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.",
prefill_tokens, max_gen_steps, seq_len);
max_gen_steps = seq_len - prefill_tokens;
}
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
@ -587,8 +578,8 @@ static void GenerateT(
timing_info.generate_start = hwy::platform::Now();
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
DecodeStepT(query_idx_start, queries_prompt, queries_mutable_pos,
queries_prefix_end, div_seq_len, config, runtime_config,
weights, sample_token, activations, kv_caches, env, non_eos,
queries_prefix_end, config, runtime_config, weights,
sample_token, activations, kv_caches, env, non_eos,
timing_info);
}
timing_info.NotifyGenerateDone();
@ -661,10 +652,11 @@ void GenerateImageTokensT(const ModelConfig& config,
HWY_ABORT("Model does not support generating image tokens.");
}
RuntimeConfig prefill_runtime_config = runtime_config;
ModelConfig vit_config = GetVitConfig(config);
const ModelConfig vit_config = GetVitConfig(config);
const size_t num_tokens = vit_config.max_seq_len;
prefill_runtime_config.prefill_tbatch_size =
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config, vit_config.seq_len, env.row_ptrs);
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config, num_tokens, env.row_ptrs);
// Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
prefill_activations, env);
@ -692,7 +684,8 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
reader_(loader.weights),
model_(reader_, loader.tokenizer, loader.wrapping),
weights_(model_.Config()),
chat_template_(model_.Tokenizer(), model_.Config().model) {
chat_template_(model_.Tokenizer(), model_.Config().model),
inference_(inference) {
weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_,
env.ctx.pools.Pool());
reader_.CloseFile();

View File

@ -117,6 +117,7 @@ class Gemma {
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
const ModelWeightsPtrs& Weights() const { return weights_; }
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
const InferenceArgs& Inference() const { return inference_; }
void Save(const Path& weights_path, hwy::ThreadPool& pool) const;
@ -159,6 +160,7 @@ class Gemma {
std::vector<MatOwner> mat_owners_;
ModelWeightsPtrs weights_;
GemmaChatTemplate chat_template_;
InferenceArgs inference_;
};
} // namespace gcpp

View File

@ -35,11 +35,6 @@
namespace gcpp {
// Allow changing k parameter of `SampleTopK` as a compiler flag
#ifndef GEMMA_TOPK
#define GEMMA_TOPK 1
#endif // !GEMMA_TOPK
struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
LoaderArgs(const std::string& tokenizer_path,
@ -115,6 +110,7 @@ using ActivationsObserverFunc =
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
// RuntimeConfig holds configuration for a single generation run.
// TODO: move into InferenceArgs, use that directly.
struct RuntimeConfig {
// If not empty, batch_stream_token is called for each token in the batch,
// instead of stream_token.
@ -137,7 +133,7 @@ struct RuntimeConfig {
// Sampling-related parameters.
float temperature; // Temperature for sampling.
size_t top_k = GEMMA_TOPK; // Top-k for sampling.
size_t top_k = 1; // Top-k for sampling.
std::mt19937* gen; // Random number generator used for sampling.
int verbosity; // Controls verbosity of printed messages.
@ -170,6 +166,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
int verbosity;
size_t seq_len;
size_t max_generated_tokens;
size_t prefill_tbatch_size;
@ -192,6 +189,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
"developer/debug info).\n Default = 1.",
1); // Changed verbosity level to 1 since it's user-facing
visitor(seq_len, "seq_len", size_t{2048},
"Sequence length, capped by ModelConfig.max_seq_len.");
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
"Maximum number of tokens to generate.");

View File

@ -15,21 +15,25 @@
#include "gemma/kv_cache.h"
#include <algorithm> // std::copy
#include <stddef.h>
#include "gemma/configs.h"
#include "gemma/gemma_args.h"
#include "util/mat.h" // ZeroInit
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // ZeroBytes
#include "hwy/base.h" // HWY_MAX
namespace gcpp {
void KVCache::ZeroGriffinCache() {
if (griffin_layers == 0) return;
if (conv1d_cache.Rows() == 0) return;
ZeroInit(conv1d_cache);
ZeroInit(rglru_cache);
}
static size_t GriffinLayers(const ModelConfig& config) {
return config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock);
}
static size_t GriffinConv1dCols(const ModelConfig& config) {
size_t conv1d_width = 0;
for (const auto& layer_config : config.layer_configs) {
@ -40,43 +44,41 @@ static size_t GriffinConv1dCols(const ModelConfig& config) {
return conv1d_width * config.model_dim;
}
// prefill_tbatch_size is the maximum number of tokens from one query to
// prefill at a time.
KVCache::KVCache(const ModelConfig& config, size_t prefill_tbatch_size)
: griffin_layers(
config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock)),
conv1d_cache("conv1d_cache",
Extents2D(griffin_layers, GriffinConv1dCols(config)),
MatPadding::kOdd),
rglru_cache("rglru_cache", Extents2D(griffin_layers, config.model_dim),
MatPadding::kOdd) {
// TODO: move to MatStorageT.
const size_t size_cache_pos = config.CachePosSize();
if (size_cache_pos != 0) {
// Allocate more so that prefill can always access one batch, even if
// near the end of the sequence.
seq_len = config.seq_len + prefill_tbatch_size;
kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
// Number of rows for KV cache. Note that both rows and cols are u32, and
// the total number of elements can exceed 2^32.
static size_t CappedSeqLen(const ModelConfig& config,
const InferenceArgs& inference_args) {
if (inference_args.seq_len > config.max_seq_len) {
HWY_WARN("Capping seq_len %zu to config.max_seq_len %u.",
inference_args.seq_len, config.max_seq_len);
return config.max_seq_len;
}
return inference_args.seq_len;
}
KVCache KVCache::Copy(const ModelConfig& weights_config,
size_t prefill_tbatch_size) {
KVCache copy(weights_config, prefill_tbatch_size);
KVCache::KVCache(const Extents2D& conv1d_extents,
const Extents2D& rglru_extents, const Extents2D& kv_extents)
: conv1d_cache("conv1d_cache", conv1d_extents, MatPadding::kOdd),
rglru_cache("rglru_cache", rglru_extents, MatPadding::kOdd),
kv_cache("kv", kv_extents, MatPadding::kOdd) {}
const size_t size_cache_pos = weights_config.CachePosSize();
if (size_cache_pos != 0) {
std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len,
copy.kv_cache.get());
}
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args)
: KVCache(Extents2D(GriffinLayers(config), GriffinConv1dCols(config)),
Extents2D(GriffinLayers(config), config.model_dim),
Extents2D(CappedSeqLen(config, inference_args),
config.KVCacheCols())) {}
if (conv1d_cache.HasPtr()) {
KVCache KVCache::Copy() {
KVCache copy(conv1d_cache.Extents(), rglru_cache.Extents(),
kv_cache.Extents());
if (conv1d_cache.Rows() != 0) {
CopyMat(conv1d_cache, copy.conv1d_cache);
}
if (rglru_cache.HasPtr()) {
CopyMat(rglru_cache, copy.rglru_cache);
}
CopyMat(kv_cache, copy.kv_cache);
return copy;
}

View File

@ -19,29 +19,34 @@
#include <stddef.h>
#include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h"
#include "util/mat.h"
#include "hwy/aligned_allocator.h"
namespace gcpp {
struct KVCache {
KVCache(const ModelConfig& weights_config, size_t prefill_tbatch_size);
KVCache(const ModelConfig& config, const InferenceArgs& inference_args);
// Returns a deep copy of the KVCache.
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
// Returns a deep copy of the KVCache. Use explicit function instead of
// copy ctor to make the cost explicit.
KVCache Copy();
size_t griffin_layers = 0;
// griffin_layers, griffin_conv1d_cols * config.model_dim
MatStorageT<float> conv1d_cache;
MatStorageT<float> rglru_cache; // griffin_layers, config.model_dim
// Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache
// and rglru_cache.
void ZeroGriffinCache();
size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size
size_t SeqLen() const { return kv_cache.Rows(); }
// seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2
hwy::AlignedFreeUniquePtr<float[]> kv_cache;
// [griffin_layers, griffin_conv1d_cols * model_dim]
MatStorageT<float> conv1d_cache;
MatStorageT<float> rglru_cache; // [griffin_layers, model_dim]
MatStorageT<float> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
private:
// For use by other ctor and Copy()
KVCache(const Extents2D& conv1d_extents, const Extents2D& rglru_extents,
const Extents2D& kv_extents);
};
} // namespace gcpp

View File

@ -256,7 +256,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
MatMulEnv env(MakeMatMulEnv(threading));
if (inference.verbosity >= 2) env.print_best = true;
const Gemma gemma(loader, inference, env);
KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
KVCache kv_cache(gemma.GetModelConfig(), inference);
if (inference.verbosity >= 1) {
std::string instructions =

View File

@ -68,7 +68,8 @@ class VitAttention {
const size_t qkv_dim = layer_config_.qkv_dim;
const size_t heads = layer_config_.heads;
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
const size_t seq_len = activations_.seq_len;
const size_t seq_len =
static_cast<size_t>(activations_.div_seq_len.GetDivisor());
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
@ -124,7 +125,8 @@ class VitAttention {
const size_t qkv_dim = layer_config_.qkv_dim;
const size_t heads = layer_config_.heads;
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
const size_t seq_len = activations_.seq_len;
const size_t seq_len =
static_cast<size_t>(activations_.div_seq_len.GetDivisor());
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
@ -138,7 +140,7 @@ class VitAttention {
activations_.q.Row(token) + head * 3 * qkv_dim;
MulByConst(query_scale, q, qkv_dim);
float* HWY_RESTRICT head_att =
activations_.att.Row(token) + head * activations_.seq_len;
activations_.att.Row(token) + head * seq_len;
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT k =
activations_.q.Row(i) + head * 3 * qkv_dim + qkv_dim;
@ -275,7 +277,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
MatMulEnv& env) {
const size_t model_dim = model_config.vit_config.model_dim;
const size_t patch_width = model_config.vit_config.patch_width;
const size_t seq_len = model_config.vit_config.seq_len;
const size_t num_tokens = model_config.vit_config.seq_len;
const size_t patch_size = patch_width * patch_width * 3;
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
@ -285,9 +287,9 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
// image_patches is (256, 14 * 14 * 3)
// Must be padded, see `DoDecompressA`.
MatStorageT<float> image_patches("patches", Extents2D(seq_len, patch_size),
MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_size),
MatPadding::kOdd);
for (size_t i = 0; i < seq_len; ++i) {
for (size_t i = 0; i < num_tokens; ++i) {
image.GetPatch(i, image_patches.Row(i));
}
CallMatMul(image_patches, weights.vit_img_embedding_kernel,

View File

@ -161,7 +161,7 @@ PYBIND11_MODULE(configs, py_module) {
.def_readwrite("num_layers", &ModelConfig::num_layers)
.def_readwrite("model_dim", &ModelConfig::model_dim)
.def_readwrite("vocab_size", &ModelConfig::vocab_size)
.def_readwrite("seq_len", &ModelConfig::seq_len)
.def_readwrite("max_seq_len", &ModelConfig::max_seq_len)
// Skip `unused_num_tensor_scales`.
.def_readwrite("att_cap", &ModelConfig::att_cap)
.def_readwrite("final_cap", &ModelConfig::final_cap)