mirror of https://github.com/google/gemma.cpp.git
Add sliding window attention for Gemma 2.
PiperOrigin-RevId: 648778253
This commit is contained in:
parent
09a7e75ead
commit
7e4b20455e
|
|
@ -62,6 +62,16 @@ constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
|
|||
return config;
|
||||
}
|
||||
|
||||
template <size_t kNum>
|
||||
constexpr std::array<size_t, kNum> FixedAttentionWindowSizes(
|
||||
size_t window_size) {
|
||||
std::array<size_t, kNum> window_size_configs = {};
|
||||
for (size_t& l : window_size_configs) {
|
||||
l = window_size;
|
||||
}
|
||||
return window_size_configs;
|
||||
}
|
||||
|
||||
template <size_t kNumLayers>
|
||||
constexpr size_t NumLayersOfTypeBefore(
|
||||
const std::array<LayerAttentionType, kNumLayers>& layers,
|
||||
|
|
@ -114,10 +124,16 @@ template <typename TWeight>
|
|||
struct ConfigGemma27B : public ConfigCapNoSSM {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
static constexpr int kSeqLen = 8192;
|
||||
static constexpr int kVocabSize = 256000;
|
||||
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
|
||||
FixedLayerConfig<46>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 46> kAttentionWindowSizes = {
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen};
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 4608;
|
||||
|
|
@ -134,10 +150,16 @@ template <typename TWeight>
|
|||
struct ConfigGemma9B : public ConfigCapNoSSM {
|
||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||
|
||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
||||
static constexpr int kSeqLen = 8192;
|
||||
static constexpr int kVocabSize = 256000;
|
||||
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
|
||||
FixedLayerConfig<42>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 42> kAttentionWindowSizes = {
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||
4096, kSeqLen};
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 3584;
|
||||
|
|
@ -158,6 +180,8 @@ struct ConfigGemma7B : public ConfigNoCapNoSSM {
|
|||
static constexpr int kVocabSize = 256000;
|
||||
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
|
||||
FixedLayerConfig<28>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 28> kAttentionWindowSizes =
|
||||
FixedAttentionWindowSizes<28>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 3072;
|
||||
|
|
@ -178,6 +202,8 @@ struct ConfigGemma2B : public ConfigNoCapNoSSM {
|
|||
static constexpr int kVocabSize = 256000;
|
||||
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
|
||||
FixedLayerConfig<18>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 18> kAttentionWindowSizes =
|
||||
FixedAttentionWindowSizes<18>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 2048;
|
||||
|
|
@ -198,6 +224,8 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
|
|||
static constexpr int kVocabSize = 64;
|
||||
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
|
||||
FixedLayerConfig<3>(LayerAttentionType::kGemma);
|
||||
static constexpr std::array<size_t, 3> kAttentionWindowSizes =
|
||||
FixedAttentionWindowSizes<3>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kGemmaLayers = kLayers;
|
||||
static constexpr int kModelDim = 128;
|
||||
|
|
@ -250,6 +278,8 @@ struct ConfigGriffin2B {
|
|||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
LayerAttentionType::kGriffinRecurrentBlock,
|
||||
};
|
||||
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
|
||||
FixedAttentionWindowSizes<26>(kSeqLen);
|
||||
static constexpr int kLayers = kLayerConfig.size();
|
||||
static constexpr int kGemmaLayers =
|
||||
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
|
||||
|
|
|
|||
|
|
@ -377,7 +377,8 @@ HWY_NOINLINE void Attention(
|
|||
MulByConst(kQueryScale, q, kQKVDim);
|
||||
|
||||
// Compute Q dot K scores
|
||||
const size_t start_pos = pos - std::min(kSeqLen - 1, pos);
|
||||
const size_t start_pos =
|
||||
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
|
||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
||||
const size_t kv_offset =
|
||||
|
|
|
|||
Loading…
Reference in New Issue