From 8cf3966be451068c7038db0ee39ca6d2891ac95b Mon Sep 17 00:00:00 2001 From: Nanubala Gnana Sai <45007169+jonpsy@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:25:28 +0530 Subject: [PATCH] compile success: set default self extend values in noSSM and griffin --- gemma/configs.h | 260 ++++++++++++++++++++++++++++++++++++++++++++++ gemma/gemma-inl.h | 15 +++ 2 files changed, 275 insertions(+) diff --git a/gemma/configs.h b/gemma/configs.h index f7c6ac2..30e9a6a 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -192,6 +192,266 @@ ModelConfig ConfigFromModel(Model model); // Returns the sub-config for the ViT model of the PaliGemma model. ModelConfig VitConfig(const ModelConfig& config); +template +struct CacheLayerSize { + constexpr size_t operator()() const { + return TConfig::kKVHeads * TConfig::kQKVDim * 2; + } +}; + +template +struct CachePosSize { + constexpr size_t operator()() const { + return TConfig::kGemmaLayers * CacheLayerSize()(); + } +}; + +struct ConfigNoSSM { + static constexpr int kGriffinLayers = 0; + + static constexpr int kConv1dWidth = 0; + static constexpr bool kFFBiases = false; + static constexpr bool kSoftmaxAttnOutputBiases = false; + static constexpr bool kUseHalfRope = false; + static constexpr bool kUseLocalAttention = false; + static constexpr bool kInterleaveQKV = true; + static constexpr int kNumTensorScales = 0; + + static constexpr PostQKType kPostQK = PostQKType::Rope; + static constexpr ActivationType kActivation = ActivationType::Gelu; + static constexpr ResidualType kResidual = ResidualType::Add; + + // Self-extend parameters with defaul values + static constexpr bool kSelfExtend = false; + static constexpr size_t kSelfExtendNgbSize = 0; + static constexpr size_t kSelfExtendGrpSize = 1; +}; + +struct ConfigBaseGemmaV1 : ConfigNoSSM { + static constexpr float kAttCap = 0.0f; + static constexpr float kFinalCap = 0.0f; + static constexpr PostNormType kPostNorm = PostNormType::None; + static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; +}; + +struct ConfigBaseGemmaV2 : ConfigNoSSM { + static constexpr float kAttCap = 50.0f; + static constexpr float kFinalCap = 30.0f; + static constexpr PostNormType kPostNorm = PostNormType::Scale; +}; + +template +struct ConfigGemma27B : public ConfigBaseGemmaV2 { + using Weight = TWeight; // make accessible where we only have a TConfig + + static constexpr int kSeqLen = 8192; + static constexpr int kVocabSize = 256000; + static constexpr std::array kLayerConfig = + FixedLayerConfig<46>(LayerAttentionType::kGemma); + static constexpr std::array kAttentionWindowSizes = + RepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen}); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = kLayers; + static constexpr int kModelDim = 4608; + static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864 + static constexpr int kHeads = 32; + static constexpr int kKVHeads = 16; + static constexpr int kQKVDim = 128; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; + static constexpr QueryScaleType kQueryScale = + QueryScaleType::SqrtModelDimDivNumHeads; +}; + +template +struct ConfigGemma9B : public ConfigBaseGemmaV2 { + using Weight = TWeight; // make accessible where we only have a TConfig + + static constexpr int kSeqLen = 8192; + static constexpr int kVocabSize = 256000; + static constexpr std::array kLayerConfig = + FixedLayerConfig<42>(LayerAttentionType::kGemma); + static constexpr std::array kAttentionWindowSizes = + RepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen}); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = kLayers; + static constexpr int kModelDim = 3584; + static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336 + static constexpr int kHeads = 16; + static constexpr int kKVHeads = 8; + static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; + static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; +}; + +template +struct ConfigGemma7B : public ConfigBaseGemmaV1 { + using Weight = TWeight; // make accessible where we only have a TConfig + + static constexpr int kSeqLen = gcpp::kSeqLen; + static constexpr int kVocabSize = 256000; + static constexpr std::array kLayerConfig = + FixedLayerConfig<28>(LayerAttentionType::kGemma); + static constexpr std::array kAttentionWindowSizes = + FixedAttentionWindowSizes<28>(kSeqLen); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = kLayers; + static constexpr int kModelDim = 3072; + static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 + static constexpr int kHeads = 16; + static constexpr int kKVHeads = 16; // standard MHA + static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; +}; + +template +struct ConfigGemma2B : public ConfigBaseGemmaV1 { + using Weight = TWeight; // make accessible where we only have a TConfig + + static constexpr int kSeqLen = gcpp::kSeqLen; + static constexpr int kVocabSize = 256000; + static constexpr std::array kLayerConfig = + FixedLayerConfig<18>(LayerAttentionType::kGemma); + static constexpr std::array kAttentionWindowSizes = + FixedAttentionWindowSizes<18>(kSeqLen); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = kLayers; + static constexpr int kModelDim = 2048; + static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 + static constexpr int kHeads = 8; + static constexpr int kKVHeads = 1; + static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; +}; + +template +struct ConfigGemma2_2B : public ConfigBaseGemmaV2 { + using Weight = TWeight; // make accessible where we only have a TConfig + + static constexpr int kSeqLen = 8192; + static constexpr int kVocabSize = 256000; + static constexpr std::array kLayerConfig = + FixedLayerConfig<26>(LayerAttentionType::kGemma); + static constexpr std::array kAttentionWindowSizes = + RepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen}); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = kLayers; + static constexpr int kModelDim = 2304; + static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216 + static constexpr int kHeads = 8; + static constexpr int kKVHeads = 4; + static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; + static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; +}; + +template +struct ConfigGemmaTiny : public ConfigNoSSM { + using Weight = TWeight; // make accessible where we only have a TConfig + + static constexpr int kSeqLen = 32; + static constexpr int kVocabSize = 64; + static constexpr std::array kLayerConfig = + FixedLayerConfig<3>(LayerAttentionType::kGemma); + static constexpr std::array kAttentionWindowSizes = + FixedAttentionWindowSizes<3>(kSeqLen); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = kLayers; + static constexpr int kModelDim = 128; + static constexpr int kFFHiddenDim = 256; + static constexpr int kHeads = 4; + static constexpr int kKVHeads = 1; + static constexpr int kQKVDim = 16; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; + static constexpr PostNormType kPostNorm = PostNormType::None; + static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; + + static constexpr float kAttCap = 0.0f; + // This is required for optimize_test to pass. + static constexpr float kFinalCap = 30.0f; +}; + +template +struct ConfigGriffin2B { + using Weight = TWeight; // make accessible where we only have a TConfig + + // Griffin uses local attention, so kSeqLen is actually the local attention + // window. + static constexpr int kSeqLen = 2048; + static constexpr int kVocabSize = 256000; + static constexpr std::array kLayerConfig = { + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGemma, + LayerAttentionType::kGriffinRecurrentBlock, + LayerAttentionType::kGriffinRecurrentBlock, + }; + static constexpr std::array kAttentionWindowSizes = + FixedAttentionWindowSizes<26>(kSeqLen); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = + NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers); + static constexpr int kGriffinLayers = + NumLayersOfTypeBefore(kLayerConfig, + LayerAttentionType::kGriffinRecurrentBlock, + kLayers); + static constexpr int kModelDim = 2560; + static constexpr int kFFHiddenDim = 7680; + static constexpr int kHeads = 10; + static constexpr int kKVHeads = 1; + static constexpr int kQKVDim = 256; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; + static constexpr PostNormType kPostNorm = PostNormType::None; + + // No SoftCap. + static constexpr float kAttCap = 0.0f; + static constexpr float kFinalCap = 0.0f; + + // SSM config. + static constexpr int kConv1dWidth = 4; + static constexpr bool kFFBiases = true; + static constexpr bool kSoftmaxAttnOutputBiases = true; + static constexpr bool kUseHalfRope = true; + static constexpr bool kUseLocalAttention = true; + static constexpr bool kInterleaveQKV = false; + static constexpr int kNumTensorScales = 140; + static constexpr PostQKType kPostQK = PostQKType::Rope; + static constexpr ActivationType kActivation = ActivationType::Gelu; + static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; + static constexpr ResidualType kResidual = ResidualType::Add; + + // Self-extend parameters with defaul values + static constexpr bool kSelfExtend = false; + static constexpr size_t kSelfExtendNgbSize = 0; + static constexpr size_t kSelfExtendGrpSize = 1; +}; } // namespace gcpp diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 028949c..196f1a4 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -327,6 +327,13 @@ class GemmaAttention { PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f, kv); + // When embedding position, we will use grouped key position + if constexpr (TConfig::kSelfExtend) { + if (pos > ngb_size) { + pos /= grp_size; + } + } + // If MHA, also copy V into KVCache. if (is_mha_) { hwy::CopyBytes(mha_kv + layer_config_.qkv_dim, @@ -417,6 +424,14 @@ class GemmaAttention { // Apply rope and scaling to Q. const size_t pos = queries_pos_[query_idx] + batch_idx; + if constexpr (TConfig::kSelfExtend) { + if (pos > ngb_size) { + const size_t grp_pos = pos / grp_size; + const size_t shift = ngb_size - ngb_size / grp_size; + const size_t shifted_grouped_pos = grp_pos + shift; + pos = shifted_grouped_pos; + } + } PositionalEncodingQK(q, pos, layer_, query_scale, q); const size_t start_pos = StartPos(pos, layer_);