diff --git a/gemma/configs.cc b/gemma/configs.cc index 562500d..40b95bb 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -471,6 +471,37 @@ static ModelConfig ConfigGemma3_27B() { return config; } +static LayerConfig LayerConfigGemma3_270M_LM(size_t model_dim) { + LayerConfig config; + config.model_dim = model_dim; + config.ff_hidden_dim = 2048; + config.heads = 4; + config.kv_heads = 1; + config.qkv_dim = 256; + config.optimized_gating = true; + config.post_norm = PostNormType::Scale; + config.use_qk_norm = true; + return config; +} + +static ModelConfig ConfigGemma3_270M() { + ModelConfig config = ConfigBaseGemmaV3(); + config.display_name = "Gemma3_270M"; + config.model = Model::GEMMA3_270M; + config.wrapping = PromptWrapping::GEMMA_IT; + config.model_dim = 640; + config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer + config.max_seq_len = 32 * 1024; + LayerConfig layer_config = LayerConfigGemma3_270M_LM(config.model_dim); + config.num_layers = 18; + config.layer_configs = {config.num_layers, layer_config}; + config.query_scale = QueryScaleType::SqrtKeySize; + // interleaved local / global attention + config.attention_window_sizes = RepeatedAttentionWindowSizes<18, 6>( + {512, 512, 512, 512, 512, config.max_seq_len}); + return config; +} + static ModelConfig ConfigFromModel(Model model) { switch (model) { case Model::GEMMA2_2B: @@ -499,6 +530,8 @@ static ModelConfig ConfigFromModel(Model model) { return ConfigGemma3_12B(); case Model::GEMMA3_27B: return ConfigGemma3_27B(); + case Model::GEMMA3_270M: + return ConfigGemma3_270M(); default: HWY_ABORT("Model type %d unknown.", static_cast(model)); } @@ -534,6 +567,8 @@ const char* ModelPrefix(Model model) { return "gemma3-12b"; case Model::GEMMA3_27B: return "gemma3-27b"; + case Model::GEMMA3_270M: + return "gemma3-270m"; default: HWY_ABORT("Model type %d unknown.", static_cast(model)); } diff --git a/gemma/configs.h b/gemma/configs.h index 19e6278..a3a3114 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -175,6 +175,7 @@ enum class Model { GEMMA3_1B, GEMMA3_12B, GEMMA3_27B, + GEMMA3_270M, kSentinel, }; diff --git a/python/configs.cc b/python/configs.cc index b9a4bf6..36cd314 100644 --- a/python/configs.cc +++ b/python/configs.cc @@ -91,6 +91,7 @@ PYBIND11_MODULE(configs, py_module) { .value("PALIGEMMA2_10B_224", Model::PALIGEMMA2_10B_224) .value("PALIGEMMA2_3B_448", Model::PALIGEMMA2_3B_448) .value("PALIGEMMA2_10B_448", Model::PALIGEMMA2_10B_448) + .value("GEMMA3_270M", Model::GEMMA3_270M) .value("PALIGEMMA_448", Model::PALIGEMMA_448); class_(py_module, "TensorInfo")