diff --git a/gemma/configs.h b/gemma/configs.h index 30e9a6a..f7c6ac2 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -192,266 +192,6 @@ ModelConfig ConfigFromModel(Model model); // Returns the sub-config for the ViT model of the PaliGemma model. ModelConfig VitConfig(const ModelConfig& config); -template -struct CacheLayerSize { - constexpr size_t operator()() const { - return TConfig::kKVHeads * TConfig::kQKVDim * 2; - } -}; - -template -struct CachePosSize { - constexpr size_t operator()() const { - return TConfig::kGemmaLayers * CacheLayerSize()(); - } -}; - -struct ConfigNoSSM { - static constexpr int kGriffinLayers = 0; - - static constexpr int kConv1dWidth = 0; - static constexpr bool kFFBiases = false; - static constexpr bool kSoftmaxAttnOutputBiases = false; - static constexpr bool kUseHalfRope = false; - static constexpr bool kUseLocalAttention = false; - static constexpr bool kInterleaveQKV = true; - static constexpr int kNumTensorScales = 0; - - static constexpr PostQKType kPostQK = PostQKType::Rope; - static constexpr ActivationType kActivation = ActivationType::Gelu; - static constexpr ResidualType kResidual = ResidualType::Add; - - // Self-extend parameters with defaul values - static constexpr bool kSelfExtend = false; - static constexpr size_t kSelfExtendNgbSize = 0; - static constexpr size_t kSelfExtendGrpSize = 1; -}; - -struct ConfigBaseGemmaV1 : ConfigNoSSM { - static constexpr float kAttCap = 0.0f; - static constexpr float kFinalCap = 0.0f; - static constexpr PostNormType kPostNorm = PostNormType::None; - static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; -}; - -struct ConfigBaseGemmaV2 : ConfigNoSSM { - static constexpr float kAttCap = 50.0f; - static constexpr float kFinalCap = 30.0f; - static constexpr PostNormType kPostNorm = PostNormType::Scale; -}; - -template -struct ConfigGemma27B : public ConfigBaseGemmaV2 { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = 8192; - static constexpr int kVocabSize = 256000; - static constexpr std::array kLayerConfig = - FixedLayerConfig<46>(LayerAttentionType::kGemma); - static constexpr std::array kAttentionWindowSizes = - RepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen}); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 4608; - static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864 - static constexpr int kHeads = 32; - static constexpr int kKVHeads = 16; - static constexpr int kQKVDim = 128; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; - static constexpr QueryScaleType kQueryScale = - QueryScaleType::SqrtModelDimDivNumHeads; -}; - -template -struct ConfigGemma9B : public ConfigBaseGemmaV2 { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = 8192; - static constexpr int kVocabSize = 256000; - static constexpr std::array kLayerConfig = - FixedLayerConfig<42>(LayerAttentionType::kGemma); - static constexpr std::array kAttentionWindowSizes = - RepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen}); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 3584; - static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336 - static constexpr int kHeads = 16; - static constexpr int kKVHeads = 8; - static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; - static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; -}; - -template -struct ConfigGemma7B : public ConfigBaseGemmaV1 { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = gcpp::kSeqLen; - static constexpr int kVocabSize = 256000; - static constexpr std::array kLayerConfig = - FixedLayerConfig<28>(LayerAttentionType::kGemma); - static constexpr std::array kAttentionWindowSizes = - FixedAttentionWindowSizes<28>(kSeqLen); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 3072; - static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 - static constexpr int kHeads = 16; - static constexpr int kKVHeads = 16; // standard MHA - static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; -}; - -template -struct ConfigGemma2B : public ConfigBaseGemmaV1 { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = gcpp::kSeqLen; - static constexpr int kVocabSize = 256000; - static constexpr std::array kLayerConfig = - FixedLayerConfig<18>(LayerAttentionType::kGemma); - static constexpr std::array kAttentionWindowSizes = - FixedAttentionWindowSizes<18>(kSeqLen); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 2048; - static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 - static constexpr int kHeads = 8; - static constexpr int kKVHeads = 1; - static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; -}; - -template -struct ConfigGemma2_2B : public ConfigBaseGemmaV2 { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = 8192; - static constexpr int kVocabSize = 256000; - static constexpr std::array kLayerConfig = - FixedLayerConfig<26>(LayerAttentionType::kGemma); - static constexpr std::array kAttentionWindowSizes = - RepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen}); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 2304; - static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216 - static constexpr int kHeads = 8; - static constexpr int kKVHeads = 4; - static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; - static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; -}; - -template -struct ConfigGemmaTiny : public ConfigNoSSM { - using Weight = TWeight; // make accessible where we only have a TConfig - - static constexpr int kSeqLen = 32; - static constexpr int kVocabSize = 64; - static constexpr std::array kLayerConfig = - FixedLayerConfig<3>(LayerAttentionType::kGemma); - static constexpr std::array kAttentionWindowSizes = - FixedAttentionWindowSizes<3>(kSeqLen); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = kLayers; - static constexpr int kModelDim = 128; - static constexpr int kFFHiddenDim = 256; - static constexpr int kHeads = 4; - static constexpr int kKVHeads = 1; - static constexpr int kQKVDim = 16; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; - static constexpr PostNormType kPostNorm = PostNormType::None; - static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; - - static constexpr float kAttCap = 0.0f; - // This is required for optimize_test to pass. - static constexpr float kFinalCap = 30.0f; -}; - -template -struct ConfigGriffin2B { - using Weight = TWeight; // make accessible where we only have a TConfig - - // Griffin uses local attention, so kSeqLen is actually the local attention - // window. - static constexpr int kSeqLen = 2048; - static constexpr int kVocabSize = 256000; - static constexpr std::array kLayerConfig = { - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGemma, - LayerAttentionType::kGriffinRecurrentBlock, - LayerAttentionType::kGriffinRecurrentBlock, - }; - static constexpr std::array kAttentionWindowSizes = - FixedAttentionWindowSizes<26>(kSeqLen); - static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = - NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers); - static constexpr int kGriffinLayers = - NumLayersOfTypeBefore(kLayerConfig, - LayerAttentionType::kGriffinRecurrentBlock, - kLayers); - static constexpr int kModelDim = 2560; - static constexpr int kFFHiddenDim = 7680; - static constexpr int kHeads = 10; - static constexpr int kKVHeads = 1; - static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = gcpp::kTopK; - static constexpr bool kAbsolutePE = false; - static constexpr PostNormType kPostNorm = PostNormType::None; - - // No SoftCap. - static constexpr float kAttCap = 0.0f; - static constexpr float kFinalCap = 0.0f; - - // SSM config. - static constexpr int kConv1dWidth = 4; - static constexpr bool kFFBiases = true; - static constexpr bool kSoftmaxAttnOutputBiases = true; - static constexpr bool kUseHalfRope = true; - static constexpr bool kUseLocalAttention = true; - static constexpr bool kInterleaveQKV = false; - static constexpr int kNumTensorScales = 140; - static constexpr PostQKType kPostQK = PostQKType::Rope; - static constexpr ActivationType kActivation = ActivationType::Gelu; - static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; - static constexpr ResidualType kResidual = ResidualType::Add; - - // Self-extend parameters with defaul values - static constexpr bool kSelfExtend = false; - static constexpr size_t kSelfExtendNgbSize = 0; - static constexpr size_t kSelfExtendGrpSize = 1; -}; } // namespace gcpp