Internal change

PiperOrigin-RevId: 794620076
This commit is contained in:
Phil Culliton 2025-08-13 09:47:05 -07:00 committed by Copybara-Service
parent 71406cf6d0
commit d044801c1d
3 changed files with 37 additions and 0 deletions

View File

@ -471,6 +471,37 @@ static ModelConfig ConfigGemma3_27B() {
return config;
}
static LayerConfig LayerConfigGemma3_270M_LM(size_t model_dim) {
LayerConfig config;
config.model_dim = model_dim;
config.ff_hidden_dim = 2048;
config.heads = 4;
config.kv_heads = 1;
config.qkv_dim = 256;
config.optimized_gating = true;
config.post_norm = PostNormType::Scale;
config.use_qk_norm = true;
return config;
}
static ModelConfig ConfigGemma3_270M() {
ModelConfig config = ConfigBaseGemmaV3();
config.display_name = "Gemma3_270M";
config.model = Model::GEMMA3_270M;
config.wrapping = PromptWrapping::GEMMA_IT;
config.model_dim = 640;
config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
config.max_seq_len = 32 * 1024;
LayerConfig layer_config = LayerConfigGemma3_270M_LM(config.model_dim);
config.num_layers = 18;
config.layer_configs = {config.num_layers, layer_config};
config.query_scale = QueryScaleType::SqrtKeySize;
// interleaved local / global attention
config.attention_window_sizes = RepeatedAttentionWindowSizes<18, 6>(
{512, 512, 512, 512, 512, config.max_seq_len});
return config;
}
static ModelConfig ConfigFromModel(Model model) {
switch (model) {
case Model::GEMMA2_2B:
@ -499,6 +530,8 @@ static ModelConfig ConfigFromModel(Model model) {
return ConfigGemma3_12B();
case Model::GEMMA3_27B:
return ConfigGemma3_27B();
case Model::GEMMA3_270M:
return ConfigGemma3_270M();
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
@ -534,6 +567,8 @@ const char* ModelPrefix(Model model) {
return "gemma3-12b";
case Model::GEMMA3_27B:
return "gemma3-27b";
case Model::GEMMA3_270M:
return "gemma3-270m";
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}

View File

@ -175,6 +175,7 @@ enum class Model {
GEMMA3_1B,
GEMMA3_12B,
GEMMA3_27B,
GEMMA3_270M,
kSentinel,
};

View File

@ -91,6 +91,7 @@ PYBIND11_MODULE(configs, py_module) {
.value("PALIGEMMA2_10B_224", Model::PALIGEMMA2_10B_224)
.value("PALIGEMMA2_3B_448", Model::PALIGEMMA2_3B_448)
.value("PALIGEMMA2_10B_448", Model::PALIGEMMA2_10B_448)
.value("GEMMA3_270M", Model::GEMMA3_270M)
.value("PALIGEMMA_448", Model::PALIGEMMA_448);
class_<TensorInfo>(py_module, "TensorInfo")