mirror of https://github.com/google/gemma.cpp.git
Rename Gemma9B and Gemma27B to Gemma2_9B and Gemma2_27B.
This is to make it clear that these models are part of the Gemma2 family of models. PiperOrigin-RevId: 661181682
This commit is contained in:
parent
2ebbe4076f
commit
fd1b0743a7
|
|
@ -48,7 +48,7 @@ class GemmaTest : public ::testing::Test {
|
|||
// Using the turn structure worsens results sometimes.
|
||||
// However, gemma-2 27B seems to need the turn structure to work.
|
||||
// It would be good to make these tests more consistent.
|
||||
if (s_env->GetModel()->Info().model == Model::GEMMA_27B) {
|
||||
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B) {
|
||||
std::string mutable_prompt = prompt;
|
||||
auto [response, n] = s_env->QueryModel(mutable_prompt); // Uses turns.
|
||||
return response;
|
||||
|
|
@ -68,7 +68,7 @@ class GemmaTest : public ::testing::Test {
|
|||
// Using the turn structure worsens results sometimes.
|
||||
// However, gemma-2 27B seems to need the turn structure to work.
|
||||
// It would be good to make these tests more consistent.
|
||||
if (s_env->GetModel()->Info().model == Model::GEMMA_27B) {
|
||||
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B) {
|
||||
for (auto [response, n] : s_env->BatchQueryModel(inputs)) {
|
||||
replies.push_back(response);
|
||||
}
|
||||
|
|
@ -199,10 +199,10 @@ TEST_F(GemmaTest, CrossEntropySmall) {
|
|||
// 7B v.1 and v.1.1 produce slightly different results.
|
||||
EXPECT_NEAR(entropy, 2.8f, 0.2f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA_9B:
|
||||
case gcpp::Model::GEMMA2_9B:
|
||||
EXPECT_NEAR(entropy, 1.28f, 0.02f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA_27B:
|
||||
case gcpp::Model::GEMMA2_27B:
|
||||
EXPECT_NEAR(entropy, 1.30f, 0.02f);
|
||||
break;
|
||||
default:
|
||||
|
|
@ -224,10 +224,10 @@ TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
|||
// 7B v.1 and v.1.1 produce slightly different results.
|
||||
EXPECT_NEAR(entropy, 1.07f, 0.05f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA_9B:
|
||||
case gcpp::Model::GEMMA2_9B:
|
||||
EXPECT_NEAR(entropy, 0.37f, 0.02f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA_27B:
|
||||
case gcpp::Model::GEMMA2_27B:
|
||||
EXPECT_NEAR(entropy, 0.33f, 0.02f);
|
||||
break;
|
||||
default:
|
||||
|
|
@ -249,10 +249,10 @@ TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
|||
// 7B v.1 and v.1.1 produce slightly different results.
|
||||
EXPECT_NEAR(entropy, 0.75f, 0.1f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA_9B:
|
||||
case gcpp::Model::GEMMA2_9B:
|
||||
EXPECT_NEAR(entropy, 0.15f, 0.02f);
|
||||
break;
|
||||
case gcpp::Model::GEMMA_27B:
|
||||
case gcpp::Model::GEMMA2_27B:
|
||||
EXPECT_NEAR(entropy, 0.14f, 0.02f);
|
||||
break;
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -31,29 +31,29 @@ namespace gcpp {
|
|||
constexpr const char* kModelFlags[] = {
|
||||
"2b-pt", "2b-it", // Gemma 2B
|
||||
"7b-pt", "7b-it", // Gemma 7B
|
||||
"9b-pt", "9b-it", // Gemma 9B
|
||||
"27b-pt", "27b-it", // Gemma 27B
|
||||
"gr2b-pt", "gr2b-it", // RecurrentGemma
|
||||
"tiny", // Gemma Tiny (mostly for debugging)
|
||||
"gemma2-2b-pt", "gemma2-2b-it", // Gemma2 2B
|
||||
"9b-pt", "9b-it", // Gemma2 9B
|
||||
"27b-pt", "27b-it", // Gemma2 27B
|
||||
};
|
||||
constexpr Model kModelTypes[] = {
|
||||
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
|
||||
Model::GEMMA_7B, Model::GEMMA_7B, // Gemma 7B
|
||||
Model::GEMMA_9B, Model::GEMMA_9B, // Gemma 9B
|
||||
Model::GEMMA_27B, Model::GEMMA_27B, // Gemma 27B
|
||||
Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma
|
||||
Model::GEMMA_TINY, // Gemma Tiny
|
||||
Model::GEMMA2_2B, Model::GEMMA2_2B, // Gemma2 2B
|
||||
Model::GEMMA2_9B, Model::GEMMA2_9B, // Gemma2 9B
|
||||
Model::GEMMA2_27B, Model::GEMMA2_27B, // Gemma2 27B
|
||||
};
|
||||
constexpr ModelTraining kModelTraining[] = {
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 7B
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 9B
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 27B
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma
|
||||
ModelTraining::GEMMA_IT, // Gemma Tiny
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B2
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 2B
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 9B
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 27B
|
||||
};
|
||||
|
||||
constexpr size_t kNumModelFlags =
|
||||
|
|
|
|||
|
|
@ -41,8 +41,8 @@ ByteStorageT AllocateSizeof() {
|
|||
enum class Model {
|
||||
GEMMA_2B,
|
||||
GEMMA_7B,
|
||||
GEMMA_9B,
|
||||
GEMMA_27B,
|
||||
GEMMA2_9B,
|
||||
GEMMA2_27B,
|
||||
GRIFFIN_2B,
|
||||
GEMMA_TINY,
|
||||
GEMMA2_2B,
|
||||
|
|
@ -94,10 +94,10 @@ decltype(auto) CallForModel(Model model, TArgs&&... args) {
|
|||
return FuncT<ConfigGemma2B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA_7B:
|
||||
return FuncT<ConfigGemma7B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA_9B:
|
||||
return FuncT<ConfigGemma9B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA_27B:
|
||||
return FuncT<ConfigGemma27B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA2_9B:
|
||||
return FuncT<ConfigGemma2_9B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA2_27B:
|
||||
return FuncT<ConfigGemma2_27B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GRIFFIN_2B:
|
||||
return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA2_2B:
|
||||
|
|
@ -143,10 +143,10 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
|
|||
GEMMA_FOREACH_WEIGHT(X, ConfigGemmaTiny) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma7B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma9B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma27B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGriffin2B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_2B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_9B) \
|
||||
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_27B) \
|
||||
static_assert(true, "Allow trailing ;")
|
||||
|
||||
// Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float),
|
||||
|
|
@ -168,16 +168,6 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
|
|||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GEMMA_9B: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma9B<TWEIGHT>>) \
|
||||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GEMMA_27B: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma27B<TWEIGHT>>) \
|
||||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GRIFFIN_2B: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B<TWEIGHT>>) \
|
||||
ARGS; \
|
||||
|
|
@ -188,6 +178,16 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
|
|||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GEMMA2_9B: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2_9B<TWEIGHT>>) \
|
||||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GEMMA2_27B: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2_27B<TWEIGHT>>) \
|
||||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
|
||||
}
|
||||
|
|
|
|||
|
|
@ -167,7 +167,7 @@ struct ConfigBaseGemmaV2 : ConfigNoSSM {
|
|||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma27B : public ConfigBaseGemmaV2 {
|
||||
struct ConfigGemma2_27B : public ConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 8192;
|
||||
|
|
@ -190,7 +190,7 @@ struct ConfigGemma27B : public ConfigBaseGemmaV2 {
|
|||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma9B : public ConfigBaseGemmaV2 {
|
||||
struct ConfigGemma2_9B : public ConfigBaseGemmaV2 {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = 8192;
|
||||
|
|
|
|||
|
|
@ -17,5 +17,5 @@
|
|||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/27b_bf16.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma27B<hwy::bfloat16_t>
|
||||
#define GEMMA_CONFIG ConfigGemma2_27B<hwy::bfloat16_t>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
|
|||
|
|
@ -17,5 +17,5 @@
|
|||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/27b_f32.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma27B<float>
|
||||
#define GEMMA_CONFIG ConfigGemma2_27B<float>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
|
|||
|
|
@ -17,5 +17,5 @@
|
|||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/27b_sfp.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma27B<SfpStream>
|
||||
#define GEMMA_CONFIG ConfigGemma2_27B<SfpStream>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
|
|||
|
|
@ -17,5 +17,5 @@
|
|||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/9b_bf16.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma9B<hwy::bfloat16_t>
|
||||
#define GEMMA_CONFIG ConfigGemma2_9B<hwy::bfloat16_t>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
|
|||
|
|
@ -17,5 +17,5 @@
|
|||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/9b_f32.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma9B<float>
|
||||
#define GEMMA_CONFIG ConfigGemma2_9B<float>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
|
|||
|
|
@ -17,5 +17,5 @@
|
|||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/instantiations/9b_sfp.cc"
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#define GEMMA_CONFIG ConfigGemma9B<SfpStream>
|
||||
#define GEMMA_CONFIG ConfigGemma2_9B<SfpStream>
|
||||
#include "gemma/gemma-inl.h"
|
||||
|
|
|
|||
Loading…
Reference in New Issue