Introduce new Gemma 9B and 27B configs

PiperOrigin-RevId: 647299080
This commit is contained in:
Paul Chang 2024-06-27 06:44:38 -07:00 committed by Copybara-Service
parent 78e96fdc70
commit 8ac5d66575
3 changed files with 62 additions and 0 deletions

View File

@ -31,18 +31,24 @@ namespace gcpp {
constexpr const char* kModelFlags[] = {
"2b-pt", "2b-it", // Gemma 2B
"7b-pt", "7b-it", // Gemma 7B
"9b-pt", "9b-it", // Gemma 9B
"27b-pt", "27b-it", // Gemma 27B
"gr2b-pt", "gr2b-it", // RecurrentGemma
"tiny", // Gemma Tiny (mostly for debugging)
};
constexpr Model kModelTypes[] = {
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
Model::GEMMA_7B, Model::GEMMA_7B, // Gemma 7B
Model::GEMMA_9B, Model::GEMMA_9B, // Gemma 9B
Model::GEMMA_27B, Model::GEMMA_27B, // Gemma 27B
Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma
Model::GEMMA_TINY, // Gemma Tiny
};
constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 7B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 9B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 27B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma
ModelTraining::GEMMA_IT, // Gemma Tiny
};

View File

@ -39,6 +39,8 @@ ByteStorageT AllocateSizeof() {
enum class Model {
GEMMA_2B,
GEMMA_7B,
GEMMA_9B,
GEMMA_27B,
GRIFFIN_2B,
GEMMA_TINY,
};
@ -69,6 +71,10 @@ decltype(auto) CallForModel(Model model, TArgs&&... args) {
return FuncT<ConfigGemma2B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA_7B:
return FuncT<ConfigGemma7B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA_9B:
return FuncT<ConfigGemma9B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA_27B:
return FuncT<ConfigGemma27B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GRIFFIN_2B:
return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
default:
@ -121,6 +127,16 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
ARGS; \
break; \
} \
case Model::GEMMA_9B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma9B<TWEIGHT>>) \
ARGS; \
break; \
} \
case Model::GEMMA_27B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma27B<TWEIGHT>>) \
ARGS; \
break; \
} \
case Model::GRIFFIN_2B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B<TWEIGHT>>) \
ARGS; \

View File

@ -99,6 +99,46 @@ struct ConfigNoSSM {
static constexpr int kNumTensorScales = 0;
};
template <typename TWeight>
struct ConfigGemma27B : public ConfigNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
FixedLayerConfig<46>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 4608;
static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864
static constexpr int kHeads = 32;
static constexpr int kKVHeads = 16;
static constexpr int kQKVDim = 128; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = true;
};
template <typename TWeight>
struct ConfigGemma9B : public ConfigNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
FixedLayerConfig<42>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 3584;
static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 8;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr bool kPostNormScale = true;
};
template <typename TWeight>
struct ConfigGemma7B : public ConfigNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig