From 5a751a9a4493b618f5fbcccfa0ba8766661329e6 Mon Sep 17 00:00:00 2001 From: Daniel Keysers Date: Wed, 17 Jul 2024 05:42:39 -0700 Subject: [PATCH] Update gemma-27b to the correct query scaling. PiperOrigin-RevId: 653201646 --- gemma/common.h | 8 +++++--- gemma/configs.h | 11 ++++++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/gemma/common.h b/gemma/common.h index 663b2ca..f8da552 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -246,9 +246,11 @@ static HWY_INLINE GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling( template GEMMA_CONSTEXPR_SQRT float ChooseQueryScale() { - constexpr size_t kQKVDim = TConfig::kQKVDim; - // QueryScaleType::Sqrt - return 1.0f / Sqrt(static_cast(kQKVDim)); + if (TConfig::kQueryScale == QueryScaleType::SqrtModelDimDivNumHeads) + return 1.0f / + Sqrt(static_cast(TConfig::kModelDim / TConfig::kHeads)); + // QueryScaleType::SqrtKeySize + return 1.0f / Sqrt(static_cast(TConfig::kQKVDim)); } } // namespace gcpp diff --git a/gemma/configs.h b/gemma/configs.h index d637372..efe5476 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -70,7 +70,8 @@ enum class ActivationType { // Attention query scale. enum class QueryScaleType { - Sqrt, + SqrtKeySize, + SqrtModelDimDivNumHeads, }; // Residual connection type. @@ -149,7 +150,6 @@ struct ConfigNoSSM { static constexpr PostQKType kPostQK = PostQKType::Rope; static constexpr ActivationType kActivation = ActivationType::Gelu; - static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt; static constexpr ResidualType kResidual = ResidualType::Add; }; @@ -157,6 +157,7 @@ struct ConfigBaseGemmaV1 : ConfigNoSSM { static constexpr float kAttCap = 0.0f; static constexpr float kFinalCap = 0.0f; static constexpr PostNormType kPostNorm = PostNormType::None; + static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; }; struct ConfigBaseGemmaV2 : ConfigNoSSM { @@ -184,6 +185,8 @@ struct ConfigGemma27B : public ConfigBaseGemmaV2 { static constexpr int kQKVDim = 128; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; + static constexpr QueryScaleType kQueryScale = + QueryScaleType::SqrtModelDimDivNumHeads; }; template @@ -205,6 +208,7 @@ struct ConfigGemma9B : public ConfigBaseGemmaV2 { static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; + static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; }; template @@ -269,6 +273,7 @@ struct ConfigGemmaTiny : public ConfigNoSSM { static constexpr int kTopK = gcpp::kTopK; static constexpr bool kAbsolutePE = false; static constexpr PostNormType kPostNorm = PostNormType::None; + static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; static constexpr float kAttCap = 0.0f; // This is required for optimize_test to pass. @@ -343,7 +348,7 @@ struct ConfigGriffin2B { static constexpr int kNumTensorScales = 140; static constexpr PostQKType kPostQK = PostQKType::Rope; static constexpr ActivationType kActivation = ActivationType::Gelu; - static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt; + static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; static constexpr ResidualType kResidual = ResidualType::Add; };