mirror of https://github.com/google/gemma.cpp.git
MatPtr-ify KV, shared div_seq_len, --seq_len flag
PiperOrigin-RevId: 770194455
This commit is contained in:
parent
bd98b43cea
commit
c027a45a2e
|
|
@ -447,6 +447,7 @@ cc_library(
|
||||||
hdrs = ["gemma/kv_cache.h"],
|
hdrs = ["gemma/kv_cache.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":configs",
|
":configs",
|
||||||
|
":gemma_args",
|
||||||
":mat",
|
":mat",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -101,18 +101,6 @@ directly.
|
||||||
|
|
||||||
For other models, `gemma_export_main.py` is not yet open sourced.
|
For other models, `gemma_export_main.py` is not yet open sourced.
|
||||||
|
|
||||||
## Compile-Time Flags (Advanced)
|
|
||||||
|
|
||||||
There are several compile-time flags to be aware of (note these may or may not
|
|
||||||
be exposed to the build system):
|
|
||||||
|
|
||||||
- `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV
|
|
||||||
Cache. The default is 4096 tokens but can be overridden. This is not exposed
|
|
||||||
through `CMakeLists.txt` yet.
|
|
||||||
|
|
||||||
In the medium term this will likely be deprecated in favor of handling options
|
|
||||||
at runtime - dynamically resizing the KV cache as needed.
|
|
||||||
|
|
||||||
## Using gemma.cpp as a Library (Advanced)
|
## Using gemma.cpp as a Library (Advanced)
|
||||||
|
|
||||||
Unless you are doing lower level implementations or research, from an
|
Unless you are doing lower level implementations or research, from an
|
||||||
|
|
@ -165,7 +153,7 @@ constrained decoding type of use cases where you want to force the generation to
|
||||||
fit a grammar. If you're not doing this, you can send an empty lambda or
|
fit a grammar. If you're not doing this, you can send an empty lambda or
|
||||||
`std::function` as a no-op which is what `run.cc` does.
|
`std::function` as a no-op which is what `run.cc` does.
|
||||||
|
|
||||||
### `Transformer()` implements the inference (i.e. `forward()` method in PyTorch or Jax) computation of the neural network
|
### `Transformer()` implements inference (i.e. `forward()` in PyTorch or Jax)
|
||||||
|
|
||||||
For high-level applications, you might only call `model.Generate()` and never
|
For high-level applications, you might only call `model.Generate()` and never
|
||||||
interact directly with the neural network, but if you're doing something a bit
|
interact directly with the neural network, but if you're doing something a bit
|
||||||
|
|
|
||||||
|
|
@ -322,9 +322,10 @@ model (any model with a `-pt` suffix).
|
||||||
|
|
||||||
**What sequence lengths are supported?**
|
**What sequence lengths are supported?**
|
||||||
|
|
||||||
See `seq_len` in `configs.cc`. For the Gemma 3 models larger than 1B, this is
|
See `max_seq_len` in `configs.cc` and `InferenceArgs.seq_len`. For the Gemma 3
|
||||||
typically 32K but 128K would also work given enough RAM. Note that long
|
models larger than 1B, this is typically 32K but 128K would also work given
|
||||||
sequences will be slow due to the quadratic cost of attention.
|
enough RAM. Note that long sequences will be slow due to the quadratic cost of
|
||||||
|
attention.
|
||||||
|
|
||||||
**How do I convert my fine-tune to a `.sbs` compressed model file?**
|
**How do I convert my fine-tune to a `.sbs` compressed model file?**
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,7 @@ int BenchmarkSummary(GemmaEnv& env, const Path& text) {
|
||||||
|
|
||||||
int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
||||||
size_t batch_tokens) {
|
size_t batch_tokens) {
|
||||||
|
const Gemma& gemma = *env.GetGemma();
|
||||||
std::string input = ReadFileToString(text);
|
std::string input = ReadFileToString(text);
|
||||||
std::vector<int> prompt = env.Tokenize(input);
|
std::vector<int> prompt = env.Tokenize(input);
|
||||||
std::cout << "Number of input tokens: " << prompt.size() << "\n";
|
std::cout << "Number of input tokens: " << prompt.size() << "\n";
|
||||||
|
|
@ -73,8 +74,7 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
||||||
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
||||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||||
prompt.begin() + pos + num_tokens);
|
prompt.begin() + pos + num_tokens);
|
||||||
KVCache kv_cache(env.GetGemma()->GetModelConfig(),
|
KVCache kv_cache(gemma.GetModelConfig(), gemma.Inference());
|
||||||
env.MutableConfig().prefill_tbatch_size);
|
|
||||||
float entropy = ComputeCrossEntropy(
|
float entropy = ComputeCrossEntropy(
|
||||||
*env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
*env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
||||||
total_entropy += entropy;
|
total_entropy += entropy;
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
: env_(MakeMatMulEnv(threading)), gemma_(loader, inference, env_) {
|
: env_(MakeMatMulEnv(threading)), gemma_(loader, inference, env_) {
|
||||||
const ModelConfig& config = gemma_.GetModelConfig();
|
const ModelConfig& config = gemma_.GetModelConfig();
|
||||||
// Only allocate one for starters because GenerateBatch might not be called.
|
// Only allocate one for starters because GenerateBatch might not be called.
|
||||||
kv_caches_.push_back(KVCache(config, inference.prefill_tbatch_size));
|
kv_caches_.push_back(KVCache(config, inference));
|
||||||
|
|
||||||
if (inference.verbosity >= 2) {
|
if (inference.verbosity >= 2) {
|
||||||
ShowConfig(loader, threading, inference, config);
|
ShowConfig(loader, threading, inference, config);
|
||||||
|
|
@ -135,8 +135,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
|
|
||||||
// Ensure we have at least one KVCache per query.
|
// Ensure we have at least one KVCache per query.
|
||||||
while (kv_caches_.size() < num_queries) {
|
while (kv_caches_.size() < num_queries) {
|
||||||
kv_caches_.push_back(
|
kv_caches_.push_back(KVCache(gemma_.GetModelConfig(), gemma_.Inference()));
|
||||||
KVCache(gemma_.GetModelConfig(), runtime_config_.prefill_tbatch_size));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
|
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ int main(int argc, char** argv) {
|
||||||
// Instantiate model and KV Cache
|
// Instantiate model and KV Cache
|
||||||
gcpp::MatMulEnv env(MakeMatMulEnv(threading));
|
gcpp::MatMulEnv env(MakeMatMulEnv(threading));
|
||||||
gcpp::Gemma gemma(loader, inference, env);
|
gcpp::Gemma gemma(loader, inference, env);
|
||||||
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
|
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference);
|
||||||
size_t generated = 0;
|
size_t generated = 0;
|
||||||
|
|
||||||
// Initialize random number generator
|
// Initialize random number generator
|
||||||
|
|
|
||||||
|
|
@ -35,12 +35,9 @@ class SimplifiedGemma {
|
||||||
SimplifiedGemma(const gcpp::LoaderArgs& loader,
|
SimplifiedGemma(const gcpp::LoaderArgs& loader,
|
||||||
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
|
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
|
||||||
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
|
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
|
||||||
: loader_(loader),
|
: env_(MakeMatMulEnv(threading)),
|
||||||
threading_(threading),
|
gemma_(loader, inference, env_),
|
||||||
inference_(inference),
|
kv_cache_(gemma_.GetModelConfig(), inference) {
|
||||||
env_(MakeMatMulEnv(threading_)),
|
|
||||||
gemma_(loader_, inference_, env_),
|
|
||||||
kv_cache_(gemma_.GetModelConfig(), inference_.prefill_tbatch_size) {
|
|
||||||
// Initialize random number generator
|
// Initialize random number generator
|
||||||
std::random_device rd;
|
std::random_device rd;
|
||||||
gen_.seed(rd());
|
gen_.seed(rd());
|
||||||
|
|
@ -91,9 +88,6 @@ class SimplifiedGemma {
|
||||||
~SimplifiedGemma() = default;
|
~SimplifiedGemma() = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
gcpp::LoaderArgs loader_;
|
|
||||||
gcpp::ThreadingArgs threading_;
|
|
||||||
gcpp::InferenceArgs inference_;
|
|
||||||
gcpp::MatMulEnv env_;
|
gcpp::MatMulEnv env_;
|
||||||
gcpp::Gemma gemma_;
|
gcpp::Gemma gemma_;
|
||||||
gcpp::KVCache kv_cache_;
|
gcpp::KVCache kv_cache_;
|
||||||
|
|
|
||||||
|
|
@ -46,8 +46,7 @@ struct Activations {
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||||
: weights_config(config),
|
: weights_config(config),
|
||||||
layer_config(config.layer_configs[0]),
|
layer_config(config.layer_configs[0]),
|
||||||
seq_len(config.seq_len),
|
div_seq_len(static_cast<uint32_t>(config.max_seq_len)),
|
||||||
cache_pos_size(config.CachePosSize()),
|
|
||||||
is_griffin(config.model == Model::GRIFFIN_2B),
|
is_griffin(config.model == Model::GRIFFIN_2B),
|
||||||
query_scale(ChooseQueryScale(config)),
|
query_scale(ChooseQueryScale(config)),
|
||||||
|
|
||||||
|
|
@ -64,7 +63,9 @@ struct Activations {
|
||||||
|
|
||||||
pre_att_rms_out("pre_att_rms_out",
|
pre_att_rms_out("pre_att_rms_out",
|
||||||
Extents2D(batch_size, config.model_dim), pad_),
|
Extents2D(batch_size, config.model_dim), pad_),
|
||||||
att("att", Extents2D(batch_size, layer_config.heads * config.seq_len),
|
att("att",
|
||||||
|
Extents2D(batch_size,
|
||||||
|
layer_config.heads * div_seq_len.GetDivisor()),
|
||||||
pad_),
|
pad_),
|
||||||
att_out(
|
att_out(
|
||||||
"att_out",
|
"att_out",
|
||||||
|
|
@ -141,10 +142,14 @@ struct Activations {
|
||||||
gen_tokens.resize(batch_size);
|
gen_tokens.resize(batch_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsGlobalLayer(size_t layer_idx) const {
|
||||||
|
return weights_config.attention_window_sizes[layer_idx] ==
|
||||||
|
div_seq_len.GetDivisor();
|
||||||
|
}
|
||||||
|
|
||||||
const ModelConfig& weights_config;
|
const ModelConfig& weights_config;
|
||||||
const LayerConfig& layer_config;
|
const LayerConfig& layer_config;
|
||||||
size_t seq_len;
|
hwy::Divisor div_seq_len;
|
||||||
size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT.
|
|
||||||
bool is_griffin;
|
bool is_griffin;
|
||||||
float query_scale;
|
float query_scale;
|
||||||
const Extents2D none_ = Extents2D();
|
const Extents2D none_ = Extents2D();
|
||||||
|
|
|
||||||
|
|
@ -70,9 +70,7 @@ static void PositionalEncodingQK(U* qk, const size_t qkv_dim,
|
||||||
const PostQKType& post_qk = layer.layer_config.post_qk;
|
const PostQKType& post_qk = layer.layer_config.post_qk;
|
||||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||||
const float* inv_timescale = activations.inv_timescale.PackedScale1();
|
const float* inv_timescale = activations.inv_timescale.PackedScale1();
|
||||||
bool is_global_layer =
|
bool is_global_layer = activations.IsGlobalLayer(layer_idx);
|
||||||
activations.weights_config.attention_window_sizes[layer_idx] ==
|
|
||||||
activations.seq_len;
|
|
||||||
// TODO: add a config flag instead of hardcoding the model.
|
// TODO: add a config flag instead of hardcoding the model.
|
||||||
if (is_global_layer && IsVLM(activations.weights_config.model)) {
|
if (is_global_layer && IsVLM(activations.weights_config.model)) {
|
||||||
inv_timescale = activations.inv_timescale_global.PackedScale1();
|
inv_timescale = activations.inv_timescale_global.PackedScale1();
|
||||||
|
|
@ -116,13 +114,15 @@ static HWY_INLINE void WeightedSumV(const size_t start_pos,
|
||||||
// Calculates the attention outputs for a single q.
|
// Calculates the attention outputs for a single q.
|
||||||
void SingleDotSoftmaxWeightedSum(
|
void SingleDotSoftmaxWeightedSum(
|
||||||
const size_t pos, const size_t start_pos, const size_t last_pos,
|
const size_t pos, const size_t start_pos, const size_t last_pos,
|
||||||
const hwy::Divisor& div_seq_len, float* HWY_RESTRICT q,
|
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v,
|
||||||
const MatPtrT<float>& k, const MatPtrT<float>& v, const size_t layer_idx,
|
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||||
const LayerWeightsPtrs& layer, const Activations& activations,
|
const Activations& activations, float* HWY_RESTRICT att,
|
||||||
float* HWY_RESTRICT att, float* HWY_RESTRICT att_out) {
|
float* HWY_RESTRICT att_out) {
|
||||||
const size_t qkv_dim = layer.layer_config.qkv_dim;
|
const size_t qkv_dim = layer.layer_config.qkv_dim;
|
||||||
const float att_cap = activations.weights_config.att_cap;
|
const float att_cap = activations.weights_config.att_cap;
|
||||||
const float query_scale = activations.query_scale;
|
const float query_scale = activations.query_scale;
|
||||||
|
const size_t seq_len =
|
||||||
|
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
||||||
|
|
||||||
// Apply rope and scaling to Q.
|
// Apply rope and scaling to Q.
|
||||||
if (layer.query_norm_scale.HasPtr()) {
|
if (layer.query_norm_scale.HasPtr()) {
|
||||||
|
|
@ -133,15 +133,14 @@ void SingleDotSoftmaxWeightedSum(
|
||||||
PositionalEncodingQK(q, qkv_dim, layer_idx, layer, activations, pos,
|
PositionalEncodingQK(q, qkv_dim, layer_idx, layer, activations, pos,
|
||||||
query_scale);
|
query_scale);
|
||||||
|
|
||||||
QDotK(start_pos, last_pos, div_seq_len, q, k, att);
|
QDotK(start_pos, last_pos, activations.div_seq_len, q, k, att);
|
||||||
|
|
||||||
// SoftMax with optional SoftCap yields "probabilities" in att.
|
// SoftMax with optional SoftCap yields "probabilities" in att.
|
||||||
const size_t att_len =
|
const size_t att_len = HWY_MIN(last_pos + 1, seq_len);
|
||||||
HWY_MIN(last_pos + 1, static_cast<size_t>(div_seq_len.GetDivisor()));
|
|
||||||
MaybeLogitsSoftCap(att_cap, att, att_len);
|
MaybeLogitsSoftCap(att_cap, att, att_len);
|
||||||
Softmax(att, att_len);
|
Softmax(att, att_len);
|
||||||
|
|
||||||
WeightedSumV(start_pos, last_pos, div_seq_len, att, v, att_out);
|
WeightedSumV(start_pos, last_pos, activations.div_seq_len, att, v, att_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The attention window usually starts at 0 unless `pos` is larger than
|
// The attention window usually starts at 0 unless `pos` is larger than
|
||||||
|
|
@ -152,11 +151,13 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config,
|
||||||
return pos - HWY_MIN(att_window_size - 1, pos);
|
return pos - HWY_MIN(att_window_size - 1, pos);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DotSoftmaxWeightedSum(
|
void DotSoftmaxWeightedSum(const size_t num_tokens,
|
||||||
const size_t num_tokens, const QueriesPos& queries_pos,
|
const QueriesPos& queries_pos,
|
||||||
const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
|
const QueriesPos& queries_prefix_end,
|
||||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
const size_t layer_idx,
|
||||||
Activations& activations, const KVCaches& kv_caches, NestedPools& pools) {
|
const LayerWeightsPtrs& layer,
|
||||||
|
Activations& activations, const KVCaches& kv_caches,
|
||||||
|
NestedPools& pools) {
|
||||||
const size_t num_queries = queries_pos.size();
|
const size_t num_queries = queries_pos.size();
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
PROFILER_ZONE("Gen.Attention.DotSoftmax");
|
||||||
|
|
@ -166,7 +167,8 @@ void DotSoftmaxWeightedSum(
|
||||||
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
|
const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads;
|
||||||
|
|
||||||
const size_t cache_layer_size = layer_config.CacheLayerSize();
|
const size_t cache_layer_size = layer_config.CacheLayerSize();
|
||||||
const size_t cache_pos_size = activations.cache_pos_size;
|
const size_t seq_len =
|
||||||
|
static_cast<size_t>(activations.div_seq_len.GetDivisor());
|
||||||
|
|
||||||
// For each head (token, query), compute Q.K, softmax, and weighted V.
|
// For each head (token, query), compute Q.K, softmax, and weighted V.
|
||||||
// TODO: nested parallelism to use more threads.
|
// TODO: nested parallelism to use more threads.
|
||||||
|
|
@ -183,21 +185,19 @@ void DotSoftmaxWeightedSum(
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations.q.Row(interleaved_idx) + head * qkv_dim;
|
activations.q.Row(interleaved_idx) + head * qkv_dim;
|
||||||
float* HWY_RESTRICT att =
|
float* HWY_RESTRICT att =
|
||||||
activations.att.Row(interleaved_idx) + head * activations.seq_len;
|
activations.att.Row(interleaved_idx) + head * seq_len;
|
||||||
float* HWY_RESTRICT att_out =
|
float* HWY_RESTRICT att_out =
|
||||||
activations.att_out.Row(interleaved_idx) + head * qkv_dim;
|
activations.att_out.Row(interleaved_idx) + head * qkv_dim;
|
||||||
|
|
||||||
// Make strided views into the kv cache entries for the current
|
// Make strided views into the kv cache entries for the current
|
||||||
// query and head.
|
// query and head.
|
||||||
KVCache& kv_cache = kv_caches[query_idx];
|
auto& kv_cache = kv_caches[query_idx].kv_cache;
|
||||||
const size_t kv_head_offset =
|
const size_t kv_head_offset =
|
||||||
layer_idx * cache_layer_size + head_offset;
|
layer_idx * cache_layer_size + head_offset;
|
||||||
MatPtrT<float> k("k_view", Extents2D(kv_cache.seq_len, qkv_dim));
|
MatPtrT<float> k("k_view", Extents2D(seq_len, qkv_dim));
|
||||||
k.SetPtr(kv_cache.kv_cache.get() + kv_head_offset,
|
k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride());
|
||||||
/*stride=*/cache_pos_size);
|
MatPtrT<float> v("v_view", Extents2D(seq_len, qkv_dim));
|
||||||
MatPtrT<float> v("v_view", Extents2D(kv_cache.seq_len, qkv_dim));
|
v.SetPtr(kv_cache.Row(0) + kv_head_offset + qkv_dim, kv_cache.Stride());
|
||||||
v.SetPtr(kv_cache.kv_cache.get() + kv_head_offset + qkv_dim,
|
|
||||||
/*stride=*/cache_pos_size);
|
|
||||||
|
|
||||||
// Find the token position in the query and calculate the range
|
// Find the token position in the query and calculate the range
|
||||||
// of cache positions to attend to.
|
// of cache positions to attend to.
|
||||||
|
|
@ -211,16 +211,15 @@ void DotSoftmaxWeightedSum(
|
||||||
last_pos = prefix_end - 1;
|
last_pos = prefix_end - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, div_seq_len, q, k,
|
SingleDotSoftmaxWeightedSum(pos, start_pos, last_pos, q, k, v,
|
||||||
v, layer_idx, layer, activations, att,
|
layer_idx, layer, activations, att,
|
||||||
att_out);
|
att_out);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fills activations.q and writes to KV cache.
|
// Fills activations.q and writes to KV cache.
|
||||||
static HWY_INLINE void ComputeQKV(
|
static HWY_INLINE void ComputeQKV(
|
||||||
size_t num_tokens, const QueriesPos& queries_pos,
|
size_t num_tokens, const QueriesPos& queries_pos, const size_t layer_idx,
|
||||||
const hwy::Divisor& div_seq_len, const size_t layer_idx,
|
|
||||||
const LayerWeightsPtrs& layer, Activations& activations,
|
const LayerWeightsPtrs& layer, Activations& activations,
|
||||||
const KVCaches& kv_caches, const int flags, MatMulEnv& env) {
|
const KVCaches& kv_caches, const int flags, MatMulEnv& env) {
|
||||||
PROFILER_ZONE("Gen.Attention.QKV");
|
PROFILER_ZONE("Gen.Attention.QKV");
|
||||||
|
|
@ -230,7 +229,6 @@ static HWY_INLINE void ComputeQKV(
|
||||||
const size_t qkv_dim = layer_config.qkv_dim;
|
const size_t qkv_dim = layer_config.qkv_dim;
|
||||||
const size_t kv_heads = layer_config.kv_heads;
|
const size_t kv_heads = layer_config.kv_heads;
|
||||||
const size_t cache_layer_size = layer_config.CacheLayerSize();
|
const size_t cache_layer_size = layer_config.CacheLayerSize();
|
||||||
const size_t cache_pos_size = activations.cache_pos_size;
|
|
||||||
|
|
||||||
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
|
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim,
|
||||||
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
|
// model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows.
|
||||||
|
|
@ -247,11 +245,10 @@ static HWY_INLINE void ComputeQKV(
|
||||||
const size_t query_idx = interleaved_idx % num_queries;
|
const size_t query_idx = interleaved_idx % num_queries;
|
||||||
const size_t batch_idx = interleaved_idx / num_queries;
|
const size_t batch_idx = interleaved_idx / num_queries;
|
||||||
const size_t cache_pos =
|
const size_t cache_pos =
|
||||||
div_seq_len.Remainder(queries_pos[query_idx] + batch_idx);
|
activations.div_seq_len.Remainder(queries_pos[query_idx] + batch_idx);
|
||||||
const size_t kv_offset =
|
|
||||||
cache_pos * cache_pos_size + layer_idx * cache_layer_size;
|
|
||||||
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
|
env.row_ptrs[0][interleaved_idx] = reinterpret_cast<uint8_t*>(
|
||||||
kv_caches[query_idx].kv_cache.get() + kv_offset);
|
kv_caches[query_idx].kv_cache.Row(cache_pos) +
|
||||||
|
layer_idx * cache_layer_size);
|
||||||
}
|
}
|
||||||
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
|
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
|
||||||
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
|
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
|
||||||
|
|
@ -267,12 +264,11 @@ static HWY_INLINE void ComputeQKV(
|
||||||
const size_t query_idx = interleaved_idx % num_queries;
|
const size_t query_idx = interleaved_idx % num_queries;
|
||||||
const size_t batch_idx = interleaved_idx / num_queries;
|
const size_t batch_idx = interleaved_idx / num_queries;
|
||||||
const size_t pos = queries_pos[query_idx] + batch_idx;
|
const size_t pos = queries_pos[query_idx] + batch_idx;
|
||||||
const size_t cache_pos = div_seq_len.Remainder(pos);
|
const size_t cache_pos = activations.div_seq_len.Remainder(pos);
|
||||||
const size_t kv_offset = cache_pos * cache_pos_size +
|
auto& kv_cache = kv_caches[query_idx].kv_cache;
|
||||||
|
float* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
|
||||||
layer_idx * cache_layer_size +
|
layer_idx * cache_layer_size +
|
||||||
head * qkv_dim * 2;
|
head * qkv_dim * 2;
|
||||||
KVCache& kv_cache = kv_caches[query_idx];
|
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
|
||||||
|
|
||||||
// Apply further processing to K.
|
// Apply further processing to K.
|
||||||
if (layer.key_norm_scale.HasPtr()) {
|
if (layer.key_norm_scale.HasPtr()) {
|
||||||
|
|
@ -309,9 +305,9 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer,
|
||||||
// causal attention, and must be non-null for prefix-LM style attention.
|
// causal attention, and must be non-null for prefix-LM style attention.
|
||||||
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos,
|
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos,
|
||||||
const QueriesPos* queries_prefix_end,
|
const QueriesPos* queries_prefix_end,
|
||||||
const hwy::Divisor& div_seq_len, const size_t layer_idx,
|
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
||||||
const LayerWeightsPtrs& layer, Activations& activations,
|
Activations& activations, const KVCaches& kv_caches,
|
||||||
const KVCaches& kv_caches, MatMulEnv& env, int flags) {
|
MatMulEnv& env, int flags) {
|
||||||
const size_t num_queries = queries_pos.size();
|
const size_t num_queries = queries_pos.size();
|
||||||
HWY_DASSERT(num_queries <= kv_caches.size());
|
HWY_DASSERT(num_queries <= kv_caches.size());
|
||||||
|
|
||||||
|
|
@ -330,11 +326,10 @@ void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos,
|
||||||
queries_prefix_end = &queries_prefix_end_span;
|
queries_prefix_end = &queries_prefix_end_span;
|
||||||
}
|
}
|
||||||
|
|
||||||
ComputeQKV(num_tokens, queries_pos, div_seq_len, layer_idx, layer,
|
ComputeQKV(num_tokens, queries_pos, layer_idx, layer, activations, kv_caches,
|
||||||
activations, kv_caches, flags, env);
|
flags, env);
|
||||||
DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end,
|
DotSoftmaxWeightedSum(num_tokens, queries_pos, *queries_prefix_end, layer_idx,
|
||||||
div_seq_len, layer_idx, layer, activations, kv_caches,
|
layer, activations, kv_caches, env.ctx.pools);
|
||||||
env.ctx.pools);
|
|
||||||
SumHeads(layer, activations, env);
|
SumHeads(layer, activations, env);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,24 +30,23 @@ namespace gcpp {
|
||||||
namespace NAMESPACE { \
|
namespace NAMESPACE { \
|
||||||
void SingleDotSoftmaxWeightedSum( \
|
void SingleDotSoftmaxWeightedSum( \
|
||||||
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
const size_t pos, const size_t start_pos, const size_t last_pos, \
|
||||||
const hwy::Divisor& div_seq_len, float* HWY_RESTRICT q, \
|
float* HWY_RESTRICT q, const MatPtrT<float>& k, const MatPtrT<float>& v, \
|
||||||
const MatPtrT<float>& k, const MatPtrT<float>& v, size_t layer_idx, \
|
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||||
const LayerWeightsPtrs& layer, const Activations& activations, \
|
const Activations& activations, float* HWY_RESTRICT att, \
|
||||||
float* HWY_RESTRICT att, float* HWY_RESTRICT att_out); \
|
float* HWY_RESTRICT att_out); \
|
||||||
\
|
\
|
||||||
void DotSoftmaxWeightedSum(const size_t num_tokens, \
|
void DotSoftmaxWeightedSum(const size_t num_tokens, \
|
||||||
const QueriesPos& queries_pos, \
|
const QueriesPos& queries_pos, \
|
||||||
const QueriesPos& queries_prefix_end, \
|
const QueriesPos& queries_prefix_end, \
|
||||||
const hwy::Divisor& div_seq_len, \
|
|
||||||
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||||
Activations& activations, \
|
Activations& activations, \
|
||||||
const KVCaches& kv_caches, NestedPools& pools); \
|
const KVCaches& kv_caches, NestedPools& pools); \
|
||||||
\
|
\
|
||||||
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, \
|
void GemmaAttention(size_t num_tokens, const QueriesPos& queries_pos, \
|
||||||
const QueriesPos* queries_prefix_end, \
|
const QueriesPos* queries_prefix_end, \
|
||||||
const hwy::Divisor& div_seq_len, const size_t layer_idx, \
|
const size_t layer_idx, const LayerWeightsPtrs& layer, \
|
||||||
const LayerWeightsPtrs& layer, Activations& activations, \
|
Activations& activations, const KVCaches& kv_caches, \
|
||||||
const KVCaches& kv_caches, MatMulEnv& env, int flags); \
|
MatMulEnv& env, int flags); \
|
||||||
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
|
||||||
} // namespace NAMESPACE
|
} // namespace NAMESPACE
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,21 +43,15 @@ namespace gcpp {
|
||||||
|
|
||||||
// ConversationData constructor implementation
|
// ConversationData constructor implementation
|
||||||
ConversationData::ConversationData(const ModelConfig& model_config,
|
ConversationData::ConversationData(const ModelConfig& model_config,
|
||||||
size_t prefill_tbatch_size)
|
const InferenceArgs& inference_args)
|
||||||
: model_config_ref_(model_config),
|
: kv_cache(std::make_unique<KVCache>(model_config, inference_args)),
|
||||||
prefill_tbatch_size_(prefill_tbatch_size),
|
|
||||||
kv_cache(std::make_unique<KVCache>(model_config, prefill_tbatch_size)),
|
|
||||||
abs_pos(0) {}
|
abs_pos(0) {}
|
||||||
|
|
||||||
// ConversationData copy constructor implementation
|
// ConversationData copy constructor implementation
|
||||||
ConversationData::ConversationData(const ConversationData& other)
|
ConversationData::ConversationData(const ConversationData& other)
|
||||||
: model_config_ref_(other.model_config_ref_),
|
: kv_cache(nullptr), abs_pos(other.abs_pos) {
|
||||||
prefill_tbatch_size_(other.prefill_tbatch_size_),
|
|
||||||
kv_cache(nullptr),
|
|
||||||
abs_pos(other.abs_pos) {
|
|
||||||
if (other.kv_cache) {
|
if (other.kv_cache) {
|
||||||
kv_cache = std::make_unique<KVCache>(other.kv_cache->Copy(
|
kv_cache = std::make_unique<KVCache>(other.kv_cache->Copy());
|
||||||
other.model_config_ref_, other.prefill_tbatch_size_));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -115,7 +109,7 @@ GemmaContext::GemmaContext(const LoaderArgs& loader,
|
||||||
LogDebug("Creating initial ConversationData");
|
LogDebug("Creating initial ConversationData");
|
||||||
// Create the initial ConversationData object using make_shared
|
// Create the initial ConversationData object using make_shared
|
||||||
active_conversation = std::make_shared<ConversationData>(
|
active_conversation = std::make_shared<ConversationData>(
|
||||||
model.GetModelConfig(), inference_args.prefill_tbatch_size);
|
model.GetModelConfig(), inference_args);
|
||||||
|
|
||||||
LogDebug(
|
LogDebug(
|
||||||
"Storing initial ConversationData in conversation_cache[\"default\"]");
|
"Storing initial ConversationData in conversation_cache[\"default\"]");
|
||||||
|
|
|
||||||
|
|
@ -31,26 +31,19 @@
|
||||||
|
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/gemma_args.h"
|
#include "gemma/gemma_args.h"
|
||||||
|
#include "gemma/kv_cache.h"
|
||||||
#include "ops/matmul.h" // MatMulEnv
|
#include "ops/matmul.h" // MatMulEnv
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Forward declaration - use 'struct' to match definition tag
|
|
||||||
struct KVCache;
|
|
||||||
|
|
||||||
// Struct to hold data for a single conversation thread
|
// Struct to hold data for a single conversation thread
|
||||||
struct ConversationData {
|
struct ConversationData {
|
||||||
public:
|
ConversationData(const ModelConfig& model_config,
|
||||||
ConversationData(const ModelConfig& model_config, size_t prefill_tbatch_size);
|
const InferenceArgs& inference_args);
|
||||||
ConversationData(const ConversationData& other);
|
ConversationData(const ConversationData& other);
|
||||||
|
|
||||||
private:
|
|
||||||
const ModelConfig& model_config_ref_;
|
|
||||||
size_t prefill_tbatch_size_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
std::unique_ptr<KVCache> kv_cache;
|
std::unique_ptr<KVCache> kv_cache;
|
||||||
size_t abs_pos = 0;
|
size_t abs_pos = 0;
|
||||||
};
|
};
|
||||||
|
|
@ -142,8 +135,7 @@ class GemmaContext {
|
||||||
log_msg += "' to prewarmed_cache.";
|
log_msg += "' to prewarmed_cache.";
|
||||||
LogDebug(log_msg.c_str());
|
LogDebug(log_msg.c_str());
|
||||||
|
|
||||||
// Create a deep copy of the active_conversation.
|
// Create a deep copy of the active_conversation via copy ctor.
|
||||||
// The ConversationData copy constructor handles the deep copy of KVCache.
|
|
||||||
auto conversation_copy =
|
auto conversation_copy =
|
||||||
std::make_shared<ConversationData>(*active_conversation);
|
std::make_shared<ConversationData>(*active_conversation);
|
||||||
|
|
||||||
|
|
@ -176,8 +168,7 @@ class GemmaContext {
|
||||||
active_conversation->abs_pos = it->second->abs_pos;
|
active_conversation->abs_pos = it->second->abs_pos;
|
||||||
// Perform a deep copy of the KVCache from the prewarmed version.
|
// Perform a deep copy of the KVCache from the prewarmed version.
|
||||||
active_conversation->kv_cache =
|
active_conversation->kv_cache =
|
||||||
std::make_unique<KVCache>(it->second->kv_cache->Copy(
|
std::make_unique<KVCache>(it->second->kv_cache->Copy());
|
||||||
model.GetModelConfig(), inference_args.prefill_tbatch_size));
|
|
||||||
LogDebug((log_prefix + "Successfully restored from prewarmed_cache.")
|
LogDebug((log_prefix + "Successfully restored from prewarmed_cache.")
|
||||||
.c_str());
|
.c_str());
|
||||||
return;
|
return;
|
||||||
|
|
@ -187,8 +178,8 @@ class GemmaContext {
|
||||||
// rewind to initial state.
|
// rewind to initial state.
|
||||||
active_conversation->abs_pos = 0;
|
active_conversation->abs_pos = 0;
|
||||||
// Replace the cache within the current ConversationData object
|
// Replace the cache within the current ConversationData object
|
||||||
active_conversation->kv_cache = std::make_unique<KVCache>(
|
active_conversation->kv_cache =
|
||||||
model.GetModelConfig(), inference_args.prefill_tbatch_size);
|
std::make_unique<KVCache>(model.GetModelConfig(), inference_args);
|
||||||
|
|
||||||
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
|
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -206,7 +197,7 @@ class GemmaContext {
|
||||||
LogDebug("Creating new conversation");
|
LogDebug("Creating new conversation");
|
||||||
// Create a new ConversationData object using make_shared
|
// Create a new ConversationData object using make_shared
|
||||||
conversation_cache[name] = std::make_shared<ConversationData>(
|
conversation_cache[name] = std::make_shared<ConversationData>(
|
||||||
model.GetModelConfig(), inference_args.prefill_tbatch_size);
|
model.GetModelConfig(), inference_args);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,12 +27,8 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Allow changing pre-allocated kv cache size as a compiler flag
|
|
||||||
#ifndef GEMMA_MAX_SEQLEN
|
|
||||||
#define GEMMA_MAX_SEQLEN 4096
|
|
||||||
#endif // !GEMMA_MAX_SEQLEN
|
|
||||||
|
|
||||||
static constexpr size_t kVocabSize = 256000;
|
static constexpr size_t kVocabSize = 256000;
|
||||||
|
static constexpr size_t kMaxSeqLen = 4096;
|
||||||
|
|
||||||
static ModelConfig ConfigNoSSM() {
|
static ModelConfig ConfigNoSSM() {
|
||||||
ModelConfig config;
|
ModelConfig config;
|
||||||
|
|
@ -69,7 +65,7 @@ static ModelConfig ConfigGemma2_27B() {
|
||||||
config.model = Model::GEMMA2_27B;
|
config.model = Model::GEMMA2_27B;
|
||||||
config.model_dim = 4608;
|
config.model_dim = 4608;
|
||||||
config.vocab_size = kVocabSize;
|
config.vocab_size = kVocabSize;
|
||||||
config.seq_len = 8192;
|
config.max_seq_len = 8192;
|
||||||
LayerConfig layer_config = LayerConfigGemma2_27B(config.model_dim);
|
LayerConfig layer_config = LayerConfigGemma2_27B(config.model_dim);
|
||||||
config.num_layers = 46;
|
config.num_layers = 46;
|
||||||
config.layer_configs = {config.num_layers, layer_config};
|
config.layer_configs = {config.num_layers, layer_config};
|
||||||
|
|
@ -97,7 +93,7 @@ static ModelConfig ConfigGemma2_9B() {
|
||||||
config.model = Model::GEMMA2_9B;
|
config.model = Model::GEMMA2_9B;
|
||||||
config.model_dim = 3584;
|
config.model_dim = 3584;
|
||||||
config.vocab_size = kVocabSize;
|
config.vocab_size = kVocabSize;
|
||||||
config.seq_len = 8192;
|
config.max_seq_len = 8192;
|
||||||
LayerConfig layer_config = LayerConfigGemma2_9B(config.model_dim);
|
LayerConfig layer_config = LayerConfigGemma2_9B(config.model_dim);
|
||||||
config.num_layers = 42;
|
config.num_layers = 42;
|
||||||
config.layer_configs = {config.num_layers, layer_config};
|
config.layer_configs = {config.num_layers, layer_config};
|
||||||
|
|
@ -125,7 +121,7 @@ static ModelConfig ConfigGemma2_2B() {
|
||||||
config.model = Model::GEMMA2_2B;
|
config.model = Model::GEMMA2_2B;
|
||||||
config.model_dim = 2304;
|
config.model_dim = 2304;
|
||||||
config.vocab_size = kVocabSize;
|
config.vocab_size = kVocabSize;
|
||||||
config.seq_len = 8192;
|
config.max_seq_len = 8192;
|
||||||
LayerConfig layer_config = LayerConfigGemma2_2B(config.model_dim);
|
LayerConfig layer_config = LayerConfigGemma2_2B(config.model_dim);
|
||||||
config.num_layers = 26;
|
config.num_layers = 26;
|
||||||
config.layer_configs = {config.num_layers, layer_config};
|
config.layer_configs = {config.num_layers, layer_config};
|
||||||
|
|
@ -152,7 +148,7 @@ static ModelConfig ConfigGemmaTiny() {
|
||||||
config.wrapping = PromptWrapping::GEMMA_IT;
|
config.wrapping = PromptWrapping::GEMMA_IT;
|
||||||
config.model_dim = 32;
|
config.model_dim = 32;
|
||||||
config.vocab_size = 32; // at least two f32 vectors
|
config.vocab_size = 32; // at least two f32 vectors
|
||||||
config.seq_len = 32;
|
config.max_seq_len = 32;
|
||||||
LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim);
|
LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim);
|
||||||
config.num_layers = 2;
|
config.num_layers = 2;
|
||||||
config.layer_configs = {config.num_layers, layer_config};
|
config.layer_configs = {config.num_layers, layer_config};
|
||||||
|
|
@ -188,11 +184,11 @@ static ModelConfig ConfigGriffin2B() {
|
||||||
ModelConfig config = ConfigNoSSM();
|
ModelConfig config = ConfigNoSSM();
|
||||||
config.display_name = "Griffin2B";
|
config.display_name = "Griffin2B";
|
||||||
config.model = Model::GRIFFIN_2B;
|
config.model = Model::GRIFFIN_2B;
|
||||||
// Griffin uses local attention, so GEMMA_MAX_SEQLEN is actually the local
|
// Griffin uses local attention, so max_seq_len is actually the local
|
||||||
// attention window.
|
// attention window.
|
||||||
config.model_dim = 2560;
|
config.model_dim = 2560;
|
||||||
config.vocab_size = kVocabSize;
|
config.vocab_size = kVocabSize;
|
||||||
config.seq_len = 2048;
|
config.max_seq_len = 2048;
|
||||||
LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim);
|
LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim);
|
||||||
config.num_layers = 26;
|
config.num_layers = 26;
|
||||||
config.layer_configs = {config.num_layers, layer_config};
|
config.layer_configs = {config.num_layers, layer_config};
|
||||||
|
|
@ -200,7 +196,8 @@ static ModelConfig ConfigGriffin2B() {
|
||||||
config.layer_configs[i].type = LayerAttentionType::kGemma;
|
config.layer_configs[i].type = LayerAttentionType::kGemma;
|
||||||
config.layer_configs[i].griffin_dim = 0;
|
config.layer_configs[i].griffin_dim = 0;
|
||||||
}
|
}
|
||||||
config.attention_window_sizes = FixedAttentionWindowSizes<26>(config.seq_len);
|
config.attention_window_sizes =
|
||||||
|
FixedAttentionWindowSizes<26>(config.max_seq_len);
|
||||||
config.use_local_attention = true;
|
config.use_local_attention = true;
|
||||||
config.final_cap = 0.0f;
|
config.final_cap = 0.0f;
|
||||||
return config;
|
return config;
|
||||||
|
|
@ -238,7 +235,7 @@ static void AddVitConfig(ModelConfig& config, size_t image_size = 224) {
|
||||||
ModelConfig GetVitConfig(const ModelConfig& config) {
|
ModelConfig GetVitConfig(const ModelConfig& config) {
|
||||||
ModelConfig vit_config = ConfigNoSSM();
|
ModelConfig vit_config = ConfigNoSSM();
|
||||||
vit_config.model_dim = config.vit_config.model_dim;
|
vit_config.model_dim = config.vit_config.model_dim;
|
||||||
vit_config.seq_len = config.vit_config.seq_len;
|
vit_config.max_seq_len = config.vit_config.seq_len;
|
||||||
vit_config.layer_configs = config.vit_config.layer_configs;
|
vit_config.layer_configs = config.vit_config.layer_configs;
|
||||||
vit_config.pool_dim = config.vit_config.pool_dim;
|
vit_config.pool_dim = config.vit_config.pool_dim;
|
||||||
vit_config.wrapping = config.wrapping;
|
vit_config.wrapping = config.wrapping;
|
||||||
|
|
@ -313,14 +310,14 @@ static ModelConfig ConfigGemma3_1B() {
|
||||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||||
config.model_dim = 1152;
|
config.model_dim = 1152;
|
||||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||||
config.seq_len = 32 * 1024;
|
config.max_seq_len = 32 * 1024;
|
||||||
LayerConfig layer_config = LayerConfigGemma3_1B_LM(config.model_dim);
|
LayerConfig layer_config = LayerConfigGemma3_1B_LM(config.model_dim);
|
||||||
config.num_layers = 26;
|
config.num_layers = 26;
|
||||||
config.layer_configs = {config.num_layers, layer_config};
|
config.layer_configs = {config.num_layers, layer_config};
|
||||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||||
// interleaved local / global attention
|
// interleaved local / global attention
|
||||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<26, 6>(
|
config.attention_window_sizes = RepeatedAttentionWindowSizes<26, 6>(
|
||||||
{512, 512, 512, 512, 512, config.seq_len});
|
{512, 512, 512, 512, 512, config.max_seq_len});
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -345,14 +342,14 @@ static ModelConfig ConfigGemma3_4B_LM() {
|
||||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||||
config.model_dim = 2560;
|
config.model_dim = 2560;
|
||||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||||
config.seq_len = 32 * 1024;
|
config.max_seq_len = 32 * 1024;
|
||||||
LayerConfig layer_config = LayerConfigGemma3_4B_LM(config.model_dim);
|
LayerConfig layer_config = LayerConfigGemma3_4B_LM(config.model_dim);
|
||||||
config.num_layers = 34;
|
config.num_layers = 34;
|
||||||
config.layer_configs = {config.num_layers, layer_config};
|
config.layer_configs = {config.num_layers, layer_config};
|
||||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||||
// interleaved local / global attention
|
// interleaved local / global attention
|
||||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<34, 6>(
|
config.attention_window_sizes = RepeatedAttentionWindowSizes<34, 6>(
|
||||||
{1024, 1024, 1024, 1024, 1024, config.seq_len});
|
{1024, 1024, 1024, 1024, 1024, config.max_seq_len});
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -394,14 +391,14 @@ static ModelConfig ConfigGemma3_12B_LM() {
|
||||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||||
config.model_dim = 3840;
|
config.model_dim = 3840;
|
||||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||||
config.seq_len = 32 * 1024;
|
config.max_seq_len = 32 * 1024;
|
||||||
LayerConfig layer_config = LayerConfigGemma3_12B_LM(config.model_dim);
|
LayerConfig layer_config = LayerConfigGemma3_12B_LM(config.model_dim);
|
||||||
config.num_layers = 48;
|
config.num_layers = 48;
|
||||||
config.layer_configs = {config.num_layers, layer_config};
|
config.layer_configs = {config.num_layers, layer_config};
|
||||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||||
// interleaved local / global attention
|
// interleaved local / global attention
|
||||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<48, 6>(
|
config.attention_window_sizes = RepeatedAttentionWindowSizes<48, 6>(
|
||||||
{1024, 1024, 1024, 1024, 1024, config.seq_len});
|
{1024, 1024, 1024, 1024, 1024, config.max_seq_len});
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -443,14 +440,14 @@ static ModelConfig ConfigGemma3_27B_LM() {
|
||||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||||
config.model_dim = 5376;
|
config.model_dim = 5376;
|
||||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||||
config.seq_len = 32 * 1024;
|
config.max_seq_len = 32 * 1024;
|
||||||
LayerConfig layer_config = LayerConfigGemma3_27B_LM(config.model_dim);
|
LayerConfig layer_config = LayerConfigGemma3_27B_LM(config.model_dim);
|
||||||
config.num_layers = 62;
|
config.num_layers = 62;
|
||||||
config.layer_configs = {config.num_layers, layer_config};
|
config.layer_configs = {config.num_layers, layer_config};
|
||||||
config.query_scale = QueryScaleType::SqrtKeySize;
|
config.query_scale = QueryScaleType::SqrtKeySize;
|
||||||
// interleaved local / global attention
|
// interleaved local / global attention
|
||||||
config.attention_window_sizes = RepeatedAttentionWindowSizes<62, 6>(
|
config.attention_window_sizes = RepeatedAttentionWindowSizes<62, 6>(
|
||||||
{1024, 1024, 1024, 1024, 1024, config.seq_len});
|
{1024, 1024, 1024, 1024, 1024, config.max_seq_len});
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -347,7 +347,7 @@ struct ModelConfig : public IFields {
|
||||||
visitor(num_layers);
|
visitor(num_layers);
|
||||||
visitor(model_dim);
|
visitor(model_dim);
|
||||||
visitor(vocab_size);
|
visitor(vocab_size);
|
||||||
visitor(seq_len);
|
visitor(max_seq_len);
|
||||||
|
|
||||||
visitor(unused_num_tensor_scales);
|
visitor(unused_num_tensor_scales);
|
||||||
|
|
||||||
|
|
@ -413,7 +413,7 @@ struct ModelConfig : public IFields {
|
||||||
return num_heads;
|
return num_heads;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t CachePosSize() const {
|
size_t KVCacheCols() const {
|
||||||
size_t num_layers = layer_configs.size();
|
size_t num_layers = layer_configs.size();
|
||||||
return num_layers * layer_configs[0].CacheLayerSize();
|
return num_layers * layer_configs[0].CacheLayerSize();
|
||||||
}
|
}
|
||||||
|
|
@ -435,7 +435,7 @@ struct ModelConfig : public IFields {
|
||||||
uint32_t num_layers = 0;
|
uint32_t num_layers = 0;
|
||||||
uint32_t model_dim = 0;
|
uint32_t model_dim = 0;
|
||||||
uint32_t vocab_size = 0;
|
uint32_t vocab_size = 0;
|
||||||
uint32_t seq_len = 0;
|
uint32_t max_seq_len = 0;
|
||||||
|
|
||||||
// We no longer set nor use this: config_converter is not able to set this,
|
// We no longer set nor use this: config_converter is not able to set this,
|
||||||
// and only pre-2025 format stores scales, and we do not require advance
|
// and only pre-2025 format stores scales, and we do not require advance
|
||||||
|
|
|
||||||
113
gemma/gemma.cc
113
gemma/gemma.cc
|
|
@ -64,13 +64,12 @@ namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
void Attention(LayerAttentionType type, size_t num_tokens,
|
void Attention(LayerAttentionType type, size_t num_tokens,
|
||||||
const QueriesPos& queries_pos,
|
const QueriesPos& queries_pos,
|
||||||
const QueriesPos& queries_prefix_end,
|
const QueriesPos& queries_prefix_end, const size_t layer_idx,
|
||||||
const hwy::Divisor& div_seq_len, const size_t layer_idx,
|
|
||||||
const LayerWeightsPtrs& layer, Activations& activations,
|
const LayerWeightsPtrs& layer, Activations& activations,
|
||||||
const KVCaches& kv_caches, MatMulEnv& env) {
|
const KVCaches& kv_caches, MatMulEnv& env) {
|
||||||
if (type == LayerAttentionType::kGemma) {
|
if (type == LayerAttentionType::kGemma) {
|
||||||
GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, div_seq_len,
|
GemmaAttention(num_tokens, queries_pos, &queries_prefix_end, layer_idx,
|
||||||
layer_idx, layer, activations, kv_caches, env,
|
layer, activations, kv_caches, env,
|
||||||
/*flags=*/0);
|
/*flags=*/0);
|
||||||
} else {
|
} else {
|
||||||
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
|
HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock);
|
||||||
|
|
@ -85,16 +84,16 @@ void Attention(LayerAttentionType type, size_t num_tokens,
|
||||||
|
|
||||||
static HWY_NOINLINE void TransformerLayer(
|
static HWY_NOINLINE void TransformerLayer(
|
||||||
const size_t num_tokens, const QueriesPos& queries_pos,
|
const size_t num_tokens, const QueriesPos& queries_pos,
|
||||||
const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
|
const QueriesPos& queries_prefix_end, const size_t layer_idx,
|
||||||
const size_t layer_idx, const LayerWeightsPtrs& layer,
|
const LayerWeightsPtrs& layer, Activations& activations,
|
||||||
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) {
|
const KVCaches& kv_caches, MatMulEnv& env) {
|
||||||
const LayerConfig& layer_config = layer.layer_config;
|
const LayerConfig& layer_config = layer.layer_config;
|
||||||
|
|
||||||
RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
|
RMSNormBatched(activations.x, layer.pre_attention_norm_scale,
|
||||||
activations.pre_att_rms_out);
|
activations.pre_att_rms_out);
|
||||||
|
|
||||||
Attention(layer_config.type, num_tokens, queries_pos, queries_prefix_end,
|
Attention(layer_config.type, num_tokens, queries_pos, queries_prefix_end,
|
||||||
div_seq_len, layer_idx, layer, activations, kv_caches, env);
|
layer_idx, layer, activations, kv_caches, env);
|
||||||
|
|
||||||
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
|
PostNorm(layer_config.post_norm, layer.post_attention_norm_scale,
|
||||||
activations.att_sums);
|
activations.att_sums);
|
||||||
|
|
@ -190,10 +189,9 @@ using QueriesMutablePos = hwy::Span<size_t>;
|
||||||
static HWY_NOINLINE void PrefillTBatch(
|
static HWY_NOINLINE void PrefillTBatch(
|
||||||
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
||||||
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
|
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
|
||||||
const hwy::Divisor& div_seq_len, const ModelConfig& config,
|
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
||||||
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
|
const ModelWeightsPtrs& weights, Activations& activations,
|
||||||
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
|
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos) {
|
||||||
hwy::BitSet4096<>& non_eos) {
|
|
||||||
PROFILER_ZONE("Gen.PrefillT");
|
PROFILER_ZONE("Gen.PrefillT");
|
||||||
const size_t num_queries = queries_prompt.size();
|
const size_t num_queries = queries_prompt.size();
|
||||||
HWY_DASSERT(num_queries == queries_pos.size());
|
HWY_DASSERT(num_queries == queries_pos.size());
|
||||||
|
|
@ -265,8 +263,8 @@ static HWY_NOINLINE void PrefillTBatch(
|
||||||
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
|
for (size_t layer_idx = 0; layer_idx < config.layer_configs.size();
|
||||||
++layer_idx) {
|
++layer_idx) {
|
||||||
TransformerLayer(tbatch_size, single_query_pos, single_query_prefix_end,
|
TransformerLayer(tbatch_size, single_query_pos, single_query_prefix_end,
|
||||||
div_seq_len, layer_idx, *weights.GetLayer(layer_idx),
|
layer_idx, *weights.GetLayer(layer_idx), activations,
|
||||||
activations, single_kv_cache, env);
|
single_kv_cache, env);
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: we unconditionally call StreamToken, even if EOS.
|
// NOTE: we unconditionally call StreamToken, even if EOS.
|
||||||
|
|
@ -303,10 +301,9 @@ static HWY_NOINLINE void PrefillTBatch(
|
||||||
// token-batched `PrefillTBatch`.
|
// token-batched `PrefillTBatch`.
|
||||||
static HWY_NOINLINE void Transformer(
|
static HWY_NOINLINE void Transformer(
|
||||||
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
|
const QueriesToken& queries_token, const QueriesMutablePos& queries_pos,
|
||||||
const QueriesPos& queries_prefix_end, const hwy::Divisor& div_seq_len,
|
const QueriesPos& queries_prefix_end, const ModelConfig& config,
|
||||||
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
|
||||||
const ModelWeightsPtrs& weights, Activations& activations,
|
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env) {
|
||||||
const KVCaches& kv_caches, MatMulEnv& env) {
|
|
||||||
const size_t num_queries = queries_token.size();
|
const size_t num_queries = queries_token.size();
|
||||||
HWY_DASSERT(num_queries == queries_pos.size());
|
HWY_DASSERT(num_queries == queries_pos.size());
|
||||||
HWY_DASSERT(num_queries == queries_prefix_end.size());
|
HWY_DASSERT(num_queries == queries_prefix_end.size());
|
||||||
|
|
@ -326,8 +323,8 @@ static HWY_NOINLINE void Transformer(
|
||||||
|
|
||||||
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
|
for (size_t layer_idx = 0; layer_idx < weights.c_layers.size(); ++layer_idx) {
|
||||||
TransformerLayer(/*num_tokens=*/1, queries_pos, queries_prefix_end,
|
TransformerLayer(/*num_tokens=*/1, queries_pos, queries_prefix_end,
|
||||||
div_seq_len, layer_idx, *weights.GetLayer(layer_idx),
|
layer_idx, *weights.GetLayer(layer_idx), activations,
|
||||||
activations, kv_caches, env);
|
kv_caches, env);
|
||||||
|
|
||||||
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
if (HWY_UNLIKELY(runtime_config.activations_observer)) {
|
||||||
runtime_config.activations_observer(queries_pos, layer_idx, activations);
|
runtime_config.activations_observer(queries_pos, layer_idx, activations);
|
||||||
|
|
@ -340,10 +337,10 @@ static HWY_NOINLINE void Transformer(
|
||||||
static HWY_NOINLINE void PrefillQBatch(
|
static HWY_NOINLINE void PrefillQBatch(
|
||||||
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
||||||
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
|
const QueriesMutablePos& queries_pos, const QueriesPos& queries_prefix_end,
|
||||||
const size_t max_prompt_size, const hwy::Divisor& div_seq_len,
|
const size_t max_prompt_size, const ModelConfig& config,
|
||||||
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
|
||||||
const ModelWeightsPtrs& weights, Activations& activations,
|
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
|
||||||
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos) {
|
hwy::BitSet4096<>& non_eos) {
|
||||||
PROFILER_ZONE("Gen.Prefill");
|
PROFILER_ZONE("Gen.Prefill");
|
||||||
const size_t num_queries = queries_prompt.size();
|
const size_t num_queries = queries_prompt.size();
|
||||||
HWY_DASSERT(num_queries == queries_pos.size());
|
HWY_DASSERT(num_queries == queries_pos.size());
|
||||||
|
|
@ -380,8 +377,8 @@ static HWY_NOINLINE void PrefillQBatch(
|
||||||
// Do not call DecodeStepT because it computes logits for token
|
// Do not call DecodeStepT because it computes logits for token
|
||||||
// probabilities, which are not required for the prompt tokens.
|
// probabilities, which are not required for the prompt tokens.
|
||||||
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
|
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
|
||||||
queries_pos, queries_prefix_end, div_seq_len, config,
|
queries_pos, queries_prefix_end, config, runtime_config,
|
||||||
runtime_config, weights, activations, kv_caches, env);
|
weights, activations, kv_caches, env);
|
||||||
|
|
||||||
prefill_active.Foreach([&](size_t qi) {
|
prefill_active.Foreach([&](size_t qi) {
|
||||||
const int token = queries_prompt[qi][pos_in_prompt];
|
const int token = queries_prompt[qi][pos_in_prompt];
|
||||||
|
|
@ -393,19 +390,6 @@ static HWY_NOINLINE void PrefillQBatch(
|
||||||
} // pos_in_prompt
|
} // pos_in_prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: inline.
|
|
||||||
void RangeChecks(const ModelConfig& weights_config,
|
|
||||||
size_t& max_generated_tokens, const size_t prompt_size) {
|
|
||||||
if (!weights_config.use_local_attention) {
|
|
||||||
if (max_generated_tokens > weights_config.seq_len) {
|
|
||||||
HWY_WARN("max_generated_tokens %zu > kSeqLen %u, truncating.",
|
|
||||||
max_generated_tokens, weights_config.seq_len);
|
|
||||||
max_generated_tokens = weights_config.seq_len;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
HWY_ASSERT(prompt_size > 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Also writes the token to activations.gen_tokens for subsequent DecodeStepT,
|
// Also writes the token to activations.gen_tokens for subsequent DecodeStepT,
|
||||||
// and updates `non_eos` if the query is at the end of its sequence.
|
// and updates `non_eos` if the query is at the end of its sequence.
|
||||||
static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token,
|
static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token,
|
||||||
|
|
@ -432,17 +416,17 @@ static void StreamAndUpdateEOS(const size_t qi, const size_t pos, int token,
|
||||||
static void DecodeStepT(
|
static void DecodeStepT(
|
||||||
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
const size_t query_idx_start, const QueriesPromptTokens& queries_prompt,
|
||||||
const QueriesMutablePos& queries_mutable_pos,
|
const QueriesMutablePos& queries_mutable_pos,
|
||||||
const QueriesPos& queries_prefix_end, const hwy::Divisor div_seq_len,
|
const QueriesPos& queries_prefix_end, const ModelConfig& config,
|
||||||
const ModelConfig& config, const RuntimeConfig& runtime_config,
|
const RuntimeConfig& runtime_config, const ModelWeightsPtrs& weights,
|
||||||
const ModelWeightsPtrs& weights, const SampleFunc& sample_token,
|
const SampleFunc& sample_token, Activations& activations,
|
||||||
Activations& activations, const KVCaches& kv_caches, MatMulEnv& env,
|
const KVCaches& kv_caches, MatMulEnv& env, hwy::BitSet4096<>& non_eos,
|
||||||
hwy::BitSet4096<>& non_eos, TimingInfo& timing_info) {
|
TimingInfo& timing_info) {
|
||||||
const size_t num_queries = queries_prompt.size();
|
const size_t num_queries = queries_prompt.size();
|
||||||
HWY_DASSERT(num_queries == activations.x.Rows());
|
HWY_DASSERT(num_queries == activations.x.Rows());
|
||||||
|
|
||||||
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
|
Transformer(QueriesToken(activations.gen_tokens.data(), num_queries),
|
||||||
queries_mutable_pos, queries_prefix_end, div_seq_len, config,
|
queries_mutable_pos, queries_prefix_end, config, runtime_config,
|
||||||
runtime_config, weights, activations, kv_caches, env);
|
weights, activations, kv_caches, env);
|
||||||
|
|
||||||
RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
|
RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
|
||||||
|
|
||||||
|
|
@ -530,6 +514,7 @@ static void GenerateT(
|
||||||
size_t max_prompt_size = 0;
|
size_t max_prompt_size = 0;
|
||||||
bool all_prefix_end_are_zero = true;
|
bool all_prefix_end_are_zero = true;
|
||||||
size_t prefill_tokens = 0;
|
size_t prefill_tokens = 0;
|
||||||
|
const size_t seq_len = kv_caches[0].SeqLen();
|
||||||
for (size_t qi = 0; qi < num_queries; ++qi) {
|
for (size_t qi = 0; qi < num_queries; ++qi) {
|
||||||
const PromptTokens& prompt = queries_prompt[qi];
|
const PromptTokens& prompt = queries_prompt[qi];
|
||||||
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
|
max_prompt_size = HWY_MAX(max_prompt_size, prompt.size());
|
||||||
|
|
@ -542,9 +527,12 @@ static void GenerateT(
|
||||||
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
|
HWY_ASSERT(prompt.size() != 0 && prompt[0] != config.eos_id);
|
||||||
|
|
||||||
all_prefix_end_are_zero &= queries_prefix_end[qi] == 0;
|
all_prefix_end_are_zero &= queries_prefix_end[qi] == 0;
|
||||||
}
|
|
||||||
|
|
||||||
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
|
// We use a single divisor, so all sequence lengths must be the same.
|
||||||
|
HWY_ASSERT(kv_caches[qi].SeqLen() == seq_len);
|
||||||
|
}
|
||||||
|
HWY_ASSERT(prefill_tokens < seq_len);
|
||||||
|
activations.div_seq_len = hwy::Divisor(static_cast<uint32_t>(seq_len));
|
||||||
|
|
||||||
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
|
// Lacks a constructor to bulk-set, hence initialized by Prefill* which have
|
||||||
// qi loops anyway.
|
// qi loops anyway.
|
||||||
|
|
@ -555,13 +543,12 @@ static void GenerateT(
|
||||||
if ((num_queries > max_prompt_size) && all_prefix_end_are_zero) {
|
if ((num_queries > max_prompt_size) && all_prefix_end_are_zero) {
|
||||||
activations.SetBatchSize(num_queries); // required before PrefillQBatch
|
activations.SetBatchSize(num_queries); // required before PrefillQBatch
|
||||||
PrefillQBatch(query_idx_start, queries_prompt, queries_mutable_pos,
|
PrefillQBatch(query_idx_start, queries_prompt, queries_mutable_pos,
|
||||||
queries_prefix_end, max_prompt_size, div_seq_len, config,
|
queries_prefix_end, max_prompt_size, config, runtime_config,
|
||||||
runtime_config, weights, activations, kv_caches, env,
|
weights, activations, kv_caches, env, non_eos);
|
||||||
non_eos);
|
|
||||||
} else {
|
} else {
|
||||||
PrefillTBatch(query_idx_start, queries_prompt, queries_mutable_pos,
|
PrefillTBatch(query_idx_start, queries_prompt, queries_mutable_pos,
|
||||||
queries_prefix_end, div_seq_len, config, runtime_config,
|
queries_prefix_end, config, runtime_config, weights,
|
||||||
weights, activations, kv_caches, env, non_eos);
|
activations, kv_caches, env, non_eos);
|
||||||
activations.SetBatchSize(num_queries); // Restore after PrefillTBatch.
|
activations.SetBatchSize(num_queries); // Restore after PrefillTBatch.
|
||||||
}
|
}
|
||||||
HWY_DASSERT(num_queries == non_eos.Count());
|
HWY_DASSERT(num_queries == non_eos.Count());
|
||||||
|
|
@ -579,7 +566,11 @@ static void GenerateT(
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
size_t max_gen_steps = runtime_config.max_generated_tokens;
|
||||||
RangeChecks(config, max_gen_steps, max_prompt_size);
|
if (prefill_tokens + max_gen_steps > seq_len) {
|
||||||
|
HWY_WARN("prefill %zu + max_gen_steps %zu > seq_len %zu, truncating.",
|
||||||
|
prefill_tokens, max_gen_steps, seq_len);
|
||||||
|
max_gen_steps = seq_len - prefill_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
|
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
|
||||||
|
|
||||||
|
|
@ -587,8 +578,8 @@ static void GenerateT(
|
||||||
timing_info.generate_start = hwy::platform::Now();
|
timing_info.generate_start = hwy::platform::Now();
|
||||||
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
for (size_t gen = 0; gen < max_gen_steps && non_eos.Any(); ++gen) {
|
||||||
DecodeStepT(query_idx_start, queries_prompt, queries_mutable_pos,
|
DecodeStepT(query_idx_start, queries_prompt, queries_mutable_pos,
|
||||||
queries_prefix_end, div_seq_len, config, runtime_config,
|
queries_prefix_end, config, runtime_config, weights,
|
||||||
weights, sample_token, activations, kv_caches, env, non_eos,
|
sample_token, activations, kv_caches, env, non_eos,
|
||||||
timing_info);
|
timing_info);
|
||||||
}
|
}
|
||||||
timing_info.NotifyGenerateDone();
|
timing_info.NotifyGenerateDone();
|
||||||
|
|
@ -661,10 +652,11 @@ void GenerateImageTokensT(const ModelConfig& config,
|
||||||
HWY_ABORT("Model does not support generating image tokens.");
|
HWY_ABORT("Model does not support generating image tokens.");
|
||||||
}
|
}
|
||||||
RuntimeConfig prefill_runtime_config = runtime_config;
|
RuntimeConfig prefill_runtime_config = runtime_config;
|
||||||
ModelConfig vit_config = GetVitConfig(config);
|
const ModelConfig vit_config = GetVitConfig(config);
|
||||||
|
const size_t num_tokens = vit_config.max_seq_len;
|
||||||
prefill_runtime_config.prefill_tbatch_size =
|
prefill_runtime_config.prefill_tbatch_size =
|
||||||
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
|
num_tokens / (vit_config.pool_dim * vit_config.pool_dim);
|
||||||
Activations prefill_activations(vit_config, vit_config.seq_len, env.row_ptrs);
|
Activations prefill_activations(vit_config, num_tokens, env.row_ptrs);
|
||||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||||
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
PrefillVit(config, weights, prefill_runtime_config, image, image_tokens,
|
||||||
prefill_activations, env);
|
prefill_activations, env);
|
||||||
|
|
@ -692,7 +684,8 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
reader_(loader.weights),
|
reader_(loader.weights),
|
||||||
model_(reader_, loader.tokenizer, loader.wrapping),
|
model_(reader_, loader.tokenizer, loader.wrapping),
|
||||||
weights_(model_.Config()),
|
weights_(model_.Config()),
|
||||||
chat_template_(model_.Tokenizer(), model_.Config().model) {
|
chat_template_(model_.Tokenizer(), model_.Config().model),
|
||||||
|
inference_(inference) {
|
||||||
weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_,
|
weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_,
|
||||||
env.ctx.pools.Pool());
|
env.ctx.pools.Pool());
|
||||||
reader_.CloseFile();
|
reader_.CloseFile();
|
||||||
|
|
|
||||||
|
|
@ -117,6 +117,7 @@ class Gemma {
|
||||||
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
|
const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); }
|
||||||
const ModelWeightsPtrs& Weights() const { return weights_; }
|
const ModelWeightsPtrs& Weights() const { return weights_; }
|
||||||
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
|
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
|
||||||
|
const InferenceArgs& Inference() const { return inference_; }
|
||||||
|
|
||||||
void Save(const Path& weights_path, hwy::ThreadPool& pool) const;
|
void Save(const Path& weights_path, hwy::ThreadPool& pool) const;
|
||||||
|
|
||||||
|
|
@ -159,6 +160,7 @@ class Gemma {
|
||||||
std::vector<MatOwner> mat_owners_;
|
std::vector<MatOwner> mat_owners_;
|
||||||
ModelWeightsPtrs weights_;
|
ModelWeightsPtrs weights_;
|
||||||
GemmaChatTemplate chat_template_;
|
GemmaChatTemplate chat_template_;
|
||||||
|
InferenceArgs inference_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -35,11 +35,6 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Allow changing k parameter of `SampleTopK` as a compiler flag
|
|
||||||
#ifndef GEMMA_TOPK
|
|
||||||
#define GEMMA_TOPK 1
|
|
||||||
#endif // !GEMMA_TOPK
|
|
||||||
|
|
||||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
LoaderArgs(const std::string& tokenizer_path,
|
LoaderArgs(const std::string& tokenizer_path,
|
||||||
|
|
@ -115,6 +110,7 @@ using ActivationsObserverFunc =
|
||||||
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
|
std::function<void(const QueriesPos& queries_pos, int, const Activations&)>;
|
||||||
|
|
||||||
// RuntimeConfig holds configuration for a single generation run.
|
// RuntimeConfig holds configuration for a single generation run.
|
||||||
|
// TODO: move into InferenceArgs, use that directly.
|
||||||
struct RuntimeConfig {
|
struct RuntimeConfig {
|
||||||
// If not empty, batch_stream_token is called for each token in the batch,
|
// If not empty, batch_stream_token is called for each token in the batch,
|
||||||
// instead of stream_token.
|
// instead of stream_token.
|
||||||
|
|
@ -137,7 +133,7 @@ struct RuntimeConfig {
|
||||||
// Sampling-related parameters.
|
// Sampling-related parameters.
|
||||||
float temperature; // Temperature for sampling.
|
float temperature; // Temperature for sampling.
|
||||||
|
|
||||||
size_t top_k = GEMMA_TOPK; // Top-k for sampling.
|
size_t top_k = 1; // Top-k for sampling.
|
||||||
std::mt19937* gen; // Random number generator used for sampling.
|
std::mt19937* gen; // Random number generator used for sampling.
|
||||||
|
|
||||||
int verbosity; // Controls verbosity of printed messages.
|
int verbosity; // Controls verbosity of printed messages.
|
||||||
|
|
@ -170,6 +166,7 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
|
|
||||||
int verbosity;
|
int verbosity;
|
||||||
|
|
||||||
|
size_t seq_len;
|
||||||
size_t max_generated_tokens;
|
size_t max_generated_tokens;
|
||||||
|
|
||||||
size_t prefill_tbatch_size;
|
size_t prefill_tbatch_size;
|
||||||
|
|
@ -192,6 +189,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
"developer/debug info).\n Default = 1.",
|
"developer/debug info).\n Default = 1.",
|
||||||
1); // Changed verbosity level to 1 since it's user-facing
|
1); // Changed verbosity level to 1 since it's user-facing
|
||||||
|
|
||||||
|
visitor(seq_len, "seq_len", size_t{2048},
|
||||||
|
"Sequence length, capped by ModelConfig.max_seq_len.");
|
||||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||||
"Maximum number of tokens to generate.");
|
"Maximum number of tokens to generate.");
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,21 +15,25 @@
|
||||||
|
|
||||||
#include "gemma/kv_cache.h"
|
#include "gemma/kv_cache.h"
|
||||||
|
|
||||||
#include <algorithm> // std::copy
|
#include <stddef.h>
|
||||||
|
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
|
#include "gemma/gemma_args.h"
|
||||||
#include "util/mat.h" // ZeroInit
|
#include "util/mat.h" // ZeroInit
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/base.h" // HWY_MAX
|
||||||
#include "hwy/base.h" // ZeroBytes
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
void KVCache::ZeroGriffinCache() {
|
void KVCache::ZeroGriffinCache() {
|
||||||
if (griffin_layers == 0) return;
|
if (conv1d_cache.Rows() == 0) return;
|
||||||
ZeroInit(conv1d_cache);
|
ZeroInit(conv1d_cache);
|
||||||
ZeroInit(rglru_cache);
|
ZeroInit(rglru_cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static size_t GriffinLayers(const ModelConfig& config) {
|
||||||
|
return config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock);
|
||||||
|
}
|
||||||
|
|
||||||
static size_t GriffinConv1dCols(const ModelConfig& config) {
|
static size_t GriffinConv1dCols(const ModelConfig& config) {
|
||||||
size_t conv1d_width = 0;
|
size_t conv1d_width = 0;
|
||||||
for (const auto& layer_config : config.layer_configs) {
|
for (const auto& layer_config : config.layer_configs) {
|
||||||
|
|
@ -40,43 +44,41 @@ static size_t GriffinConv1dCols(const ModelConfig& config) {
|
||||||
return conv1d_width * config.model_dim;
|
return conv1d_width * config.model_dim;
|
||||||
}
|
}
|
||||||
|
|
||||||
// prefill_tbatch_size is the maximum number of tokens from one query to
|
// Number of rows for KV cache. Note that both rows and cols are u32, and
|
||||||
// prefill at a time.
|
// the total number of elements can exceed 2^32.
|
||||||
KVCache::KVCache(const ModelConfig& config, size_t prefill_tbatch_size)
|
static size_t CappedSeqLen(const ModelConfig& config,
|
||||||
: griffin_layers(
|
const InferenceArgs& inference_args) {
|
||||||
config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock)),
|
if (inference_args.seq_len > config.max_seq_len) {
|
||||||
conv1d_cache("conv1d_cache",
|
HWY_WARN("Capping seq_len %zu to config.max_seq_len %u.",
|
||||||
Extents2D(griffin_layers, GriffinConv1dCols(config)),
|
inference_args.seq_len, config.max_seq_len);
|
||||||
MatPadding::kOdd),
|
return config.max_seq_len;
|
||||||
rglru_cache("rglru_cache", Extents2D(griffin_layers, config.model_dim),
|
|
||||||
MatPadding::kOdd) {
|
|
||||||
// TODO: move to MatStorageT.
|
|
||||||
const size_t size_cache_pos = config.CachePosSize();
|
|
||||||
if (size_cache_pos != 0) {
|
|
||||||
// Allocate more so that prefill can always access one batch, even if
|
|
||||||
// near the end of the sequence.
|
|
||||||
seq_len = config.seq_len + prefill_tbatch_size;
|
|
||||||
kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
|
|
||||||
}
|
}
|
||||||
|
return inference_args.seq_len;
|
||||||
}
|
}
|
||||||
|
|
||||||
KVCache KVCache::Copy(const ModelConfig& weights_config,
|
KVCache::KVCache(const Extents2D& conv1d_extents,
|
||||||
size_t prefill_tbatch_size) {
|
const Extents2D& rglru_extents, const Extents2D& kv_extents)
|
||||||
KVCache copy(weights_config, prefill_tbatch_size);
|
: conv1d_cache("conv1d_cache", conv1d_extents, MatPadding::kOdd),
|
||||||
|
rglru_cache("rglru_cache", rglru_extents, MatPadding::kOdd),
|
||||||
|
kv_cache("kv", kv_extents, MatPadding::kOdd) {}
|
||||||
|
|
||||||
const size_t size_cache_pos = weights_config.CachePosSize();
|
KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args)
|
||||||
if (size_cache_pos != 0) {
|
: KVCache(Extents2D(GriffinLayers(config), GriffinConv1dCols(config)),
|
||||||
std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len,
|
Extents2D(GriffinLayers(config), config.model_dim),
|
||||||
copy.kv_cache.get());
|
Extents2D(CappedSeqLen(config, inference_args),
|
||||||
}
|
config.KVCacheCols())) {}
|
||||||
|
|
||||||
if (conv1d_cache.HasPtr()) {
|
KVCache KVCache::Copy() {
|
||||||
|
KVCache copy(conv1d_cache.Extents(), rglru_cache.Extents(),
|
||||||
|
kv_cache.Extents());
|
||||||
|
|
||||||
|
if (conv1d_cache.Rows() != 0) {
|
||||||
CopyMat(conv1d_cache, copy.conv1d_cache);
|
CopyMat(conv1d_cache, copy.conv1d_cache);
|
||||||
}
|
|
||||||
if (rglru_cache.HasPtr()) {
|
|
||||||
CopyMat(rglru_cache, copy.rglru_cache);
|
CopyMat(rglru_cache, copy.rglru_cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CopyMat(kv_cache, copy.kv_cache);
|
||||||
|
|
||||||
return copy;
|
return copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,29 +19,34 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include "gemma/configs.h" // ModelConfig
|
#include "gemma/configs.h" // ModelConfig
|
||||||
|
#include "gemma/gemma_args.h"
|
||||||
#include "util/mat.h"
|
#include "util/mat.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
struct KVCache {
|
struct KVCache {
|
||||||
KVCache(const ModelConfig& weights_config, size_t prefill_tbatch_size);
|
KVCache(const ModelConfig& config, const InferenceArgs& inference_args);
|
||||||
|
|
||||||
// Returns a deep copy of the KVCache.
|
// Returns a deep copy of the KVCache. Use explicit function instead of
|
||||||
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
|
// copy ctor to make the cost explicit.
|
||||||
|
KVCache Copy();
|
||||||
|
|
||||||
size_t griffin_layers = 0;
|
|
||||||
// griffin_layers, griffin_conv1d_cols * config.model_dim
|
|
||||||
MatStorageT<float> conv1d_cache;
|
|
||||||
MatStorageT<float> rglru_cache; // griffin_layers, config.model_dim
|
|
||||||
// Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache
|
// Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache
|
||||||
// and rglru_cache.
|
// and rglru_cache.
|
||||||
void ZeroGriffinCache();
|
void ZeroGriffinCache();
|
||||||
|
|
||||||
size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size
|
size_t SeqLen() const { return kv_cache.Rows(); }
|
||||||
|
|
||||||
// seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2
|
// [griffin_layers, griffin_conv1d_cols * model_dim]
|
||||||
hwy::AlignedFreeUniquePtr<float[]> kv_cache;
|
MatStorageT<float> conv1d_cache;
|
||||||
|
MatStorageT<float> rglru_cache; // [griffin_layers, model_dim]
|
||||||
|
|
||||||
|
MatStorageT<float> kv_cache; // [seq_len, layers * kv_heads * qkv_dim * 2]
|
||||||
|
|
||||||
|
private:
|
||||||
|
// For use by other ctor and Copy()
|
||||||
|
KVCache(const Extents2D& conv1d_extents, const Extents2D& rglru_extents,
|
||||||
|
const Extents2D& kv_extents);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -256,7 +256,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||||
MatMulEnv env(MakeMatMulEnv(threading));
|
MatMulEnv env(MakeMatMulEnv(threading));
|
||||||
if (inference.verbosity >= 2) env.print_best = true;
|
if (inference.verbosity >= 2) env.print_best = true;
|
||||||
const Gemma gemma(loader, inference, env);
|
const Gemma gemma(loader, inference, env);
|
||||||
KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
|
KVCache kv_cache(gemma.GetModelConfig(), inference);
|
||||||
|
|
||||||
if (inference.verbosity >= 1) {
|
if (inference.verbosity >= 1) {
|
||||||
std::string instructions =
|
std::string instructions =
|
||||||
|
|
|
||||||
14
gemma/vit.cc
14
gemma/vit.cc
|
|
@ -68,7 +68,8 @@ class VitAttention {
|
||||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||||
const size_t heads = layer_config_.heads;
|
const size_t heads = layer_config_.heads;
|
||||||
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
|
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
|
||||||
const size_t seq_len = activations_.seq_len;
|
const size_t seq_len =
|
||||||
|
static_cast<size_t>(activations_.div_seq_len.GetDivisor());
|
||||||
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
||||||
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
||||||
|
|
||||||
|
|
@ -124,7 +125,8 @@ class VitAttention {
|
||||||
const size_t qkv_dim = layer_config_.qkv_dim;
|
const size_t qkv_dim = layer_config_.qkv_dim;
|
||||||
const size_t heads = layer_config_.heads;
|
const size_t heads = layer_config_.heads;
|
||||||
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
|
HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA");
|
||||||
const size_t seq_len = activations_.seq_len;
|
const size_t seq_len =
|
||||||
|
static_cast<size_t>(activations_.div_seq_len.GetDivisor());
|
||||||
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
|
||||||
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
|
||||||
|
|
||||||
|
|
@ -138,7 +140,7 @@ class VitAttention {
|
||||||
activations_.q.Row(token) + head * 3 * qkv_dim;
|
activations_.q.Row(token) + head * 3 * qkv_dim;
|
||||||
MulByConst(query_scale, q, qkv_dim);
|
MulByConst(query_scale, q, qkv_dim);
|
||||||
float* HWY_RESTRICT head_att =
|
float* HWY_RESTRICT head_att =
|
||||||
activations_.att.Row(token) + head * activations_.seq_len;
|
activations_.att.Row(token) + head * seq_len;
|
||||||
for (size_t i = 0; i < seq_len; ++i) {
|
for (size_t i = 0; i < seq_len; ++i) {
|
||||||
float* HWY_RESTRICT k =
|
float* HWY_RESTRICT k =
|
||||||
activations_.q.Row(i) + head * 3 * qkv_dim + qkv_dim;
|
activations_.q.Row(i) + head * 3 * qkv_dim + qkv_dim;
|
||||||
|
|
@ -275,7 +277,7 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||||
MatMulEnv& env) {
|
MatMulEnv& env) {
|
||||||
const size_t model_dim = model_config.vit_config.model_dim;
|
const size_t model_dim = model_config.vit_config.model_dim;
|
||||||
const size_t patch_width = model_config.vit_config.patch_width;
|
const size_t patch_width = model_config.vit_config.patch_width;
|
||||||
const size_t seq_len = model_config.vit_config.seq_len;
|
const size_t num_tokens = model_config.vit_config.seq_len;
|
||||||
const size_t patch_size = patch_width * patch_width * 3;
|
const size_t patch_size = patch_width * patch_width * 3;
|
||||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
|
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
|
||||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
|
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
|
||||||
|
|
@ -285,9 +287,9 @@ static HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||||
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
|
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
|
||||||
// image_patches is (256, 14 * 14 * 3)
|
// image_patches is (256, 14 * 14 * 3)
|
||||||
// Must be padded, see `DoDecompressA`.
|
// Must be padded, see `DoDecompressA`.
|
||||||
MatStorageT<float> image_patches("patches", Extents2D(seq_len, patch_size),
|
MatStorageT<float> image_patches("patches", Extents2D(num_tokens, patch_size),
|
||||||
MatPadding::kOdd);
|
MatPadding::kOdd);
|
||||||
for (size_t i = 0; i < seq_len; ++i) {
|
for (size_t i = 0; i < num_tokens; ++i) {
|
||||||
image.GetPatch(i, image_patches.Row(i));
|
image.GetPatch(i, image_patches.Row(i));
|
||||||
}
|
}
|
||||||
CallMatMul(image_patches, weights.vit_img_embedding_kernel,
|
CallMatMul(image_patches, weights.vit_img_embedding_kernel,
|
||||||
|
|
|
||||||
|
|
@ -161,7 +161,7 @@ PYBIND11_MODULE(configs, py_module) {
|
||||||
.def_readwrite("num_layers", &ModelConfig::num_layers)
|
.def_readwrite("num_layers", &ModelConfig::num_layers)
|
||||||
.def_readwrite("model_dim", &ModelConfig::model_dim)
|
.def_readwrite("model_dim", &ModelConfig::model_dim)
|
||||||
.def_readwrite("vocab_size", &ModelConfig::vocab_size)
|
.def_readwrite("vocab_size", &ModelConfig::vocab_size)
|
||||||
.def_readwrite("seq_len", &ModelConfig::seq_len)
|
.def_readwrite("max_seq_len", &ModelConfig::max_seq_len)
|
||||||
// Skip `unused_num_tensor_scales`.
|
// Skip `unused_num_tensor_scales`.
|
||||||
.def_readwrite("att_cap", &ModelConfig::att_cap)
|
.def_readwrite("att_cap", &ModelConfig::att_cap)
|
||||||
.def_readwrite("final_cap", &ModelConfig::final_cap)
|
.def_readwrite("final_cap", &ModelConfig::final_cap)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue