mirror of https://github.com/google/gemma.cpp.git
Update gemma-27b to the correct query scaling.
PiperOrigin-RevId: 653201646
This commit is contained in:
parent
992a2cbbc0
commit
5a751a9a44
|
|
@ -246,9 +246,11 @@ static HWY_INLINE GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling(
|
|||
|
||||
template <class TConfig>
|
||||
GEMMA_CONSTEXPR_SQRT float ChooseQueryScale() {
|
||||
constexpr size_t kQKVDim = TConfig::kQKVDim;
|
||||
// QueryScaleType::Sqrt
|
||||
return 1.0f / Sqrt(static_cast<float>(kQKVDim));
|
||||
if (TConfig::kQueryScale == QueryScaleType::SqrtModelDimDivNumHeads)
|
||||
return 1.0f /
|
||||
Sqrt(static_cast<float>(TConfig::kModelDim / TConfig::kHeads));
|
||||
// QueryScaleType::SqrtKeySize
|
||||
return 1.0f / Sqrt(static_cast<float>(TConfig::kQKVDim));
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -70,7 +70,8 @@ enum class ActivationType {
|
|||
|
||||
// Attention query scale.
|
||||
enum class QueryScaleType {
|
||||
Sqrt,
|
||||
SqrtKeySize,
|
||||
SqrtModelDimDivNumHeads,
|
||||
};
|
||||
|
||||
// Residual connection type.
|
||||
|
|
@ -149,7 +150,6 @@ struct ConfigNoSSM {
|
|||
|
||||
static constexpr PostQKType kPostQK = PostQKType::Rope;
|
||||
static constexpr ActivationType kActivation = ActivationType::Gelu;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt;
|
||||
static constexpr ResidualType kResidual = ResidualType::Add;
|
||||
};
|
||||
|
||||
|
|
@ -157,6 +157,7 @@ struct ConfigBaseGemmaV1 : ConfigNoSSM {
|
|||
static constexpr float kAttCap = 0.0f;
|
||||
static constexpr float kFinalCap = 0.0f;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
};
|
||||
|
||||
struct ConfigBaseGemmaV2 : ConfigNoSSM {
|
||||
|
|
@ -184,6 +185,8 @@ struct ConfigGemma27B : public ConfigBaseGemmaV2 {
|
|||
static constexpr int kQKVDim = 128; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr QueryScaleType kQueryScale =
|
||||
QueryScaleType::SqrtModelDimDivNumHeads;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
|
|
@ -205,6 +208,7 @@ struct ConfigGemma9B : public ConfigBaseGemmaV2 {
|
|||
static constexpr int kQKVDim = 256; // query size == key size == value size
|
||||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
};
|
||||
|
||||
template <typename TWeight>
|
||||
|
|
@ -269,6 +273,7 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
|
|||
static constexpr int kTopK = gcpp::kTopK;
|
||||
static constexpr bool kAbsolutePE = false;
|
||||
static constexpr PostNormType kPostNorm = PostNormType::None;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
|
||||
static constexpr float kAttCap = 0.0f;
|
||||
// This is required for optimize_test to pass.
|
||||
|
|
@ -343,7 +348,7 @@ struct ConfigGriffin2B {
|
|||
static constexpr int kNumTensorScales = 140;
|
||||
static constexpr PostQKType kPostQK = PostQKType::Rope;
|
||||
static constexpr ActivationType kActivation = ActivationType::Gelu;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt;
|
||||
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
|
||||
static constexpr ResidualType kResidual = ResidualType::Add;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue