Add sliding window attention for Gemma 2.

PiperOrigin-RevId: 648778253
This commit is contained in:
Kan Wu 2024-07-02 11:07:24 -07:00 committed by Copybara-Service
parent 09a7e75ead
commit 7e4b20455e
2 changed files with 34 additions and 3 deletions

View File

@ -62,6 +62,16 @@ constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
return config; return config;
} }
template <size_t kNum>
constexpr std::array<size_t, kNum> FixedAttentionWindowSizes(
size_t window_size) {
std::array<size_t, kNum> window_size_configs = {};
for (size_t& l : window_size_configs) {
l = window_size;
}
return window_size_configs;
}
template <size_t kNumLayers> template <size_t kNumLayers>
constexpr size_t NumLayersOfTypeBefore( constexpr size_t NumLayersOfTypeBefore(
const std::array<LayerAttentionType, kNumLayers>& layers, const std::array<LayerAttentionType, kNumLayers>& layers,
@ -114,10 +124,16 @@ template <typename TWeight>
struct ConfigGemma27B : public ConfigCapNoSSM { struct ConfigGemma27B : public ConfigCapNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = 256000; static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 46> kLayerConfig = static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
FixedLayerConfig<46>(LayerAttentionType::kGemma); FixedLayerConfig<46>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 46> kAttentionWindowSizes = {
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen};
static constexpr int kLayers = kLayerConfig.size(); static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers; static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 4608; static constexpr int kModelDim = 4608;
@ -134,10 +150,16 @@ template <typename TWeight>
struct ConfigGemma9B : public ConfigCapNoSSM { struct ConfigGemma9B : public ConfigCapNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = gcpp::kSeqLen; static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = 256000; static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 42> kLayerConfig = static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
FixedLayerConfig<42>(LayerAttentionType::kGemma); FixedLayerConfig<42>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 42> kAttentionWindowSizes = {
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
4096, kSeqLen};
static constexpr int kLayers = kLayerConfig.size(); static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers; static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 3584; static constexpr int kModelDim = 3584;
@ -158,6 +180,8 @@ struct ConfigGemma7B : public ConfigNoCapNoSSM {
static constexpr int kVocabSize = 256000; static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 28> kLayerConfig = static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
FixedLayerConfig<28>(LayerAttentionType::kGemma); FixedLayerConfig<28>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 28> kAttentionWindowSizes =
FixedAttentionWindowSizes<28>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size(); static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers; static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 3072; static constexpr int kModelDim = 3072;
@ -178,6 +202,8 @@ struct ConfigGemma2B : public ConfigNoCapNoSSM {
static constexpr int kVocabSize = 256000; static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 18> kLayerConfig = static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
FixedLayerConfig<18>(LayerAttentionType::kGemma); FixedLayerConfig<18>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 18> kAttentionWindowSizes =
FixedAttentionWindowSizes<18>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size(); static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers; static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 2048; static constexpr int kModelDim = 2048;
@ -198,6 +224,8 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
static constexpr int kVocabSize = 64; static constexpr int kVocabSize = 64;
static constexpr std::array<LayerAttentionType, 3> kLayerConfig = static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
FixedLayerConfig<3>(LayerAttentionType::kGemma); FixedLayerConfig<3>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 3> kAttentionWindowSizes =
FixedAttentionWindowSizes<3>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size(); static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers; static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 128; static constexpr int kModelDim = 128;
@ -250,6 +278,8 @@ struct ConfigGriffin2B {
LayerAttentionType::kGriffinRecurrentBlock, LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock, LayerAttentionType::kGriffinRecurrentBlock,
}; };
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
FixedAttentionWindowSizes<26>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size(); static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers); NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);

View File

@ -377,7 +377,8 @@ HWY_NOINLINE void Attention(
MulByConst(kQueryScale, q, kQKVDim); MulByConst(kQueryScale, q, kQKVDim);
// Compute Q dot K scores // Compute Q dot K scores
const size_t start_pos = pos - std::min(kSeqLen - 1, pos); const size_t start_pos =
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset = const size_t kv_offset =