mirror of https://github.com/google/gemma.cpp.git
Add config flag for global timescale & rely on config to deduce wrapping
PiperOrigin-RevId: 823512377
This commit is contained in:
parent
a48e614f64
commit
1bdde1af3c
|
|
@ -82,8 +82,7 @@ void PositionalEncodingQK(float* qk, const size_t layer_idx,
|
||||||
// qk is either q or k, so qkv_dim is the length we operate on.
|
// qk is either q or k, so qkv_dim is the length we operate on.
|
||||||
const float* inv_timescale = activations.inv_timescale.PackedScale1();
|
const float* inv_timescale = activations.inv_timescale.PackedScale1();
|
||||||
const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx);
|
const bool is_global_layer = activations.config.IsGlobalLayer(layer_idx);
|
||||||
// TODO: add a config flag instead of hardcoding the model.
|
if (is_global_layer && activations.config.use_global_timescale) {
|
||||||
if (is_global_layer && IsVLM(activations.config.model)) {
|
|
||||||
inv_timescale = activations.inv_timescale_global.PackedScale1();
|
inv_timescale = activations.inv_timescale_global.PackedScale1();
|
||||||
}
|
}
|
||||||
// PostQKType::Rope
|
// PostQKType::Rope
|
||||||
|
|
|
||||||
|
|
@ -22,8 +22,8 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/types.h" // Type
|
#include "compression/types.h" // Type
|
||||||
#include "io/fields.h" // IFields
|
#include "io/fields.h" // IFields
|
||||||
#include "io/io.h" // Path
|
#include "io/io.h" // Path
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -238,6 +238,7 @@ static ModelConfig ConfigGemma3_1B() {
|
||||||
config.display_name = "Gemma3_1B";
|
config.display_name = "Gemma3_1B";
|
||||||
config.model = Model::GEMMA3_1B;
|
config.model = Model::GEMMA3_1B;
|
||||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||||
|
config.use_global_timescale = true;
|
||||||
config.model_dim = 1152;
|
config.model_dim = 1152;
|
||||||
config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
|
config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
|
||||||
config.max_seq_len = 32 * 1024;
|
config.max_seq_len = 32 * 1024;
|
||||||
|
|
@ -288,6 +289,7 @@ static ModelConfig ConfigGemma3_4B() {
|
||||||
config.display_name = "Gemma3_4B";
|
config.display_name = "Gemma3_4B";
|
||||||
config.model = Model::GEMMA3_4B;
|
config.model = Model::GEMMA3_4B;
|
||||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||||
|
config.use_global_timescale = true;
|
||||||
AddVitConfig(config, /*image_size=*/896);
|
AddVitConfig(config, /*image_size=*/896);
|
||||||
config.vocab_size = kGemmaV3VocabSize;
|
config.vocab_size = kGemmaV3VocabSize;
|
||||||
config.vit_config.pool_dim = 4;
|
config.vit_config.pool_dim = 4;
|
||||||
|
|
@ -337,6 +339,7 @@ static ModelConfig ConfigGemma3_12B() {
|
||||||
config.display_name = "Gemma3_12B";
|
config.display_name = "Gemma3_12B";
|
||||||
config.model = Model::GEMMA3_12B;
|
config.model = Model::GEMMA3_12B;
|
||||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||||
|
config.use_global_timescale = true;
|
||||||
AddVitConfig(config, /*image_size=*/896);
|
AddVitConfig(config, /*image_size=*/896);
|
||||||
config.vocab_size = kGemmaV3VocabSize;
|
config.vocab_size = kGemmaV3VocabSize;
|
||||||
config.vit_config.pool_dim = 4;
|
config.vit_config.pool_dim = 4;
|
||||||
|
|
@ -386,6 +389,7 @@ static ModelConfig ConfigGemma3_27B() {
|
||||||
config.display_name = "Gemma3_27B";
|
config.display_name = "Gemma3_27B";
|
||||||
config.model = Model::GEMMA3_27B;
|
config.model = Model::GEMMA3_27B;
|
||||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||||
|
config.use_global_timescale = true;
|
||||||
AddVitConfig(config, /*image_size=*/896);
|
AddVitConfig(config, /*image_size=*/896);
|
||||||
config.vocab_size = kGemmaV3VocabSize;
|
config.vocab_size = kGemmaV3VocabSize;
|
||||||
config.vit_config.pool_dim = 4;
|
config.vit_config.pool_dim = 4;
|
||||||
|
|
@ -495,19 +499,19 @@ const char* ModelPrefix(Model model) {
|
||||||
}
|
}
|
||||||
|
|
||||||
PromptWrapping ChooseWrapping(const Model model, Tristate wrapping) {
|
PromptWrapping ChooseWrapping(const Model model, Tristate wrapping) {
|
||||||
if (IsPaliGemma(model)) {
|
const PromptWrapping config_wrapping = ConfigFromModel(model).wrapping;
|
||||||
|
|
||||||
|
// For models with a fixed wrapping mode, ignore user override.
|
||||||
|
if (config_wrapping == PromptWrapping::PALIGEMMA ||
|
||||||
|
config_wrapping == PromptWrapping::GEMMA_VLM) {
|
||||||
if (wrapping != Tristate::kDefault) {
|
if (wrapping != Tristate::kDefault) {
|
||||||
HWY_WARN("Ignoring unnecessary --wrapping for PaliGemma models.");
|
HWY_WARN("Ignoring unnecessary --wrapping for model %s.",
|
||||||
|
ModelPrefix(model));
|
||||||
}
|
}
|
||||||
return PromptWrapping::PALIGEMMA;
|
return config_wrapping;
|
||||||
}
|
}
|
||||||
if (IsVLM(model)) {
|
|
||||||
if (wrapping != Tristate::kDefault) {
|
// For other models, default to IT unless --wrapping=0 is passed.
|
||||||
HWY_WARN("Ignoring unnecessary --wrapping for VLM models.");
|
|
||||||
}
|
|
||||||
return PromptWrapping::GEMMA_VLM;
|
|
||||||
}
|
|
||||||
// Default to IT unless --wrapping=0.
|
|
||||||
return wrapping == Tristate::kFalse ? PromptWrapping::GEMMA_PT
|
return wrapping == Tristate::kFalse ? PromptWrapping::GEMMA_PT
|
||||||
: PromptWrapping::GEMMA_IT;
|
: PromptWrapping::GEMMA_IT;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -184,13 +184,6 @@ enum class Model {
|
||||||
// in Specifier and thus does not change.
|
// in Specifier and thus does not change.
|
||||||
const char* ModelPrefix(Model model);
|
const char* ModelPrefix(Model model);
|
||||||
|
|
||||||
// Gemma3 is multimodal and has a different prompt wrapping than PaliGemma.
|
|
||||||
// This is used for deducing the PromptWrapping for pre-2025 BlobStore.
|
|
||||||
static inline bool IsVLM(Model model) {
|
|
||||||
return model == Model::GEMMA3_4B || model == Model::GEMMA3_1B ||
|
|
||||||
model == Model::GEMMA3_12B || model == Model::GEMMA3_27B;
|
|
||||||
}
|
|
||||||
|
|
||||||
static inline bool IsPaliGemma(Model model) {
|
static inline bool IsPaliGemma(Model model) {
|
||||||
if (model == Model::PALIGEMMA2_3B_224 || model == Model::PALIGEMMA2_3B_448 ||
|
if (model == Model::PALIGEMMA2_3B_224 || model == Model::PALIGEMMA2_3B_448 ||
|
||||||
model == Model::PALIGEMMA2_10B_224 ||
|
model == Model::PALIGEMMA2_10B_224 ||
|
||||||
|
|
@ -280,7 +273,7 @@ struct LayerConfig : public IFields {
|
||||||
uint32_t kv_heads = 0;
|
uint32_t kv_heads = 0;
|
||||||
uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous).
|
uint32_t qkv_dim = 0; // length of Q, K, V vectors (contiguous).
|
||||||
bool ff_biases = false;
|
bool ff_biases = false;
|
||||||
bool optimized_gating = true; // for Gemma3
|
bool optimized_gating = true; // for Gemma3
|
||||||
PostNormType post_norm = PostNormType::None;
|
PostNormType post_norm = PostNormType::None;
|
||||||
LayerAttentionType type = LayerAttentionType::kGemma;
|
LayerAttentionType type = LayerAttentionType::kGemma;
|
||||||
ActivationType activation = ActivationType::Gelu;
|
ActivationType activation = ActivationType::Gelu;
|
||||||
|
|
@ -383,6 +376,8 @@ struct ModelConfig : public IFields {
|
||||||
|
|
||||||
internal.VisitFields(visitor);
|
internal.VisitFields(visitor);
|
||||||
|
|
||||||
|
visitor(use_global_timescale);
|
||||||
|
|
||||||
// Append new fields here, then update `python/configs.cc`.
|
// Append new fields here, then update `python/configs.cc`.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -481,6 +476,7 @@ struct ModelConfig : public IFields {
|
||||||
std::vector<std::string> scale_base_names;
|
std::vector<std::string> scale_base_names;
|
||||||
|
|
||||||
InternalModelConfig internal;
|
InternalModelConfig internal;
|
||||||
|
bool use_global_timescale = false; // for Gemma 3
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns the sub-config for the ViT model of the PaliGemma model.
|
// Returns the sub-config for the ViT model of the PaliGemma model.
|
||||||
|
|
|
||||||
|
|
@ -173,6 +173,8 @@ PYBIND11_MODULE(configs, py_module) {
|
||||||
.def_readwrite("secondary_eos_id", &ModelConfig::secondary_eos_id)
|
.def_readwrite("secondary_eos_id", &ModelConfig::secondary_eos_id)
|
||||||
.def_readwrite("scale_base_names", &ModelConfig::scale_base_names)
|
.def_readwrite("scale_base_names", &ModelConfig::scale_base_names)
|
||||||
.def_readwrite("internal", &ModelConfig::internal)
|
.def_readwrite("internal", &ModelConfig::internal)
|
||||||
|
.def_readwrite("use_global_timescale",
|
||||||
|
&ModelConfig::use_global_timescale)
|
||||||
|
|
||||||
.def("add_layer_config", &ModelConfig::AddLayerConfig,
|
.def("add_layer_config", &ModelConfig::AddLayerConfig,
|
||||||
arg("layer_config"))
|
arg("layer_config"))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue