Internal change.

PiperOrigin-RevId: 736014152
This commit is contained in:
Phil Culliton 2025-03-11 23:10:08 -07:00 committed by Copybara-Service
parent 2bdf26d81d
commit 9d83ff202e
3 changed files with 13 additions and 2 deletions

View File

@ -186,11 +186,16 @@ constexpr bool IsNuqStream() {
}
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class PromptWrapping { GEMMA_IT, GEMMA_PT, PALIGEMMA };
enum class PromptWrapping {
GEMMA_IT,
GEMMA_PT,
PALIGEMMA,
kSentinel // must be last
};
inline bool EnumValid(PromptWrapping type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(PromptWrapping::PALIGEMMA);
static_cast<int>(type) < static_cast<int>(PromptWrapping::kSentinel);
}
// Tensor types for loading weights. Note that not all types are supported as

View File

@ -290,6 +290,7 @@ ModelConfig GetVitConfig(const ModelConfig& config) {
vit_config.model_dim = config.vit_config.model_dim;
vit_config.seq_len = config.vit_config.seq_len;
vit_config.layer_configs = config.vit_config.layer_configs;
vit_config.pool_dim = config.vit_config.pool_dim;
// The Vit part does not have a vocabulary, the image patches are embedded.
vit_config.vocab_size = 0;
return vit_config;

View File

@ -22,6 +22,7 @@
#include <algorithm>
#include <array>
#include <cstdint>
#include <string>
#include <unordered_set>
#include <vector>
@ -236,6 +237,7 @@ struct VitConfig : public IFields {
visitor(patch_width);
visitor(image_size);
visitor(layer_configs);
visitor(pool_dim);
}
uint32_t model_dim = 0;
@ -243,6 +245,7 @@ struct VitConfig : public IFields {
uint32_t num_scales = 0;
uint32_t patch_width = 14;
uint32_t image_size = 224;
uint32_t pool_dim = 1;
std::vector<LayerConfig> layer_configs;
};
@ -304,6 +307,7 @@ struct ModelConfig : public IFields {
visitor(attention_window_sizes);
visitor(norm_num_groups);
visitor(vit_config);
visitor(pool_dim);
}
// Major version of the model family. It is used as a fallback to distinguish
@ -329,6 +333,7 @@ struct ModelConfig : public IFields {
uint32_t norm_num_groups = 1;
// Dimensions related to image processing.
VitConfig vit_config;
uint32_t pool_dim = 1; // used only for VitConfig copy
};
// Returns the config for the given model.