mirror of https://github.com/google/gemma.cpp.git
Improve readability with RepeatedAttentionWindowSizes
PiperOrigin-RevId: 651431738
This commit is contained in:
parent
edaf61b983
commit
df3fb70802
|
|
@ -98,6 +98,19 @@ constexpr std::array<size_t, kNum> FixedAttentionWindowSizes(
|
||||||
return window_size_configs;
|
return window_size_configs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Repeat window_size_pattern for kNum / kPatternSize times.
|
||||||
|
template <size_t kNum, size_t kPatternSize>
|
||||||
|
constexpr std::array<size_t, kNum> RepeatedAttentionWindowSizes(
|
||||||
|
const std::array<size_t, kPatternSize>& window_size_pattern) {
|
||||||
|
static_assert(kNum % kPatternSize == 0,
|
||||||
|
"kNum must be a multiple of kPatternSize");
|
||||||
|
std::array<size_t, kNum> window_size_configs = {};
|
||||||
|
for (size_t i = 0; i < kNum; ++i) {
|
||||||
|
window_size_configs[i] = window_size_pattern[i % kPatternSize];
|
||||||
|
}
|
||||||
|
return window_size_configs;
|
||||||
|
}
|
||||||
|
|
||||||
template <size_t kNumLayers>
|
template <size_t kNumLayers>
|
||||||
constexpr size_t NumLayersOfTypeBefore(
|
constexpr size_t NumLayersOfTypeBefore(
|
||||||
const std::array<LayerAttentionType, kNumLayers>& layers,
|
const std::array<LayerAttentionType, kNumLayers>& layers,
|
||||||
|
|
@ -160,12 +173,8 @@ struct ConfigGemma27B : public ConfigBaseGemmaV2 {
|
||||||
static constexpr int kVocabSize = 256000;
|
static constexpr int kVocabSize = 256000;
|
||||||
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
|
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
|
||||||
FixedLayerConfig<46>(LayerAttentionType::kGemma);
|
FixedLayerConfig<46>(LayerAttentionType::kGemma);
|
||||||
static constexpr std::array<size_t, 46> kAttentionWindowSizes = {
|
static constexpr std::array<size_t, 46> kAttentionWindowSizes =
|
||||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
RepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen});
|
||||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
|
||||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
|
||||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
|
||||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen};
|
|
||||||
static constexpr int kLayers = kLayerConfig.size();
|
static constexpr int kLayers = kLayerConfig.size();
|
||||||
static constexpr int kGemmaLayers = kLayers;
|
static constexpr int kGemmaLayers = kLayers;
|
||||||
static constexpr int kModelDim = 4608;
|
static constexpr int kModelDim = 4608;
|
||||||
|
|
@ -185,12 +194,8 @@ struct ConfigGemma9B : public ConfigBaseGemmaV2 {
|
||||||
static constexpr int kVocabSize = 256000;
|
static constexpr int kVocabSize = 256000;
|
||||||
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
|
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
|
||||||
FixedLayerConfig<42>(LayerAttentionType::kGemma);
|
FixedLayerConfig<42>(LayerAttentionType::kGemma);
|
||||||
static constexpr std::array<size_t, 42> kAttentionWindowSizes = {
|
static constexpr std::array<size_t, 42> kAttentionWindowSizes =
|
||||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
RepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen});
|
||||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
|
||||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
|
||||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
|
||||||
4096, kSeqLen};
|
|
||||||
static constexpr int kLayers = kLayerConfig.size();
|
static constexpr int kLayers = kLayerConfig.size();
|
||||||
static constexpr int kGemmaLayers = kLayers;
|
static constexpr int kGemmaLayers = kLayers;
|
||||||
static constexpr int kModelDim = 3584;
|
static constexpr int kModelDim = 3584;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue