diff --git a/gemma/configs.h b/gemma/configs.h index f1327d3..d637372 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -98,6 +98,19 @@ constexpr std::array FixedAttentionWindowSizes( return window_size_configs; } +// Repeat window_size_pattern for kNum / kPatternSize times. +template +constexpr std::array RepeatedAttentionWindowSizes( + const std::array& window_size_pattern) { + static_assert(kNum % kPatternSize == 0, + "kNum must be a multiple of kPatternSize"); + std::array window_size_configs = {}; + for (size_t i = 0; i < kNum; ++i) { + window_size_configs[i] = window_size_pattern[i % kPatternSize]; + } + return window_size_configs; +} + template constexpr size_t NumLayersOfTypeBefore( const std::array& layers, @@ -160,12 +173,8 @@ struct ConfigGemma27B : public ConfigBaseGemmaV2 { static constexpr int kVocabSize = 256000; static constexpr std::array kLayerConfig = FixedLayerConfig<46>(LayerAttentionType::kGemma); - static constexpr std::array kAttentionWindowSizes = { - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen}; + static constexpr std::array kAttentionWindowSizes = + RepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen}); static constexpr int kLayers = kLayerConfig.size(); static constexpr int kGemmaLayers = kLayers; static constexpr int kModelDim = 4608; @@ -185,12 +194,8 @@ struct ConfigGemma9B : public ConfigBaseGemmaV2 { static constexpr int kVocabSize = 256000; static constexpr std::array kLayerConfig = FixedLayerConfig<42>(LayerAttentionType::kGemma); - static constexpr std::array kAttentionWindowSizes = { - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, - 4096, kSeqLen}; + static constexpr std::array kAttentionWindowSizes = + RepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen}); static constexpr int kLayers = kLayerConfig.size(); static constexpr int kGemmaLayers = kLayers; static constexpr int kModelDim = 3584;