Internal change.

PiperOrigin-RevId: 736015810
This commit is contained in:
Phil Culliton 2025-03-11 23:19:36 -07:00 committed by Copybara-Service
parent 9d83ff202e
commit 4ab601da10
13 changed files with 487 additions and 63 deletions

View File

@ -189,6 +189,7 @@ constexpr bool IsNuqStream() {
enum class PromptWrapping {
GEMMA_IT,
GEMMA_PT,
GEMMA_VLM,
PALIGEMMA,
kSentinel // must be last
};

View File

@ -61,6 +61,7 @@ struct Activations {
// Rope
RowVectorBatch<float> inv_timescale;
RowVectorBatch<float> inv_timescale_global;
// Dynamic because no default ctor and only initialized in `Allocate`.
MatMulEnv* env;
@ -108,6 +109,8 @@ struct Activations {
inv_timescale = CreateInvTimescale(layer_config.qkv_dim,
post_qk == PostQKType::HalfRope);
inv_timescale_global =
CreateInvTimescale(qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0);
this->env = env;
}

View File

@ -44,6 +44,10 @@ constexpr const char* kModelFlags[] = {
"paligemma2-3b-448", // PaliGemma2 3B 448
"paligemma2-10b-224", // PaliGemma2 10B 224
"paligemma2-10b-448", // PaliGemma2 10B 448
"gemma3-4b", // Gemma3 4B
"gemma3-1b", // Gemma3 1B
"gemma3-12b", // Gemma3 12B
"gemma3-27b", // Gemma3 27B
};
constexpr Model kModelTypes[] = {
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
@ -59,6 +63,10 @@ constexpr Model kModelTypes[] = {
Model::PALIGEMMA2_3B_448, // PaliGemma2 3B 448
Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224
Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448
Model::GEMMA3_4B, // Gemma3 4B
Model::GEMMA3_1B, // Gemma3 1B
Model::GEMMA3_12B, // Gemma3 12B
Model::GEMMA3_27B, // Gemma3 27B
};
constexpr PromptWrapping kPromptWrapping[] = {
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B
@ -71,6 +79,10 @@ constexpr PromptWrapping kPromptWrapping[] = {
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PaliGemma 224/448
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448
PromptWrapping::GEMMA_VLM, // Gemma3 4B
PromptWrapping::GEMMA_PT, // Gemma3 1B
PromptWrapping::GEMMA_VLM, // Gemma3 12B
PromptWrapping::GEMMA_VLM, // Gemma3 27B
};
constexpr size_t kNumModelFlags = std::size(kModelFlags);

View File

@ -328,6 +328,186 @@ static ModelConfig ConfigPaliGemma2_10B_448() {
return config;
}
static ModelConfig ConfigBaseGemmaV3() {
ModelConfig config = ConfigNoSSM();
config.att_cap = 0.0f;
config.final_cap = 0.0f;
return config;
}
// 1B does not include a vision encoder.
static LayerConfig LayerConfigGemma3_1B_LM(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.ff_hidden_dim = 6912;
config.heads = 4;
config.kv_heads = 1;
config.qkv_dim = 256;
config.optimized_gating = true;
config.post_norm = PostNormType::Scale;
config.use_qk_norm = true;
return config;
}
static ModelConfig ConfigGemma3_1B() {
ModelConfig config = ConfigBaseGemmaV3();
config.model_name = "Gemma3_1B";
config.model = Model::GEMMA3_1B;
config.model_dim = 1152;
config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024;
LayerConfig layer_config = LayerConfigGemma3_1B_LM(config.model_dim);
config.layer_configs = {26, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
// interleaved local / global attention
config.attention_window_sizes = RepeatedAttentionWindowSizes<26, 6>(
{512, 512, 512, 512, 512, config.seq_len});
return config;
}
static LayerConfig LayerConfigGemma3_4B_LM(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.ff_hidden_dim = 8 * 2560 / 2; // = 10240
config.heads = 8;
config.kv_heads = 4;
config.qkv_dim = 256;
config.optimized_gating = true;
config.post_norm = PostNormType::Scale;
config.use_qk_norm = true;
return config;
}
// Until we have the SigLIP checkpoints included, we use the LM config directly.
static ModelConfig ConfigGemma3_4B_LM() {
ModelConfig config = ConfigBaseGemmaV3();
config.model_name = "Gemma3_4B";
config.model = Model::GEMMA3_4B;
config.model_dim = 2560;
config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024;
LayerConfig layer_config = LayerConfigGemma3_4B_LM(config.model_dim);
config.layer_configs = {34, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
// interleaved local / global attention
config.attention_window_sizes = RepeatedAttentionWindowSizes<34, 6>(
{1024, 1024, 1024, 1024, 1024, config.seq_len});
return config;
}
static ModelConfig ConfigGemma3_4B() {
ModelConfig config = ConfigGemma3_4B_LM();
config.model_name = "Gemma3_4B";
config.model = Model::GEMMA3_4B;
AddVitConfig(config, /*image_size=*/896);
config.vocab_size = 262144;
config.vit_config.pool_dim = 4;
const size_t num_patches =
config.vit_config.image_size / config.vit_config.patch_width;
config.vit_config.seq_len = (num_patches * num_patches);
// The above resets optimized gating to false; for Gemma 3 it should be true.
for (auto& layer_config : config.layer_configs) {
layer_config.optimized_gating = true;
}
return config;
}
static LayerConfig LayerConfigGemma3_12B_LM(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.ff_hidden_dim = 15360;
config.heads = 16;
config.kv_heads = 8;
config.qkv_dim = 256;
config.optimized_gating = true;
config.post_norm = PostNormType::Scale;
config.use_qk_norm = true;
return config;
}
static ModelConfig ConfigGemma3_12B_LM() {
ModelConfig config = ConfigBaseGemmaV3();
config.model_name = "Gemma3_12B";
config.model = Model::GEMMA3_12B;
config.model_dim = 3840;
config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024;
LayerConfig layer_config = LayerConfigGemma3_12B_LM(config.model_dim);
config.layer_configs = {48, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
// interleaved local / global attention
config.attention_window_sizes = RepeatedAttentionWindowSizes<48, 6>(
{1024, 1024, 1024, 1024, 1024, config.seq_len});
return config;
}
static ModelConfig ConfigGemma3_12B() {
ModelConfig config = ConfigGemma3_12B_LM();
config.model_name = "Gemma3_12B";
config.model = Model::GEMMA3_12B;
AddVitConfig(config, /*image_size=*/896);
config.vocab_size = 262144;
config.vit_config.pool_dim = 4;
const size_t num_patches =
config.vit_config.image_size / config.vit_config.patch_width;
config.vit_config.seq_len = (num_patches * num_patches);
// The above resets optimized gating to false; for Gemma 3 it should be true.
for (auto& layer_config : config.layer_configs) {
layer_config.optimized_gating = true;
}
return config;
}
static LayerConfig LayerConfigGemma3_27B_LM(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.ff_hidden_dim = 21504;
config.heads = 32;
config.kv_heads = 16;
config.qkv_dim = 128;
config.optimized_gating = true;
config.post_norm = PostNormType::Scale;
config.use_qk_norm = true;
return config;
}
static ModelConfig ConfigGemma3_27B_LM() {
ModelConfig config = ConfigBaseGemmaV3();
config.model_name = "Gemma3_27B";
config.model = Model::GEMMA3_27B;
config.model_dim = 5376;
config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024;
LayerConfig layer_config = LayerConfigGemma3_27B_LM(config.model_dim);
config.layer_configs = {62, layer_config};
config.num_tensor_scales = 4 * config.layer_configs.size();
config.query_scale = QueryScaleType::SqrtKeySize;
// interleaved local / global attention
config.attention_window_sizes = RepeatedAttentionWindowSizes<62, 6>(
{1024, 1024, 1024, 1024, 1024, config.seq_len});
return config;
}
static ModelConfig ConfigGemma3_27B() {
ModelConfig config = ConfigGemma3_27B_LM();
config.model_name = "Gemma3_27B";
config.model = Model::GEMMA3_27B;
AddVitConfig(config, /*image_size=*/896);
config.vocab_size = 262144;
config.vit_config.pool_dim = 4;
const size_t num_patches =
config.vit_config.image_size / config.vit_config.patch_width;
config.vit_config.seq_len = (num_patches * num_patches);
// The above resets optimized gating to false; for Gemma 3 it should be true.
for (auto& layer_config : config.layer_configs) {
layer_config.optimized_gating = true;
}
return config;
}
ModelConfig ConfigFromModel(Model model) {
switch (model) {
case Model::GEMMA_2B:
@ -356,6 +536,14 @@ ModelConfig ConfigFromModel(Model model) {
return ConfigPaliGemma2_10B_224();
case Model::PALIGEMMA2_10B_448:
return ConfigPaliGemma2_10B_448();
case Model::GEMMA3_4B:
return ConfigGemma3_4B();
case Model::GEMMA3_1B:
return ConfigGemma3_1B();
case Model::GEMMA3_12B:
return ConfigGemma3_12B();
case Model::GEMMA3_27B:
return ConfigGemma3_27B();
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}

View File

@ -151,6 +151,10 @@ enum class Model {
PALIGEMMA2_3B_448,
PALIGEMMA2_10B_224,
PALIGEMMA2_10B_448,
GEMMA3_4B,
GEMMA3_1B,
GEMMA3_12B,
GEMMA3_27B,
};
// Allows the Model enum to be iterated over.
@ -159,7 +163,8 @@ static constexpr Model kAllModels[] = {
Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B,
Model::PALIGEMMA_224, Model::PALIGEMMA_448, Model::PALIGEMMA2_3B_224,
Model::PALIGEMMA2_3B_448, Model::PALIGEMMA2_10B_224,
Model::PALIGEMMA2_10B_448,
Model::PALIGEMMA2_10B_448, Model::GEMMA3_4B, Model::GEMMA3_1B,
Model::GEMMA3_12B, Model::GEMMA3_27B,
};
inline bool EnumValid(Model model) {
@ -202,6 +207,7 @@ struct LayerConfig : public IFields {
visitor(type);
visitor(activation);
visitor(post_qk);
visitor(use_qk_norm);
}
uint32_t model_dim = 0;
@ -218,6 +224,7 @@ struct LayerConfig : public IFields {
LayerAttentionType type = LayerAttentionType::kGemma;
ActivationType activation = ActivationType::Gelu;
PostQKType post_qk = PostQKType::Rope;
bool use_qk_norm = false;
};
// Dimensions related to image processing.

View File

@ -219,6 +219,17 @@ class GemmaAttention {
// qk is either q or k, so qkv_dim is the length we operate on.
const size_t qkv_dim = layer_config_.qkv_dim;
const float* inv_timescale = activations_.inv_timescale.Const();
bool is_global_layer =
activations_.weights_config.attention_window_sizes[layer] ==
activations_.seq_len;
// TODO: add a config flag instead of hardcoding the model.
if (is_global_layer &&
(activations_.weights_config.model == Model::GEMMA3_4B ||
activations_.weights_config.model == Model::GEMMA3_12B ||
activations_.weights_config.model == Model::GEMMA3_27B ||
activations_.weights_config.model == Model::GEMMA3_1B)) {
inv_timescale = activations_.inv_timescale_global.Const();
}
// PostQKType::Rope
(void)layer;
if (layer_weights_.layer_config.post_qk == PostQKType::HalfRope) {
@ -324,6 +335,10 @@ class GemmaAttention {
}
// Apply further processing to K.
if (layer_weights_.key_norm_scale.data()) {
RMSNormInplace(layer_weights_.key_norm_scale.data(), kv,
qkv_dim);
}
PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f);
});
}
@ -411,6 +426,10 @@ class GemmaAttention {
// Apply rope and scaling to Q.
const size_t pos = queries_pos_[query_idx] + batch_idx;
if (layer_weights_.query_norm_scale.data()) {
RMSNormInplace(layer_weights_.query_norm_scale.data(), q,
qkv_dim);
}
PositionalEncodingQK(q, pos, layer_, query_scale);
const size_t start_pos = StartPos(pos, layer_);
@ -590,6 +609,7 @@ class VitAttention {
RowPtrFromBatch(qkv));
}
// TODO(philculliton): transition fully to MatMul.
HWY_NOINLINE void DotSoftmaxWeightedSum() {
const size_t qkv_dim = layer_config_.qkv_dim;
const size_t heads = layer_config_.heads;
@ -598,34 +618,55 @@ class VitAttention {
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
// Compute Q.K, softmax, and weighted V.
pool_.Run(0, layer_config_.heads * num_tokens_,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % layer_config_.heads;
const size_t token = task / layer_config_.heads;
// Compute Q.K scores, which are "logits" stored in head_att.
float* HWY_RESTRICT q =
activations_.q.Batch(token) + head * 3 * qkv_dim;
MulByConst(query_scale, q, qkv_dim);
float* HWY_RESTRICT head_att =
activations_.att.Batch(token) + head * activations_.seq_len;
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT k =
activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim;
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
}
// SoftMax yields "probabilities" in head_att.
Softmax(head_att, seq_len);
// Compute weighted sum of v into att_out.
float* HWY_RESTRICT att_out =
activations_.att_out.Batch(token) + head * qkv_dim;
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.q.Batch(i) +
head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
}
});
// Shift Q, K, VT to RowVectorBatches with AllocateAlignedRows(extents)
RowVectorBatch<float> Q =
AllocateAlignedRows<float>(Extents2D(num_tokens_, qkv_dim));
RowVectorBatch<float> K =
AllocateAlignedRows<float>(Extents2D(seq_len, qkv_dim));
RowVectorBatch<float> C(Extents2D(num_tokens_, seq_len));
// Initialize att_out to zero prior to head loop.
hwy::ZeroBytes(activations_.att_out.All(),
num_tokens_ * heads * qkv_dim * sizeof(float));
for (size_t head = 0; head < heads; ++head) {
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t token = task;
float* HWY_RESTRICT q =
activations_.q.Batch(token) + head * 3 * qkv_dim;
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
MulByConst(query_scale, q, qkv_dim);
hwy::CopyBytes(q, Q.Batch(token), qkv_dim * sizeof(float));
});
pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t seq_idx = task;
float* HWY_RESTRICT k =
activations_.q.Batch(seq_idx) + head * 3 * qkv_dim + qkv_dim;
hwy::CopyBytes(k, K.Batch(seq_idx), qkv_dim * sizeof(float));
});
// this produces C, a (num_tokens_, seq_len) matrix of dot products
MatMul(ConstMatFromBatch(Q.BatchSize(), Q),
ConstMatFromBatch(K.BatchSize(), K), nullptr, *activations_.env,
RowPtrFromBatch(C));
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
float* HWY_RESTRICT c = C.Batch(task);
Softmax(c, C.Cols());
});
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
size_t token = task;
float* HWY_RESTRICT att_out =
activations_.att_out.Batch(token) + head * qkv_dim;
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v =
activations_.q.Batch(i) + head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(C.Batch(token)[i], v, att_out, qkv_dim);
}
});
}
}
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
@ -784,14 +825,30 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
// `batch_idx` indicates which row of `x` to write to.
// `pos` is the *token*'s position, not the start of the batch, because this is
// called for batches of tokens in prefill, but batches of queries in decode.
//
// For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3
// spec) until we run out of image tokens. This allows for a multi-image prompt
// if -2 locations with appropriate begin/end image tokens are created by the
// calling application.
template <typename T>
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
size_t pos_in_prompt,
const ModelWeightsPtrs<T>& weights,
RowVectorBatch<float>& x,
const ImageTokens* image_tokens) {
HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos,
size_t pos_in_prompt,
const ModelWeightsPtrs<T>& weights,
RowVectorBatch<float>& x,
const ImageTokens* image_tokens,
size_t& image_token_position) {
// Image tokens just need to be copied.
if (image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) {
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM &&
image_tokens != nullptr && token == -2 &&
image_token_position < image_tokens->BatchSize()) {
hwy::CopyBytes(image_tokens->Batch(image_token_position),
x.Batch(batch_idx), x.Cols() * sizeof(x.Const()[0]));
image_token_position++;
return;
}
if (weights.weights_config.wrapping == PromptWrapping::PALIGEMMA &&
image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) {
hwy::CopyBytes(image_tokens->Batch(pos_in_prompt), x.Batch(batch_idx),
x.Cols() * sizeof(x.Const()[0]));
return;
@ -816,6 +873,21 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
}
}
// `batch_idx` indicates which row of `x` to write to.
// `pos` is the *token*'s position, not the start of the batch, because this is
// called for batches of tokens in prefill, but batches of queries in decode.
// This version of the function doesn't track internal image token position.
template <typename T>
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
size_t pos_in_prompt,
const ModelWeightsPtrs<T>& weights,
RowVectorBatch<float>& x,
const ImageTokens* image_tokens) {
size_t image_token_position = 0;
EmbedMMToken<T>(token, batch_idx, pos, pos_in_prompt, weights, x,
image_tokens, image_token_position);
}
template <typename Weights, typename T>
HWY_NOINLINE void ResidualConnection(
size_t num_interleaved, T* HWY_RESTRICT other, T* HWY_RESTRICT x,
@ -990,12 +1062,13 @@ HWY_NOINLINE void Prefill(
HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start);
// Fill activations.x (much faster than TransformerLayer).
size_t image_token_position = 0;
for (size_t ti = 0; ti < tbatch_size; ++ti) {
const size_t pos = queries_pos[qi] + ti;
const size_t pos_in_prompt = tbatch_start + ti;
const int token = queries_prompt[qi][pos_in_prompt];
EmbedToken(token, ti, pos, pos_in_prompt, weights, activations.x,
runtime_config.image_tokens);
EmbedMMToken(token, ti, pos, pos_in_prompt, weights, activations.x,
runtime_config.image_tokens, image_token_position);
}
// Transformer with one batch of tokens from a single query.
@ -1104,8 +1177,14 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
weights.vit_encoder_norm_bias.data_scale1(),
activations.x.All(), vit_model_dim);
activations.x = AvgPool4x4(activations.x);
// Apply soft embedding norm before input projection.
RMSNormInplace(weights.mm_embed_norm.data_scale1(), activations.x.All(),
vit_model_dim);
// Apply head embedding into image_tokens of size of the LLM kModelDim.
MatMul(ConstMatFromBatch(num_tokens, activations.x),
MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x),
ConstMatFromWeights(weights.vit_img_head_kernel),
weights.vit_img_head_bias.data_scale1(), *activations.env,
RowPtrFromBatch(image_tokens));
@ -1133,10 +1212,11 @@ HWY_NOINLINE void Transformer(
}
}
size_t image_token_position = 0;
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
EmbedToken(queries_token[query_idx], query_idx, queries_pos[query_idx],
/*pos_in_prompt=*/0, weights, activations.x,
/*image_tokens=*/nullptr);
EmbedMMToken(queries_token[query_idx], query_idx, queries_pos[query_idx],
/*pos_in_prompt=*/0, weights, activations.x,
/*image_tokens=*/nullptr, image_token_position);
}
for (size_t layer = 0; layer < weights.c_layers.size(); ++layer) {
@ -1440,7 +1520,8 @@ void GenerateImageTokensT(const ModelWeightsStorage& model,
}
RuntimeConfig prefill_runtime_config = runtime_config;
ModelConfig vit_config = GetVitConfig(model.Config());
prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len;
prefill_runtime_config.prefill_tbatch_size =
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config);
prefill_activations.Allocate(vit_config.seq_len, env);
// Weights are for the full PaliGemma model, not just the ViT part.
@ -1487,4 +1568,4 @@ void GenerateImageTokens( // NOLINT(misc-definitions-in-headers)
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_

View File

@ -94,10 +94,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
Image image;
ImageTokens image_tokens;
if (have_image) {
image_tokens =
ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
model.GetModelConfig().model_dim));
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
image_tokens = ImageTokens(Extents2D(
model.GetModelConfig().vit_config.seq_len / (pool_dim * pool_dim),
model.GetModelConfig().model_dim));
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA ||
model.Info().wrapping == PromptWrapping::GEMMA_VLM);
HWY_ASSERT(image.ReadPPM(args.image_file.path));
const size_t image_size = model.GetModelConfig().vit_config.image_size;
image.Resize(image_size, image_size);
@ -190,7 +192,15 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
size_t prefix_end = 0;
if (have_image) {
runtime_config.image_tokens = &image_tokens;
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
if (model.Info().wrapping == PromptWrapping::PALIGEMMA) {
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
} else if (model.Info().wrapping == PromptWrapping::GEMMA_VLM) {
size_t seq_len = model.GetModelConfig().vit_config.seq_len;
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
prompt =
WrapVLM(model.Tokenizer(), model.Info(), abs_pos, prompt,
image_tokens.BatchSize(), seq_len / (pool_dim * pool_dim));
}
prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma.
// See Figure 2 of https://arxiv.org/abs/2407.07726.
@ -209,9 +219,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
// Prepare for the next turn.
if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) {
abs_pos = 0; // Start a new turn at position 0.
InitGenerator(args, gen);
} else {
// The last token was either EOS, then it should be ignored because it is
// never part of the dialog, see Table 5 in the Gemma-2 paper:
// https://arxiv.org/pdf/2408.00118

View File

@ -64,14 +64,14 @@ std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
},
TensorInfo{
.name = "img_head_bias",
.source_names = {"img/head/bias"},
.source_names = {"img/head/bias", "embedder/mm_input_projection/b"},
.axes = {0},
.shape = {config.model_dim},
.min_size = Type::kF32,
},
TensorInfo{
.name = "img_head_kernel",
.source_names = {"img/head/kernel"},
.source_names = {"img/head/kernel", "embedder/mm_input_projection/w"},
.axes = {1, 0},
.shape = {config.model_dim, config.vit_config.model_dim},
.min_size = Type::kBF16,
@ -84,12 +84,21 @@ std::vector<TensorInfo> ModelTensors(const ModelConfig& config) {
config.vit_config.model_dim},
.min_size = Type::kF32,
},
// RMS norm applied to soft tokens prior to pos embedding.
TensorInfo{
.name = "mm_embed_norm",
.source_names = {"embedder/mm_soft_embedding_norm/scale"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
},
};
}
// Returns the tensors for the given image layer config.
std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
const LayerConfig& layer_config) {
const LayerConfig& layer_config,
const int img_layer_idx) {
return {
// Vit layers.
TensorInfo{
@ -207,28 +216,40 @@ std::vector<TensorInfo> ImageLayerTensors(const ModelConfig& config,
},
TensorInfo{
.name = "ln_0_bias",
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias"},
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias",
"img/Transformer/encoderblock_" +
std::to_string(img_layer_idx) +
"/LayerNorm_0/bias"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "ln_0_scale",
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale"},
.source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale",
"img/Transformer/encoderblock_" +
std::to_string(img_layer_idx) +
"/LayerNorm_0/scale"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "ln_1_bias",
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias"},
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias",
"img/Transformer/encoderblock_" +
std::to_string(img_layer_idx) +
"/LayerNorm_1/bias"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "ln_1_scale",
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale"},
.source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale",
"img/Transformer/encoderblock_" +
std::to_string(img_layer_idx) +
"/LayerNorm_1/scale"},
.axes = {0},
.shape = {config.vit_config.model_dim},
.min_size = Type::kBF16,
@ -241,6 +262,20 @@ std::vector<TensorInfo> LLMLayerTensors(const ModelConfig& config,
const LayerConfig& layer_config,
bool reshape_att) {
std::vector<TensorInfo> tensors = {
TensorInfo{
.name = "key_norm",
.source_names = {"attn/_key_norm/scale"},
.axes = {0},
.shape = {layer_config.qkv_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "query_norm",
.source_names = {"attn/_query_norm/scale"},
.axes = {0},
.shape = {layer_config.qkv_dim},
.min_size = Type::kBF16,
},
TensorInfo{
.name = "qkv1_w",
.source_names = {"attn/q_einsum/w"},
@ -529,7 +564,7 @@ TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx,
} else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx &&
img_layer_idx < config.vit_config.layer_configs.size()) {
const auto& layer_config = config.vit_config.layer_configs[img_layer_idx];
tensors_ = ImageLayerTensors(config, layer_config);
tensors_ = ImageLayerTensors(config, layer_config, img_layer_idx);
} else if (0 <= llm_layer_idx &&
llm_layer_idx < config.layer_configs.size()) {
const auto& layer_config = config.layer_configs[llm_layer_idx];
@ -569,4 +604,4 @@ const TensorInfo* TensorIndex::FindName(const std::string& name) const {
return &tensors_[it->second];
}
} // namespace gcpp
} // namespace gcpp

View File

@ -127,7 +127,9 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
}
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
if (info.wrapping == PromptWrapping::PALIGEMMA) {
if (info.wrapping == PromptWrapping::PALIGEMMA
// || info.wrapping == PromptWrapping::GEMMA_VLM
) {
std::vector<int> sep_tokens;
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
@ -136,4 +138,33 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
return tokens;
}
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info,
size_t pos, std::vector<int>& tokens,
size_t image_batch_size, size_t max_image_batch_size) {
HWY_ASSERT(info.wrapping == PromptWrapping::GEMMA_VLM);
size_t num_images = hwy::DivCeil(image_batch_size, max_image_batch_size);
std::vector<int> sep_tokens;
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
std::string begin_image_prompt = "\n\n<start_of_image>";
std::vector<int> begin_image_tokens =
WrapAndTokenize(tokenizer, info, pos, begin_image_prompt);
std::string end_image_prompt = "<end_of_image>\n\n";
std::vector<int> end_image_tokens =
WrapAndTokenize(tokenizer, info, pos, end_image_prompt);
for (size_t i = 0; i < num_images; ++i) {
tokens.insert(tokens.begin(), begin_image_tokens.begin(),
begin_image_tokens.end());
tokens.insert(tokens.begin() + begin_image_tokens.size(), image_batch_size,
-2);
tokens.insert(tokens.begin() + begin_image_tokens.size() + image_batch_size,
end_image_tokens.begin(), end_image_tokens.end());
}
return tokens;
}
} // namespace gcpp

View File

@ -57,6 +57,10 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const ModelInfo& info, size_t pos,
std::string& prompt);
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info,
size_t pos, std::vector<int>& tokens,
size_t image_batch_size, size_t max_image_batch_size);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_

