From 4ab601da1089ea74eb9e84e22657a240d8182be6 Mon Sep 17 00:00:00 2001 From: Phil Culliton Date: Tue, 11 Mar 2025 23:19:36 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 736015810 --- compression/shared.h | 1 + gemma/activations.h | 3 + gemma/common.cc | 12 +++ gemma/configs.cc | 188 ++++++++++++++++++++++++++++++++++++++++++ gemma/configs.h | 9 +- gemma/gemma-inl.h | 165 ++++++++++++++++++++++++++---------- gemma/run.cc | 23 ++++-- gemma/tensor_index.cc | 53 ++++++++++-- gemma/tokenizer.cc | 33 +++++++- gemma/tokenizer.h | 4 + gemma/weights.h | 15 +++- ops/matmul.h | 2 +- ops/ops-inl.h | 42 ++++++++++ 13 files changed, 487 insertions(+), 63 deletions(-) diff --git a/compression/shared.h b/compression/shared.h index 29b07cc..a5c87ae 100644 --- a/compression/shared.h +++ b/compression/shared.h @@ -189,6 +189,7 @@ constexpr bool IsNuqStream() { enum class PromptWrapping { GEMMA_IT, GEMMA_PT, + GEMMA_VLM, PALIGEMMA, kSentinel // must be last }; diff --git a/gemma/activations.h b/gemma/activations.h index f08470c..86345e2 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -61,6 +61,7 @@ struct Activations { // Rope RowVectorBatch inv_timescale; + RowVectorBatch inv_timescale_global; // Dynamic because no default ctor and only initialized in `Allocate`. MatMulEnv* env; @@ -108,6 +109,8 @@ struct Activations { inv_timescale = CreateInvTimescale(layer_config.qkv_dim, post_qk == PostQKType::HalfRope); + inv_timescale_global = + CreateInvTimescale(qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0); this->env = env; } diff --git a/gemma/common.cc b/gemma/common.cc index dc37e3e..0d8977b 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -44,6 +44,10 @@ constexpr const char* kModelFlags[] = { "paligemma2-3b-448", // PaliGemma2 3B 448 "paligemma2-10b-224", // PaliGemma2 10B 224 "paligemma2-10b-448", // PaliGemma2 10B 448 + "gemma3-4b", // Gemma3 4B + "gemma3-1b", // Gemma3 1B + "gemma3-12b", // Gemma3 12B + "gemma3-27b", // Gemma3 27B }; constexpr Model kModelTypes[] = { Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B @@ -59,6 +63,10 @@ constexpr Model kModelTypes[] = { Model::PALIGEMMA2_3B_448, // PaliGemma2 3B 448 Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224 Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448 + Model::GEMMA3_4B, // Gemma3 4B + Model::GEMMA3_1B, // Gemma3 1B + Model::GEMMA3_12B, // Gemma3 12B + Model::GEMMA3_27B, // Gemma3 27B }; constexpr PromptWrapping kPromptWrapping[] = { PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B @@ -71,6 +79,10 @@ constexpr PromptWrapping kPromptWrapping[] = { PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PaliGemma 224/448 PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448 PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448 + PromptWrapping::GEMMA_VLM, // Gemma3 4B + PromptWrapping::GEMMA_PT, // Gemma3 1B + PromptWrapping::GEMMA_VLM, // Gemma3 12B + PromptWrapping::GEMMA_VLM, // Gemma3 27B }; constexpr size_t kNumModelFlags = std::size(kModelFlags); diff --git a/gemma/configs.cc b/gemma/configs.cc index 9372ee1..d058faf 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -328,6 +328,186 @@ static ModelConfig ConfigPaliGemma2_10B_448() { return config; } +static ModelConfig ConfigBaseGemmaV3() { + ModelConfig config = ConfigNoSSM(); + config.att_cap = 0.0f; + config.final_cap = 0.0f; + return config; +} + +// 1B does not include a vision encoder. +static LayerConfig LayerConfigGemma3_1B_LM(size_t model_dim) { + LayerConfig config; + config.model_dim = model_dim; + config.ff_hidden_dim = 6912; + config.heads = 4; + config.kv_heads = 1; + config.qkv_dim = 256; + config.optimized_gating = true; + config.post_norm = PostNormType::Scale; + config.use_qk_norm = true; + return config; +} + +static ModelConfig ConfigGemma3_1B() { + ModelConfig config = ConfigBaseGemmaV3(); + config.model_name = "Gemma3_1B"; + config.model = Model::GEMMA3_1B; + config.model_dim = 1152; + config.vocab_size = 262144; // new vocab size / tokenizer + config.seq_len = 32 * 1024; + LayerConfig layer_config = LayerConfigGemma3_1B_LM(config.model_dim); + config.layer_configs = {26, layer_config}; + config.num_tensor_scales = 4 * config.layer_configs.size(); + config.query_scale = QueryScaleType::SqrtKeySize; + // interleaved local / global attention + config.attention_window_sizes = RepeatedAttentionWindowSizes<26, 6>( + {512, 512, 512, 512, 512, config.seq_len}); + return config; +} + +static LayerConfig LayerConfigGemma3_4B_LM(size_t model_dim) { + LayerConfig config; + config.model_dim = model_dim; + config.ff_hidden_dim = 8 * 2560 / 2; // = 10240 + config.heads = 8; + config.kv_heads = 4; + config.qkv_dim = 256; + config.optimized_gating = true; + config.post_norm = PostNormType::Scale; + config.use_qk_norm = true; + return config; +} + +// Until we have the SigLIP checkpoints included, we use the LM config directly. +static ModelConfig ConfigGemma3_4B_LM() { + ModelConfig config = ConfigBaseGemmaV3(); + config.model_name = "Gemma3_4B"; + config.model = Model::GEMMA3_4B; + config.model_dim = 2560; + config.vocab_size = 262144; // new vocab size / tokenizer + config.seq_len = 32 * 1024; + LayerConfig layer_config = LayerConfigGemma3_4B_LM(config.model_dim); + config.layer_configs = {34, layer_config}; + config.num_tensor_scales = 4 * config.layer_configs.size(); + config.query_scale = QueryScaleType::SqrtKeySize; + // interleaved local / global attention + config.attention_window_sizes = RepeatedAttentionWindowSizes<34, 6>( + {1024, 1024, 1024, 1024, 1024, config.seq_len}); + return config; +} + +static ModelConfig ConfigGemma3_4B() { + ModelConfig config = ConfigGemma3_4B_LM(); + config.model_name = "Gemma3_4B"; + config.model = Model::GEMMA3_4B; + AddVitConfig(config, /*image_size=*/896); + config.vocab_size = 262144; + config.vit_config.pool_dim = 4; + const size_t num_patches = + config.vit_config.image_size / config.vit_config.patch_width; + config.vit_config.seq_len = (num_patches * num_patches); + // The above resets optimized gating to false; for Gemma 3 it should be true. + for (auto& layer_config : config.layer_configs) { + layer_config.optimized_gating = true; + } + return config; +} + +static LayerConfig LayerConfigGemma3_12B_LM(size_t model_dim) { + LayerConfig config; + config.model_dim = model_dim; + config.ff_hidden_dim = 15360; + config.heads = 16; + config.kv_heads = 8; + config.qkv_dim = 256; + config.optimized_gating = true; + config.post_norm = PostNormType::Scale; + config.use_qk_norm = true; + return config; +} + +static ModelConfig ConfigGemma3_12B_LM() { + ModelConfig config = ConfigBaseGemmaV3(); + config.model_name = "Gemma3_12B"; + config.model = Model::GEMMA3_12B; + config.model_dim = 3840; + config.vocab_size = 262144; // new vocab size / tokenizer + config.seq_len = 32 * 1024; + LayerConfig layer_config = LayerConfigGemma3_12B_LM(config.model_dim); + config.layer_configs = {48, layer_config}; + config.num_tensor_scales = 4 * config.layer_configs.size(); + config.query_scale = QueryScaleType::SqrtKeySize; + // interleaved local / global attention + config.attention_window_sizes = RepeatedAttentionWindowSizes<48, 6>( + {1024, 1024, 1024, 1024, 1024, config.seq_len}); + return config; +} + +static ModelConfig ConfigGemma3_12B() { + ModelConfig config = ConfigGemma3_12B_LM(); + config.model_name = "Gemma3_12B"; + config.model = Model::GEMMA3_12B; + AddVitConfig(config, /*image_size=*/896); + config.vocab_size = 262144; + config.vit_config.pool_dim = 4; + const size_t num_patches = + config.vit_config.image_size / config.vit_config.patch_width; + config.vit_config.seq_len = (num_patches * num_patches); + // The above resets optimized gating to false; for Gemma 3 it should be true. + for (auto& layer_config : config.layer_configs) { + layer_config.optimized_gating = true; + } + return config; +} + +static LayerConfig LayerConfigGemma3_27B_LM(size_t model_dim) { + LayerConfig config; + config.model_dim = model_dim; + config.ff_hidden_dim = 21504; + config.heads = 32; + config.kv_heads = 16; + config.qkv_dim = 128; + config.optimized_gating = true; + config.post_norm = PostNormType::Scale; + config.use_qk_norm = true; + return config; +} + +static ModelConfig ConfigGemma3_27B_LM() { + ModelConfig config = ConfigBaseGemmaV3(); + config.model_name = "Gemma3_27B"; + config.model = Model::GEMMA3_27B; + config.model_dim = 5376; + config.vocab_size = 262144; // new vocab size / tokenizer + config.seq_len = 32 * 1024; + LayerConfig layer_config = LayerConfigGemma3_27B_LM(config.model_dim); + config.layer_configs = {62, layer_config}; + config.num_tensor_scales = 4 * config.layer_configs.size(); + config.query_scale = QueryScaleType::SqrtKeySize; + // interleaved local / global attention + config.attention_window_sizes = RepeatedAttentionWindowSizes<62, 6>( + {1024, 1024, 1024, 1024, 1024, config.seq_len}); + return config; +} + +static ModelConfig ConfigGemma3_27B() { + ModelConfig config = ConfigGemma3_27B_LM(); + config.model_name = "Gemma3_27B"; + config.model = Model::GEMMA3_27B; + AddVitConfig(config, /*image_size=*/896); + config.vocab_size = 262144; + config.vit_config.pool_dim = 4; + const size_t num_patches = + config.vit_config.image_size / config.vit_config.patch_width; + config.vit_config.seq_len = (num_patches * num_patches); + // The above resets optimized gating to false; for Gemma 3 it should be true. + for (auto& layer_config : config.layer_configs) { + layer_config.optimized_gating = true; + } + return config; +} + ModelConfig ConfigFromModel(Model model) { switch (model) { case Model::GEMMA_2B: @@ -356,6 +536,14 @@ ModelConfig ConfigFromModel(Model model) { return ConfigPaliGemma2_10B_224(); case Model::PALIGEMMA2_10B_448: return ConfigPaliGemma2_10B_448(); + case Model::GEMMA3_4B: + return ConfigGemma3_4B(); + case Model::GEMMA3_1B: + return ConfigGemma3_1B(); + case Model::GEMMA3_12B: + return ConfigGemma3_12B(); + case Model::GEMMA3_27B: + return ConfigGemma3_27B(); default: HWY_ABORT("Model type %d unknown.", static_cast(model)); } diff --git a/gemma/configs.h b/gemma/configs.h index d7078b3..a5dba12 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -151,6 +151,10 @@ enum class Model { PALIGEMMA2_3B_448, PALIGEMMA2_10B_224, PALIGEMMA2_10B_448, + GEMMA3_4B, + GEMMA3_1B, + GEMMA3_12B, + GEMMA3_27B, }; // Allows the Model enum to be iterated over. @@ -159,7 +163,8 @@ static constexpr Model kAllModels[] = { Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B, Model::PALIGEMMA_224, Model::PALIGEMMA_448, Model::PALIGEMMA2_3B_224, Model::PALIGEMMA2_3B_448, Model::PALIGEMMA2_10B_224, - Model::PALIGEMMA2_10B_448, + Model::PALIGEMMA2_10B_448, Model::GEMMA3_4B, Model::GEMMA3_1B, + Model::GEMMA3_12B, Model::GEMMA3_27B, }; inline bool EnumValid(Model model) { @@ -202,6 +207,7 @@ struct LayerConfig : public IFields { visitor(type); visitor(activation); visitor(post_qk); + visitor(use_qk_norm); } uint32_t model_dim = 0; @@ -218,6 +224,7 @@ struct LayerConfig : public IFields { LayerAttentionType type = LayerAttentionType::kGemma; ActivationType activation = ActivationType::Gelu; PostQKType post_qk = PostQKType::Rope; + bool use_qk_norm = false; }; // Dimensions related to image processing. diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index eab1a4d..e971fbe 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -219,6 +219,17 @@ class GemmaAttention { // qk is either q or k, so qkv_dim is the length we operate on. const size_t qkv_dim = layer_config_.qkv_dim; const float* inv_timescale = activations_.inv_timescale.Const(); + bool is_global_layer = + activations_.weights_config.attention_window_sizes[layer] == + activations_.seq_len; + // TODO: add a config flag instead of hardcoding the model. + if (is_global_layer && + (activations_.weights_config.model == Model::GEMMA3_4B || + activations_.weights_config.model == Model::GEMMA3_12B || + activations_.weights_config.model == Model::GEMMA3_27B || + activations_.weights_config.model == Model::GEMMA3_1B)) { + inv_timescale = activations_.inv_timescale_global.Const(); + } // PostQKType::Rope (void)layer; if (layer_weights_.layer_config.post_qk == PostQKType::HalfRope) { @@ -324,6 +335,10 @@ class GemmaAttention { } // Apply further processing to K. + if (layer_weights_.key_norm_scale.data()) { + RMSNormInplace(layer_weights_.key_norm_scale.data(), kv, + qkv_dim); + } PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f); }); } @@ -411,6 +426,10 @@ class GemmaAttention { // Apply rope and scaling to Q. const size_t pos = queries_pos_[query_idx] + batch_idx; + if (layer_weights_.query_norm_scale.data()) { + RMSNormInplace(layer_weights_.query_norm_scale.data(), q, + qkv_dim); + } PositionalEncodingQK(q, pos, layer_, query_scale); const size_t start_pos = StartPos(pos, layer_); @@ -590,6 +609,7 @@ class VitAttention { RowPtrFromBatch(qkv)); } + // TODO(philculliton): transition fully to MatMul. HWY_NOINLINE void DotSoftmaxWeightedSum() { const size_t qkv_dim = layer_config_.qkv_dim; const size_t heads = layer_config_.heads; @@ -598,34 +618,55 @@ class VitAttention { const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); - // Compute Q.K, softmax, and weighted V. - pool_.Run(0, layer_config_.heads * num_tokens_, - [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t head = task % layer_config_.heads; - const size_t token = task / layer_config_.heads; - // Compute Q.K scores, which are "logits" stored in head_att. - float* HWY_RESTRICT q = - activations_.q.Batch(token) + head * 3 * qkv_dim; - MulByConst(query_scale, q, qkv_dim); - float* HWY_RESTRICT head_att = - activations_.att.Batch(token) + head * activations_.seq_len; - for (size_t i = 0; i < seq_len; ++i) { - float* HWY_RESTRICT k = - activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim; - head_att[i] = Dot(q, k, qkv_dim); // score = q.k - } - // SoftMax yields "probabilities" in head_att. - Softmax(head_att, seq_len); - // Compute weighted sum of v into att_out. - float* HWY_RESTRICT att_out = - activations_.att_out.Batch(token) + head * qkv_dim; - hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); - for (size_t i = 0; i < seq_len; ++i) { - float* HWY_RESTRICT v = activations_.q.Batch(i) + - head * 3 * qkv_dim + 2 * qkv_dim; - MulByConstAndAdd(head_att[i], v, att_out, qkv_dim); - } - }); + // Shift Q, K, VT to RowVectorBatches with AllocateAlignedRows(extents) + RowVectorBatch Q = + AllocateAlignedRows(Extents2D(num_tokens_, qkv_dim)); + RowVectorBatch K = + AllocateAlignedRows(Extents2D(seq_len, qkv_dim)); + RowVectorBatch C(Extents2D(num_tokens_, seq_len)); + + // Initialize att_out to zero prior to head loop. + hwy::ZeroBytes(activations_.att_out.All(), + num_tokens_ * heads * qkv_dim * sizeof(float)); + + for (size_t head = 0; head < heads; ++head) { + pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + const size_t token = task; + float* HWY_RESTRICT q = + activations_.q.Batch(token) + head * 3 * qkv_dim; + // TODO: shift to MatMul with A.scale once MatMul is confirmed working + MulByConst(query_scale, q, qkv_dim); + hwy::CopyBytes(q, Q.Batch(token), qkv_dim * sizeof(float)); + }); + + pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + const size_t seq_idx = task; + float* HWY_RESTRICT k = + activations_.q.Batch(seq_idx) + head * 3 * qkv_dim + qkv_dim; + hwy::CopyBytes(k, K.Batch(seq_idx), qkv_dim * sizeof(float)); + }); + + // this produces C, a (num_tokens_, seq_len) matrix of dot products + MatMul(ConstMatFromBatch(Q.BatchSize(), Q), + ConstMatFromBatch(K.BatchSize(), K), nullptr, *activations_.env, + RowPtrFromBatch(C)); + + pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + float* HWY_RESTRICT c = C.Batch(task); + Softmax(c, C.Cols()); + }); + + pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + size_t token = task; + float* HWY_RESTRICT att_out = + activations_.att_out.Batch(token) + head * qkv_dim; + for (size_t i = 0; i < seq_len; ++i) { + float* HWY_RESTRICT v = + activations_.q.Batch(i) + head * 3 * qkv_dim + 2 * qkv_dim; + MulByConstAndAdd(C.Batch(token)[i], v, att_out, qkv_dim); + } + }); + } } // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and @@ -784,14 +825,30 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved, // `batch_idx` indicates which row of `x` to write to. // `pos` is the *token*'s position, not the start of the batch, because this is // called for batches of tokens in prefill, but batches of queries in decode. +// +// For GEMMA_VLM, image tokens are copied into -2 locations (per the Gemma 3 +// spec) until we run out of image tokens. This allows for a multi-image prompt +// if -2 locations with appropriate begin/end image tokens are created by the +// calling application. template -HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos, - size_t pos_in_prompt, - const ModelWeightsPtrs& weights, - RowVectorBatch& x, - const ImageTokens* image_tokens) { +HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos, + size_t pos_in_prompt, + const ModelWeightsPtrs& weights, + RowVectorBatch& x, + const ImageTokens* image_tokens, + size_t& image_token_position) { // Image tokens just need to be copied. - if (image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) { + if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM && + image_tokens != nullptr && token == -2 && + image_token_position < image_tokens->BatchSize()) { + hwy::CopyBytes(image_tokens->Batch(image_token_position), + x.Batch(batch_idx), x.Cols() * sizeof(x.Const()[0])); + image_token_position++; + return; + } + + if (weights.weights_config.wrapping == PromptWrapping::PALIGEMMA && + image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) { hwy::CopyBytes(image_tokens->Batch(pos_in_prompt), x.Batch(batch_idx), x.Cols() * sizeof(x.Const()[0])); return; @@ -816,6 +873,21 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos, } } +// `batch_idx` indicates which row of `x` to write to. +// `pos` is the *token*'s position, not the start of the batch, because this is +// called for batches of tokens in prefill, but batches of queries in decode. +// This version of the function doesn't track internal image token position. +template +HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos, + size_t pos_in_prompt, + const ModelWeightsPtrs& weights, + RowVectorBatch& x, + const ImageTokens* image_tokens) { + size_t image_token_position = 0; + EmbedMMToken(token, batch_idx, pos, pos_in_prompt, weights, x, + image_tokens, image_token_position); +} + template HWY_NOINLINE void ResidualConnection( size_t num_interleaved, T* HWY_RESTRICT other, T* HWY_RESTRICT x, @@ -990,12 +1062,13 @@ HWY_NOINLINE void Prefill( HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start); // Fill activations.x (much faster than TransformerLayer). + size_t image_token_position = 0; for (size_t ti = 0; ti < tbatch_size; ++ti) { const size_t pos = queries_pos[qi] + ti; const size_t pos_in_prompt = tbatch_start + ti; const int token = queries_prompt[qi][pos_in_prompt]; - EmbedToken(token, ti, pos, pos_in_prompt, weights, activations.x, - runtime_config.image_tokens); + EmbedMMToken(token, ti, pos, pos_in_prompt, weights, activations.x, + runtime_config.image_tokens, image_token_position); } // Transformer with one batch of tokens from a single query. @@ -1104,8 +1177,14 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs& weights, weights.vit_encoder_norm_bias.data_scale1(), activations.x.All(), vit_model_dim); + activations.x = AvgPool4x4(activations.x); + + // Apply soft embedding norm before input projection. + RMSNormInplace(weights.mm_embed_norm.data_scale1(), activations.x.All(), + vit_model_dim); + // Apply head embedding into image_tokens of size of the LLM kModelDim. - MatMul(ConstMatFromBatch(num_tokens, activations.x), + MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x), ConstMatFromWeights(weights.vit_img_head_kernel), weights.vit_img_head_bias.data_scale1(), *activations.env, RowPtrFromBatch(image_tokens)); @@ -1133,10 +1212,11 @@ HWY_NOINLINE void Transformer( } } + size_t image_token_position = 0; for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - EmbedToken(queries_token[query_idx], query_idx, queries_pos[query_idx], - /*pos_in_prompt=*/0, weights, activations.x, - /*image_tokens=*/nullptr); + EmbedMMToken(queries_token[query_idx], query_idx, queries_pos[query_idx], + /*pos_in_prompt=*/0, weights, activations.x, + /*image_tokens=*/nullptr, image_token_position); } for (size_t layer = 0; layer < weights.c_layers.size(); ++layer) { @@ -1440,7 +1520,8 @@ void GenerateImageTokensT(const ModelWeightsStorage& model, } RuntimeConfig prefill_runtime_config = runtime_config; ModelConfig vit_config = GetVitConfig(model.Config()); - prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len; + prefill_runtime_config.prefill_tbatch_size = + vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim); Activations prefill_activations(vit_config); prefill_activations.Allocate(vit_config.seq_len, env); // Weights are for the full PaliGemma model, not just the ViT part. @@ -1487,4 +1568,4 @@ void GenerateImageTokens( // NOLINT(misc-definitions-in-headers) } // namespace gcpp HWY_AFTER_NAMESPACE(); -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ \ No newline at end of file diff --git a/gemma/run.cc b/gemma/run.cc index 2052100..e4c2d2e 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -94,10 +94,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, Image image; ImageTokens image_tokens; if (have_image) { - image_tokens = - ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len, - model.GetModelConfig().model_dim)); - HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA); + size_t pool_dim = model.GetModelConfig().vit_config.pool_dim; + image_tokens = ImageTokens(Extents2D( + model.GetModelConfig().vit_config.seq_len / (pool_dim * pool_dim), + model.GetModelConfig().model_dim)); + HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA || + model.Info().wrapping == PromptWrapping::GEMMA_VLM); HWY_ASSERT(image.ReadPPM(args.image_file.path)); const size_t image_size = model.GetModelConfig().vit_config.image_size; image.Resize(image_size, image_size); @@ -190,7 +192,15 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, size_t prefix_end = 0; if (have_image) { runtime_config.image_tokens = &image_tokens; - prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0); + if (model.Info().wrapping == PromptWrapping::PALIGEMMA) { + prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0); + } else if (model.Info().wrapping == PromptWrapping::GEMMA_VLM) { + size_t seq_len = model.GetModelConfig().vit_config.seq_len; + size_t pool_dim = model.GetModelConfig().vit_config.pool_dim; + prompt = + WrapVLM(model.Tokenizer(), model.Info(), abs_pos, prompt, + image_tokens.BatchSize(), seq_len / (pool_dim * pool_dim)); + } prompt_size = prompt.size(); // The end of the prefix for prefix-LM style attention in Paligemma. // See Figure 2 of https://arxiv.org/abs/2407.07726. @@ -209,9 +219,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, // Prepare for the next turn. if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) { - abs_pos = 0; // Start a new turn at position 0. - InitGenerator(args, gen); - } else { // The last token was either EOS, then it should be ignored because it is // never part of the dialog, see Table 5 in the Gemma-2 paper: // https://arxiv.org/pdf/2408.00118 diff --git a/gemma/tensor_index.cc b/gemma/tensor_index.cc index 354a1b4..4308c9d 100644 --- a/gemma/tensor_index.cc +++ b/gemma/tensor_index.cc @@ -64,14 +64,14 @@ std::vector ModelTensors(const ModelConfig& config) { }, TensorInfo{ .name = "img_head_bias", - .source_names = {"img/head/bias"}, + .source_names = {"img/head/bias", "embedder/mm_input_projection/b"}, .axes = {0}, .shape = {config.model_dim}, .min_size = Type::kF32, }, TensorInfo{ .name = "img_head_kernel", - .source_names = {"img/head/kernel"}, + .source_names = {"img/head/kernel", "embedder/mm_input_projection/w"}, .axes = {1, 0}, .shape = {config.model_dim, config.vit_config.model_dim}, .min_size = Type::kBF16, @@ -84,12 +84,21 @@ std::vector ModelTensors(const ModelConfig& config) { config.vit_config.model_dim}, .min_size = Type::kF32, }, + // RMS norm applied to soft tokens prior to pos embedding. + TensorInfo{ + .name = "mm_embed_norm", + .source_names = {"embedder/mm_soft_embedding_norm/scale"}, + .axes = {0}, + .shape = {config.vit_config.model_dim}, + .min_size = Type::kBF16, + }, }; } // Returns the tensors for the given image layer config. std::vector ImageLayerTensors(const ModelConfig& config, - const LayerConfig& layer_config) { + const LayerConfig& layer_config, + const int img_layer_idx) { return { // Vit layers. TensorInfo{ @@ -207,28 +216,40 @@ std::vector ImageLayerTensors(const ModelConfig& config, }, TensorInfo{ .name = "ln_0_bias", - .source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias"}, + .source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias", + "img/Transformer/encoderblock_" + + std::to_string(img_layer_idx) + + "/LayerNorm_0/bias"}, .axes = {0}, .shape = {config.vit_config.model_dim}, .min_size = Type::kBF16, }, TensorInfo{ .name = "ln_0_scale", - .source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale"}, + .source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale", + "img/Transformer/encoderblock_" + + std::to_string(img_layer_idx) + + "/LayerNorm_0/scale"}, .axes = {0}, .shape = {config.vit_config.model_dim}, .min_size = Type::kBF16, }, TensorInfo{ .name = "ln_1_bias", - .source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias"}, + .source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias", + "img/Transformer/encoderblock_" + + std::to_string(img_layer_idx) + + "/LayerNorm_1/bias"}, .axes = {0}, .shape = {config.vit_config.model_dim}, .min_size = Type::kBF16, }, TensorInfo{ .name = "ln_1_scale", - .source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale"}, + .source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale", + "img/Transformer/encoderblock_" + + std::to_string(img_layer_idx) + + "/LayerNorm_1/scale"}, .axes = {0}, .shape = {config.vit_config.model_dim}, .min_size = Type::kBF16, @@ -241,6 +262,20 @@ std::vector LLMLayerTensors(const ModelConfig& config, const LayerConfig& layer_config, bool reshape_att) { std::vector tensors = { + TensorInfo{ + .name = "key_norm", + .source_names = {"attn/_key_norm/scale"}, + .axes = {0}, + .shape = {layer_config.qkv_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "query_norm", + .source_names = {"attn/_query_norm/scale"}, + .axes = {0}, + .shape = {layer_config.qkv_dim}, + .min_size = Type::kBF16, + }, TensorInfo{ .name = "qkv1_w", .source_names = {"attn/q_einsum/w"}, @@ -529,7 +564,7 @@ TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx, } else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx && img_layer_idx < config.vit_config.layer_configs.size()) { const auto& layer_config = config.vit_config.layer_configs[img_layer_idx]; - tensors_ = ImageLayerTensors(config, layer_config); + tensors_ = ImageLayerTensors(config, layer_config, img_layer_idx); } else if (0 <= llm_layer_idx && llm_layer_idx < config.layer_configs.size()) { const auto& layer_config = config.layer_configs[llm_layer_idx]; @@ -569,4 +604,4 @@ const TensorInfo* TensorIndex::FindName(const std::string& name) const { return &tensors_[it->second]; } -} // namespace gcpp +} // namespace gcpp \ No newline at end of file diff --git a/gemma/tokenizer.cc b/gemma/tokenizer.cc index 9e0b827..e48abae 100644 --- a/gemma/tokenizer.cc +++ b/gemma/tokenizer.cc @@ -127,7 +127,9 @@ std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, } // PaliGemma separator. The SEP token "\n" is always tokenized separately. - if (info.wrapping == PromptWrapping::PALIGEMMA) { + if (info.wrapping == PromptWrapping::PALIGEMMA + // || info.wrapping == PromptWrapping::GEMMA_VLM + ) { std::vector sep_tokens; HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens)); tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end()); @@ -136,4 +138,33 @@ std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, return tokens; } +std::vector WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info, + size_t pos, std::vector& tokens, + size_t image_batch_size, size_t max_image_batch_size) { + HWY_ASSERT(info.wrapping == PromptWrapping::GEMMA_VLM); + size_t num_images = hwy::DivCeil(image_batch_size, max_image_batch_size); + + std::vector sep_tokens; + HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens)); + + std::string begin_image_prompt = "\n\n"; + std::vector begin_image_tokens = + WrapAndTokenize(tokenizer, info, pos, begin_image_prompt); + + std::string end_image_prompt = "\n\n"; + std::vector end_image_tokens = + WrapAndTokenize(tokenizer, info, pos, end_image_prompt); + + for (size_t i = 0; i < num_images; ++i) { + tokens.insert(tokens.begin(), begin_image_tokens.begin(), + begin_image_tokens.end()); + tokens.insert(tokens.begin() + begin_image_tokens.size(), image_batch_size, + -2); + tokens.insert(tokens.begin() + begin_image_tokens.size() + image_batch_size, + end_image_tokens.begin(), end_image_tokens.end()); + } + + return tokens; +} + } // namespace gcpp diff --git a/gemma/tokenizer.h b/gemma/tokenizer.h index e2bb611..a5c4c4f 100644 --- a/gemma/tokenizer.h +++ b/gemma/tokenizer.h @@ -57,6 +57,10 @@ std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, const ModelInfo& info, size_t pos, std::string& prompt); +std::vector WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info, + size_t pos, std::vector& tokens, + size_t image_batch_size, size_t max_image_batch_size); + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TOKENIZER_H_ diff --git a/gemma/weights.h b/gemma/weights.h index 1a71777..ba427ac 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -99,6 +99,8 @@ struct LayerWeightsPtrs { ffw_gating_biases("ffw_gat_b", tensor_index), ffw_output_biases("ffw_out_b", tensor_index), att_weights("att_w", tensor_index), + key_norm_scale("key_norm", tensor_index), + query_norm_scale("query_norm", tensor_index), layer_config(config) {} ~LayerWeightsPtrs() = default; @@ -204,6 +206,9 @@ struct LayerWeightsPtrs { att_weights.set_scale(attn_vec_einsum_w.scale()); } + ArrayT key_norm_scale; + ArrayT query_norm_scale; + // Used by ForEachTensor for per-layer tensors. #define GEMMA_CALL_FUNC(member) \ { \ @@ -282,6 +287,10 @@ struct LayerWeightsPtrs { GEMMA_CALL_FUNC(post_attention_norm_scale); GEMMA_CALL_FUNC(post_ffw_norm_scale); } + if (ptrs[0]->layer_config.use_qk_norm) { + GEMMA_CALL_FUNC(key_norm_scale); + GEMMA_CALL_FUNC(query_norm_scale); + } if (ptrs[0]->layer_config.ff_biases) { GEMMA_CALL_FUNC(ffw_gating_biases); @@ -332,6 +341,7 @@ struct ModelWeightsPtrs { vit_img_pos_embedding("img_pos_emb", tensor_index), vit_img_head_bias("img_head_bias", tensor_index), vit_img_head_kernel("img_head_kernel", tensor_index), + mm_embed_norm("mm_embed_norm", tensor_index), scale_names(config.scale_names), weights_config(config) { c_layers.reserve(config.layer_configs.size()); @@ -372,6 +382,8 @@ struct ModelWeightsPtrs { MatPtrT vit_img_head_bias; MatPtrT vit_img_head_kernel; + MatPtrT mm_embed_norm; + std::unordered_set scale_names; const ModelConfig& weights_config; @@ -488,6 +500,7 @@ struct ModelWeightsPtrs { GEMMA_CALL_FUNC(vit_img_pos_embedding); GEMMA_CALL_FUNC(vit_img_head_bias); GEMMA_CALL_FUNC(vit_img_head_kernel); + GEMMA_CALL_FUNC(mm_embed_norm); } for (int layer_idx = 0; layer_idx < ptrs[0]->c_layers.size(); ++layer_idx) { @@ -605,4 +618,4 @@ class ModelWeightsStorage { } // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ \ No newline at end of file diff --git a/ops/matmul.h b/ops/matmul.h index 72ccd4a..b7f0ac6 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -202,7 +202,7 @@ class MMStorage { // Compile-time bounds on matrix dimensions to enable pre-allocating storage // and reusing it across `MatMul` calls. The resulting allocations are 256 MiB // per package and 512 MiB, respectively. - static constexpr size_t kMaxM = 2048; + static constexpr size_t kMaxM = 4096; static constexpr size_t kMaxK = 64 * 1024; static constexpr size_t kMaxN = 256 * 1024; // Upper bound for per-worker B storage on the stack. Chosen such that one row diff --git a/ops/ops-inl.h b/ops/ops-inl.h index a207381..52f72bd 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -802,6 +802,48 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( .prob = topk_logits[topk_sampled_index]}; } +// Performs 4x4 average pooling across row vectors +// Input has 4096 (64*64) rows, output has 256 (16*16) rows +// Each output row is the average of a 4x4 block of input rows +template +RowVectorBatch AvgPool4x4(RowVectorBatch& input) { + Extents2D extents = input.Extents(); + // Input validation + HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows + // Create output with 256 rows and same number of columns + const size_t out_rows = 256; // 16 * 16 = 256 output rows + RowVectorBatch result(Extents2D{out_rows, extents.cols}); + const size_t input_dim = 64; // Input is 64×64 + const size_t output_dim = 16; // Output is 16×16 + for (size_t out_row_idx = 0; out_row_idx < output_dim; ++out_row_idx) { + for (size_t out_col_idx = 0; out_col_idx < output_dim; ++out_col_idx) { + size_t out_idx = out_row_idx * output_dim + out_col_idx; + T* output_row = result.Batch(out_idx); + // Initialize output row to zeros + std::fill(output_row, output_row + extents.cols, 0); + // Average 16 row vectors from a 4x4 block + for (size_t i = 0; i < 4; ++i) { + for (size_t j = 0; j < 4; ++j) { + size_t in_row_idx = out_row_idx * 4 + i; + size_t in_col_idx = out_col_idx * 4 + j; + size_t in_idx = in_row_idx * input_dim + in_col_idx; + const T* input_row = input.Batch(in_idx); + // Add each input row to the output + // TODO(philculliton): use AddFrom in ops-inl for a vectorized loop. + for (size_t col = 0; col < extents.cols; ++col) { + output_row[col] += input_row[col]; + } + } + } + // Divide by 16 to get the average + for (size_t col = 0; col < extents.cols; ++col) { + output_row[col] *= T{0.0625}; + } + } + } + return result; +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp