From 9d83ff202ec9c4ea7983ea7820981471b0e0a5d1 Mon Sep 17 00:00:00 2001 From: Phil Culliton Date: Tue, 11 Mar 2025 23:10:08 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 736014152 --- compression/shared.h | 9 +++++++-- gemma/configs.cc | 1 + gemma/configs.h | 5 +++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/compression/shared.h b/compression/shared.h index a947fee..29b07cc 100644 --- a/compression/shared.h +++ b/compression/shared.h @@ -186,11 +186,16 @@ constexpr bool IsNuqStream() { } // Instruction-tuned models require extra 'turn structure' tokens in prompts. -enum class PromptWrapping { GEMMA_IT, GEMMA_PT, PALIGEMMA }; +enum class PromptWrapping { + GEMMA_IT, + GEMMA_PT, + PALIGEMMA, + kSentinel // must be last +}; inline bool EnumValid(PromptWrapping type) { return static_cast(type) >= 0 && - static_cast(type) <= static_cast(PromptWrapping::PALIGEMMA); + static_cast(type) < static_cast(PromptWrapping::kSentinel); } // Tensor types for loading weights. Note that not all types are supported as diff --git a/gemma/configs.cc b/gemma/configs.cc index 2c4f887..9372ee1 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -290,6 +290,7 @@ ModelConfig GetVitConfig(const ModelConfig& config) { vit_config.model_dim = config.vit_config.model_dim; vit_config.seq_len = config.vit_config.seq_len; vit_config.layer_configs = config.vit_config.layer_configs; + vit_config.pool_dim = config.vit_config.pool_dim; // The Vit part does not have a vocabulary, the image patches are embedded. vit_config.vocab_size = 0; return vit_config; diff --git a/gemma/configs.h b/gemma/configs.h index 42693e6..d7078b3 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -22,6 +22,7 @@ #include #include +#include #include #include #include @@ -236,6 +237,7 @@ struct VitConfig : public IFields { visitor(patch_width); visitor(image_size); visitor(layer_configs); + visitor(pool_dim); } uint32_t model_dim = 0; @@ -243,6 +245,7 @@ struct VitConfig : public IFields { uint32_t num_scales = 0; uint32_t patch_width = 14; uint32_t image_size = 224; + uint32_t pool_dim = 1; std::vector layer_configs; }; @@ -304,6 +307,7 @@ struct ModelConfig : public IFields { visitor(attention_window_sizes); visitor(norm_num_groups); visitor(vit_config); + visitor(pool_dim); } // Major version of the model family. It is used as a fallback to distinguish @@ -329,6 +333,7 @@ struct ModelConfig : public IFields { uint32_t norm_num_groups = 1; // Dimensions related to image processing. VitConfig vit_config; + uint32_t pool_dim = 1; // used only for VitConfig copy }; // Returns the config for the given model.