diff --git a/gemma/configs.h b/gemma/configs.h index 47efed6..0ca14fd 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -74,6 +74,8 @@ constexpr size_t NumLayersOfTypeBefore( } struct ConfigNoSSM { + static constexpr int kGriffinLayers = 0; + static constexpr int kConv1dWidth = 0; static constexpr bool kFFBiases = false; static constexpr bool kSoftmaxAttnOutputBiases = false; @@ -92,12 +94,7 @@ struct ConfigGemma7B : public ConfigNoSSM { static constexpr std::array kLayerConfig = FixedLayerConfig<28>(LayerAttentionType::kGemma); static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = - NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers); - static constexpr int kGriffinLayers = - NumLayersOfTypeBefore(kLayerConfig, - LayerAttentionType::kGriffinRecurrentBlock, - kLayers); + static constexpr int kGemmaLayers = kLayers; static constexpr int kModelDim = 3072; static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 static constexpr int kHeads = 16; @@ -117,10 +114,7 @@ struct ConfigGemma2B : public ConfigNoSSM { static constexpr std::array kLayerConfig = FixedLayerConfig<18>(LayerAttentionType::kGemma); static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = - NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers); - static constexpr int kGriffinLayers = NumLayersOfTypeBefore( - kLayerConfig, LayerAttentionType::kGriffinRecurrentBlock, kLayers); + static constexpr int kGemmaLayers = kLayers; static constexpr int kModelDim = 2048; static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 static constexpr int kHeads = 8; @@ -140,12 +134,7 @@ struct ConfigGemmaTiny : public ConfigNoSSM { static constexpr std::array kLayerConfig = FixedLayerConfig<3>(LayerAttentionType::kGemma); static constexpr int kLayers = kLayerConfig.size(); - static constexpr int kGemmaLayers = - NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers); - static constexpr int kGriffinLayers = - NumLayersOfTypeBefore(kLayerConfig, - LayerAttentionType::kGriffinRecurrentBlock, - kLayers); + static constexpr int kGemmaLayers = kLayers; static constexpr int kModelDim = 128; static constexpr int kFFHiddenDim = 256; static constexpr int kHeads = 4;