mirror of https://github.com/google/gemma.cpp.git
Introduce new Gemma 9B and 27B configs
PiperOrigin-RevId: 647299080
This commit is contained in:
parent
78e96fdc70
commit
8ac5d66575
|
|
@ -31,18 +31,24 @@ namespace gcpp {
|
|||
constexpr const char* kModelFlags[] = {
|
||||
"2b-pt", "2b-it", // Gemma 2B
|
||||
"7b-pt", "7b-it", // Gemma 7B
|
||||
"9b-pt", "9b-it", // Gemma 9B
|
||||
"27b-pt", "27b-it", // Gemma 27B
|
||||
"gr2b-pt", "gr2b-it", // RecurrentGemma
|
||||
"tiny", // Gemma Tiny (mostly for debugging)
|
||||
};
|
||||
constexpr Model kModelTypes[] = {
|
||||
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
|
||||
Model::GEMMA_7B, Model::GEMMA_7B, // Gemma 7B
|
||||
Model::GEMMA_9B, Model::GEMMA_9B, // Gemma 9B
|
||||
Model::GEMMA_27B, Model::GEMMA_27B, // Gemma 27B
|
||||
Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma
|
||||
Model::GEMMA_TINY, // Gemma Tiny
|
||||
};
|
||||
constexpr ModelTraining kModelTraining[] = {
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 7B
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 9B
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 27B
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma
|
||||
ModelTraining::GEMMA_IT, // Gemma Tiny
|
||||
};
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ ByteStorageT AllocateSizeof() {
|
|||
enum class Model {
|
||||
GEMMA_2B,
|
||||
GEMMA_7B,
|
||||
GEMMA_9B,
|
||||
GEMMA_27B,
|
||||
GRIFFIN_2B,
|
||||
GEMMA_TINY,
|
||||
};
|
||||
|
|
@ -69,6 +71,10 @@ decltype(auto) CallForModel(Model model, TArgs&&... args) {
|
|||
return FuncT<ConfigGemma2B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA_7B:
|
||||
return FuncT<ConfigGemma7B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA_9B:
|
||||
return FuncT<ConfigGemma9B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GEMMA_27B:
|
||||
return FuncT<ConfigGemma27B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
case Model::GRIFFIN_2B:
|
||||
return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
|
||||
default:
|
||||
|
|
@ -121,6 +127,16 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
|
|||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GEMMA_9B: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma9B<TWEIGHT>>) \
|
||||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GEMMA_27B: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma27B<TWEIGHT>>) \
|
||||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GRIFFIN_2B: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B<TWEIGHT>>) \
|
||||
ARGS; \
|
||||
|
|
|
|||
|
|
@ -99,6 +99,46 @@ struct ConfigNoSSM {
|
|||
static constexpr int kNumTensorScales = 0;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma27B : public ConfigNoSSM {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
static constexpr int kVocabSize = 256000;
|
||||
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
|
||||
FixedLayerConfig<46>(LayerAttentionType::kGemma);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 4608;
|
||||
static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864
|
||||
static constexpr int kHeads = 32;
|
||||
static constexpr int kKVHeads = 16;
|
||||
static constexpr int kQKVDim = 128; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = true;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma9B : public ConfigNoSSM {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
static constexpr int kVocabSize = 256000;
|
||||
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
|
||||
FixedLayerConfig<42>(LayerAttentionType::kGemma);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 3584;
|
||||
static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336
|
||||
static constexpr int kHeads = 16;
|
||||
static constexpr int kKVHeads = 8;
|
||||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr bool kPostNormScale = true;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
struct ConfigGemma7B : public ConfigNoSSM {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
|
|
|||
Loading…
Reference in New Issue