mirror of https://github.com/google/gemma.cpp.git
Add sliding window attention for Gemma 2.
PiperOrigin-RevId: 648778253
This commit is contained in:
parent
09a7e75ead
commit
7e4b20455e
|
|
@ -62,6 +62,16 @@ constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <size_t kNum>
|
||||||
|
constexpr std::array<size_t, kNum> FixedAttentionWindowSizes(
|
||||||
|
size_t window_size) {
|
||||||
|
std::array<size_t, kNum> window_size_configs = {};
|
||||||
|
for (size_t& l : window_size_configs) {
|
||||||
|
l = window_size;
|
||||||
|
}
|
||||||
|
return window_size_configs;
|
||||||
|
}
|
||||||
|
|
||||||
template <size_t kNumLayers>
|
template <size_t kNumLayers>
|
||||||
constexpr size_t NumLayersOfTypeBefore(
|
constexpr size_t NumLayersOfTypeBefore(
|
||||||
const std::array<LayerAttentionType, kNumLayers>& layers,
|
const std::array<LayerAttentionType, kNumLayers>& layers,
|
||||||
|
|
@ -114,10 +124,16 @@ template <typename TWeight>
|
||||||
struct ConfigGemma27B : public ConfigCapNoSSM {
|
struct ConfigGemma27B : public ConfigCapNoSSM {
|
||||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||||
|
|
||||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
static constexpr int kSeqLen = 8192;
|
||||||
static constexpr int kVocabSize = 256000;
|
static constexpr int kVocabSize = 256000;
|
||||||
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
|
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
|
||||||
FixedLayerConfig<46>(LayerAttentionType::kGemma);
|
FixedLayerConfig<46>(LayerAttentionType::kGemma);
|
||||||
|
static constexpr std::array<size_t, 46> kAttentionWindowSizes = {
|
||||||
|
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||||
|
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||||
|
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||||
|
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||||
|
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen};
|
||||||
static constexpr int kLayers = kLayerConfig.size();
|
static constexpr int kLayers = kLayerConfig.size();
|
||||||
static constexpr int kGemmaLayers = kLayers;
|
static constexpr int kGemmaLayers = kLayers;
|
||||||
static constexpr int kModelDim = 4608;
|
static constexpr int kModelDim = 4608;
|
||||||
|
|
@ -134,10 +150,16 @@ template <typename TWeight>
|
||||||
struct ConfigGemma9B : public ConfigCapNoSSM {
|
struct ConfigGemma9B : public ConfigCapNoSSM {
|
||||||
using Weight = TWeight; // make accessible where we only have a TConfig
|
using Weight = TWeight; // make accessible where we only have a TConfig
|
||||||
|
|
||||||
static constexpr int kSeqLen = gcpp::kSeqLen;
|
static constexpr int kSeqLen = 8192;
|
||||||
static constexpr int kVocabSize = 256000;
|
static constexpr int kVocabSize = 256000;
|
||||||
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
|
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
|
||||||
FixedLayerConfig<42>(LayerAttentionType::kGemma);
|
FixedLayerConfig<42>(LayerAttentionType::kGemma);
|
||||||
|
static constexpr std::array<size_t, 42> kAttentionWindowSizes = {
|
||||||
|
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||||
|
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||||
|
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||||
|
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
|
||||||
|
4096, kSeqLen};
|
||||||
static constexpr int kLayers = kLayerConfig.size();
|
static constexpr int kLayers = kLayerConfig.size();
|
||||||
static constexpr int kGemmaLayers = kLayers;
|
static constexpr int kGemmaLayers = kLayers;
|
||||||
static constexpr int kModelDim = 3584;
|
static constexpr int kModelDim = 3584;
|
||||||
|
|
@ -158,6 +180,8 @@ struct ConfigGemma7B : public ConfigNoCapNoSSM {
|
||||||
static constexpr int kVocabSize = 256000;
|
static constexpr int kVocabSize = 256000;
|
||||||
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
|
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
|
||||||
FixedLayerConfig<28>(LayerAttentionType::kGemma);
|
FixedLayerConfig<28>(LayerAttentionType::kGemma);
|
||||||
|
static constexpr std::array<size_t, 28> kAttentionWindowSizes =
|
||||||
|
FixedAttentionWindowSizes<28>(kSeqLen);
|
||||||
static constexpr int kLayers = kLayerConfig.size();
|
static constexpr int kLayers = kLayerConfig.size();
|
||||||
static constexpr int kGemmaLayers = kLayers;
|
static constexpr int kGemmaLayers = kLayers;
|
||||||
static constexpr int kModelDim = 3072;
|
static constexpr int kModelDim = 3072;
|
||||||
|
|
@ -178,6 +202,8 @@ struct ConfigGemma2B : public ConfigNoCapNoSSM {
|
||||||
static constexpr int kVocabSize = 256000;
|
static constexpr int kVocabSize = 256000;
|
||||||
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
|
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
|
||||||
FixedLayerConfig<18>(LayerAttentionType::kGemma);
|
FixedLayerConfig<18>(LayerAttentionType::kGemma);
|
||||||
|
static constexpr std::array<size_t, 18> kAttentionWindowSizes =
|
||||||
|
FixedAttentionWindowSizes<18>(kSeqLen);
|
||||||
static constexpr int kLayers = kLayerConfig.size();
|
static constexpr int kLayers = kLayerConfig.size();
|
||||||
static constexpr int kGemmaLayers = kLayers;
|
static constexpr int kGemmaLayers = kLayers;
|
||||||
static constexpr int kModelDim = 2048;
|
static constexpr int kModelDim = 2048;
|
||||||
|
|
@ -198,6 +224,8 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
|
||||||
static constexpr int kVocabSize = 64;
|
static constexpr int kVocabSize = 64;
|
||||||
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
|
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
|
||||||
FixedLayerConfig<3>(LayerAttentionType::kGemma);
|
FixedLayerConfig<3>(LayerAttentionType::kGemma);
|
||||||
|
static constexpr std::array<size_t, 3> kAttentionWindowSizes =
|
||||||
|
FixedAttentionWindowSizes<3>(kSeqLen);
|
||||||
static constexpr int kLayers = kLayerConfig.size();
|
static constexpr int kLayers = kLayerConfig.size();
|
||||||
static constexpr int kGemmaLayers = kLayers;
|
static constexpr int kGemmaLayers = kLayers;
|
||||||
static constexpr int kModelDim = 128;
|
static constexpr int kModelDim = 128;
|
||||||
|
|
@ -250,6 +278,8 @@ struct ConfigGriffin2B {
|
||||||
LayerAttentionType::kGriffinRecurrentBlock,
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
LayerAttentionType::kGriffinRecurrentBlock,
|
LayerAttentionType::kGriffinRecurrentBlock,
|
||||||
};
|
};
|
||||||
|
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
|
||||||
|
FixedAttentionWindowSizes<26>(kSeqLen);
|
||||||
static constexpr int kLayers = kLayerConfig.size();
|
static constexpr int kLayers = kLayerConfig.size();
|
||||||
static constexpr int kGemmaLayers =
|
static constexpr int kGemmaLayers =
|
||||||
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
|
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
|
||||||
|
|
|
||||||
|
|
@ -377,7 +377,8 @@ HWY_NOINLINE void Attention(
|
||||||
MulByConst(kQueryScale, q, kQKVDim);
|
MulByConst(kQueryScale, q, kQKVDim);
|
||||||
|
|
||||||
// Compute Q dot K scores
|
// Compute Q dot K scores
|
||||||
const size_t start_pos = pos - std::min(kSeqLen - 1, pos);
|
const size_t start_pos =
|
||||||
|
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
|
||||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue