Rename Gemma9B and Gemma27B to Gemma2_9B and Gemma2_27B.

This is to make it clear that these models are part of the Gemma2 family of models.

PiperOrigin-RevId: 661181682
This commit is contained in:
Apoorv Reddy 2024-08-09 02:08:40 -07:00 committed by Copybara-Service
parent 2ebbe4076f
commit fd1b0743a7
10 changed files with 41 additions and 41 deletions

View File

@ -48,7 +48,7 @@ class GemmaTest : public ::testing::Test {
// Using the turn structure worsens results sometimes. // Using the turn structure worsens results sometimes.
// However, gemma-2 27B seems to need the turn structure to work. // However, gemma-2 27B seems to need the turn structure to work.
// It would be good to make these tests more consistent. // It would be good to make these tests more consistent.
if (s_env->GetModel()->Info().model == Model::GEMMA_27B) { if (s_env->GetModel()->Info().model == Model::GEMMA2_27B) {
std::string mutable_prompt = prompt; std::string mutable_prompt = prompt;
auto [response, n] = s_env->QueryModel(mutable_prompt); // Uses turns. auto [response, n] = s_env->QueryModel(mutable_prompt); // Uses turns.
return response; return response;
@ -68,7 +68,7 @@ class GemmaTest : public ::testing::Test {
// Using the turn structure worsens results sometimes. // Using the turn structure worsens results sometimes.
// However, gemma-2 27B seems to need the turn structure to work. // However, gemma-2 27B seems to need the turn structure to work.
// It would be good to make these tests more consistent. // It would be good to make these tests more consistent.
if (s_env->GetModel()->Info().model == Model::GEMMA_27B) { if (s_env->GetModel()->Info().model == Model::GEMMA2_27B) {
for (auto [response, n] : s_env->BatchQueryModel(inputs)) { for (auto [response, n] : s_env->BatchQueryModel(inputs)) {
replies.push_back(response); replies.push_back(response);
} }
@ -199,10 +199,10 @@ TEST_F(GemmaTest, CrossEntropySmall) {
// 7B v.1 and v.1.1 produce slightly different results. // 7B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 2.8f, 0.2f); EXPECT_NEAR(entropy, 2.8f, 0.2f);
break; break;
case gcpp::Model::GEMMA_9B: case gcpp::Model::GEMMA2_9B:
EXPECT_NEAR(entropy, 1.28f, 0.02f); EXPECT_NEAR(entropy, 1.28f, 0.02f);
break; break;
case gcpp::Model::GEMMA_27B: case gcpp::Model::GEMMA2_27B:
EXPECT_NEAR(entropy, 1.30f, 0.02f); EXPECT_NEAR(entropy, 1.30f, 0.02f);
break; break;
default: default:
@ -224,10 +224,10 @@ TEST_F(GemmaTest, CrossEntropyJingleBells) {
// 7B v.1 and v.1.1 produce slightly different results. // 7B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 1.07f, 0.05f); EXPECT_NEAR(entropy, 1.07f, 0.05f);
break; break;
case gcpp::Model::GEMMA_9B: case gcpp::Model::GEMMA2_9B:
EXPECT_NEAR(entropy, 0.37f, 0.02f); EXPECT_NEAR(entropy, 0.37f, 0.02f);
break; break;
case gcpp::Model::GEMMA_27B: case gcpp::Model::GEMMA2_27B:
EXPECT_NEAR(entropy, 0.33f, 0.02f); EXPECT_NEAR(entropy, 0.33f, 0.02f);
break; break;
default: default:
@ -249,10 +249,10 @@ TEST_F(GemmaTest, CrossEntropyGettysburg) {
// 7B v.1 and v.1.1 produce slightly different results. // 7B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 0.75f, 0.1f); EXPECT_NEAR(entropy, 0.75f, 0.1f);
break; break;
case gcpp::Model::GEMMA_9B: case gcpp::Model::GEMMA2_9B:
EXPECT_NEAR(entropy, 0.15f, 0.02f); EXPECT_NEAR(entropy, 0.15f, 0.02f);
break; break;
case gcpp::Model::GEMMA_27B: case gcpp::Model::GEMMA2_27B:
EXPECT_NEAR(entropy, 0.14f, 0.02f); EXPECT_NEAR(entropy, 0.14f, 0.02f);
break; break;
default: default:

View File

@ -31,29 +31,29 @@ namespace gcpp {
constexpr const char* kModelFlags[] = { constexpr const char* kModelFlags[] = {
"2b-pt", "2b-it", // Gemma 2B "2b-pt", "2b-it", // Gemma 2B
"7b-pt", "7b-it", // Gemma 7B "7b-pt", "7b-it", // Gemma 7B
"9b-pt", "9b-it", // Gemma 9B
"27b-pt", "27b-it", // Gemma 27B
"gr2b-pt", "gr2b-it", // RecurrentGemma "gr2b-pt", "gr2b-it", // RecurrentGemma
"tiny", // Gemma Tiny (mostly for debugging) "tiny", // Gemma Tiny (mostly for debugging)
"gemma2-2b-pt", "gemma2-2b-it", // Gemma2 2B "gemma2-2b-pt", "gemma2-2b-it", // Gemma2 2B
"9b-pt", "9b-it", // Gemma2 9B
"27b-pt", "27b-it", // Gemma2 27B
}; };
constexpr Model kModelTypes[] = { constexpr Model kModelTypes[] = {
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
Model::GEMMA_7B, Model::GEMMA_7B, // Gemma 7B Model::GEMMA_7B, Model::GEMMA_7B, // Gemma 7B
Model::GEMMA_9B, Model::GEMMA_9B, // Gemma 9B
Model::GEMMA_27B, Model::GEMMA_27B, // Gemma 27B
Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma
Model::GEMMA_TINY, // Gemma Tiny Model::GEMMA_TINY, // Gemma Tiny
Model::GEMMA2_2B, Model::GEMMA2_2B, // Gemma2 2B Model::GEMMA2_2B, Model::GEMMA2_2B, // Gemma2 2B
Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B
Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B
}; };
constexpr ModelTraining kModelTraining[] = { constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 7B ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 7B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 9B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 27B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma
ModelTraining::GEMMA_IT, // Gemma Tiny ModelTraining::GEMMA_IT, // Gemma Tiny
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B2 ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 2B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 9B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 27B
}; };
constexpr size_t kNumModelFlags = constexpr size_t kNumModelFlags =

View File

@ -41,8 +41,8 @@ ByteStorageT AllocateSizeof() {
enum class Model { enum class Model {
GEMMA_2B, GEMMA_2B,
GEMMA_7B, GEMMA_7B,
GEMMA_9B, GEMMA2_9B,
GEMMA_27B, GEMMA2_27B,
GRIFFIN_2B, GRIFFIN_2B,
GEMMA_TINY, GEMMA_TINY,
GEMMA2_2B, GEMMA2_2B,
@ -94,10 +94,10 @@ decltype(auto) CallForModel(Model model, TArgs&&... args) {
return FuncT<ConfigGemma2B<TWeight>>()(std::forward<TArgs>(args)...); return FuncT<ConfigGemma2B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA_7B: case Model::GEMMA_7B:
return FuncT<ConfigGemma7B<TWeight>>()(std::forward<TArgs>(args)...); return FuncT<ConfigGemma7B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA_9B: case Model::GEMMA2_9B:
return FuncT<ConfigGemma9B<TWeight>>()(std::forward<TArgs>(args)...); return FuncT<ConfigGemma2_9B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA_27B: case Model::GEMMA2_27B:
return FuncT<ConfigGemma27B<TWeight>>()(std::forward<TArgs>(args)...); return FuncT<ConfigGemma2_27B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GRIFFIN_2B: case Model::GRIFFIN_2B:
return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...); return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA2_2B: case Model::GEMMA2_2B:
@ -143,10 +143,10 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
GEMMA_FOREACH_WEIGHT(X, ConfigGemmaTiny) \ GEMMA_FOREACH_WEIGHT(X, ConfigGemmaTiny) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2B) \ GEMMA_FOREACH_WEIGHT(X, ConfigGemma2B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma7B) \ GEMMA_FOREACH_WEIGHT(X, ConfigGemma7B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma9B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma27B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGriffin2B) \ GEMMA_FOREACH_WEIGHT(X, ConfigGriffin2B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_2B) \ GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_2B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_9B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_27B) \
static_assert(true, "Allow trailing ;") static_assert(true, "Allow trailing ;")
// Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float), // Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float),
@ -168,16 +168,6 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
ARGS; \ ARGS; \
break; \ break; \
} \ } \
case Model::GEMMA_9B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma9B<TWEIGHT>>) \
ARGS; \
break; \
} \
case Model::GEMMA_27B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma27B<TWEIGHT>>) \
ARGS; \
break; \
} \
case Model::GRIFFIN_2B: { \ case Model::GRIFFIN_2B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B<TWEIGHT>>) \ HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B<TWEIGHT>>) \
ARGS; \ ARGS; \
@ -188,6 +178,16 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
ARGS; \ ARGS; \
break; \ break; \
} \ } \
case Model::GEMMA2_9B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2_9B<TWEIGHT>>) \
ARGS; \
break; \
} \
case Model::GEMMA2_27B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2_27B<TWEIGHT>>) \
ARGS; \
break; \
} \
default: \ default: \
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \ HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
} }

