From df3fb70802e45ab57dfa60031d04d39ac8aead91 Mon Sep 17 00:00:00 2001 From: "The gemma.cpp Authors" Date: Thu, 11 Jul 2024 09:11:15 -0700 Subject: [PATCH] Improve readability with RepeatedAttentionWindowSizes PiperOrigin-RevId: 651431738 --- gemma/configs.h | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/gemma/configs.h b/gemma/configs.h index f1327d3..d637372 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -98,6 +98,19 @@ constexpr std::array FixedAttentionWindowSizes( return window_size_configs; } +// Repeat window_size_pattern for kNum / kPatternSize times. +template +constexpr std::array RepeatedAttentionWindowSizes( + const std::array& window_size_pattern) { + static_assert(kNum % kPatternSize == 0, + "kNum must be a multiple of kPatternSize"); + std::array window_size_configs = {}; + for (size_t i = 0; i < kNum; ++i) { + window_size_configs[i] = window_size_pattern[i % kPatternSize]; + } + return window_size_configs; +} + template constexpr size_t NumLayersOfTypeBefore( const std::array& layers, @@ -160,12 +173,8 @@ struct ConfigGemma27B : public ConfigBaseGemmaV2 { static constexpr int kVocabSize = 256000; static constexpr std::array kLayerConfig = FixedLayerConfig<46>(LayerAttentionType::kGemma); - static constexpr std::array kAttentionWindowSizes = { - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen}; + static constexpr std::array kAttentionWindowSizes = + RepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen}); static constexpr int kLayers = kLayerConfig.size(); static constexpr int kGemmaLayers = kLayers; static constexpr int kModelDim = 4608; @@ -185,12 +194,8 @@ struct ConfigGemma9B : public ConfigBaseGemmaV2 { static constexpr int kVocabSize = 256000; static constexpr std::array kLayerConfig = FixedLayerConfig<42>(LayerAttentionType::kGemma); - static constexpr std::array kAttentionWindowSizes = { - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, - 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, - 4096, kSeqLen}; + static constexpr std::array kAttentionWindowSizes = + RepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen}); static constexpr int kLayers = kLayerConfig.size(); static constexpr int kGemmaLayers = kLayers; static constexpr int kModelDim = 3584;