diff --git a/gemma/common.h b/gemma/common.h index e5a581e..f234786 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -36,7 +36,12 @@ ByteStorageT AllocateSizeof() { } // Model variants: see configs.h for details. -enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY }; +enum class Model { + GEMMA_2B, + GEMMA_7B, + GRIFFIN_2B, + GEMMA_TINY, +}; // Instruction-tuned models require extra 'turn structure' tokens in prompts. enum class ModelTraining { GEMMA_IT, GEMMA_PT }; diff --git a/gemma/configs.h b/gemma/configs.h index fb4231a..b59b450 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -73,8 +73,18 @@ constexpr size_t NumLayersOfTypeBefore( return count; } +struct ConfigNoSSM { + static constexpr int kConv1dWidth = 0; + static constexpr bool kFFBiases = false; + static constexpr bool kSoftmaxAttnOutputBiases = false; + static constexpr bool kUseHalfRope = false; + static constexpr bool kUseLocalAttention = false; + static constexpr bool kInterleaveQKV = true; + static constexpr int kNumTensorScales = 0; +}; + template -struct ConfigGemma7B { +struct ConfigGemma7B : public ConfigNoSSM { using Weight = TWeight; // make accessible where we only have a TConfig static constexpr int kSeqLen = gcpp::kSeqLen; @@ -96,19 +106,10 @@ struct ConfigGemma7B { static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; static constexpr bool kPostNormScale = false; - - // SSM config. - static constexpr int kConv1dWidth = 0; - static constexpr bool kFFBiases = false; - static constexpr bool kSoftmaxAttnOutputBiases = false; - static constexpr bool kUseHalfRope = false; - static constexpr bool kUseLocalAttention = false; - static constexpr bool kInterleaveQKV = true; - static constexpr int kNumTensorScales = 0; }; template -struct ConfigGemma2B { +struct ConfigGemma2B : public ConfigNoSSM { using Weight = TWeight; // make accessible where we only have a TConfig static constexpr int kSeqLen = gcpp::kSeqLen; @@ -118,10 +119,8 @@ struct ConfigGemma2B { static constexpr int kLayers = kLayerConfig.size(); static constexpr int kGemmaLayers = NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers); - static constexpr int kGriffinLayers = - NumLayersOfTypeBefore(kLayerConfig, - LayerAttentionType::kGriffinRecurrentBlock, - kLayers); + static constexpr int kGriffinLayers = NumLayersOfTypeBefore( + kLayerConfig, LayerAttentionType::kGriffinRecurrentBlock, kLayers); static constexpr int kModelDim = 2048; static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 static constexpr int kHeads = 8; @@ -130,19 +129,10 @@ struct ConfigGemma2B { static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; static constexpr bool kPostNormScale = false; - - // SSM config. - static constexpr int kConv1dWidth = 0; - static constexpr bool kFFBiases = false; - static constexpr bool kSoftmaxAttnOutputBiases = false; - static constexpr bool kUseHalfRope = false; - static constexpr bool kUseLocalAttention = false; - static constexpr bool kInterleaveQKV = true; - static constexpr int kNumTensorScales = 0; }; template -struct ConfigGemmaTiny { +struct ConfigGemmaTiny : public ConfigNoSSM { using Weight = TWeight; // make accessible where we only have a TConfig static constexpr int kSeqLen = 32; @@ -164,15 +154,6 @@ struct ConfigGemmaTiny { static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; static constexpr bool kPostNormScale = false; - - // SSM config. - static constexpr int kConv1dWidth = 0; - static constexpr bool kFFBiases = false; - static constexpr bool kSoftmaxAttnOutputBiases = false; - static constexpr bool kUseHalfRope = false; - static constexpr bool kUseLocalAttention = false; - static constexpr bool kInterleaveQKV = true; - static constexpr int kNumTensorScales = 0; }; template