diff --git a/gemma/common.cc b/gemma/common.cc index 75d9282..d5fa0da 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -28,36 +28,47 @@ namespace gcpp { +constexpr const char* kModelFlags[] = { + "2b-pt", "2b-it", // Gemma 2B + "7b-pt", "7b-it", // Gemma 7B + "gr2b-pt", "gr2b-it", // RecurrentGemma + "tiny", // Gemma Tiny (mostly for debugging) +}; +constexpr Model kModelTypes[] = { + Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B + Model::GEMMA_7B, Model::GEMMA_7B, // Gemma 7B + Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma + Model::GEMMA_TINY, // Gemma Tiny +}; +constexpr ModelTraining kModelTraining[] = { + ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B + ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 7B + ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma + ModelTraining::GEMMA_IT, // Gemma Tiny +}; + +constexpr size_t kNumModelFlags = std::end(kModelFlags) - std::begin(kModelFlags); +static_assert(kNumModelFlags == + std::end(kModelTypes) - std::begin(kModelTypes)); +static_assert(kNumModelFlags == + std::end(kModelTraining) - std::begin(kModelTraining)); + const char* ParseModelTypeAndTraining(const std::string& model_flag, Model& model, ModelTraining& training) { - constexpr const char* kModelFlags[] = { - "2b-pt", "7b-pt", "gr2b-pt", "2b-it", "7b-it", "gr2b-it", "tiny", - }; - constexpr Model kModelTypes[] = { - Model::GEMMA_2B, Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_2B, - Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_TINY, - }; - constexpr ModelTraining kModelTraining[] = { - ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, - ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, - ModelTraining::GEMMA_IT, - }; - - constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags); - static char kErrorMessageBuffer[kNum * 8 + 1024] = + static char kErrorMessageBuffer[kNumModelFlags * 8 + 1024] = "Invalid or missing model flag, need to specify one of "; - for (size_t i = 0; i + 1 < kNum; i++) { + for (size_t i = 0; i + 1 < kNumModelFlags; i++) { strcat(kErrorMessageBuffer, kModelFlags[i]); // NOLINT strcat(kErrorMessageBuffer, ", "); // NOLINT } - strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]); // NOLINT + strcat(kErrorMessageBuffer, kModelFlags[kNumModelFlags - 1]); // NOLINT strcat(kErrorMessageBuffer, "."); // NOLINT std::string model_type_lc = model_flag; std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc), [](unsigned char c) { return std::tolower(c); }); - for (size_t i = 0; i < kNum; i++) { + for (size_t i = 0; i < kNumModelFlags; i++) { if (kModelFlags[i] == model_type_lc) { model = kModelTypes[i]; training = kModelTraining[i]; @@ -69,14 +80,10 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag, } const char* ModelString(Model model, ModelTraining training) { - if (model == Model::GEMMA_TINY) return "tiny"; - static_assert(static_cast(ModelTraining::GEMMA_IT) == 0); - constexpr const char* k2B[] = {"2b-it", "2b-pt"}; - constexpr const char* k7B[] = {"7b-it", "7b-pt"}; - constexpr const char* kGr2B[] = {"gr2b-it", "gr2b-pt"}; - if (model == Model::GEMMA_2B) return k2B[static_cast(training)]; - if (model == Model::GEMMA_7B) return k7B[static_cast(training)]; - if (model == Model::GRIFFIN_2B) return kGr2B[static_cast(training)]; + for (size_t i = 0; i < kNumModelFlags; i++) { + if (kModelTypes[i] == model && kModelTraining[i] == training) + return kModelFlags[i]; + } HWY_ABORT("Unknown model %d training %d\n", static_cast(model), static_cast(training)); }