diff --git a/gemma/configs.cc b/gemma/configs.cc index b8048b8..70fd8fa 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -678,7 +678,9 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) { return Model::GEMMA3_270M; case 26: - if (layer_types & kDeducedViT) return Model::GEMMA3_1B; + if (layer_types & (kDeducedViT|kDeducedKqNorm)) { + return Model::GEMMA3_1B; + } return Model::GEMMA2_2B; case 27: return (layer_types & kDeduced448) ? Model::PALIGEMMA2_3B_448 diff --git a/gemma/configs.h b/gemma/configs.h index e2cb5e2..5de74bf 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -513,6 +513,7 @@ ModelConfig GetVitConfig(const ModelConfig& config); enum DeducedLayerTypes { kDeducedViT = 2, kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224. + kDeducedKqNorm = 8, }; // layer_types is one or more of `DeducedLayerTypes`. diff --git a/gemma/model_store.cc b/gemma/model_store.cc index 2f3e1ec..204dee9 100644 --- a/gemma/model_store.cc +++ b/gemma/model_store.cc @@ -221,6 +221,8 @@ static size_t DeduceNumLayers(const KeyVec& keys) { // This works with or without type prefixes because it searches for substrings. static int DeduceLayerTypes(const BlobReader& reader) { int layer_types = 0; + bool has_key_norm = false; + bool has_query_norm = false; for (size_t key_idx = 0; key_idx < reader.Keys().size(); ++key_idx) { const std::string& key = reader.Keys()[key_idx]; if (key.find("qkv_ein_w") != std::string::npos) { // NOLINT @@ -232,6 +234,15 @@ static int DeduceLayerTypes(const BlobReader& reader) { layer_types |= kDeduced448; } } + if (key.find("key_norm") != std::string::npos) { // NOLINT + has_key_norm = true; + } + if (key.find("query_norm") != std::string::npos) { // NOLINT + has_query_norm = true; + } + } + if (has_key_norm && has_query_norm) { + layer_types |= kDeducedKqNorm; } return layer_types; }