// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "gemma/configs.h" #include #include #include #include #include "compression/types.h" // Type #include "io/fields.h" // IFields #include "io/io.h" // Path #include "hwy/base.h" namespace gcpp { static constexpr size_t kVocabSize = 256000; static constexpr size_t kGemmaV3VocabSize = 262144; static ModelConfig ConfigNoSSM() { ModelConfig config; config.scale_base_names = {"att_ein", "qkv_ein", "gr_lin_x_w", "gr_lin_y_w", "gr_lin_out_w", "gr_gate_w", "gating_ein", "linear_w"}; return config; } static ModelConfig ConfigBaseGemmaV2() { ModelConfig config = ConfigNoSSM(); config.att_cap = 50.0f; config.final_cap = 30.0f; config.eos_id = 1; config.secondary_eos_id = 107; return config; } static LayerConfig LayerConfigGemma2_27B(size_t model_dim) { LayerConfig config; config.model_dim = model_dim; config.ff_hidden_dim = 16 * 4608 / 2; // = 36864 config.heads = 32; config.kv_heads = 16; config.qkv_dim = 128; config.optimized_gating = false; config.post_norm = PostNormType::Scale; return config; } static ModelConfig ConfigGemma2_27B() { ModelConfig config = ConfigBaseGemmaV2(); config.display_name = "Gemma2_27B"; config.model = Model::GEMMA2_27B; config.model_dim = 4608; config.vocab_size = kVocabSize; config.max_seq_len = 8192; LayerConfig layer_config = LayerConfigGemma2_27B(config.model_dim); config.num_layers = 46; config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtModelDimDivNumHeads; config.attention_window_sizes = RepeatedAttentionWindowSizes<46, 2>({4096, config.max_seq_len}); return config; } static LayerConfig LayerConfigGemma2_9B(size_t model_dim) { LayerConfig config; config.model_dim = model_dim; config.ff_hidden_dim = 8 * 3584 / 2; // = 14336 config.heads = 16; config.kv_heads = 8; config.qkv_dim = 256; config.optimized_gating = false; config.post_norm = PostNormType::Scale; return config; } static ModelConfig ConfigGemma2_9B() { ModelConfig config = ConfigBaseGemmaV2(); config.display_name = "Gemma2_9B"; config.model = Model::GEMMA2_9B; config.model_dim = 3584; config.vocab_size = kVocabSize; config.max_seq_len = 8192; LayerConfig layer_config = LayerConfigGemma2_9B(config.model_dim); config.num_layers = 42; config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; config.attention_window_sizes = RepeatedAttentionWindowSizes<42, 2>({4096, config.max_seq_len}); return config; } static LayerConfig LayerConfigGemma2_2B(size_t model_dim) { LayerConfig config; config.model_dim = model_dim; config.ff_hidden_dim = 8 * 2304 / 2; // = 9216 config.heads = 8; config.kv_heads = 4; config.qkv_dim = 256; config.optimized_gating = false; config.post_norm = PostNormType::Scale; return config; } static ModelConfig ConfigGemma2_2B() { ModelConfig config = ConfigBaseGemmaV2(); config.display_name = "Gemma2_2B"; config.model = Model::GEMMA2_2B; config.model_dim = 2304; config.vocab_size = kVocabSize; config.max_seq_len = 8192; LayerConfig layer_config = LayerConfigGemma2_2B(config.model_dim); config.num_layers = 26; config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; config.attention_window_sizes = RepeatedAttentionWindowSizes<26, 2>({4096, config.max_seq_len}); return config; } static LayerConfig LayerConfigGemmaTiny(size_t model_dim) { LayerConfig config; config.model_dim = model_dim; config.ff_hidden_dim = 256; config.heads = 4; config.kv_heads = 1; config.qkv_dim = 16; return config; } static ModelConfig ConfigGemmaTiny() { ModelConfig config = ConfigNoSSM(); config.display_name = "GemmaTiny"; config.model = Model::GEMMA_TINY; config.wrapping = PromptWrapping::GEMMA_IT; config.model_dim = 32; config.vocab_size = 32; // at least two f32 vectors config.max_seq_len = 32; LayerConfig layer_config = LayerConfigGemmaTiny(config.model_dim); config.num_layers = 2; config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; config.attention_window_sizes = FixedAttentionWindowSizes<2>(32); config.att_cap = 50.0f; config.final_cap = 30.0f; config.eos_id = 11; config.secondary_eos_id = 11; return config; } static LayerConfig LayerConfigGriffin2B(size_t model_dim) { LayerConfig config; config.model_dim = model_dim; config.griffin_dim = model_dim; config.ff_hidden_dim = 7680; config.heads = 10; config.kv_heads = 1; config.qkv_dim = 256; config.conv1d_width = 4; HWY_DASSERT(config.conv1d_width <= kMaxConv1DWidth); config.ff_biases = true; config.softmax_attn_output_biases = true; config.optimized_gating = false; config.type = LayerAttentionType::kGriffinRecurrentBlock; config.activation = ActivationType::Gelu; config.post_qk = PostQKType::HalfRope; return config; } static ModelConfig ConfigGriffin2B() { ModelConfig config = ConfigNoSSM(); config.display_name = "Griffin2B"; config.model = Model::GRIFFIN_2B; // Griffin uses local attention, so max_seq_len is actually the local // attention window. config.model_dim = 2560; config.vocab_size = kVocabSize; config.max_seq_len = 2048; LayerConfig layer_config = LayerConfigGriffin2B(config.model_dim); config.num_layers = 26; config.layer_configs = {config.num_layers, layer_config}; for (size_t i = 2; i < config.num_layers; i += 3) { config.layer_configs[i].type = LayerAttentionType::kGemma; config.layer_configs[i].griffin_dim = 0; } config.attention_window_sizes = FixedAttentionWindowSizes<26>(config.max_seq_len); config.use_local_attention = true; config.final_cap = 0.0f; return config; } static LayerConfig LayerConfigVit(size_t model_dim) { LayerConfig config; config.model_dim = model_dim; config.ff_hidden_dim = 4304; config.heads = 16; config.kv_heads = 16; config.qkv_dim = 72; config.ff_biases = true; config.type = LayerAttentionType::kVit; return config; } // Adds a ViT config (SigLIP SoViT ViT, used in PaliGemma) to the model config. static void AddVitConfig(ModelConfig& config, size_t image_size = 224) { config.vit_config.model_dim = 1152; config.vocab_size = 256000 + 1024 + 128; // = 257152 config.vit_config.image_size = image_size; config.vit_config.patch_width = 14; const size_t num_patches = config.vit_config.image_size / config.vit_config.patch_width; config.vit_config.seq_len = num_patches * num_patches; for (auto& layer_config : config.layer_configs) { layer_config.optimized_gating = false; } LayerConfig vit_layer_config = LayerConfigVit(config.vit_config.model_dim); config.vit_config.layer_configs = {27, vit_layer_config}; config.vit_config.num_scales = 4 * config.vit_config.layer_configs.size(); } ModelConfig GetVitConfig(const ModelConfig& config) { ModelConfig vit_config = ConfigNoSSM(); vit_config.model_dim = config.vit_config.model_dim; vit_config.max_seq_len = config.vit_config.seq_len; vit_config.layer_configs = config.vit_config.layer_configs; vit_config.pool_dim = config.vit_config.pool_dim; vit_config.wrapping = config.wrapping; // The Vit part does not have a vocabulary, the image patches are embedded. vit_config.vocab_size = 0; return vit_config; } static ModelConfig ConfigPaliGemma2_3B_224() { ModelConfig config = ConfigGemma2_2B(); config.display_name = "PaliGemma2_3B_224"; config.model = Model::PALIGEMMA2_3B_224; config.wrapping = PromptWrapping::PALIGEMMA; AddVitConfig(config); return config; } static ModelConfig ConfigPaliGemma2_3B_448() { ModelConfig config = ConfigGemma2_2B(); config.display_name = "PaliGemma2_3B_448"; config.model = Model::PALIGEMMA2_3B_448; config.wrapping = PromptWrapping::PALIGEMMA; AddVitConfig(config, /*image_size=*/448); return config; } static ModelConfig ConfigPaliGemma2_10B_224() { ModelConfig config = ConfigGemma2_9B(); config.display_name = "PaliGemma2_10B_224"; config.model = Model::PALIGEMMA2_10B_224; config.wrapping = PromptWrapping::PALIGEMMA; AddVitConfig(config); return config; } static ModelConfig ConfigPaliGemma2_10B_448() { ModelConfig config = ConfigGemma2_9B(); config.display_name = "PaliGemma2_10B_448"; config.model = Model::PALIGEMMA2_10B_448; config.wrapping = PromptWrapping::PALIGEMMA; AddVitConfig(config, /*image_size=*/448); return config; } static ModelConfig ConfigBaseGemmaV3() { ModelConfig config = ConfigNoSSM(); config.att_cap = 0.0f; config.final_cap = 0.0f; config.eos_id = 1; config.secondary_eos_id = 106; return config; } // 1B does not include a vision encoder. static LayerConfig LayerConfigGemma3_1B_LM(size_t model_dim) { LayerConfig config; config.model_dim = model_dim; config.ff_hidden_dim = 6912; config.heads = 4; config.kv_heads = 1; config.qkv_dim = 256; config.optimized_gating = true; config.post_norm = PostNormType::Scale; config.use_qk_norm = true; return config; } static ModelConfig ConfigGemma3_1B() { ModelConfig config = ConfigBaseGemmaV3(); config.display_name = "Gemma3_1B"; config.model = Model::GEMMA3_1B; config.wrapping = PromptWrapping::GEMMA_VLM; config.model_dim = 1152; config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer config.max_seq_len = 32 * 1024; LayerConfig layer_config = LayerConfigGemma3_1B_LM(config.model_dim); config.num_layers = 26; config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; // interleaved local / global attention config.attention_window_sizes = RepeatedAttentionWindowSizes<26, 6>( {512, 512, 512, 512, 512, config.max_seq_len}); return config; } static LayerConfig LayerConfigGemma3_4B_LM(size_t model_dim) { LayerConfig config; config.model_dim = model_dim; config.ff_hidden_dim = 8 * 2560 / 2; // = 10240 config.heads = 8; config.kv_heads = 4; config.qkv_dim = 256; config.optimized_gating = true; config.post_norm = PostNormType::Scale; config.use_qk_norm = true; return config; } // Until we have the SigLIP checkpoints included, we use the LM config directly. static ModelConfig ConfigGemma3_4B_LM() { ModelConfig config = ConfigBaseGemmaV3(); config.display_name = "Gemma3_4B"; config.model = Model::GEMMA3_4B; config.wrapping = PromptWrapping::GEMMA_VLM; config.model_dim = 2560; config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer config.max_seq_len = 32 * 1024; LayerConfig layer_config = LayerConfigGemma3_4B_LM(config.model_dim); config.num_layers = 34; config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; // interleaved local / global attention config.attention_window_sizes = RepeatedAttentionWindowSizes<34, 6>( {1024, 1024, 1024, 1024, 1024, config.max_seq_len}); return config; } static ModelConfig ConfigGemma3_4B() { ModelConfig config = ConfigGemma3_4B_LM(); config.display_name = "Gemma3_4B"; config.model = Model::GEMMA3_4B; config.wrapping = PromptWrapping::GEMMA_VLM; AddVitConfig(config, /*image_size=*/896); config.vocab_size = kGemmaV3VocabSize; config.vit_config.pool_dim = 4; const size_t num_patches = config.vit_config.image_size / config.vit_config.patch_width; config.vit_config.seq_len = (num_patches * num_patches); // The above resets optimized gating to false; for Gemma 3 it should be true. for (auto& layer_config : config.layer_configs) { layer_config.optimized_gating = true; } return config; } static LayerConfig LayerConfigGemma3_12B_LM(size_t model_dim) { LayerConfig config; config.model_dim = model_dim; config.ff_hidden_dim = 15360; config.heads = 16; config.kv_heads = 8; config.qkv_dim = 256; config.optimized_gating = true; config.post_norm = PostNormType::Scale; config.use_qk_norm = true; return config; } static ModelConfig ConfigGemma3_12B_LM() { ModelConfig config = ConfigBaseGemmaV3(); config.display_name = "Gemma3_12B"; config.model = Model::GEMMA3_12B; config.wrapping = PromptWrapping::GEMMA_VLM; config.model_dim = 3840; config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer config.max_seq_len = 32 * 1024; LayerConfig layer_config = LayerConfigGemma3_12B_LM(config.model_dim); config.num_layers = 48; config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; // interleaved local / global attention config.attention_window_sizes = RepeatedAttentionWindowSizes<48, 6>( {1024, 1024, 1024, 1024, 1024, config.max_seq_len}); return config; } static ModelConfig ConfigGemma3_12B() { ModelConfig config = ConfigGemma3_12B_LM(); config.display_name = "Gemma3_12B"; config.model = Model::GEMMA3_12B; config.wrapping = PromptWrapping::GEMMA_VLM; AddVitConfig(config, /*image_size=*/896); config.vocab_size = kGemmaV3VocabSize; config.vit_config.pool_dim = 4; const size_t num_patches = config.vit_config.image_size / config.vit_config.patch_width; config.vit_config.seq_len = (num_patches * num_patches); // The above resets optimized gating to false; for Gemma 3 it should be true. for (auto& layer_config : config.layer_configs) { layer_config.optimized_gating = true; } return config; } static LayerConfig LayerConfigGemma3_27B_LM(size_t model_dim) { LayerConfig config; config.model_dim = model_dim; config.ff_hidden_dim = 21504; config.heads = 32; config.kv_heads = 16; config.qkv_dim = 128; config.optimized_gating = true; config.post_norm = PostNormType::Scale; config.use_qk_norm = true; return config; } static ModelConfig ConfigGemma3_27B_LM() { ModelConfig config = ConfigBaseGemmaV3(); config.display_name = "Gemma3_27B"; config.model = Model::GEMMA3_27B; config.wrapping = PromptWrapping::GEMMA_VLM; config.model_dim = 5376; config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer config.max_seq_len = 32 * 1024; LayerConfig layer_config = LayerConfigGemma3_27B_LM(config.model_dim); config.num_layers = 62; config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; // interleaved local / global attention config.attention_window_sizes = RepeatedAttentionWindowSizes<62, 6>( {1024, 1024, 1024, 1024, 1024, config.max_seq_len}); return config; } static ModelConfig ConfigGemma3_27B() { ModelConfig config = ConfigGemma3_27B_LM(); config.display_name = "Gemma3_27B"; config.model = Model::GEMMA3_27B; config.wrapping = PromptWrapping::GEMMA_VLM; AddVitConfig(config, /*image_size=*/896); config.vocab_size = kGemmaV3VocabSize; config.vit_config.pool_dim = 4; const size_t num_patches = config.vit_config.image_size / config.vit_config.patch_width; config.vit_config.seq_len = (num_patches * num_patches); // The above resets optimized gating to false; for Gemma 3 it should be true. for (auto& layer_config : config.layer_configs) { layer_config.optimized_gating = true; } return config; } static LayerConfig LayerConfigGemma3_270M_LM(size_t model_dim) { LayerConfig config; config.model_dim = model_dim; config.ff_hidden_dim = 2048; config.heads = 4; config.kv_heads = 1; config.qkv_dim = 256; config.optimized_gating = true; config.post_norm = PostNormType::Scale; config.use_qk_norm = true; return config; } static ModelConfig ConfigGemma3_270M() { ModelConfig config = ConfigBaseGemmaV3(); config.display_name = "Gemma3_270M"; config.model = Model::GEMMA3_270M; config.wrapping = PromptWrapping::GEMMA_IT; config.model_dim = 640; config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer config.max_seq_len = 32 * 1024; LayerConfig layer_config = LayerConfigGemma3_270M_LM(config.model_dim); config.num_layers = 18; config.layer_configs = {config.num_layers, layer_config}; config.query_scale = QueryScaleType::SqrtKeySize; // interleaved local / global attention config.attention_window_sizes = RepeatedAttentionWindowSizes<18, 6>( {512, 512, 512, 512, 512, config.max_seq_len}); return config; } static ModelConfig ConfigFromModel(Model model) { switch (model) { case Model::GEMMA2_2B: return ConfigGemma2_2B(); case Model::GEMMA2_9B: return ConfigGemma2_9B(); case Model::GEMMA2_27B: return ConfigGemma2_27B(); case Model::GRIFFIN_2B: return ConfigGriffin2B(); case Model::GEMMA_TINY: return ConfigGemmaTiny(); case Model::PALIGEMMA2_3B_224: return ConfigPaliGemma2_3B_224(); case Model::PALIGEMMA2_3B_448: return ConfigPaliGemma2_3B_448(); case Model::PALIGEMMA2_10B_224: return ConfigPaliGemma2_10B_224(); case Model::PALIGEMMA2_10B_448: return ConfigPaliGemma2_10B_448(); case Model::GEMMA3_4B: return ConfigGemma3_4B(); case Model::GEMMA3_1B: return ConfigGemma3_1B(); case Model::GEMMA3_12B: return ConfigGemma3_12B(); case Model::GEMMA3_27B: return ConfigGemma3_27B(); case Model::GEMMA3_270M: return ConfigGemma3_270M(); default: HWY_ABORT("Model type %d unknown.", static_cast(model)); } } const char* ModelPrefix(Model model) { switch (model) { case Model::UNKNOWN: return "unknown"; case Model::GEMMA2_2B: return "gemma2-2b"; case Model::GEMMA2_9B: return "9b"; case Model::GEMMA2_27B: return "27b"; case Model::GRIFFIN_2B: return "gr2b"; case Model::GEMMA_TINY: return "tiny"; case Model::PALIGEMMA2_3B_224: return "paligemma2-3b-224"; case Model::PALIGEMMA2_3B_448: return "paligemma2-3b-448"; case Model::PALIGEMMA2_10B_224: return "paligemma2-10b-224"; case Model::PALIGEMMA2_10B_448: return "paligemma2-10b-448"; case Model::GEMMA3_4B: return "gemma3-4b"; case Model::GEMMA3_1B: return "gemma3-1b"; case Model::GEMMA3_12B: return "gemma3-12b"; case Model::GEMMA3_27B: return "gemma3-27b"; case Model::GEMMA3_270M: return "gemma3-270m"; default: HWY_ABORT("Model type %d unknown.", static_cast(model)); } } PromptWrapping ChooseWrapping(const Model model, Tristate wrapping) { if (IsPaliGemma(model)) { if (wrapping != Tristate::kDefault) { HWY_WARN("Ignoring unnecessary --wrapping for PaliGemma models."); } return PromptWrapping::PALIGEMMA; } if (IsVLM(model)) { if (wrapping != Tristate::kDefault) { HWY_WARN("Ignoring unnecessary --wrapping for VLM models."); } return PromptWrapping::GEMMA_VLM; } // Default to IT unless --wrapping=0. return wrapping == Tristate::kFalse ? PromptWrapping::GEMMA_PT : PromptWrapping::GEMMA_IT; } ModelConfig::ModelConfig(const Model model, Type weight, PromptWrapping wrapping) { HWY_ASSERT(weight != Type::kUnknown); HWY_ASSERT(wrapping != PromptWrapping::kSentinel); this->model = model; if (model != Model::UNKNOWN) *this = ConfigFromModel(model); HWY_ASSERT(this->model == model); this->weight = weight; this->wrapping = wrapping; } static Model FindModel(const std::string& specifier) { Model found_model = Model::UNKNOWN; ForEachModel([&](Model model) { // Some model names are prefixes of other model names const std::string prefix = std::string(ModelPrefix(model)) + "-"; if (specifier.rfind(prefix, 0) == 0) { // Starts with prefix. // We only expect one match. HWY_ASSERT_M(found_model == Model::UNKNOWN, specifier.c_str()); found_model = model; } }); HWY_ASSERT_M(found_model != Model::UNKNOWN, specifier.c_str()); return found_model; } static Type FindType(const std::string& specifier) { Type found_type = Type::kUnknown; for (size_t i = 1; i < kNumTypes; ++i) { const Type type = static_cast(i); if (specifier.find(TypeName(type)) != std::string::npos) { // NOLINT // We only expect one match. HWY_ASSERT_M(found_type == Type::kUnknown, specifier.c_str()); found_type = type; } } HWY_ASSERT_M(found_type != Type::kUnknown, specifier.c_str()); return found_type; } static PromptWrapping FindWrapping(const std::string& specifier) { PromptWrapping found_wrapping = PromptWrapping::kSentinel; for (size_t i = 0; i < static_cast(PromptWrapping::kSentinel); ++i) { const PromptWrapping w = static_cast(i); if (specifier.find(WrappingSuffix(w)) != std::string::npos) { // NOLINT // We expect zero or one match. HWY_ASSERT_M(found_wrapping == PromptWrapping::kSentinel, specifier.c_str()); found_wrapping = w; } } if (found_wrapping == PromptWrapping::kSentinel) { return ChooseWrapping(FindModel(specifier)); } return found_wrapping; } // Obtains model/weight/wrapping by finding prefix and suffix strings. ModelConfig::ModelConfig(const std::string& specifier) : ModelConfig(FindModel(specifier), FindType(specifier), FindWrapping(specifier)) {} std::string ModelConfig::Specifier() const { HWY_ASSERT(model != Model::UNKNOWN); HWY_ASSERT(weight != Type::kUnknown); HWY_ASSERT(wrapping != PromptWrapping::kSentinel); std::string base_name = ModelPrefix(model); base_name += '-'; base_name += TypeName(weight); if (wrapping != PromptWrapping::GEMMA_VLM && wrapping != PromptWrapping::PALIGEMMA) { base_name += WrappingSuffix(wrapping); } return base_name; } // Returns whether all fields match. static bool AllEqual(const IFields& a, const IFields& b, bool print) { const std::vector serialized_a = a.Write(); const std::vector serialized_b = b.Write(); if (serialized_a != serialized_b) { if (print) { fprintf(stderr, "%s differs. Recommend generating a diff:\n", a.Name()); a.Print(); b.Print(); } return false; } return true; } bool LayerConfig::TestEqual(const LayerConfig& other, bool print) const { return AllEqual(*this, other, print); } bool VitConfig::TestEqual(const VitConfig& other, bool print) const { return AllEqual(*this, other, print); } bool ModelConfig::TestEqual(const ModelConfig& other, bool print) const { // Early out to guard the loop below; a differing number of layers will anyway // cause a mismatch. if (layer_configs.size() != other.layer_configs.size()) { if (print) { HWY_WARN("Layer configs size mismatch %zu vs %zu", layer_configs.size(), other.layer_configs.size()); } return false; } // Copy so we can 'ignore' fields by setting them to the same value. ModelConfig a = *this; ModelConfig b = other; // Called by `OverwriteWithCanonical`, so ignore the fields it will set. a.display_name = b.display_name; a.model = b.model; // The following are not yet set by config_converter.py, so we here ignore // them for purposes of comparison, and there overwrite the converter's config // with the canonical ModelConfig constructed via (deduced) enum, so that // these fields will be set. // `vit_config` is also not yet set, but we must not ignore it because // otherwise PaliGemma models will be indistinguishable for `configs_test`. a.pool_dim = b.pool_dim; // ViT a.eos_id = b.eos_id; a.secondary_eos_id = b.secondary_eos_id; a.scale_base_names = b.scale_base_names; for (size_t i = 0; i < a.layer_configs.size(); ++i) { a.layer_configs[i].optimized_gating = b.layer_configs[i].optimized_gating; } return AllEqual(a, b, print); } // Constructs the canonical ModelConfig for each model. If there is one for // which TestEqual returns true, overwrites `*this` with that and returns true. bool ModelConfig::OverwriteWithCanonical() { bool found = false; const bool print = false; ForEachModel([&](Model model) { const ModelConfig config(model, weight, wrapping); if (config.TestEqual(*this, print)) { HWY_ASSERT(!found); // Should only find one. found = true; *this = config; } }); return found; } Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) { switch (layers) { case 2: return Model::GEMMA_TINY; case 18: return Model::GEMMA3_270M; case 26: if (layer_types & kDeducedGriffin) return Model::GRIFFIN_2B; if (layer_types & kDeducedViT) return Model::GEMMA3_1B; return Model::GEMMA2_2B; case 27: return (layer_types & kDeduced448) ? Model::PALIGEMMA2_3B_448 : Model::PALIGEMMA2_3B_224; case 34: return Model::GEMMA3_4B; case 42: if (layer_types & kDeducedViT) { return (layer_types & kDeduced448) ? Model::PALIGEMMA2_10B_448 : Model::PALIGEMMA2_10B_224; } return Model::GEMMA2_9B; case 46: return Model::GEMMA2_27B; case 48: return Model::GEMMA3_12B; case 62: return Model::GEMMA3_27B; // TODO: detect these. /* return Model::GEMMA2_772M; return Model::PALIGEMMA2_772M_224; */ default: HWY_WARN("Failed to deduce model type from %s, layer count %zu types %x.", blob_path.path.c_str(), layers, layer_types); return Model::UNKNOWN; } } } // namespace gcpp