Add config flag for global timescale & rely on config to deduce wrapping

PiperOrigin-RevId: 823512377
This commit is contained in:
Theotime Combes 2025-10-24 06:54:19 -07:00 committed by Copybara-Service
parent a48e614f64
commit 1bdde1af3c
4 changed files with 23 additions and 22 deletions

View File

@ -82,8 +82,7 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx,
// qk is either q or k, so qkv_dim is the length we operate on. // qk is either q or k, so qkv_dim is the length we operate on.
const float* inv_timescale = activations.inv_timescale.PackedScale1(); const float* inv_timescale = activations.inv_timescale.PackedScale1();
const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx); const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx);
// TODO: add a config flag instead of hardcoding the model. if (is_global_layer && activations.config.use_global_timescale) {
if (is_global_layer && IsVLM(activations.config.model)) {
inv_timescale = activations.inv_timescale_global.PackedScale1(); inv_timescale = activations.inv_timescale_global.PackedScale1();
} }
// PostQKType::Rope // PostQKType::Rope

View File

@ -22,8 +22,8 @@
#include <vector> #include <vector>
#include "compression/types.h" // Type #include "compression/types.h" // Type
#include "io/fields.h" // IFields #include "io/fields.h" // IFields
#include "io/io.h" // Path #include "io/io.h" // Path
#include "hwy/base.h" #include "hwy/base.h"
namespace gcpp { namespace gcpp {
@ -238,6 +238,7 @@ static ModelConfig ConfigGemma3_1B() {
config.display_name = "Gemma3_1B"; config.display_name = "Gemma3_1B";
config.model = Model::GEMMA3_1B; config.model = Model::GEMMA3_1B;
config.wrapping = PromptWrapping::GEMMA_VLM; config.wrapping = PromptWrapping::GEMMA_VLM;
config.use_global_timescale = true;
config.model_dim = 1152; config.model_dim = 1152;
config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
config.max_seq_len = 32 * 1024; config.max_seq_len = 32 * 1024;
@ -288,6 +289,7 @@ static ModelConfig ConfigGemma3_4B() {
config.display_name = "Gemma3_4B"; config.display_name = "Gemma3_4B";
config.model = Model::GEMMA3_4B; config.model = Model::GEMMA3_4B;
config.wrapping = PromptWrapping::GEMMA_VLM; config.wrapping = PromptWrapping::GEMMA_VLM;
config.use_global_timescale = true;
AddVitConfig(config, /*image_size=*/896); AddVitConfig(config, /*image_size=*/896);
config.vocab_size = kGemmaV3VocabSize; config.vocab_size = kGemmaV3VocabSize;
config.vit_config.pool_dim = 4; config.vit_config.pool_dim = 4;
@ -337,6 +339,7 @@ static ModelConfig ConfigGemma3_12B() {
config.display_name = "Gemma3_12B"; config.display_name = "Gemma3_12B";
config.model = Model::GEMMA3_12B; config.model = Model::GEMMA3_12B;
config.wrapping = PromptWrapping::GEMMA_VLM; config.wrapping = PromptWrapping::GEMMA_VLM;
config.use_global_timescale = true;
AddVitConfig(config, /*image_size=*/896); AddVitConfig(config, /*image_size=*/896);
config.vocab_size = kGemmaV3VocabSize; config.vocab_size = kGemmaV3VocabSize;
config.vit_config.pool_dim = 4; config.vit_config.pool_dim = 4;
@ -386,6 +389,7 @@ static ModelConfig ConfigGemma3_27B() {
config.display_name = "Gemma3_27B"; config.display_name = "Gemma3_27B";
config.model = Model::GEMMA3_27B; config.model = Model::GEMMA3_27B;
config.wrapping = PromptWrapping::GEMMA_VLM; config.wrapping = PromptWrapping::GEMMA_VLM;
config.use_global_timescale = true;
AddVitConfig(config, /*image_size=*/896); AddVitConfig(config, /*image_size=*/896);
config.vocab_size = kGemmaV3VocabSize; config.vocab_size = kGemmaV3VocabSize;
config.vit_config.pool_dim = 4; config.vit_config.pool_dim = 4;
@ -495,19 +499,19 @@ const char* ModelPrefix(Model model) {
} }
PromptWrapping ChooseWrapping(const Model model, Tristate wrapping) { PromptWrapping ChooseWrapping(const Model model, Tristate wrapping) {
if (IsPaliGemma(model)) { const PromptWrapping config_wrapping = ConfigFromModel(model).wrapping;
// For models with a fixed wrapping mode, ignore user override.
if (config_wrapping == PromptWrapping::PALIGEMMA ||
config_wrapping == PromptWrapping::GEMMA_VLM) {
if (wrapping != Tristate::kDefault) { if (wrapping != Tristate::kDefault) {
HWY_WARN("Ignoring unnecessary --wrapping for PaliGemma models."); HWY_WARN("Ignoring unnecessary --wrapping for model %s.",
ModelPrefix(model));
} }
return PromptWrapping::PALIGEMMA; return config_wrapping;
} }
if (IsVLM(model)) {
if (wrapping != Tristate::kDefault) { // For other models, default to IT unless --wrapping=0 is passed.
HWY_WARN("Ignoring unnecessary --wrapping for VLM models.");
}
return PromptWrapping::GEMMA_VLM;
}
// Default to IT unless --wrapping=0.
return wrapping == Tristate::kFalse ? PromptWrapping::GEMMA_PT return wrapping == Tristate::kFalse ? PromptWrapping::GEMMA_PT
: PromptWrapping::GEMMA_IT; : PromptWrapping::GEMMA_IT;
} }

View File

@ -184,13 +184,6 @@ enum class Model {
// in Specifier and thus does not change. // in Specifier and thus does not change.
const char* ModelPrefix(Model model); const char* ModelPrefix(Model model);
// Gemma3 is multimodal and has a different prompt wrapping than PaliGemma.
// This is used for deducing the PromptWrapping for pre-2025 BlobStore.
static inline bool IsVLM(Model model) {
return model == Model::GEMMA3_4B || model == Model::GEMMA3_1B ||
model == Model::GEMMA3_12B || model == Model::GEMMA3_27B;
}
static inline bool IsPaliGemma(Model model) { static inline bool IsPaliGemma(Model model) {
if (model == Model::PALIGEMMA2_3B_224 || model == Model::PALIGEMMA2_3B_448 || if (model == Model::PALIGEMMA2_3B_224 || model == Model::PALIGEMMA2_3B_448 ||
model == Model::PALIGEMMA2_10B_224 || model == Model::PALIGEMMA2_10B_224 ||
@ -280,7 +273,7 @@ struct LayerConfig : public IFields {
uint32_t kv_heads = 0; uint32_t kv_heads = 0;
uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous). uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous).
bool ff_biases = false; bool ff_biases = false;
bool optimized_gating = true; // for Gemma3 bool optimized_gating = true; // for Gemma3
PostNormType post_norm = PostNormType::None; PostNormType post_norm = PostNormType::None;
LayerAttentionType type = LayerAttentionType::kGemma; LayerAttentionType type = LayerAttentionType::kGemma;
ActivationType activation = ActivationType::Gelu; ActivationType activation = ActivationType::Gelu;
@ -383,6 +376,8 @@ struct ModelConfig : public IFields {
internal.VisitFields(visitor); internal.VisitFields(visitor);
visitor(use_global_timescale);
// Append new fields here, then update `python/configs.cc`. // Append new fields here, then update `python/configs.cc`.
} }
@ -481,6 +476,7 @@ struct ModelConfig : public IFields {
std::vector<std::string> scale_base_names; std::vector<std::string> scale_base_names;
InternalModelConfig internal; InternalModelConfig internal;
bool use_global_timescale = false; // for Gemma 3
}; };
// Returns the sub-config for the ViT model of the PaliGemma model. // Returns the sub-config for the ViT model of the PaliGemma model.

View File

@ -173,6 +173,8 @@ PYBIND11_MODULE(configs, py_module) {
.def_readwrite("secondary_eos_id", &ModelConfig::secondary_eos_id) .def_readwrite("secondary_eos_id", &ModelConfig::secondary_eos_id)
.def_readwrite("scale_base_names", &ModelConfig::scale_base_names) .def_readwrite("scale_base_names", &ModelConfig::scale_base_names)
.def_readwrite("internal", &ModelConfig::internal) .def_readwrite("internal", &ModelConfig::internal)
.def_readwrite("use_global_timescale",
&ModelConfig::use_global_timescale)
.def("add_layer_config", &ModelConfig::AddLayerConfig, .def("add_layer_config", &ModelConfig::AddLayerConfig,
arg("layer_config")) arg("layer_config"))