Update gemma-27b to the correct query scaling.

PiperOrigin-RevId: 653201646
This commit is contained in:
Daniel Keysers 2024-07-17 05:42:39 -07:00 committed by Copybara-Service
parent 992a2cbbc0
commit 5a751a9a44
2 changed files with 13 additions and 6 deletions

View File

@ -246,9 +246,11 @@ static HWY_INLINE GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling(
template <class TConfig> template <class TConfig>
GEMMA_CONSTEXPR_SQRT float ChooseQueryScale() { GEMMA_CONSTEXPR_SQRT float ChooseQueryScale() {
constexpr size_t kQKVDim = TConfig::kQKVDim; if (TConfig::kQueryScale == QueryScaleType::SqrtModelDimDivNumHeads)
// QueryScaleType::Sqrt return 1.0f /
return 1.0f / Sqrt(static_cast<float>(kQKVDim)); Sqrt(static_cast<float>(TConfig::kModelDim / TConfig::kHeads));
// QueryScaleType::SqrtKeySize
return 1.0f / Sqrt(static_cast<float>(TConfig::kQKVDim));
} }
} // namespace gcpp } // namespace gcpp

View File

@ -70,7 +70,8 @@ enum class ActivationType {
// Attention query scale. // Attention query scale.
enum class QueryScaleType { enum class QueryScaleType {
Sqrt, SqrtKeySize,
SqrtModelDimDivNumHeads,
}; };
// Residual connection type. // Residual connection type.
@ -149,7 +150,6 @@ struct ConfigNoSSM {
static constexpr PostQKType kPostQK = PostQKType::Rope; static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr ActivationType kActivation = ActivationType::Gelu; static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt;
static constexpr ResidualType kResidual = ResidualType::Add; static constexpr ResidualType kResidual = ResidualType::Add;
}; };
@ -157,6 +157,7 @@ struct ConfigBaseGemmaV1 : ConfigNoSSM {
static constexpr float kAttCap = 0.0f; static constexpr float kAttCap = 0.0f;
static constexpr float kFinalCap = 0.0f; static constexpr float kFinalCap = 0.0f;
static constexpr PostNormType kPostNorm = PostNormType::None; static constexpr PostNormType kPostNorm = PostNormType::None;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
}; };
struct ConfigBaseGemmaV2 : ConfigNoSSM { struct ConfigBaseGemmaV2 : ConfigNoSSM {
@ -184,6 +185,8 @@ struct ConfigGemma27B : public ConfigBaseGemmaV2 {
static constexpr int kQKVDim = 128; // query size == key size == value size static constexpr int kQKVDim = 128; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false; static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale =
QueryScaleType::SqrtModelDimDivNumHeads;
}; };
template <typename TWeight> template <typename TWeight>
@ -205,6 +208,7 @@ struct ConfigGemma9B : public ConfigBaseGemmaV2 {
static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false; static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
}; };
template <typename TWeight> template <typename TWeight>
@ -269,6 +273,7 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
static constexpr int kTopK = gcpp::kTopK; static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false; static constexpr bool kAbsolutePE = false;
static constexpr PostNormType kPostNorm = PostNormType::None; static constexpr PostNormType kPostNorm = PostNormType::None;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
static constexpr float kAttCap = 0.0f; static constexpr float kAttCap = 0.0f;
// This is required for optimize_test to pass. // This is required for optimize_test to pass.
@ -343,7 +348,7 @@ struct ConfigGriffin2B {
static constexpr int kNumTensorScales = 140; static constexpr int kNumTensorScales = 140;
static constexpr PostQKType kPostQK = PostQKType::Rope; static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr ActivationType kActivation = ActivationType::Gelu; static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt; static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
static constexpr ResidualType kResidual = ResidualType::Add; static constexpr ResidualType kResidual = ResidualType::Add;
}; };