diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 29c189e..fc580e2 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -48,7 +48,7 @@ class GemmaTest : public ::testing::Test { // Using the turn structure worsens results sometimes. // However, gemma-2 27B seems to need the turn structure to work. // It would be good to make these tests more consistent. - if (s_env->GetModel()->Info().model == Model::GEMMA_27B) { + if (s_env->GetModel()->Info().model == Model::GEMMA2_27B) { std::string mutable_prompt = prompt; auto [response, n] = s_env->QueryModel(mutable_prompt); // Uses turns. return response; @@ -68,7 +68,7 @@ class GemmaTest : public ::testing::Test { // Using the turn structure worsens results sometimes. // However, gemma-2 27B seems to need the turn structure to work. // It would be good to make these tests more consistent. - if (s_env->GetModel()->Info().model == Model::GEMMA_27B) { + if (s_env->GetModel()->Info().model == Model::GEMMA2_27B) { for (auto [response, n] : s_env->BatchQueryModel(inputs)) { replies.push_back(response); } @@ -199,10 +199,10 @@ TEST_F(GemmaTest, CrossEntropySmall) { // 7B v.1 and v.1.1 produce slightly different results. EXPECT_NEAR(entropy, 2.8f, 0.2f); break; - case gcpp::Model::GEMMA_9B: + case gcpp::Model::GEMMA2_9B: EXPECT_NEAR(entropy, 1.28f, 0.02f); break; - case gcpp::Model::GEMMA_27B: + case gcpp::Model::GEMMA2_27B: EXPECT_NEAR(entropy, 1.30f, 0.02f); break; default: @@ -224,10 +224,10 @@ TEST_F(GemmaTest, CrossEntropyJingleBells) { // 7B v.1 and v.1.1 produce slightly different results. EXPECT_NEAR(entropy, 1.07f, 0.05f); break; - case gcpp::Model::GEMMA_9B: + case gcpp::Model::GEMMA2_9B: EXPECT_NEAR(entropy, 0.37f, 0.02f); break; - case gcpp::Model::GEMMA_27B: + case gcpp::Model::GEMMA2_27B: EXPECT_NEAR(entropy, 0.33f, 0.02f); break; default: @@ -249,10 +249,10 @@ TEST_F(GemmaTest, CrossEntropyGettysburg) { // 7B v.1 and v.1.1 produce slightly different results. EXPECT_NEAR(entropy, 0.75f, 0.1f); break; - case gcpp::Model::GEMMA_9B: + case gcpp::Model::GEMMA2_9B: EXPECT_NEAR(entropy, 0.15f, 0.02f); break; - case gcpp::Model::GEMMA_27B: + case gcpp::Model::GEMMA2_27B: EXPECT_NEAR(entropy, 0.14f, 0.02f); break; default: diff --git a/gemma/common.cc b/gemma/common.cc index 1d50743..4161686 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -31,29 +31,29 @@ namespace gcpp { constexpr const char* kModelFlags[] = { "2b-pt", "2b-it", // Gemma 2B "7b-pt", "7b-it", // Gemma 7B - "9b-pt", "9b-it", // Gemma 9B - "27b-pt", "27b-it", // Gemma 27B "gr2b-pt", "gr2b-it", // RecurrentGemma "tiny", // Gemma Tiny (mostly for debugging) "gemma2-2b-pt", "gemma2-2b-it", // Gemma2 2B + "9b-pt", "9b-it", // Gemma2 9B + "27b-pt", "27b-it", // Gemma2 27B }; constexpr Model kModelTypes[] = { Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B Model::GEMMA_7B, Model::GEMMA_7B, // Gemma 7B - Model::GEMMA_9B, Model::GEMMA_9B, // Gemma 9B - Model::GEMMA_27B, Model::GEMMA_27B, // Gemma 27B Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma Model::GEMMA_TINY, // Gemma Tiny Model::GEMMA2_2B, Model::GEMMA2_2B, // Gemma2 2B + Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B + Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B }; constexpr ModelTraining kModelTraining[] = { ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 7B - ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 9B - ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 27B ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma ModelTraining::GEMMA_IT, // Gemma Tiny - ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B2 + ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 2B + ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 9B + ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 27B }; constexpr size_t kNumModelFlags = diff --git a/gemma/common.h b/gemma/common.h index f0498ba..ea541d9 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -41,8 +41,8 @@ ByteStorageT AllocateSizeof() { enum class Model { GEMMA_2B, GEMMA_7B, - GEMMA_9B, - GEMMA_27B, + GEMMA2_9B, + GEMMA2_27B, GRIFFIN_2B, GEMMA_TINY, GEMMA2_2B, @@ -94,10 +94,10 @@ decltype(auto) CallForModel(Model model, TArgs&&... args) { return FuncT>()(std::forward(args)...); case Model::GEMMA_7B: return FuncT>()(std::forward(args)...); - case Model::GEMMA_9B: - return FuncT>()(std::forward(args)...); - case Model::GEMMA_27B: - return FuncT>()(std::forward(args)...); + case Model::GEMMA2_9B: + return FuncT>()(std::forward(args)...); + case Model::GEMMA2_27B: + return FuncT>()(std::forward(args)...); case Model::GRIFFIN_2B: return FuncT>()(std::forward(args)...); case Model::GEMMA2_2B: @@ -143,10 +143,10 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight, GEMMA_FOREACH_WEIGHT(X, ConfigGemmaTiny) \ GEMMA_FOREACH_WEIGHT(X, ConfigGemma2B) \ GEMMA_FOREACH_WEIGHT(X, ConfigGemma7B) \ - GEMMA_FOREACH_WEIGHT(X, ConfigGemma9B) \ - GEMMA_FOREACH_WEIGHT(X, ConfigGemma27B) \ GEMMA_FOREACH_WEIGHT(X, ConfigGriffin2B) \ GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_2B) \ + GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_9B) \ + GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_27B) \ static_assert(true, "Allow trailing ;") // Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float), @@ -168,16 +168,6 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight, ARGS; \ break; \ } \ - case Model::GEMMA_9B: { \ - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC>) \ - ARGS; \ - break; \ - } \ - case Model::GEMMA_27B: { \ - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC>) \ - ARGS; \ - break; \ - } \ case Model::GRIFFIN_2B: { \ HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC>) \ ARGS; \ @@ -188,6 +178,16 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight, ARGS; \ break; \ } \ + case Model::GEMMA2_9B: { \ + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC>) \ + ARGS; \ + break; \ + } \ + case Model::GEMMA2_27B: { \ + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC>) \ + ARGS; \ + break; \ + } \ default: \ HWY_ABORT("Model type %d unknown.", static_cast(MODEL)); \ } diff --git a/gemma/configs.h b/gemma/configs.h index be995f9..3ddcb41 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -167,7 +167,7 @@ struct ConfigBaseGemmaV2 : ConfigNoSSM { }; template -struct ConfigGemma27B : public ConfigBaseGemmaV2 { +struct ConfigGemma2_27B : public ConfigBaseGemmaV2 { using Weight = TWeight; // make accessible where we only have a TConfig static constexpr int kSeqLen = 8192; @@ -190,7 +190,7 @@ struct ConfigGemma27B : public ConfigBaseGemmaV2 { }; template -struct ConfigGemma9B : public ConfigBaseGemmaV2 { +struct ConfigGemma2_9B : public ConfigBaseGemmaV2 { using Weight = TWeight; // make accessible where we only have a TConfig static constexpr int kSeqLen = 8192; diff --git a/gemma/instantiations/27b_bf16.cc b/gemma/instantiations/27b_bf16.cc index bd68679..8698c7a 100644 --- a/gemma/instantiations/27b_bf16.cc +++ b/gemma/instantiations/27b_bf16.cc @@ -17,5 +17,5 @@ #define HWY_TARGET_INCLUDE \ "gemma/instantiations/27b_bf16.cc" #include "hwy/foreach_target.h" // IWYU pragma: keep -#define GEMMA_CONFIG ConfigGemma27B +#define GEMMA_CONFIG ConfigGemma2_27B #include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/27b_f32.cc b/gemma/instantiations/27b_f32.cc index 75a200d..f4b5d6c 100644 --- a/gemma/instantiations/27b_f32.cc +++ b/gemma/instantiations/27b_f32.cc @@ -17,5 +17,5 @@ #define HWY_TARGET_INCLUDE \ "gemma/instantiations/27b_f32.cc" #include "hwy/foreach_target.h" // IWYU pragma: keep -#define GEMMA_CONFIG ConfigGemma27B +#define GEMMA_CONFIG ConfigGemma2_27B #include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/27b_sfp.cc b/gemma/instantiations/27b_sfp.cc index 5a268e4..7d0072a 100644 --- a/gemma/instantiations/27b_sfp.cc +++ b/gemma/instantiations/27b_sfp.cc @@ -17,5 +17,5 @@ #define HWY_TARGET_INCLUDE \ "gemma/instantiations/27b_sfp.cc" #include "hwy/foreach_target.h" // IWYU pragma: keep -#define GEMMA_CONFIG ConfigGemma27B +#define GEMMA_CONFIG ConfigGemma2_27B #include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/9b_bf16.cc b/gemma/instantiations/9b_bf16.cc index df72954..1cd5d13 100644 --- a/gemma/instantiations/9b_bf16.cc +++ b/gemma/instantiations/9b_bf16.cc @@ -17,5 +17,5 @@ #define HWY_TARGET_INCLUDE \ "gemma/instantiations/9b_bf16.cc" #include "hwy/foreach_target.h" // IWYU pragma: keep -#define GEMMA_CONFIG ConfigGemma9B +#define GEMMA_CONFIG ConfigGemma2_9B #include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/9b_f32.cc b/gemma/instantiations/9b_f32.cc index 26f0eed..d96b279 100644 --- a/gemma/instantiations/9b_f32.cc +++ b/gemma/instantiations/9b_f32.cc @@ -17,5 +17,5 @@ #define HWY_TARGET_INCLUDE \ "gemma/instantiations/9b_f32.cc" #include "hwy/foreach_target.h" // IWYU pragma: keep -#define GEMMA_CONFIG ConfigGemma9B +#define GEMMA_CONFIG ConfigGemma2_9B #include "gemma/gemma-inl.h" diff --git a/gemma/instantiations/9b_sfp.cc b/gemma/instantiations/9b_sfp.cc index 17aefb1..b822524 100644 --- a/gemma/instantiations/9b_sfp.cc +++ b/gemma/instantiations/9b_sfp.cc @@ -17,5 +17,5 @@ #define HWY_TARGET_INCLUDE \ "gemma/instantiations/9b_sfp.cc" #include "hwy/foreach_target.h" // IWYU pragma: keep -#define GEMMA_CONFIG ConfigGemma9B +#define GEMMA_CONFIG ConfigGemma2_9B #include "gemma/gemma-inl.h"