diff --git a/gemma/attention.cc b/gemma/attention.cc index 61d76ef..e894981 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -48,9 +48,6 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -constexpr int kFlagReserved = 1; // LINTER: unused, reserved for future use. -constexpr int kUseOldAttention = 2; - // Computes Q.K scores, which are "logits" (or scores) stored to att. // `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim]. static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos, @@ -357,7 +354,7 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx, (void)layer_config; // only used in HWY_DASSERT ComputeQKV(num_tokens, layer_idx, layer, activations, qbatch, flags, env); - if (flags & kUseOldAttention) { + if (flags & kAttentionUseOld) { DotSoftmaxWeightedSum(num_tokens, layer_idx, layer, activations, qbatch, env.ctx); } else { diff --git a/gemma/configs.h b/gemma/configs.h index e4a26b8..e02645b 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -32,8 +32,9 @@ namespace gcpp { -static constexpr size_t kMaxConv1DWidth = 4; -static constexpr size_t kMaxQKVDim = 1024; +HWY_INLINE_VAR constexpr int kAttentionUseOld = 2; + +HWY_INLINE_VAR constexpr size_t kMaxQKVDim = 1024; // Instruction-tuned models require extra 'turn structure' tokens in prompts. enum class PromptWrapping { diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 785bd87..05583b3 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -73,9 +73,9 @@ void Attention(LayerAttentionType type, const size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, Activations& activations, QBatch& qbatch, MatMulEnv& env) { if (type == LayerAttentionType::kGemma) { + // TODO: remove flag to enable FlashAttention. GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch, - env, - /*flags=*/0); + env, kAttentionUseOld); } }