Update gemma-27b to the correct query scaling.

PiperOrigin-RevId: 653201646
This commit is contained in:
Daniel Keysers 2024-07-17 05:42:39 -07:00 committed by Copybara-Service
parent 992a2cbbc0
commit 5a751a9a44
2 changed files with 13 additions and 6 deletions

View File

@ -246,9 +246,11 @@ static HWY_INLINE GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling(
template <class TConfig>
GEMMA_CONSTEXPR_SQRT float ChooseQueryScale() {
constexpr size_t kQKVDim = TConfig::kQKVDim;
// QueryScaleType::Sqrt
return 1.0f / Sqrt(static_cast<float>(kQKVDim));
if (TConfig::kQueryScale == QueryScaleType::SqrtModelDimDivNumHeads)
return 1.0f /
Sqrt(static_cast<float>(TConfig::kModelDim / TConfig::kHeads));
// QueryScaleType::SqrtKeySize
return 1.0f / Sqrt(static_cast<float>(TConfig::kQKVDim));
}
} // namespace gcpp

View File

@ -70,7 +70,8 @@ enum class ActivationType {
// Attention query scale.
enum class QueryScaleType {
Sqrt,
SqrtKeySize,
SqrtModelDimDivNumHeads,
};
// Residual connection type.
@ -149,7 +150,6 @@ struct ConfigNoSSM {
static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt;
static constexpr ResidualType kResidual = ResidualType::Add;
};
@ -157,6 +157,7 @@ struct ConfigBaseGemmaV1 : ConfigNoSSM {
static constexpr float kAttCap = 0.0f;
static constexpr float kFinalCap = 0.0f;
static constexpr PostNormType kPostNorm = PostNormType::None;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};
struct ConfigBaseGemmaV2 : ConfigNoSSM {
@ -184,6 +185,8 @@ struct ConfigGemma27B : public ConfigBaseGemmaV2 {
static constexpr int kQKVDim = 128; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale =
QueryScaleType::SqrtModelDimDivNumHeads;
};
template <typename TWeight>
@ -205,6 +208,7 @@ struct ConfigGemma9B : public ConfigBaseGemmaV2 {
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};
template <typename TWeight>
@ -269,6 +273,7 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr PostNormType kPostNorm = PostNormType::None;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
static constexpr float kAttCap = 0.0f;
// This is required for optimize_test to pass.
@ -343,7 +348,7 @@ struct ConfigGriffin2B {
static constexpr int kNumTensorScales = 140;
static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
static constexpr ResidualType kResidual = ResidualType::Add;
};