diff --git a/gemma/configs.h b/gemma/configs.h index 1ffdfc2..b7e2a44 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -62,6 +62,16 @@ constexpr std::array FixedLayerConfig( return config; } +template +constexpr std::array FixedAttentionWindowSizes( + size_t window_size) { + std::array window_size_configs = {}; + for (size_t& l : window_size_configs) { + l = window_size; + } + return window_size_configs; +} + template constexpr size_t NumLayersOfTypeBefore( const std::array& layers, @@ -114,10 +124,16 @@ template struct ConfigGemma27B : public ConfigCapNoSSM { using Weight = TWeight; // make accessible where we only have a TConfig - static constexpr int kSeqLen = gcpp::kSeqLen; + static constexpr int kSeqLen = 8192; static constexpr int kVocabSize = 256000; static constexpr std::array kLayerConfig = FixedLayerConfig<46>(LayerAttentionType::kGemma); + static constexpr std::array kAttentionWindowSizes = { + 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, + 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, + 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, + 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, + 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen}; static constexpr int kLayers = kLayerConfig.size(); static constexpr int kGemmaLayers = kLayers; static constexpr int kModelDim = 4608; @@ -134,10 +150,16 @@ template struct ConfigGemma9B : public ConfigCapNoSSM { using Weight = TWeight; // make accessible where we only have a TConfig - static constexpr int kSeqLen = gcpp::kSeqLen; + static constexpr int kSeqLen = 8192; static constexpr int kVocabSize = 256000; static constexpr std::array kLayerConfig = FixedLayerConfig<42>(LayerAttentionType::kGemma); + static constexpr std::array kAttentionWindowSizes = { + 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, + 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, + 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, + 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, + 4096, kSeqLen}; static constexpr int kLayers = kLayerConfig.size(); static constexpr int kGemmaLayers = kLayers; static constexpr int kModelDim = 3584; @@ -158,6 +180,8 @@ struct ConfigGemma7B : public ConfigNoCapNoSSM { static constexpr int kVocabSize = 256000; static constexpr std::array kLayerConfig = FixedLayerConfig<28>(LayerAttentionType::kGemma); + static constexpr std::array kAttentionWindowSizes = + FixedAttentionWindowSizes<28>(kSeqLen); static constexpr int kLayers = kLayerConfig.size(); static constexpr int kGemmaLayers = kLayers; static constexpr int kModelDim = 3072; @@ -178,6 +202,8 @@ struct ConfigGemma2B : public ConfigNoCapNoSSM { static constexpr int kVocabSize = 256000; static constexpr std::array kLayerConfig = FixedLayerConfig<18>(LayerAttentionType::kGemma); + static constexpr std::array kAttentionWindowSizes = + FixedAttentionWindowSizes<18>(kSeqLen); static constexpr int kLayers = kLayerConfig.size(); static constexpr int kGemmaLayers = kLayers; static constexpr int kModelDim = 2048; @@ -198,6 +224,8 @@ struct ConfigGemmaTiny : public ConfigNoSSM { static constexpr int kVocabSize = 64; static constexpr std::array kLayerConfig = FixedLayerConfig<3>(LayerAttentionType::kGemma); + static constexpr std::array kAttentionWindowSizes = + FixedAttentionWindowSizes<3>(kSeqLen); static constexpr int kLayers = kLayerConfig.size(); static constexpr int kGemmaLayers = kLayers; static constexpr int kModelDim = 128; @@ -250,6 +278,8 @@ struct ConfigGriffin2B { LayerAttentionType::kGriffinRecurrentBlock, LayerAttentionType::kGriffinRecurrentBlock, }; + static constexpr std::array kAttentionWindowSizes = + FixedAttentionWindowSizes<26>(kSeqLen); static constexpr int kLayers = kLayerConfig.size(); static constexpr int kGemmaLayers = NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers); diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 99ab616..6735d18 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -377,7 +377,8 @@ HWY_NOINLINE void Attention( MulByConst(kQueryScale, q, kQKVDim); // Compute Q dot K scores - const size_t start_pos = pos - std::min(kSeqLen - 1, pos); + const size_t start_pos = + pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos); for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); const size_t kv_offset =