mirror of https://github.com/google/gemma.cpp.git
MatPtr-ify KV, shared div_seq_len, --seq_len flag
PiperOrigin-RevId: 770194455
This commit is contained in:
parent
bd98b43cea
commit
c027a45a2e
|
|
@ -447,6 +447,7 @@ cc_library(
|
|||
hdrs = ["gemma/kv_cache.h"],
|
||||
deps = [
|
||||
":configs",
|
||||
":gemma_args",
|
||||
":mat",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -101,18 +101,6 @@ directly.
|
|||
|
||||
For other models, `gemma_export_main.py` is not yet open sourced.
|
||||
|
||||
## Compile-Time Flags (Advanced)
|
||||
|
||||
There are several compile-time flags to be aware of (note these may or may not
|
||||
be exposed to the build system):
|
||||
|
||||
- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV
|
||||
Cache. The default is 4096 tokens but can be overridden. This is not exposed
|
||||
through `CMakeLists.txt` yet.
|
||||
|
||||
In the medium term this will likely be deprecated in favor of handling options
|
||||
at runtime - dynamically resizing the KV cache as needed.
|
||||
|
||||
## Using gemma.cpp as a Library (Advanced)
|
||||
|
||||
Unless you are doing lower level implementations or research, from an
|
||||
|
|
@ -165,7 +153,7 @@ constrained decoding type of use cases where you want to force the generation to
|
|||
fit a grammar. If you're not doing this, you can send an empty lambda or
|
||||
`std::function` as a no-op which is what `run.cc` does.
|
||||
|
||||
### `Transformer()` implements the inference (i.e. `forward()` method in PyTorch or Jax) computation of the neural network
|
||||
### `Transformer()` implements inference (i.e. `forward()` in PyTorch or Jax)
|
||||
|
||||
For high-level applications, you might only call `model.Generate()` and never
|
||||
interact directly with the neural network, but if you're doing something a bit
|
||||
|
|
|
|||
|
|
@ -322,9 +322,10 @@ model (any model with a `-pt` suffix).
|
|||
|
||||
**What sequence lengths are supported?**
|
||||
|
||||
See `seq_len` in `configs.cc`. For the Gemma 3 models larger than 1B, this is
|
||||
typically 32K but 128K would also work given enough RAM. Note that long
|
||||
sequences will be slow due to the quadratic cost of attention.
|
||||
See `max_seq_len` in `configs.cc` and `InferenceArgs.seq_len`. For the Gemma 3
|
||||
models larger than 1B, this is typically 32K but 128K would also work given
|
||||
enough RAM. Note that long sequences will be slow due to the quadratic cost of
|
||||
attention.
|
||||
|
||||
**How do I convert my fine-tune to a `.sbs` compressed model file?**
|
||||
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ int BenchmarkSummary(GemmaEnv& env, const Path& text) {
|
|||
|
||||
int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
||||
size_t batch_tokens) {
|
||||
const Gemma& gemma = *env.GetGemma();
|
||||
std::string input = ReadFileToString(text);
|
||||
std::vector<int> prompt = env.Tokenize(input);
|
||||
std::cout << "Number of input tokens: " << prompt.size() << "\n";
|
||||
|
|
@ -73,8 +74,7 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
|||
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||
prompt.begin() + pos + num_tokens);
|
||||
KVCache kv_cache(env.GetGemma()->GetModelConfig(),
|
||||
env.MutableConfig().prefill_tbatch_size);
|
||||
KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference());
|
||||
float entropy = ComputeCrossEntropy(
|
||||
*env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
||||
total_entropy += entropy;
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
: env_(MakeMatMulEnv(threading)), gemma_(loader, inference, env_) {
|
||||
const ModelConfig& config = gemma_.GetModelConfig();
|
||||
// Only allocate one for starters because GenerateBatch might not be called.
|
||||
kv_caches_.push_back(KVCache(config, inference.prefill_tbatch_size));
|
||||
kv_caches_.push_back(KVCache(config, inference));
|
||||
|
||||
if (inference.verbosity >= 2) {
|
||||
ShowConfig(loader, threading, inference, config);
|
||||
|
|
@ -135,8 +135,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
|
||||
// Ensure we have at least one KVCache per query.
|
||||
while (kv_caches_.size() < num_queries) {
|
||||
kv_caches_.push_back(
|
||||
KVCache(gemma_.GetModelConfig(), runtime_config_.prefill_tbatch_size));
|
||||
kv_caches_.push_back(KVCache(gemma_.GetModelConfig(), gemma_.Inference()));
|
||||
}
|
||||
|
||||
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ int main(int argc, char** argv) {
|
|||
// Instantiate model and KV Cache
|
||||
gcpp::MatMulEnv env(MakeMatMulEnv(threading));
|
||||
gcpp::Gemma gemma(loader, inference, env);
|
||||
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
|
||||
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference);
|
||||
size_t generated = 0;
|
||||
|
||||
// Initialize random number generator
|
||||
|
|
|
|||
|
|
@ -35,12 +35,9 @@ class SimplifiedGemma {
|
|||
SimplifiedGemma(const gcpp::LoaderArgs& loader,
|
||||
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
|
||||
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
|
||||
: loader_(loader),
|
||||
threading_(threading),
|
||||
inference_(inference),
|
||||
env_(MakeMatMulEnv(threading_)),
|
||||
gemma_(loader_, inference_, env_),
|
||||
kv_cache_(gemma_.GetModelConfig(), inference_.prefill_tbatch_size) {
|
||||
: env_(MakeMatMulEnv(threading)),
|
||||
gemma_(loader, inference, env_),
|
||||
kv_cache_(gemma_.GetModelConfig(), inference) {
|
||||
// Initialize random number generator
|
||||
std::random_device rd;
|
||||
gen_.seed(rd());
|
||||
|
|
@ -91,9 +88,6 @@ class SimplifiedGemma {
|
|||
~SimplifiedGemma() = default;
|
||||
|
||||
private:
|
||||
gcpp::LoaderArgs loader_;
|
||||
gcpp::ThreadingArgs threading_;
|
||||
gcpp::InferenceArgs inference_;
|
||||
gcpp::MatMulEnv env_;
|
||||
gcpp::Gemma gemma_;
|
||||
gcpp::KVCache kv_cache_;
|
||||
|
|
|
|||
|
|
@ -46,8 +46,7 @@ struct Activations {
|
|||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||
: weights_config(config),
|
||||
layer_config(config.layer_configs[0]),
|
||||
seq_len(config.seq_len),
|
||||
cache_pos_size(config.CachePosSize()),
|
||||
div_seq_len(static_cast<uint32_t>(config.max_seq_len)),
|
||||
is_griffin(config.model == Model::GRIFFIN_2B),
|
||||
query_scale(ChooseQueryScale(config)),
|
||||
|
||||
|
|
@ -64,7 +63,9 @@ struct Activations {
|
|||
|
||||
pre_att_rms_out("pre_att_rms_out",
|
||||
Extents2D(batch_size, config.model_dim), pad_),
|
||||
att("att", Extents2D(batch_size, layer_config.heads * config.seq_len),
|
||||
att("att",
|
||||
Extents2D(batch_size,
|
||||
layer_config.heads * div_seq_len.GetDivisor()),
|
||||
pad_),
|
||||
att_out(
|
||||
"att_out",
|
||||
|
|
@ -141,10 +142,14 @@ struct Activations {
|
|||
gen_tokens.resize(batch_size);
|
||||
}
|
||||
|
||||
bool IsGlobalLayer(size_t layer_idx) const {
|
||||
return weights_config.attention_window_sizes[layer_idx] ==
|
||||
div_seq_len.GetDivisor();
|
||||
}
|
||||
|
||||
const ModelConfig& weights_config;
|
||||
const LayerConfig& layer_config;
|
||||
size_t seq_len;
|
||||
size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT.
|
||||
hwy::Divisor div_seq_len;
|
||||
bool is_griffin;
|
||||
float query_scale;
|
||||
const Extents2D none_ = Extents2D();
|
||||
|
|
|
|||
|
|
@ -70,9 +70,7 @@ static void PositionalEncodingQK(U* qk, const size_t qkv_dim,
|
|||
const PostQKType& post_qk = layer.layer_config.post_qk;
|
||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||
const float* inv_timescale = activations.inv_timescale.PackedScale1();
|
||||
bool is_global_layer =
|
||||
activations.weights_config.attention_window_sizes[layer_idx] ==
|
||||
activations.seq_len;
|
||||
bool is_global_layer = activations.IsGlobalLayer(layer_idx);
|
||||
// TODO: add a config flag instead of hardcoding the model.
|
||||
if (is_global_layer && IsVLM(activations.weights_config.model)) {
|
||||
inv_timescale = activations.inv_timescale_global.PackedScale1();
|
||||
|
|
@ -116,13 +114,15 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
|||
// Calculates the attention outputs for a single q.
|
||||
void SingleDotSoftmaxWeightedSum(
|
||||
const size_t pos, const size_t start_pos, const size_t last_pos,
|
||||
const hwy::Divisor& div_seq_len, float* HWY_RESTRICT q,
|
||||
const MatPtrT<float>& k, const MatPtrT<float>& v, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, const Activations& activations,
|
||||
float* HWY_RESTRICT att, float* HWY_RESTRICT att_out) {
|
||||
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
const Activations& activations, float* HWY_RESTRICT att,
|
||||
float* HWY_RESTRICT att_out) {
|
||||
const size_t qkv_dim = layer.layer_config.qkv_dim;
|
||||
const float att_cap = activations.weights_config.att_cap;
|
||||
const float query_scale = activations.query_scale;
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
||||
|
||||
// Apply rope and scaling to Q.
|
||||
if (layer.query_norm_scale.HasPtr()) {
|
||||
|
|
@ -133,15 +133,14 @@ void SingleDotSoftmaxWeightedSum(
|
|||
PositionalEncodingQK(q, qkv_dim, layer_idx, layer, activations, pos,
|
||||
query_scale);
|
||||
|
||||
QDotK(start_pos, last_pos, div_seq_len, q, k, att);
|
||||
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att);
|
||||
|
||||
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||
const size_t att_len =
|
||||
HWY_MIN(last_pos + 1, static_cast<size_t>(div_seq_len.GetDivisor()));
|
||||
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
|
||||
MaybeLogitsSoftCap(att_cap, att, att_len);
|
||||
Softmax(att, att_len);
|
||||
|
||||
WeightedSumV(start_pos, last_pos, div_seq_len, att, v, att_out);
|
||||
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out);
|
||||
}
|
||||
|
||||
// The attention window usually starts at 0 unless `pos` is larger than
|
||||
|
|
@ -152,11 +151,13 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
|
|||
return pos - HWY_MIN(att_window_size - 1, pos);
|
||||
}
|
||||
|
||||
void DotSoftmaxWeightedSum(
|
||||
const size_t num_tokens, const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
Activations& activations, const KVCaches& kv_caches, NestedPools& pools) {
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer,
|
||||
Activations& activations, const KVCaches& kv_caches,
|
||||
NestedPools& pools) {
|
||||
const size_t num_queries = queries_pos.size();
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
||||
|
|
@ -166,7 +167,8 @@ void DotSoftmaxWeightedSum(
|
|||
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
|
||||
|
||||
const size_t cache_layer_size = layer_config.CacheLayerSize();
|
||||
const size_t cache_pos_size = activations.cache_pos_size;
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
||||
|
||||
// For each head (token, query), compute Q.K, softmax, and weighted V.
|
||||
// TODO: nested parallelism to use more threads.
|
||||
|
|
@ -183,21 +185,19 @@ void DotSoftmaxWeightedSum(
|
|||
float* HWY_RESTRICT q =
|
||||
activations.q.Row(interleaved_idx) + head * qkv_dim;
|
||||
float* HWY_RESTRICT att =
|
||||
activations.att.Row(interleaved_idx) + head * activations.seq_len;
|
||||
activations.att.Row(interleaved_idx) + head * seq_len;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations.att_out.Row(interleaved_idx) + head * qkv_dim;
|
||||
|
||||
// Make strided views into the kv cache entries for the current
|
||||
// query and head.
|
||||
KVCache& kv_cache = kv_caches[query_idx];
|
||||
auto& kv_cache = kv_caches[query_idx].kv_cache;
|
||||
const size_t kv_head_offset =
|
||||
layer_idx * cache_layer_size + head_offset;
|
||||
MatPtrT<float> k("k_view", Extents2D(kv_cache.seq_len, qkv_dim));
|
||||
k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset,
|
||||
/*stride=*/cache_pos_size);
|
||||
MatPtrT<float> v("v_view", Extents2D(kv_cache.seq_len, qkv_dim));
|
||||
v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
|
||||
/*stride=*/cache_pos_size);
|
||||
MatPtrT<float> k("k_view", Extents2D(seq_len, qkv_dim));
|
||||
k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride());
|
||||
MatPtrT<float> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
||||
|
||||
// Find the token position in the query and calculate the range
|
||||
// of cache positions to attend to.
|
||||
|
|
@ -211,16 +211,15 @@ void DotSoftmaxWeightedSum(
|
|||
last_pos = prefix_end - 1;
|
||||
}
|
||||
|
||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, div_seq_len, q, k,
|
||||
v, layer_idx, layer, activations, att,
|
||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
|
||||
layer_idx, layer, activations, att,
|
||||
att_out);
|
||||
});
|
||||
}
|
||||
|
||||
// Fills activations.q and writes to KV cache.
|
||||
static HWY_INLINE void ComputeQKV(
|
||||
size_t num_tokens, const QueriesPos& queries_pos,
|
||||
const hwy::Divisor& div_seq_len, const size_t layer_idx,
|
||||
size_t num_tokens, const QueriesPos& queries_pos, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, Activations& activations,
|
||||
const KVCaches& kv_caches, const int flags, MatMulEnv& env) {
|
||||
PROFILER_ZONE("Gen.Attention.QKV");
|
||||
|
|
@ -230,7 +229,6 @@ static HWY_INLINE void ComputeQKV(
|
|||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const size_t kv_heads = layer_config.kv_heads;
|
||||
const size_t cache_layer_size = layer_config.CacheLayerSize();
|
||||
const size_t cache_pos_size = activations.cache_pos_size;
|
||||
|
||||
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
|
||||
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
|
||||
|
|
@ -247,11 +245,10 @@ static HWY_INLINE void ComputeQKV(
|
|||
const size_t query_idx = interleaved_idx % num_queries;
|
||||
const size_t batch_idx = interleaved_idx / num_queries;
|
||||
const size_t cache_pos =
|
||||
div_seq_len.Remainder(queries_pos[query_idx] + batch_idx);
|
||||
const size_t kv_offset =
|
||||
cache_pos * cache_pos_size + layer_idx * cache_layer_size;
|
||||
activations.div_seq_len.Remainder(queries_pos[query_idx] + batch_idx);
|
||||
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
|
||||
kv_caches[query_idx].kv_cache.get() + kv_offset);
|
||||
kv_caches[query_idx].kv_cache.Row(cache_pos) +
|
||||
layer_idx * cache_layer_size);
|
||||
}
|
||||
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
|
||||
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
|
||||
|
|
@ -267,12 +264,11 @@ static HWY_INLINE void ComputeQKV(
|
|||
const size_t query_idx = interleaved_idx % num_queries;
|
||||
const size_t batch_idx = interleaved_idx / num_queries;
|
||||
const size_t pos = queries_pos[query_idx] + batch_idx;
|
||||
const size_t cache_pos = div_seq_len.Remainder(pos);
|
||||
const size_t kv_offset = cache_pos * cache_pos_size +
|
||||
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
|
||||
auto& kv_cache = kv_caches[query_idx].kv_cache;
|
||||
float* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
|
||||
layer_idx * cache_layer_size +
|
||||
head * qkv_dim * 2;
|
||||
KVCache& kv_cache = kv_caches[query_idx];
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
|
||||
// Apply further processing to K.
|
||||
if (layer.key_norm_scale.HasPtr()) {
|
||||
|
|
@ -309,9 +305,9 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
|||
// causal attention, and must be non-null for prefix-LM style attention.
|
||||
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos,
|
||||
const QueriesPos* queries_prefix_end,
|
||||
const hwy::Divisor& div_seq_len, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, Activations& activations,
|
||||
const KVCaches& kv_caches, MatMulEnv& env, int flags) {
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
Activations& activations, const KVCaches& kv_caches,
|
||||
MatMulEnv& env, int flags) {
|
||||
const size_t num_queries = queries_pos.size();
|
||||
HWY_DASSERT(num_queries <= kv_caches.size());
|
||||
|
||||
|
|
@ -330,11 +326,10 @@ void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos,
|
|||
queries_prefix_end = &queries_prefix_end_span;
|
||||
}
|
||||
|
||||
ComputeQKV(num_tokens, queries_pos, div_seq_len, layer_idx, layer,
|
||||
activations, kv_caches, flags, env);
|
||||
DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end,
|
||||
div_seq_len, layer_idx, layer, activations, kv_caches,
|
||||
env.ctx.pools);
|
||||
ComputeQKV(num_tokens, queries_pos, layer_idx, layer, activations, kv_caches,
|
||||
flags, env);
|
||||
DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end, layer_idx,
|
||||
layer, activations, kv_caches, env.ctx.pools);
|
||||
SumHeads(layer, activations, env);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -30,24 +30,23 @@ namespace gcpp {
|
|||
namespace NAMESPACE { \
|
||||
void SingleDotSoftmaxWeightedSum( \
|
||||
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
||||
const hwy::Divisor& div_seq_len, float* HWY_RESTRICT q, \
|
||||
const MatPtrT<float>& k, const MatPtrT<float>& v, size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, const Activations& activations, \
|
||||
float* HWY_RESTRICT att, float* HWY_RESTRICT att_out); \
|
||||
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
const Activations& activations, float* HWY_RESTRICT att, \
|
||||
float* HWY_RESTRICT att_out); \
|
||||
\
|
||||
void DotSoftmaxWeightedSum(const size_t num_tokens, \
|
||||
const QueriesPos& queries_pos, \
|
||||
const QueriesPos& queries_prefix_end, \
|
||||
const hwy::Divisor& div_seq_len, \
|
||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
Activations& activations, \
|
||||
const KVCaches& kv_caches, NestedPools& pools); \
|
||||
\
|
||||
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, \
|
||||
const QueriesPos* queries_prefix_end, \
|
||||
const hwy::Divisor& div_seq_len, const size_t layer_idx, \
|
||||
const LayerWeightsPtrs& layer, Activations& activations, \
|
||||
const KVCaches& kv_caches, MatMulEnv& env, int flags); \
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||
Activations& activations, const KVCaches& kv_caches, \
|
||||
MatMulEnv& env, int flags); \
|
||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||
} // namespace NAMESPACE
|
||||
|
||||
|
|
|
|||
|
|
@ -43,21 +43,15 @@ namespace gcpp {
|
|||
|
||||
// ConversationData constructor implementation
|
||||
ConversationData::ConversationData(const ModelConfig& model_config,
|
||||
size_t prefill_tbatch_size)
|
||||
: model_config_ref_(model_config),
|
||||
prefill_tbatch_size_(prefill_tbatch_size),
|
||||
kv_cache(std::make_unique<KVCache>(model_config, prefill_tbatch_size)),
|
||||
const InferenceArgs& inference_args)
|
||||
: kv_cache(std::make_unique<KVCache>(model_config, inference_args)),
|
||||
abs_pos(0) {}
|
||||
|
||||
// ConversationData copy constructor implementation
|
||||
ConversationData::ConversationData(const ConversationData& other)
|
||||
: model_config_ref_(other.model_config_ref_),
|
||||
prefill_tbatch_size_(other.prefill_tbatch_size_),
|
||||
kv_cache(nullptr),
|
||||
abs_pos(other.abs_pos) {
|
||||
: kv_cache(nullptr), abs_pos(other.abs_pos) {
|
||||
if (other.kv_cache) {
|
||||
kv_cache = std::make_unique<KVCache>(other.kv_cache->Copy(
|
||||
other.model_config_ref_, other.prefill_tbatch_size_));
|
||||
kv_cache = std::make_unique<KVCache>(other.kv_cache->Copy());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -115,7 +109,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
|
|||
LogDebug("Creating initial ConversationData");
|
||||
// Create the initial ConversationData object using make_shared
|
||||
active_conversation = std::make_shared<ConversationData>(
|
||||
model.GetModelConfig(), inference_args.prefill_tbatch_size);
|
||||
model.GetModelConfig(), inference_args);
|
||||
|
||||
LogDebug(
|
||||
"Storing initial ConversationData in conversation_cache[\"default\"]");
|
||||
|
|
|
|||
|
|
@ -31,26 +31,19 @@
|
|||
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Forward declaration - use 'struct' to match definition tag
|
||||
struct KVCache;
|
||||
|
||||
// Struct to hold data for a single conversation thread
|
||||
struct ConversationData {
|
||||
public:
|
||||
ConversationData(const ModelConfig& model_config, size_t prefill_tbatch_size);
|
||||
ConversationData(const ModelConfig& model_config,
|
||||
const InferenceArgs& inference_args);
|
||||
ConversationData(const ConversationData& other);
|
||||
|
||||
private:
|
||||
const ModelConfig& model_config_ref_;
|
||||
size_t prefill_tbatch_size_;
|
||||
|
||||
public:
|
||||
std::unique_ptr<KVCache> kv_cache;
|
||||
size_t abs_pos = 0;
|
||||
};
|
||||
|
|
@ -142,8 +135,7 @@ class GemmaContext {
|
|||
log_msg += "' to prewarmed_cache.";
|
||||
LogDebug(log_msg.c_str());
|
||||
|
||||
// Create a deep copy of the active_conversation.
|
||||
// The ConversationData copy constructor handles the deep copy of KVCache.
|
||||
// Create a deep copy of the active_conversation via copy ctor.
|
||||
auto conversation_copy =
|
||||
std::make_shared<ConversationData>(*active_conversation);
|
||||
|
||||
|
|
@ -176,8 +168,7 @@ class GemmaContext {
|
|||
active_conversation->abs_pos = it->second->abs_pos;
|
||||
// Perform a deep copy of the KVCache from the prewarmed version.
|
||||
active_conversation->kv_cache =
|
||||
std::make_unique<KVCache>(it->second->kv_cache->Copy(
|
||||
model.GetModelConfig(), inference_args.prefill_tbatch_size));
|
||||
std::make_unique<KVCache>(it->second->kv_cache->Copy());
|
||||
LogDebug((log_prefix + "Successfully restored from prewarmed_cache.")
|
||||
.c_str());
|
||||
return;
|
||||
|
|
@ -187,8 +178,8 @@ class GemmaContext {
|
|||
// rewind to initial state.
|
||||
active_conversation->abs_pos = 0;
|
||||
// Replace the cache within the current ConversationData object
|
||||
active_conversation->kv_cache = std::make_unique<KVCache>(
|
||||
model.GetModelConfig(), inference_args.prefill_tbatch_size);
|
||||
active_conversation->kv_cache =
|
||||
std::make_unique<KVCache>(model.GetModelConfig(), inference_args);
|
||||
|
||||
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
|
||||
} else {
|
||||
|
|
@ -206,7 +197,7 @@ class GemmaContext {
|
|||
LogDebug("Creating new conversation");
|
||||
// Create a new ConversationData object using make_shared
|
||||
conversation_cache[name] = std::make_shared<ConversationData>(
|
||||
model.GetModelConfig(), inference_args.prefill_tbatch_size);
|
||||
model.GetModelConfig(), inference_args);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -27,12 +27,8 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// Allow changing pre-allocated kv cache size as a compiler flag
|
||||
#ifndef GEMMA_MAX_SEQLEN
|
||||
#define GEMMA_MAX_SEQLEN 4096
|
||||
#endif // !GEMMA_MAX_SEQLEN
|
||||
|
||||
static constexpr size_t kVocabSize = 256000;
|
||||
static constexpr size_t kMaxSeqLen = 4096;
|
||||
|
||||
static ModelConfig ConfigNoSSM() {
|
||||
ModelConfig config;
|
||||
|
|
@ -69,7 +65,7 @@ static ModelConfig ConfigGemma2_27B() {
|
|||
config.model = Model::GEMMA2_27B;
|
||||
config.model_dim = 4608;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
config.max_seq_len = 8192;
|
||||
LayerConfig layer_config = LayerConfigGemma2_27B(config.model_dim);
|
||||
config.num_layers = 46;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
|
|
@ -97,7 +93,7 @@ static ModelConfig ConfigGemma2_9B() {
|
|||
config.model = Model::GEMMA2_9B;
|
||||
config.model_dim = 3584;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
config.max_seq_len = 8192;
|
||||
LayerConfig layer_config = LayerConfigGemma2_9B(config.model_dim);
|
||||
config.num_layers = 42;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
|
|
@ -125,7 +121,7 @@ static ModelConfig ConfigGemma2_2B() {
|
|||
config.model = Model::GEMMA2_2B;
|
||||
config.model_dim = 2304;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 8192;
|
||||
config.max_seq_len = 8192;
|
||||
LayerConfig layer_config = LayerConfigGemma2_2B(config.model_dim);
|
||||
config.num_layers = 26;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
|
|
@ -152,7 +148,7 @@ static ModelConfig ConfigGemmaTiny() {
|
|||
config.wrapping = PromptWrapping::GEMMA_IT;
|
||||
config.model_dim = 32;
|
||||
config.vocab_size = 32; // at least two f32 vectors
|
||||
config.seq_len = 32;
|
||||
config.max_seq_len = 32;
|
||||
LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim);
|
||||
config.num_layers = 2;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
|
|
@ -188,11 +184,11 @@ static ModelConfig ConfigGriffin2B() {
|
|||
ModelConfig config = ConfigNoSSM();
|
||||
config.display_name = "Griffin2B";
|
||||
config.model = Model::GRIFFIN_2B;
|
||||
// Griffin uses local attention, so GEMMA_MAX_SEQLEN is actually the local
|
||||
// Griffin uses local attention, so max_seq_len is actually the local
|
||||
// attention window.
|
||||
config.model_dim = 2560;
|
||||
config.vocab_size = kVocabSize;
|
||||
config.seq_len = 2048;
|
||||
config.max_seq_len = 2048;
|
||||
LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim);
|
||||
config.num_layers = 26;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
|
|
@ -200,7 +196,8 @@ static ModelConfig ConfigGriffin2B() {
|
|||
config.layer_configs[i].type = LayerAttentionType::kGemma;
|
||||
config.layer_configs[i].griffin_dim = 0;
|
||||
}
|
||||
config.attention_window_sizes = FixedAttentionWindowSizes<26>(config.seq_len);
|
||||
config.attention_window_sizes =
|
||||
FixedAttentionWindowSizes<26>(config.max_seq_len);
|
||||
config.use_local_attention = true;
|
||||
config.final_cap = 0.0f;
|
||||
return config;
|
||||
|
|
@ -238,7 +235,7 @@ static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
|
|||
ModelConfig GetVitConfig(const ModelConfig& config) {
|
||||
ModelConfig vit_config = ConfigNoSSM();
|
||||
vit_config.model_dim = config.vit_config.model_dim;
|
||||
vit_config.seq_len = config.vit_config.seq_len;
|
||||
vit_config.max_seq_len = config.vit_config.seq_len;
|
||||
vit_config.layer_configs = config.vit_config.layer_configs;
|
||||
vit_config.pool_dim = config.vit_config.pool_dim;
|
||||
vit_config.wrapping = config.wrapping;
|
||||
|
|
@ -313,14 +310,14 @@ static ModelConfig ConfigGemma3_1B() {
|
|||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.model_dim = 1152;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
config.max_seq_len = 32 * 1024;
|
||||
LayerConfig layer_config = LayerConfigGemma3_1B_LM(config.model_dim);
|
||||
config.num_layers = 26;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
// interleaved local / global attention
|
||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<26, 6>(
|
||||
{512, 512, 512, 512, 512, config.seq_len});
|
||||
{512, 512, 512, 512, 512, config.max_seq_len});
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
@ -345,14 +342,14 @@ static ModelConfig ConfigGemma3_4B_LM() {
|
|||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.model_dim = 2560;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
config.max_seq_len = 32 * 1024;
|
||||
LayerConfig layer_config = LayerConfigGemma3_4B_LM(config.model_dim);
|
||||
config.num_layers = 34;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
// interleaved local / global attention
|
||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<34, 6>(
|
||||
{1024, 1024, 1024, 1024, 1024, config.seq_len});
|
||||
{1024, 1024, 1024, 1024, 1024, config.max_seq_len});
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
@ -394,14 +391,14 @@ static ModelConfig ConfigGemma3_12B_LM() {
|
|||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.model_dim = 3840;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
config.max_seq_len = 32 * 1024;
|
||||
LayerConfig layer_config = LayerConfigGemma3_12B_LM(config.model_dim);
|
||||
config.num_layers = 48;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
// interleaved local / global attention
|
||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<48, 6>(
|
||||
{1024, 1024, 1024, 1024, 1024, config.seq_len});
|
||||
{1024, 1024, 1024, 1024, 1024, config.max_seq_len});
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
@ -443,14 +440,14 @@ static ModelConfig ConfigGemma3_27B_LM() {
|
|||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.model_dim = 5376;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
config.max_seq_len = 32 * 1024;
|
||||
LayerConfig layer_config = LayerConfigGemma3_27B_LM(config.model_dim);
|
||||
config.num_layers = 62;
|
||||
config.layer_configs = {config.num_layers, layer_config};
|
||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||
// interleaved local / global attention
|
||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<62, 6>(
|
||||
{1024, 1024, 1024, 1024, 1024, config.seq_len});
|
||||
{1024, 1024, 1024, 1024, 1024, config.max_seq_len});
|
||||
return config;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -347,7 +347,7 @@ struct ModelConfig : public IFields {
|
|||
visitor(num_layers);
|
||||
visitor(model_dim);
|
||||
visitor(vocab_size);
|
||||
visitor(seq_len);
|
||||
visitor(max_seq_len);
|
||||
|
||||
visitor(unused_num_tensor_scales);
|
||||
|
||||
|
|
@ -413,7 +413,7 @@ struct ModelConfig : public IFields {
|
|||
return num_heads;
|
||||
}
|
||||
|
||||
size_t CachePosSize() const {
|
||||
size_t KVCacheCols() const {
|
||||
size_t num_layers = layer_configs.size();
|
||||
return num_layers * layer_configs[0].CacheLayerSize();
|
||||
}
|
||||
|
|
@ -435,7 +435,7 @@ struct ModelConfig : public IFields {
|
|||
uint32_t num_layers = 0;
|
||||
uint32_t model_dim = 0;
|
||||
uint32_t vocab_size = 0;
|
||||
uint32_t seq_len = 0;
|
||||
uint32_t max_seq_len = 0;
|
||||
|
||||
// We no longer set nor use this: config_converter is not able to set this,
|
||||
// and only pre-2025 format stores scales, and we do not require advance
|
||||
|
|
|
|||
113
gemma/gemma.cc
113
gemma/gemma.cc
|
|
@ -64,13 +64,12 @@ namespace HWY_NAMESPACE {
|
|||
|
||||
void Attention(LayerAttentionType type, size_t num_tokens,
|
||||
const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end,
|
||||
const hwy::Divisor& div_seq_len, const size_t layer_idx,
|
||||
const QueriesPos& queries_prefix_end, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, Activations& activations,
|
||||
const KVCaches& kv_caches, MatMulEnv& env) {
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, div_seq_len,
|
||||
layer_idx, layer, activations, kv_caches, env,
|
||||
GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, layer_idx,
|
||||
layer, activations, kv_caches, env,
|
||||
/*flags=*/0);
|
||||
} else {
|
||||
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
|
||||
|
|
@ -85,16 +84,16 @@ void Attention(LayerAttentionType type, size_t num_tokens,
|
|||
|
||||
static HWY_NOINLINE void TransformerLayer(
|
||||
const size_t num_tokens, const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
|
||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) {
|
||||
const QueriesPos& queries_prefix_end, const size_t layer_idx,
|
||||
const LayerWeightsPtrs& layer, Activations& activations,
|
||||
const KVCaches& kv_caches, MatMulEnv& env) {
|
||||
const LayerConfig& layer_config = layer.layer_config;
|
||||
|
||||
RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
|
||||
activations.pre_att_rms_out);
|
||||
|
||||
Attention(layer_config.type, num_tokens, queries_pos, queries_prefix_end,
|
||||
div_seq_len, layer_idx, layer, activations, kv_caches, env);
|
||||
layer_idx, layer, activations, kv_caches, env);
|
||||
|
||||
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
|
||||
activations.att_sums);
|
||||
|
|
@ -190,10 +189,9 @@ using QueriesMutablePos = hwy::Span<size_t>;
|
|||
static HWY_NOINLINE void PrefillTBatch(
|
||||
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
|
||||
const hwy::Divisor& div_seq_len, const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
|
||||
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
|
||||
hwy::BitSet4096<>& non_eos) {
|
||||
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights, Activations& activations,
|
||||
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos) {
|
||||
PROFILER_ZONE("Gen.PrefillT");
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_DASSERT(num_queries == queries_pos.size());
|
||||
|
|
@ -265,8 +263,8 @@ static HWY_NOINLINE void PrefillTBatch(
|
|||
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
|
||||
++layer_idx) {
|
||||
TransformerLayer(tbatch_size, single_query_pos, single_query_prefix_end,
|
||||
div_seq_len, layer_idx, *weights.GetLayer(layer_idx),
|
||||
activations, single_kv_cache, env);
|
||||
layer_idx, *weights.GetLayer(layer_idx), activations,
|
||||
single_kv_cache, env);
|
||||
}
|
||||
|
||||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||
|
|
@ -303,10 +301,9 @@ static HWY_NOINLINE void PrefillTBatch(
|
|||
// token-batched `PrefillTBatch`.
|
||||
static HWY_NOINLINE void Transformer(
|
||||
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
|
||||
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights, Activations& activations,
|
||||
const KVCaches& kv_caches, MatMulEnv& env) {
|
||||
const QueriesPos& queries_prefix_end, const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
|
||||
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) {
|
||||
const size_t num_queries = queries_token.size();
|
||||
HWY_DASSERT(num_queries == queries_pos.size());
|
||||
HWY_DASSERT(num_queries == queries_prefix_end.size());
|
||||
|
|
@ -326,8 +323,8 @@ static HWY_NOINLINE void Transformer(
|
|||
|
||||
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
|
||||
TransformerLayer(/*num_tokens=*/1, queries_pos, queries_prefix_end,
|
||||
div_seq_len, layer_idx, *weights.GetLayer(layer_idx),
|
||||
activations, kv_caches, env);
|
||||
layer_idx, *weights.GetLayer(layer_idx), activations,
|
||||
kv_caches, env);
|
||||
|
||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||
runtime_config.activations_observer(queries_pos, layer_idx, activations);
|
||||
|
|
@ -340,10 +337,10 @@ static HWY_NOINLINE void Transformer(
|
|||
static HWY_NOINLINE void PrefillQBatch(
|
||||
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
|
||||
const size_t max_prompt_size, const hwy::Divisor& div_seq_len,
|
||||
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights, Activations& activations,
|
||||
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos) {
|
||||
const size_t max_prompt_size, const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
|
||||
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
|
||||
hwy::BitSet4096<>& non_eos) {
|
||||
PROFILER_ZONE("Gen.Prefill");
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_DASSERT(num_queries == queries_pos.size());
|
||||
|
|
@ -380,8 +377,8 @@ static HWY_NOINLINE void PrefillQBatch(
|
|||
// Do not call DecodeStepT because it computes logits for token
|
||||
// probabilities, which are not required for the prompt tokens.
|
||||
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
|
||||
queries_pos, queries_prefix_end, div_seq_len, config,
|
||||
runtime_config, weights, activations, kv_caches, env);
|
||||
queries_pos, queries_prefix_end, config, runtime_config,
|
||||
weights, activations, kv_caches, env);
|
||||
|
||||
prefill_active.Foreach([&](size_t qi) {
|
||||
const int token = queries_prompt[qi][pos_in_prompt];
|
||||
|
|
@ -393,19 +390,6 @@ static HWY_NOINLINE void PrefillQBatch(
|
|||
} // pos_in_prompt
|
||||
}
|
||||
|
||||
// TODO: inline.
|
||||
void RangeChecks(const ModelConfig& weights_config,
|
||||
size_t& max_generated_tokens, const size_t prompt_size) {
|
||||
if (!weights_config.use_local_attention) {
|
||||
if (max_generated_tokens > weights_config.seq_len) {
|
||||
HWY_WARN("max_generated_tokens %zu > kSeqLen %u, truncating.",
|
||||
max_generated_tokens, weights_config.seq_len);
|
||||
max_generated_tokens = weights_config.seq_len;
|
||||
}
|
||||
}
|
||||
HWY_ASSERT(prompt_size > 0);
|
||||
}
|
||||
|
||||
// Also writes the token to activations.gen_tokens for subsequent DecodeStepT,
|
||||
// and updates `non_eos` if the query is at the end of its sequence.
|
||||
static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token,
|
||||
|
|
@ -432,17 +416,17 @@ static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token,
|
|||
static void DecodeStepT(
|
||||
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesMutablePos& queries_mutable_pos,
|
||||
const QueriesPos& queries_prefix_end, const hwy::Divisor div_seq_len,
|
||||
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
||||
const ModelWeightsPtrs& weights, const SampleFunc& sample_token,
|
||||
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
|
||||
hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) {
|
||||
const QueriesPos& queries_prefix_end, const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
|
||||
const SampleFunc& sample_token, Activations& activations,
|
||||
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos,
|
||||
TimingInfo& timing_info) {
|
||||
const size_t num_queries = queries_prompt.size();
|
||||
HWY_DASSERT(num_queries == activations.x.Rows());
|
||||
|
||||
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
|
||||
queries_mutable_pos, queries_prefix_end, div_seq_len, config,
|
||||
runtime_config, weights, activations, kv_caches, env);
|
||||
queries_mutable_pos, queries_prefix_end, config, runtime_config,
|
||||
weights, activations, kv_caches, env);
|
||||
|
||||
RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
|
||||
|
||||
|
|
@ -530,6 +514,7 @@ static void GenerateT(
|
|||
size_t max_prompt_size = 0;
|
||||
bool all_prefix_end_are_zero = true;
|
||||
size_t prefill_tokens = 0;
|
||||
const size_t seq_len = kv_caches[0].SeqLen();
|
||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||
const PromptTokens& prompt = queries_prompt[qi];
|
||||
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
|
||||
|
|
@ -542,9 +527,12 @@ static void GenerateT(
|
|||
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
|
||||
|
||||
all_prefix_end_are_zero &= queries_prefix_end[qi] == 0;
|
||||
}
|
||||
|
||||
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
||||
// We use a single divisor, so all sequence lengths must be the same.
|
||||
HWY_ASSERT(kv_caches[qi].SeqLen() == seq_len);
|
||||
}
|
||||
HWY_ASSERT(prefill_tokens < seq_len);
|
||||
activations.div_seq_len = hwy::Divisor(static_cast<uint32_t>(seq_len));
|
||||
|
||||
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
|
||||
// qi loops anyway.
|
||||
|
|
@ -555,13 +543,12 @@ static void GenerateT(
|
|||
if ((num_queries > max_prompt_size) && all_prefix_end_are_zero) {
|
||||
activations.SetBatchSize(num_queries); // required before PrefillQBatch
|
||||
PrefillQBatch(query_idx_start, queries_prompt, queries_mutable_pos,
|
||||
queries_prefix_end, max_prompt_size, div_seq_len, config,
|
||||
runtime_config, weights, activations, kv_caches, env,
|
||||
non_eos);
|
||||
queries_prefix_end, max_prompt_size, config, runtime_config,
|
||||
weights, activations, kv_caches, env, non_eos);
|
||||
} else {
|
||||
PrefillTBatch(query_idx_start, queries_prompt, queries_mutable_pos,
|
||||
queries_prefix_end, div_seq_len, config, runtime_config,
|
||||
weights, activations, kv_caches, env, non_eos);
|
||||
queries_prefix_end, config, runtime_config, weights,
|
||||
activations, kv_caches, env, non_eos);
|
||||
activations.SetBatchSize(num_queries); // Restore after PrefillTBatch.
|
||||
}
|
||||
HWY_DASSERT(num_queries == non_eos.Count());
|
||||
|
|
@ -579,7 +566,11 @@ static void GenerateT(
|
|||
}
|
||||
|
||||
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
||||
RangeChecks(config, max_gen_steps, max_prompt_size);
|
||||
if (prefill_tokens + max_gen_steps > seq_len) {
|
||||
HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.",
|
||||
prefill_tokens, max_gen_steps, seq_len);
|
||||
max_gen_steps = seq_len - prefill_tokens;
|
||||
}
|
||||
|
||||
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
|
||||
|
||||
|
|
@ -587,8 +578,8 @@ static void GenerateT(
|
|||
timing_info.generate_start = hwy::platform::Now();
|
||||
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
||||
DecodeStepT(query_idx_start, queries_prompt, queries_mutable_pos,
|
||||
queries_prefix_end, div_seq_len, config, runtime_config,
|
||||
weights, sample_token, activations, kv_caches, env, non_eos,
|
||||
queries_prefix_end, config, runtime_config, weights,
|
||||
sample_token, activations, kv_caches, env, non_eos,
|
||||
timing_info);
|
||||
}
|
||||
timing_info.NotifyGenerateDone();
|
||||
|
|
@ -661,10 +652,11 @@ void GenerateImageTokensT(const ModelConfig& config,
|
|||
HWY_ABORT("Model does not support generating image tokens.");
|
||||
}
|
||||
RuntimeConfig prefill_runtime_config = runtime_config;
|
||||
ModelConfig vit_config = GetVitConfig(config);
|
||||
const ModelConfig vit_config = GetVitConfig(config);
|
||||
const size_t num_tokens = vit_config.max_seq_len;
|
||||
prefill_runtime_config.prefill_tbatch_size =
|
||||
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
|
||||
Activations prefill_activations(vit_config, vit_config.seq_len, env.row_ptrs);
|
||||
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
|
||||
Activations prefill_activations(vit_config, num_tokens, env.row_ptrs);
|
||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
||||
prefill_activations, env);
|
||||
|
|
@ -692,7 +684,8 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
|||
reader_(loader.weights),
|
||||
model_(reader_, loader.tokenizer, loader.wrapping),
|
||||
weights_(model_.Config()),
|
||||
chat_template_(model_.Tokenizer(), model_.Config().model) {
|
||||
chat_template_(model_.Tokenizer(), model_.Config().model),
|
||||
inference_(inference) {
|
||||
weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_,
|
||||
env.ctx.pools.Pool());
|
||||
reader_.CloseFile();
|
||||
|
|
|
|||
|
|
@ -117,6 +117,7 @@ class Gemma {
|
|||
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
|
||||
const ModelWeightsPtrs& Weights() const { return weights_; }
|
||||
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
|
||||
const InferenceArgs& Inference() const { return inference_; }
|
||||
|
||||
void Save(const Path& weights_path, hwy::ThreadPool& pool) const;
|
||||
|
||||
|
|
@ -159,6 +160,7 @@ class Gemma {
|
|||
std::vector<MatOwner> mat_owners_;
|
||||
ModelWeightsPtrs weights_;
|
||||
GemmaChatTemplate chat_template_;
|
||||
InferenceArgs inference_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -35,11 +35,6 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// Allow changing k parameter of `SampleTopK` as a compiler flag
|
||||
#ifndef GEMMA_TOPK
|
||||
#define GEMMA_TOPK 1
|
||||
#endif // !GEMMA_TOPK
|
||||
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
LoaderArgs(const std::string& tokenizer_path,
|
||||
|
|
@ -115,6 +110,7 @@ using ActivationsObserverFunc =
|
|||
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
|
||||
|
||||
// RuntimeConfig holds configuration for a single generation run.
|
||||
// TODO: move into InferenceArgs, use that directly.
|
||||
struct RuntimeConfig {
|
||||
// If not empty, batch_stream_token is called for each token in the batch,
|
||||
// instead of stream_token.
|
||||
|
|
@ -137,7 +133,7 @@ struct RuntimeConfig {
|
|||
// Sampling-related parameters.
|
||||
float temperature; // Temperature for sampling.
|
||||
|
||||
size_t top_k = GEMMA_TOPK; // Top-k for sampling.
|
||||
size_t top_k = 1; // Top-k for sampling.
|
||||
std::mt19937* gen; // Random number generator used for sampling.
|
||||
|
||||
int verbosity; // Controls verbosity of printed messages.
|
||||
|
|
@ -170,6 +166,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
|
||||
int verbosity;
|
||||
|
||||
size_t seq_len;
|
||||
size_t max_generated_tokens;
|
||||
|
||||
size_t prefill_tbatch_size;
|
||||
|
|
@ -192,6 +189,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
"developer/debug info).\n Default = 1.",
|
||||
1); // Changed verbosity level to 1 since it's user-facing
|
||||
|
||||
visitor(seq_len, "seq_len", size_t{2048},
|
||||
"Sequence length, capped by ModelConfig.max_seq_len.");
|
||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||
"Maximum number of tokens to generate.");
|
||||
|
||||
|
|
|
|||
|
|
@ -15,21 +15,25 @@
|
|||
|
||||
#include "gemma/kv_cache.h"
|
||||
|
||||
#include <algorithm> // std::copy
|
||||
#include <stddef.h>
|
||||
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "util/mat.h" // ZeroInit
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // ZeroBytes
|
||||
#include "hwy/base.h" // HWY_MAX
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void KVCache::ZeroGriffinCache() {
|
||||
if (griffin_layers == 0) return;
|
||||
if (conv1d_cache.Rows() == 0) return;
|
||||
ZeroInit(conv1d_cache);
|
||||
ZeroInit(rglru_cache);
|
||||
}
|
||||
|
||||
static size_t GriffinLayers(const ModelConfig& config) {
|
||||
return config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock);
|
||||
}
|
||||
|
||||
static size_t GriffinConv1dCols(const ModelConfig& config) {
|
||||
size_t conv1d_width = 0;
|
||||
for (const auto& layer_config : config.layer_configs) {
|
||||
|
|
@ -40,43 +44,41 @@ static size_t GriffinConv1dCols(const ModelConfig& config) {
|
|||
return conv1d_width * config.model_dim;
|
||||
}
|
||||
|
||||
// prefill_tbatch_size is the maximum number of tokens from one query to
|
||||
// prefill at a time.
|
||||
KVCache::KVCache(const ModelConfig& config, size_t prefill_tbatch_size)
|
||||
: griffin_layers(
|
||||
config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock)),
|
||||
conv1d_cache("conv1d_cache",
|
||||
Extents2D(griffin_layers, GriffinConv1dCols(config)),
|
||||
MatPadding::kOdd),
|
||||
rglru_cache("rglru_cache", Extents2D(griffin_layers, config.model_dim),
|
||||
MatPadding::kOdd) {
|
||||
// TODO: move to MatStorageT.
|
||||
const size_t size_cache_pos = config.CachePosSize();
|
||||
if (size_cache_pos != 0) {
|
||||
// Allocate more so that prefill can always access one batch, even if
|
||||
// near the end of the sequence.
|
||||
seq_len = config.seq_len + prefill_tbatch_size;
|
||||
kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
||||
// Number of rows for KV cache. Note that both rows and cols are u32, and
|
||||
// the total number of elements can exceed 2^32.
|
||||
static size_t CappedSeqLen(const ModelConfig& config,
|
||||
const InferenceArgs& inference_args) {
|
||||
if (inference_args.seq_len > config.max_seq_len) {
|
||||
HWY_WARN("Capping seq_len %zu to config.max_seq_len %u.",
|
||||
inference_args.seq_len, config.max_seq_len);
|
||||
return config.max_seq_len;
|
||||
}
|
||||
return inference_args.seq_len;
|
||||
}
|
||||
|
||||
KVCache KVCache::Copy(const ModelConfig& weights_config,
|
||||
size_t prefill_tbatch_size) {
|
||||
KVCache copy(weights_config, prefill_tbatch_size);
|
||||
KVCache::KVCache(const Extents2D& conv1d_extents,
|
||||
const Extents2D& rglru_extents, const Extents2D& kv_extents)
|
||||
: conv1d_cache("conv1d_cache", conv1d_extents, MatPadding::kOdd),
|
||||
rglru_cache("rglru_cache", rglru_extents, MatPadding::kOdd),
|
||||
kv_cache("kv", kv_extents, MatPadding::kOdd) {}
|
||||
|
||||
const size_t size_cache_pos = weights_config.CachePosSize();
|
||||
if (size_cache_pos != 0) {
|
||||
std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len,
|
||||
copy.kv_cache.get());
|
||||
}
|
||||
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args)
|
||||
: KVCache(Extents2D(GriffinLayers(config), GriffinConv1dCols(config)),
|
||||
Extents2D(GriffinLayers(config), config.model_dim),
|
||||
Extents2D(CappedSeqLen(config, inference_args),
|
||||
config.KVCacheCols())) {}
|
||||
|
||||
if (conv1d_cache.HasPtr()) {
|
||||
KVCache KVCache::Copy() {
|
||||
KVCache copy(conv1d_cache.Extents(), rglru_cache.Extents(),
|
||||
kv_cache.Extents());
|
||||
|
||||
if (conv1d_cache.Rows() != 0) {
|
||||
CopyMat(conv1d_cache, copy.conv1d_cache);
|
||||
}
|
||||
if (rglru_cache.HasPtr()) {
|
||||
CopyMat(rglru_cache, copy.rglru_cache);
|
||||
}
|
||||
|
||||
CopyMat(kv_cache, copy.kv_cache);
|
||||
|
||||
return copy;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,29 +19,34 @@
|
|||
#include <stddef.h>
|
||||
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
struct KVCache {
|
||||
KVCache(const ModelConfig& weights_config, size_t prefill_tbatch_size);
|
||||
KVCache(const ModelConfig& config, const InferenceArgs& inference_args);
|
||||
|
||||
// Returns a deep copy of the KVCache.
|
||||
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
|
||||
// Returns a deep copy of the KVCache. Use explicit function instead of
|
||||
// copy ctor to make the cost explicit.
|
||||
KVCache Copy();
|
||||
|
||||
size_t griffin_layers = 0;
|
||||
// griffin_layers, griffin_conv1d_cols * config.model_dim
|
||||
MatStorageT<float> conv1d_cache;
|
||||
MatStorageT<float> rglru_cache; // griffin_layers, config.model_dim
|
||||
// Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache
|
||||
// and rglru_cache.
|
||||
void ZeroGriffinCache();
|
||||
|
||||
size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size
|
||||
size_t SeqLen() const { return kv_cache.Rows(); }
|
||||
|
||||
// seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2
|
||||
hwy::AlignedFreeUniquePtr<float[]> kv_cache;
|
||||
// [griffin_layers, griffin_conv1d_cols * model_dim]
|
||||
MatStorageT<float> conv1d_cache;
|
||||
MatStorageT<float> rglru_cache; // [griffin_layers, model_dim]
|
||||
|
||||
MatStorageT<float> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
||||
|
||||
private:
|
||||
// For use by other ctor and Copy()
|
||||
KVCache(const Extents2D& conv1d_extents, const Extents2D& rglru_extents,
|
||||
const Extents2D& kv_extents);
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -256,7 +256,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
MatMulEnv env(MakeMatMulEnv(threading));
|
||||
if (inference.verbosity >= 2) env.print_best = true;
|
||||
const Gemma gemma(loader, inference, env);
|
||||
KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
|
||||
KVCache kv_cache(gemma.GetModelConfig(), inference);
|
||||
|
||||
if (inference.verbosity >= 1) {
|
||||
std::string instructions =
|
||||
|
|
|
|||
14
gemma/vit.cc
14
gemma/vit.cc
|
|
@ -68,7 +68,8 @@ class VitAttention {
|
|||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const size_t heads = layer_config_.heads;
|
||||
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
|
||||
const size_t seq_len = activations_.seq_len;
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(activations_.div_seq_len.GetDivisor());
|
||||
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
||||
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
||||
|
||||
|
|
@ -124,7 +125,8 @@ class VitAttention {
|
|||
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||
const size_t heads = layer_config_.heads;
|
||||
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
|
||||
const size_t seq_len = activations_.seq_len;
|
||||
const size_t seq_len =
|
||||
static_cast<size_t>(activations_.div_seq_len.GetDivisor());
|
||||
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
||||
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
||||
|
||||
|
|
@ -138,7 +140,7 @@ class VitAttention {
|
|||
activations_.q.Row(token) + head * 3 * qkv_dim;
|
||||
MulByConst(query_scale, q, qkv_dim);
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations_.att.Row(token) + head * activations_.seq_len;
|
||||
activations_.att.Row(token) + head * seq_len;
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
float* HWY_RESTRICT k =
|
||||
activations_.q.Row(i) + head * 3 * qkv_dim + qkv_dim;
|
||||
|
|
@ -275,7 +277,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
MatMulEnv& env) {
|
||||
const size_t model_dim = model_config.vit_config.model_dim;
|
||||
const size_t patch_width = model_config.vit_config.patch_width;
|
||||
const size_t seq_len = model_config.vit_config.seq_len;
|
||||
const size_t num_tokens = model_config.vit_config.seq_len;
|
||||
const size_t patch_size = patch_width * patch_width * 3;
|
||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
|
||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
|
||||
|
|
@ -285,9 +287,9 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
|
||||
// image_patches is (256, 14 * 14 * 3)
|
||||
// Must be padded, see `DoDecompressA`.
|
||||
MatStorageT<float> image_patches("patches", Extents2D(seq_len, patch_size),
|
||||
MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_size),
|
||||
MatPadding::kOdd);
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
image.GetPatch(i, image_patches.Row(i));
|
||||
}
|
||||
CallMatMul(image_patches, weights.vit_img_embedding_kernel,
|
||||
|
|
|
|||
|
|
@ -161,7 +161,7 @@ PYBIND11_MODULE(configs, py_module) {
|
|||
.def_readwrite("num_layers", &ModelConfig::num_layers)
|
||||
.def_readwrite("model_dim", &ModelConfig::model_dim)
|
||||
.def_readwrite("vocab_size", &ModelConfig::vocab_size)
|
||||
.def_readwrite("seq_len", &ModelConfig::seq_len)
|
||||
.def_readwrite("max_seq_len", &ModelConfig::max_seq_len)
|
||||
// Skip `unused_num_tensor_scales`.
|
||||
.def_readwrite("att_cap", &ModelConfig::att_cap)
|
||||
.def_readwrite("final_cap", &ModelConfig::final_cap)
|
||||
|
|
|
|||
Loading…
Reference in New Issue