Internal change

PiperOrigin-RevId: 702961613
This commit is contained in:
Phil Culliton 2024-12-04 20:41:07 -08:00 committed by Copybara-Service
parent 6a34e9c547
commit 9dfe2a76be
3 changed files with 29 additions and 1 deletions

View File

@ -39,6 +39,8 @@ constexpr const char* kModelFlags[] = {
"9b-pt", "9b-it", // Gemma2 9B "9b-pt", "9b-it", // Gemma2 9B
"27b-pt", "27b-it", // Gemma2 27B "27b-pt", "27b-it", // Gemma2 27B
"paligemma-224", // PaliGemma 224 "paligemma-224", // PaliGemma 224
"paligemma2-3b-224", // PaliGemma2 3B 224
"paligemma2-10b-224", // PaliGemma2 10B 224
}; };
constexpr Model kModelTypes[] = { constexpr Model kModelTypes[] = {
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
@ -49,6 +51,8 @@ constexpr Model kModelTypes[] = {
Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B
Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B
Model::PALIGEMMA_224, // PaliGemma 224 Model::PALIGEMMA_224, // PaliGemma 224
Model::PALIGEMMA2_3B_224, // PaliGemma2 3B 224
Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224
}; };
constexpr ModelTraining kModelTraining[] = { constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
@ -59,6 +63,8 @@ constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 9B ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 9B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 27B ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 27B
ModelTraining::PALIGEMMA, // PaliGemma 224 ModelTraining::PALIGEMMA, // PaliGemma 224
ModelTraining::PALIGEMMA, // PaliGemma2 3B 224
ModelTraining::PALIGEMMA, // PaliGemma2 10B 224
}; };
constexpr size_t kNumModelFlags = constexpr size_t kNumModelFlags =

View File

@ -246,6 +246,22 @@ ModelConfig VitConfig(const ModelConfig& config) {
return vit_config; return vit_config;
} }
static ModelConfig ConfigPaliGemma2_3B_224() {
ModelConfig config = ConfigGemma2_2B();
config.model_name = "PaliGemma2_3B_224";
config.model = Model::PALIGEMMA2_3B_224;
AddVitConfig(config);
return config;
}
static ModelConfig ConfigPaliGemma2_10B_224() {
ModelConfig config = ConfigGemma2_9B();
config.model_name = "PaliGemma2_10B_224";
config.model = Model::PALIGEMMA2_10B_224;
AddVitConfig(config);
return config;
}
ModelConfig ConfigFromModel(Model model) { ModelConfig ConfigFromModel(Model model) {
switch (model) { switch (model) {
case Model::GEMMA_2B: case Model::GEMMA_2B:
@ -264,6 +280,10 @@ ModelConfig ConfigFromModel(Model model) {
return ConfigGemmaTiny(); return ConfigGemmaTiny();
case Model::PALIGEMMA_224: case Model::PALIGEMMA_224:
return ConfigPaliGemma_224(); return ConfigPaliGemma_224();
case Model::PALIGEMMA2_3B_224:
return ConfigPaliGemma2_3B_224();
case Model::PALIGEMMA2_10B_224:
return ConfigPaliGemma2_10B_224();
default: default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model)); HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
} }

View File

@ -114,13 +114,15 @@ enum class Model {
GEMMA_TINY, GEMMA_TINY,
GEMMA2_2B, GEMMA2_2B,
PALIGEMMA_224, PALIGEMMA_224,
PALIGEMMA2_3B_224,
PALIGEMMA2_10B_224,
}; };
// Allows the Model enum to be iterated over. // Allows the Model enum to be iterated over.
static constexpr Model kAllModels[] = { static constexpr Model kAllModels[] = {
Model::GEMMA_2B, Model::GEMMA_7B, Model::GEMMA2_9B, Model::GEMMA2_27B, Model::GEMMA_2B, Model::GEMMA_7B, Model::GEMMA2_9B, Model::GEMMA2_27B,
Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B, Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B,
Model::PALIGEMMA_224, Model::PALIGEMMA_224, Model::PALIGEMMA2_3B_224, Model::PALIGEMMA2_10B_224,
}; };
struct LayerConfig { struct LayerConfig {