Internal change.

PiperOrigin-RevId: 788463042
This commit is contained in:
Jeremiah Harmsen 2025-07-29 08:20:36 -07:00 committed by Copybara-Service
parent e76e29ce11
commit 33fabd4ed1
1 changed files with 9 additions and 7 deletions

View File

@ -30,6 +30,8 @@ namespace gcpp {
static constexpr size_t kVocabSize = 256000; static constexpr size_t kVocabSize = 256000;
static constexpr size_t kGemmaV3VocabSize = 262144;
static ModelConfig ConfigNoSSM() { static ModelConfig ConfigNoSSM() {
ModelConfig config; ModelConfig config;
config.scale_base_names = {"att_ein", "qkv_ein", "gr_lin_x_w", config.scale_base_names = {"att_ein", "qkv_ein", "gr_lin_x_w",
@ -309,7 +311,7 @@ static ModelConfig ConfigGemma3_1B() {
config.model = Model::GEMMA3_1B; config.model = Model::GEMMA3_1B;
config.wrapping = PromptWrapping::GEMMA_VLM; config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 1152; config.model_dim = 1152;
config.vocab_size = 262144; // new vocab size / tokenizer config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
config.max_seq_len = 32 * 1024; config.max_seq_len = 32 * 1024;
LayerConfig layer_config = LayerConfigGemma3_1B_LM(config.model_dim); LayerConfig layer_config = LayerConfigGemma3_1B_LM(config.model_dim);
config.num_layers = 26; config.num_layers = 26;
@ -341,7 +343,7 @@ static ModelConfig ConfigGemma3_4B_LM() {
config.model = Model::GEMMA3_4B; config.model = Model::GEMMA3_4B;
config.wrapping = PromptWrapping::GEMMA_VLM; config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 2560; config.model_dim = 2560;
config.vocab_size = 262144; // new vocab size / tokenizer config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
config.max_seq_len = 32 * 1024; config.max_seq_len = 32 * 1024;
LayerConfig layer_config = LayerConfigGemma3_4B_LM(config.model_dim); LayerConfig layer_config = LayerConfigGemma3_4B_LM(config.model_dim);
config.num_layers = 34; config.num_layers = 34;
@ -359,7 +361,7 @@ static ModelConfig ConfigGemma3_4B() {
config.model = Model::GEMMA3_4B; config.model = Model::GEMMA3_4B;
config.wrapping = PromptWrapping::GEMMA_VLM; config.wrapping = PromptWrapping::GEMMA_VLM;
AddVitConfig(config, /*image_size=*/896); AddVitConfig(config, /*image_size=*/896);
config.vocab_size = 262144; config.vocab_size = kGemmaV3VocabSize;
config.vit_config.pool_dim = 4; config.vit_config.pool_dim = 4;
const size_t num_patches = const size_t num_patches =
config.vit_config.image_size / config.vit_config.patch_width; config.vit_config.image_size / config.vit_config.patch_width;
@ -390,7 +392,7 @@ static ModelConfig ConfigGemma3_12B_LM() {
config.model = Model::GEMMA3_12B; config.model = Model::GEMMA3_12B;
config.wrapping = PromptWrapping::GEMMA_VLM; config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 3840; config.model_dim = 3840;
config.vocab_size = 262144; // new vocab size / tokenizer config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
config.max_seq_len = 32 * 1024; config.max_seq_len = 32 * 1024;
LayerConfig layer_config = LayerConfigGemma3_12B_LM(config.model_dim); LayerConfig layer_config = LayerConfigGemma3_12B_LM(config.model_dim);
config.num_layers = 48; config.num_layers = 48;
@ -408,7 +410,7 @@ static ModelConfig ConfigGemma3_12B() {
config.model = Model::GEMMA3_12B; config.model = Model::GEMMA3_12B;
config.wrapping = PromptWrapping::GEMMA_VLM; config.wrapping = PromptWrapping::GEMMA_VLM;
AddVitConfig(config, /*image_size=*/896); AddVitConfig(config, /*image_size=*/896);
config.vocab_size = 262144; config.vocab_size = kGemmaV3VocabSize;
config.vit_config.pool_dim = 4; config.vit_config.pool_dim = 4;
const size_t num_patches = const size_t num_patches =
config.vit_config.image_size / config.vit_config.patch_width; config.vit_config.image_size / config.vit_config.patch_width;
@ -439,7 +441,7 @@ static ModelConfig ConfigGemma3_27B_LM() {
config.model = Model::GEMMA3_27B; config.model = Model::GEMMA3_27B;
config.wrapping = PromptWrapping::GEMMA_VLM; config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 5376; config.model_dim = 5376;
config.vocab_size = 262144; // new vocab size / tokenizer config.vocab_size = kGemmaV3VocabSize; // new vocab size / tokenizer
config.max_seq_len = 32 * 1024; config.max_seq_len = 32 * 1024;
LayerConfig layer_config = LayerConfigGemma3_27B_LM(config.model_dim); LayerConfig layer_config = LayerConfigGemma3_27B_LM(config.model_dim);
config.num_layers = 62; config.num_layers = 62;
@ -457,7 +459,7 @@ static ModelConfig ConfigGemma3_27B() {
config.model = Model::GEMMA3_27B; config.model = Model::GEMMA3_27B;
config.wrapping = PromptWrapping::GEMMA_VLM; config.wrapping = PromptWrapping::GEMMA_VLM;
AddVitConfig(config, /*image_size=*/896); AddVitConfig(config, /*image_size=*/896);
config.vocab_size = 262144; config.vocab_size = kGemmaV3VocabSize;
config.vit_config.pool_dim = 4; config.vit_config.pool_dim = 4;
const size_t num_patches = const size_t num_patches =
config.vit_config.image_size / config.vit_config.patch_width; config.vit_config.image_size / config.vit_config.patch_width;