Improves autodetection of Gemma3-1B.

Uses the key_norm and query_norm layers to disambiguate between the Gemma2-2B and Gemma3-1B models.
Since Gemma3-1B is not multimodal, ViT is not an effective disambiguator. KQ normalization is a structural disambiguator between gemma2 and gemma3.

PiperOrigin-RevId: 833213331
This commit is contained in:
The gemma.cpp Authors 2025-11-17 01:12:07 -08:00 committed by Copybara-Service
parent 7c1656f2fc
commit b8f6be72b1
3 changed files with 15 additions and 1 deletions

View File

@ -678,7 +678,9 @@ Model DeduceModel(const Path& blob_path, size_t layers, int layer_types) {
return Model::GEMMA3_270M;
case 26:
if (layer_types & kDeducedViT) return Model::GEMMA3_1B;
if (layer_types & (kDeducedViT|kDeducedKqNorm)) {
return Model::GEMMA3_1B;
}
return Model::GEMMA2_2B;
case 27:
return (layer_types & kDeduced448) ? Model::PALIGEMMA2_3B_448

View File

@ -513,6 +513,7 @@ ModelConfig GetVitConfig(const ModelConfig& config);
enum DeducedLayerTypes {
kDeducedViT = 2,
kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224.
kDeducedKqNorm = 8,
};
// layer_types is one or more of `DeducedLayerTypes`.

View File

@ -221,6 +221,8 @@ static size_t DeduceNumLayers(const KeyVec& keys) {
// This works with or without type prefixes because it searches for substrings.
static int DeduceLayerTypes(const BlobReader& reader) {
int layer_types = 0;
bool has_key_norm = false;
bool has_query_norm = false;
for (size_t key_idx = 0; key_idx < reader.Keys().size(); ++key_idx) {
const std::string& key = reader.Keys()[key_idx];
if (key.find("qkv_ein_w") != std::string::npos) { // NOLINT
@ -232,6 +234,15 @@ static int DeduceLayerTypes(const BlobReader& reader) {
layer_types |= kDeduced448;
}
}
if (key.find("key_norm") != std::string::npos) { // NOLINT
has_key_norm = true;
}
if (key.find("query_norm") != std::string::npos) { // NOLINT
has_query_norm = true;
}
}
if (has_key_norm && has_query_norm) {
layer_types |= kDeducedKqNorm;
}
return layer_types;
}