View File

@ -99,6 +99,8 @@ struct LayerWeightsPtrs {
ffw_gating_biases("ffw_gat_b", tensor_index),
ffw_output_biases("ffw_out_b", tensor_index),
att_weights("att_w", tensor_index),
key_norm_scale("key_norm", tensor_index),
query_norm_scale("query_norm", tensor_index),
layer_config(config) {}
~LayerWeightsPtrs() = default;
@ -204,6 +206,9 @@ struct LayerWeightsPtrs {
att_weights.set_scale(attn_vec_einsum_w.scale());
}
ArrayT<WeightF32OrBF16> key_norm_scale;
ArrayT<WeightF32OrBF16> query_norm_scale;
// Used by ForEachTensor for per-layer tensors.
#define GEMMA_CALL_FUNC(member) \
{ \
@ -282,6 +287,10 @@ struct LayerWeightsPtrs {
GEMMA_CALL_FUNC(post_attention_norm_scale);
GEMMA_CALL_FUNC(post_ffw_norm_scale);
}
if (ptrs[0]->layer_config.use_qk_norm) {
GEMMA_CALL_FUNC(key_norm_scale);
GEMMA_CALL_FUNC(query_norm_scale);
}
if (ptrs[0]->layer_config.ff_biases) {
GEMMA_CALL_FUNC(ffw_gating_biases);
@ -332,6 +341,7 @@ struct ModelWeightsPtrs {
vit_img_pos_embedding("img_pos_emb", tensor_index),
vit_img_head_bias("img_head_bias", tensor_index),
vit_img_head_kernel("img_head_kernel", tensor_index),
mm_embed_norm("mm_embed_norm", tensor_index),
scale_names(config.scale_names),
weights_config(config) {
c_layers.reserve(config.layer_configs.size());
@ -372,6 +382,8 @@ struct ModelWeightsPtrs {
MatPtrT<float> vit_img_head_bias;
MatPtrT<WeightF32OrBF16> vit_img_head_kernel;
MatPtrT<WeightF32OrBF16> mm_embed_norm;
std::unordered_set<std::string> scale_names;
const ModelConfig& weights_config;
@ -488,6 +500,7 @@ struct ModelWeightsPtrs {
GEMMA_CALL_FUNC(vit_img_pos_embedding);
GEMMA_CALL_FUNC(vit_img_head_bias);
GEMMA_CALL_FUNC(vit_img_head_kernel);
GEMMA_CALL_FUNC(mm_embed_norm);
}
for (int layer_idx = 0; layer_idx < ptrs[0]->c_layers.size(); ++layer_idx) {
@ -605,4 +618,4 @@ class ModelWeightsStorage {
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_

View File

@ -202,7 +202,7 @@ class MMStorage {
// Compile-time bounds on matrix dimensions to enable pre-allocating storage
// and reusing it across `MatMul` calls. The resulting allocations are 256 MiB
// per package and 512 MiB, respectively.
static constexpr size_t kMaxM = 2048;
static constexpr size_t kMaxM = 4096;
static constexpr size_t kMaxK = 64 * 1024;
static constexpr size_t kMaxN = 256 * 1024;
// Upper bound for per-worker B storage on the stack. Chosen such that one row

View File

@ -802,6 +802,48 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
.prob = topk_logits[topk_sampled_index]};
}
// Performs 4x4 average pooling across row vectors
// Input has 4096 (64*64) rows, output has 256 (16*16) rows
// Each output row is the average of a 4x4 block of input rows
template <typename T>
RowVectorBatch<T> AvgPool4x4(RowVectorBatch<T>& input) {
Extents2D extents = input.Extents();
// Input validation
HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows
// Create output with 256 rows and same number of columns
const size_t out_rows = 256; // 16 * 16 = 256 output rows
RowVectorBatch<T> result(Extents2D{out_rows, extents.cols});
const size_t input_dim = 64; // Input is 64×64
const size_t output_dim = 16; // Output is 16×16
for (size_t out_row_idx = 0; out_row_idx < output_dim; ++out_row_idx) {
for (size_t out_col_idx = 0; out_col_idx < output_dim; ++out_col_idx) {
size_t out_idx = out_row_idx * output_dim + out_col_idx;
T* output_row = result.Batch(out_idx);
// Initialize output row to zeros
std::fill(output_row, output_row + extents.cols, 0);
// Average 16 row vectors from a 4x4 block
for (size_t i = 0; i < 4; ++i) {
for (size_t j = 0; j < 4; ++j) {
size_t in_row_idx = out_row_idx * 4 + i;
size_t in_col_idx = out_col_idx * 4 + j;
size_t in_idx = in_row_idx * input_dim + in_col_idx;
const T* input_row = input.Batch(in_idx);
// Add each input row to the output
// TODO(philculliton): use AddFrom in ops-inl for a vectorized loop.
for (size_t col = 0; col < extents.cols; ++col) {
output_row[col] += input_row[col];
}
}
}
// Divide by 16 to get the average
for (size_t col = 0; col < extents.cols; ++col) {
output_row[col] *= T{0.0625};
}
}
}
return result;
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp