From 8ac5d66575429c4fca19fb394c8926074352c766 Mon Sep 17 00:00:00 2001 From: Paul Chang Date: Thu, 27 Jun 2024 06:44:38 -0700 Subject: [PATCH] Introduce new Gemma 9B and 27B configs PiperOrigin-RevId: 647299080 --- gemma/common.cc | 6 ++++++ gemma/common.h | 16 ++++++++++++++++ gemma/configs.h | 40 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/gemma/common.cc b/gemma/common.cc index d5fa0da..fd10809 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -31,18 +31,24 @@ namespace gcpp { constexpr const char* kModelFlags[] = { "2b-pt", "2b-it", // Gemma 2B "7b-pt", "7b-it", // Gemma 7B + "9b-pt", "9b-it", // Gemma 9B + "27b-pt", "27b-it", // Gemma 27B "gr2b-pt", "gr2b-it", // RecurrentGemma "tiny", // Gemma Tiny (mostly for debugging) }; constexpr Model kModelTypes[] = { Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B Model::GEMMA_7B, Model::GEMMA_7B, // Gemma 7B + Model::GEMMA_9B, Model::GEMMA_9B, // Gemma 9B + Model::GEMMA_27B, Model::GEMMA_27B, // Gemma 27B Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma Model::GEMMA_TINY, // Gemma Tiny }; constexpr ModelTraining kModelTraining[] = { ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 7B + ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 9B + ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 27B ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma ModelTraining::GEMMA_IT, // Gemma Tiny }; diff --git a/gemma/common.h b/gemma/common.h index ff30c68..35f6e78 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -39,6 +39,8 @@ ByteStorageT AllocateSizeof() { enum class Model { GEMMA_2B, GEMMA_7B, + GEMMA_9B, + GEMMA_27B, GRIFFIN_2B, GEMMA_TINY, }; @@ -69,6 +71,10 @@ decltype(auto) CallForModel(Model model, TArgs&&... args) { return FuncT>()(std::forward(args)...); case Model::GEMMA_7B: return FuncT>()(std::forward(args)...); + case Model::GEMMA_9B: + return FuncT>()(std::forward(args)...); + case Model::GEMMA_27B: + return FuncT>()(std::forward(args)...); case Model::GRIFFIN_2B: return FuncT>()(std::forward(args)...); default: @@ -121,6 +127,16 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight, ARGS; \ break; \ } \ + case Model::GEMMA_9B: { \ + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC>) \ + ARGS; \ + break; \ + } \ + case Model::GEMMA_27B: { \ + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC>) \ + ARGS; \ + break; \ + } \ case Model::GRIFFIN_2B: { \ HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC>) \ ARGS; \ diff --git a/gemma/configs.h b/gemma/configs.h index c1fc176..32d021d 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -99,6 +99,46 @@ struct ConfigNoSSM { static constexpr int kNumTensorScales = 0; }; +template +struct ConfigGemma27B : public ConfigNoSSM { + using Weight = TWeight; // make accessible where we only have a TConfig + + static constexpr int kSeqLen = gcpp::kSeqLen; + static constexpr int kVocabSize = 256000; + static constexpr std::array kLayerConfig = + FixedLayerConfig<46>(LayerAttentionType::kGemma); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = kLayers; + static constexpr int kModelDim = 4608; + static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864 + static constexpr int kHeads = 32; + static constexpr int kKVHeads = 16; + static constexpr int kQKVDim = 128; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; + static constexpr bool kPostNormScale = true; +}; + +template +struct ConfigGemma9B : public ConfigNoSSM { + using Weight = TWeight; // make accessible where we only have a TConfig + + static constexpr int kSeqLen = gcpp::kSeqLen; + static constexpr int kVocabSize = 256000; + static constexpr std::array kLayerConfig = + FixedLayerConfig<42>(LayerAttentionType::kGemma); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = kLayers; + static constexpr int kModelDim = 3584; + static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336 + static constexpr int kHeads = 16; + static constexpr int kKVHeads = 8; + static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; + static constexpr bool kPostNormScale = true; +}; + template struct ConfigGemma7B : public ConfigNoSSM { using Weight = TWeight; // make accessible where we only have a TConfig