mirror of https://github.com/google/gemma.cpp.git
Improve readability with RepeatedAttentionWindowSizes
PiperOrigin-RevId: 651431738
This commit is contained in:
parent
edaf61b983
commit
df3fb70802
|
|
@ -98,6 +98,19 @@ constexpr std::array<size_t, kNum> FixedAttentionWindowSizes(
|
|||
return window_size_configs;
|
||||
}
|
||||
|
||||
// Repeat window_size_pattern for kNum / kPatternSize times.
|
||||
template <size_t kNum, size_t kPatternSize>
|
||||
constexpr std::array<size_t, kNum> RepeatedAttentionWindowSizes(
|
||||
const std::array<size_t, kPatternSize>& window_size_pattern) {
|
||||
static_assert(kNum % kPatternSize == 0,
|
||||
"kNum must be a multiple of kPatternSize");
|
||||
std::array<size_t, kNum> window_size_configs = {};
|
||||
for (size_t i = 0; i < kNum; ++i) {
|
||||
window_size_configs[i] = window_size_pattern[i % kPatternSize];
|
||||
}
|
||||
return window_size_configs;
|
||||
}
|
||||
|
||||
template <size_t kNumLayers>
|
||||
constexpr size_t NumLayersOfTypeBefore(
|
||||
const std::array<LayerAttentionType, kNumLayers>& layers,
|
||||
|
|
@ -160,12 +173,8 @@ struct ConfigGemma27B : public ConfigBaseGemmaV2 {
|
|||
static constexpr int kVocabSize = 256000;
|
||||
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
|
||||
FixedLayerConfig<46>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 46> kAttentionWindowSizes = {
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen};
|
||||
static constexpr std::array<size_t, 46> kAttentionWindowSizes =
|
||||
RepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen});
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 4608;
|
||||
|
|
@ -185,12 +194,8 @@ struct ConfigGemma9B : public ConfigBaseGemmaV2 {
|
|||
static constexpr int kVocabSize = 256000;
|
||||
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
|
||||
FixedLayerConfig<42>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 42> kAttentionWindowSizes = {
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||
4096, kSeqLen};
|
||||
static constexpr std::array<size_t, 42> kAttentionWindowSizes =
|
||||
RepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen});
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 3584;
|
||||
|
|
|
|||
Loading…
Reference in New Issue