View File

@ -167,7 +167,7 @@ struct ConfigBaseGemmaV2 : ConfigNoSSM {
}; };
template <typename TWeight> template <typename TWeight>
struct ConfigGemma27B : public ConfigBaseGemmaV2 { struct ConfigGemma2_27B : public ConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 8192; static constexpr int kSeqLen = 8192;
@ -190,7 +190,7 @@ struct ConfigGemma27B : public ConfigBaseGemmaV2 {
}; };
template <typename TWeight> template <typename TWeight>
struct ConfigGemma9B : public ConfigBaseGemmaV2 { struct ConfigGemma2_9B : public ConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 8192; static constexpr int kSeqLen = 8192;

View File

@ -17,5 +17,5 @@
#define HWY_TARGET_INCLUDE \ #define HWY_TARGET_INCLUDE \
"gemma/instantiations/27b_bf16.cc" "gemma/instantiations/27b_bf16.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma27B<hwy::bfloat16_t> #define GEMMA_CONFIG ConfigGemma2_27B<hwy::bfloat16_t>
#include "gemma/gemma-inl.h" #include "gemma/gemma-inl.h"

View File

@ -17,5 +17,5 @@
#define HWY_TARGET_INCLUDE \ #define HWY_TARGET_INCLUDE \
"gemma/instantiations/27b_f32.cc" "gemma/instantiations/27b_f32.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma27B<float> #define GEMMA_CONFIG ConfigGemma2_27B<float>
#include "gemma/gemma-inl.h" #include "gemma/gemma-inl.h"

View File

@ -17,5 +17,5 @@
#define HWY_TARGET_INCLUDE \ #define HWY_TARGET_INCLUDE \
"gemma/instantiations/27b_sfp.cc" "gemma/instantiations/27b_sfp.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma27B<SfpStream> #define GEMMA_CONFIG ConfigGemma2_27B<SfpStream>
#include "gemma/gemma-inl.h" #include "gemma/gemma-inl.h"

View File

@ -17,5 +17,5 @@
#define HWY_TARGET_INCLUDE \ #define HWY_TARGET_INCLUDE \
"gemma/instantiations/9b_bf16.cc" "gemma/instantiations/9b_bf16.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma9B<hwy::bfloat16_t> #define GEMMA_CONFIG ConfigGemma2_9B<hwy::bfloat16_t>
#include "gemma/gemma-inl.h" #include "gemma/gemma-inl.h"

View File

@ -17,5 +17,5 @@
#define HWY_TARGET_INCLUDE \ #define HWY_TARGET_INCLUDE \
"gemma/instantiations/9b_f32.cc" "gemma/instantiations/9b_f32.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma9B<float> #define GEMMA_CONFIG ConfigGemma2_9B<float>
#include "gemma/gemma-inl.h" #include "gemma/gemma-inl.h"

View File

@ -17,5 +17,5 @@
#define HWY_TARGET_INCLUDE \ #define HWY_TARGET_INCLUDE \
"gemma/instantiations/9b_sfp.cc" "gemma/instantiations/9b_sfp.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma9B<SfpStream> #define GEMMA_CONFIG ConfigGemma2_9B<SfpStream>
#include "gemma/gemma-inl.h" #include "gemma/gemma-inl.h"