From 1bdde1af3c1a8a50d53994f88b0bfac4272214fb Mon Sep 17 00:00:00 2001 From: Theotime Combes Date: Fri, 24 Oct 2025 06:54:19 -0700 Subject: [PATCH] Add config flag for global timescale & rely on config to deduce wrapping PiperOrigin-RevId: 823512377 --- gemma/attention.cc | 3 +-- gemma/configs.cc | 28 ++++++++++++++++------------ gemma/configs.h | 12 ++++-------- python/configs.cc | 2 ++ 4 files changed, 23 insertions(+), 22 deletions(-) diff --git a/gemma/attention.cc b/gemma/attention.cc index 117b533..668105a 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -82,8 +82,7 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx, // qk is either q or k, so qkv_dim is the length we operate on. const float* inv_timescale = activations.inv_timescale.PackedScale1(); const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx); - // TODO: add a config flag instead of hardcoding the model. - if (is_global_layer && IsVLM(activations.config.model)) { + if (is_global_layer && activations.config.use_global_timescale) { inv_timescale = activations.inv_timescale_global.PackedScale1(); } // PostQKType::Rope diff --git a/gemma/configs.cc b/gemma/configs.cc index 8856203..b8048b8 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -22,8 +22,8 @@ #include #include "compression/types.h" // Type -#include "io/fields.h" // IFields -#include "io/io.h" // Path +#include "io/fields.h" // IFields +#include "io/io.h" // Path #include "hwy/base.h" namespace gcpp { @@ -238,6 +238,7 @@ static ModelConfig ConfigGemma3_1B() { config.display_name = "Gemma3_1B"; config.model = Model::GEMMA3_1B; config.wrapping = PromptWrapping::GEMMA_VLM; + config.use_global_timescale = true; config.model_dim = 1152; config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer config.max_seq_len = 32 * 1024; @@ -288,6 +289,7 @@ static ModelConfig ConfigGemma3_4B() { config.display_name = "Gemma3_4B"; config.model = Model::GEMMA3_4B; config.wrapping = PromptWrapping::GEMMA_VLM; + config.use_global_timescale = true; AddVitConfig(config, /*image_size=*/896); config.vocab_size = kGemmaV3VocabSize; config.vit_config.pool_dim = 4; @@ -337,6 +339,7 @@ static ModelConfig ConfigGemma3_12B() { config.display_name = "Gemma3_12B"; config.model = Model::GEMMA3_12B; config.wrapping = PromptWrapping::GEMMA_VLM; + config.use_global_timescale = true; AddVitConfig(config, /*image_size=*/896); config.vocab_size = kGemmaV3VocabSize; config.vit_config.pool_dim = 4; @@ -386,6 +389,7 @@ static ModelConfig ConfigGemma3_27B() { config.display_name = "Gemma3_27B"; config.model = Model::GEMMA3_27B; config.wrapping = PromptWrapping::GEMMA_VLM; + config.use_global_timescale = true; AddVitConfig(config, /*image_size=*/896); config.vocab_size = kGemmaV3VocabSize; config.vit_config.pool_dim = 4; @@ -495,19 +499,19 @@ const char* ModelPrefix(Model model) { } PromptWrapping ChooseWrapping(const Model model, Tristate wrapping) { - if (IsPaliGemma(model)) { + const PromptWrapping config_wrapping = ConfigFromModel(model).wrapping; + + // For models with a fixed wrapping mode, ignore user override. + if (config_wrapping == PromptWrapping::PALIGEMMA || + config_wrapping == PromptWrapping::GEMMA_VLM) { if (wrapping != Tristate::kDefault) { - HWY_WARN("Ignoring unnecessary --wrapping for PaliGemma models."); + HWY_WARN("Ignoring unnecessary --wrapping for model %s.", + ModelPrefix(model)); } - return PromptWrapping::PALIGEMMA; + return config_wrapping; } - if (IsVLM(model)) { - if (wrapping != Tristate::kDefault) { - HWY_WARN("Ignoring unnecessary --wrapping for VLM models."); - } - return PromptWrapping::GEMMA_VLM; - } - // Default to IT unless --wrapping=0. + + // For other models, default to IT unless --wrapping=0 is passed. return wrapping == Tristate::kFalse ? PromptWrapping::GEMMA_PT : PromptWrapping::GEMMA_IT; } diff --git a/gemma/configs.h b/gemma/configs.h index 275f374..d774481 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -184,13 +184,6 @@ enum class Model { // in Specifier and thus does not change. const char* ModelPrefix(Model model); -// Gemma3 is multimodal and has a different prompt wrapping than PaliGemma. -// This is used for deducing the PromptWrapping for pre-2025 BlobStore. -static inline bool IsVLM(Model model) { - return model == Model::GEMMA3_4B || model == Model::GEMMA3_1B || - model == Model::GEMMA3_12B || model == Model::GEMMA3_27B; -} - static inline bool IsPaliGemma(Model model) { if (model == Model::PALIGEMMA2_3B_224 || model == Model::PALIGEMMA2_3B_448 || model == Model::PALIGEMMA2_10B_224 || @@ -280,7 +273,7 @@ struct LayerConfig : public IFields { uint32_t kv_heads = 0; uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous). bool ff_biases = false; - bool optimized_gating = true; // for Gemma3 + bool optimized_gating = true; // for Gemma3 PostNormType post_norm = PostNormType::None; LayerAttentionType type = LayerAttentionType::kGemma; ActivationType activation = ActivationType::Gelu; @@ -383,6 +376,8 @@ struct ModelConfig : public IFields { internal.VisitFields(visitor); + visitor(use_global_timescale); + // Append new fields here, then update `python/configs.cc`. } @@ -481,6 +476,7 @@ struct ModelConfig : public IFields { std::vector scale_base_names; InternalModelConfig internal; + bool use_global_timescale = false; // for Gemma 3 }; // Returns the sub-config for the ViT model of the PaliGemma model. diff --git a/python/configs.cc b/python/configs.cc index e544bb0..0d505dc 100644 --- a/python/configs.cc +++ b/python/configs.cc @@ -173,6 +173,8 @@ PYBIND11_MODULE(configs, py_module) { .def_readwrite("secondary_eos_id", &ModelConfig::secondary_eos_id) .def_readwrite("scale_base_names", &ModelConfig::scale_base_names) .def_readwrite("internal", &ModelConfig::internal) + .def_readwrite("use_global_timescale", + &ModelConfig::use_global_timescale) .def("add_layer_config", &ModelConfig::AddLayerConfig, arg("layer_config"))