Add config flag for global timescale & rely on config to deduce wrapping

PiperOrigin-RevId: 823512377
This commit is contained in:
Theotime Combes 2025-10-24 06:54:19 -07:00 committed by Copybara-Service
parent a48e614f64
commit 1bdde1af3c
4 changed files with 23 additions and 22 deletions

View File

@ -82,8 +82,7 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx,
// qk is either q or k, so qkv_dim is the length we operate on.
const float* inv_timescale = activations.inv_timescale.PackedScale1();
const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx);
// TODO: add a config flag instead of hardcoding the model.
if (is_global_layer && IsVLM(activations.config.model)) {
if (is_global_layer && activations.config.use_global_timescale) {
inv_timescale = activations.inv_timescale_global.PackedScale1();
}
// PostQKType::Rope

View File

@ -238,6 +238,7 @@ static ModelConfig ConfigGemma3_1B() {
config.display_name = "Gemma3_1B";
config.model = Model::GEMMA3_1B;
config.wrapping = PromptWrapping::GEMMA_VLM;
config.use_global_timescale = true;
config.model_dim = 1152;
config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
config.max_seq_len = 32 * 1024;
@ -288,6 +289,7 @@ static ModelConfig ConfigGemma3_4B() {
config.display_name = "Gemma3_4B";
config.model = Model::GEMMA3_4B;
config.wrapping = PromptWrapping::GEMMA_VLM;
config.use_global_timescale = true;
AddVitConfig(config, /*image_size=*/896);
config.vocab_size = kGemmaV3VocabSize;
config.vit_config.pool_dim = 4;
@ -337,6 +339,7 @@ static ModelConfig ConfigGemma3_12B() {
config.display_name = "Gemma3_12B";
config.model = Model::GEMMA3_12B;
config.wrapping = PromptWrapping::GEMMA_VLM;
config.use_global_timescale = true;
AddVitConfig(config, /*image_size=*/896);
config.vocab_size = kGemmaV3VocabSize;
config.vit_config.pool_dim = 4;
@ -386,6 +389,7 @@ static ModelConfig ConfigGemma3_27B() {
config.display_name = "Gemma3_27B";
config.model = Model::GEMMA3_27B;
config.wrapping = PromptWrapping::GEMMA_VLM;
config.use_global_timescale = true;
AddVitConfig(config, /*image_size=*/896);
config.vocab_size = kGemmaV3VocabSize;
config.vit_config.pool_dim = 4;
@ -495,19 +499,19 @@ const char* ModelPrefix(Model model) {
}
PromptWrapping ChooseWrapping(const Model model, Tristate wrapping) {
if (IsPaliGemma(model)) {
const PromptWrapping config_wrapping = ConfigFromModel(model).wrapping;
// For models with a fixed wrapping mode, ignore user override.
if (config_wrapping == PromptWrapping::PALIGEMMA ||
config_wrapping == PromptWrapping::GEMMA_VLM) {
if (wrapping != Tristate::kDefault) {
HWY_WARN("Ignoring unnecessary --wrapping for PaliGemma models.");
HWY_WARN("Ignoring unnecessary --wrapping for model %s.",
ModelPrefix(model));
}
return PromptWrapping::PALIGEMMA;
return config_wrapping;
}
if (IsVLM(model)) {
if (wrapping != Tristate::kDefault) {
HWY_WARN("Ignoring unnecessary --wrapping for VLM models.");
}
return PromptWrapping::GEMMA_VLM;
}
// Default to IT unless --wrapping=0.
// For other models, default to IT unless --wrapping=0 is passed.
return wrapping == Tristate::kFalse ? PromptWrapping::GEMMA_PT
: PromptWrapping::GEMMA_IT;
}

View File

@ -184,13 +184,6 @@ enum class Model {
// in Specifier and thus does not change.
const char* ModelPrefix(Model model);
// Gemma3 is multimodal and has a different prompt wrapping than PaliGemma.
// This is used for deducing the PromptWrapping for pre-2025 BlobStore.
static inline bool IsVLM(Model model) {
return model == Model::GEMMA3_4B || model == Model::GEMMA3_1B ||
model == Model::GEMMA3_12B || model == Model::GEMMA3_27B;
}
static inline bool IsPaliGemma(Model model) {
if (model == Model::PALIGEMMA2_3B_224 || model == Model::PALIGEMMA2_3B_448 ||
model == Model::PALIGEMMA2_10B_224 ||
@ -383,6 +376,8 @@ struct ModelConfig : public IFields {
internal.VisitFields(visitor);
visitor(use_global_timescale);
// Append new fields here, then update `python/configs.cc`.
}
@ -481,6 +476,7 @@ struct ModelConfig : public IFields {
std::vector<std::string> scale_base_names;
InternalModelConfig internal;
bool use_global_timescale = false; // for Gemma 3
};
// Returns the sub-config for the ViT model of the PaliGemma model.

View File

@ -173,6 +173,8 @@ PYBIND11_MODULE(configs, py_module) {
.def_readwrite("secondary_eos_id", &ModelConfig::secondary_eos_id)
.def_readwrite("scale_base_names", &ModelConfig::scale_base_names)
.def_readwrite("internal", &ModelConfig::internal)
.def_readwrite("use_global_timescale",
&ModelConfig::use_global_timescale)
.def("add_layer_config", &ModelConfig::AddLayerConfig,
arg("layer_config"))