Improve readability with RepeatedAttentionWindowSizes

PiperOrigin-RevId: 651431738
This commit is contained in:
The gemma.cpp Authors 2024-07-11 09:11:15 -07:00 committed by Copybara-Service
parent edaf61b983
commit df3fb70802
1 changed files with 17 additions and 12 deletions

View File

@ -98,6 +98,19 @@ constexpr std::array<size_t, kNum> FixedAttentionWindowSizes(
return window_size_configs;
}
// Repeat window_size_pattern for kNum / kPatternSize times.
template <size_t kNum, size_t kPatternSize>
constexpr std::array<size_t, kNum> RepeatedAttentionWindowSizes(
const std::array<size_t, kPatternSize>& window_size_pattern) {
static_assert(kNum % kPatternSize == 0,
"kNum must be a multiple of kPatternSize");
std::array<size_t, kNum> window_size_configs = {};
for (size_t i = 0; i < kNum; ++i) {
window_size_configs[i] = window_size_pattern[i % kPatternSize];
}
return window_size_configs;
}
template <size_t kNumLayers>
constexpr size_t NumLayersOfTypeBefore(
const std::array<LayerAttentionType, kNumLayers>& layers,
@ -160,12 +173,8 @@ struct ConfigGemma27B : public ConfigBaseGemmaV2 {
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
FixedLayerConfig<46>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 46> kAttentionWindowSizes = {
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen};
static constexpr std::array<size_t, 46> kAttentionWindowSizes =
RepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 4608;
@ -185,12 +194,8 @@ struct ConfigGemma9B : public ConfigBaseGemmaV2 {
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
FixedLayerConfig<42>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 42> kAttentionWindowSizes = {
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
4096, kSeqLen};
static constexpr std::array<size_t, 42> kAttentionWindowSizes =
RepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 3584;