From 15135f5b3d581cf92788eaf58d7508d3e01c59b4 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 18 Jun 2024 05:07:37 -0700 Subject: [PATCH] Simplify Attention. Shared kMHA, reuse from Activations, inline Attn lambda, use QDim as the stride between successive Q. PiperOrigin-RevId: 644343854 --- gemma/gemma.cc | 173 +++++++++++++++++++++++-------------------------- 1 file changed, 81 insertions(+), 92 deletions(-) diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 8b470e6..db601c9 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -30,7 +30,6 @@ #ifndef GEMMA_ONCE #define GEMMA_ONCE -#include // sqrtf #include #include #include @@ -38,7 +37,6 @@ #include #include -#include #include #include #include @@ -73,11 +71,14 @@ struct Activations { static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2; static constexpr size_t kCachePosSize = TConfig::kGemmaLayers * kCacheLayerSize; - static constexpr size_t kQDim = kHeads == kKVHeads ? kQKVDim * 3 : kQKVDim; + static constexpr bool kIsMHA = kHeads == kKVHeads; // Multi-Head Attention + // Stride between subsequent queries. Each of Q, K, V are of length kQKVDim, + // but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V. + static constexpr size_t kQStride = kQKVDim * (kIsMHA ? 3 : 1); std::array x; // input std::array pre_att_rms_out; - std::array q; // query vector + std::array q; // query vector std::array att; // attention vector std::array att_out; // attention output @@ -242,7 +243,7 @@ HWY_NOINLINE void GriffinRecurrent( using D = hn::ScalableTag; HWY_DASSERT(num_tokens <= kBatchSize); static constexpr size_t kModelDim = - gcpp::Activations::kModelDim; + Activations::kModelDim; static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; static constexpr size_t kHeads = TConfig::kHeads; @@ -370,71 +371,29 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer, hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.Attention"); HWY_DASSERT(num_tokens <= kBatchSize); - static constexpr size_t kQKVDim = gcpp::Activations::kQKVDim; - static constexpr size_t kCachePosSize = - gcpp::Activations::kCachePosSize; - static constexpr size_t kCacheLayerSize = - gcpp::Activations::kCacheLayerSize; - static constexpr size_t kModelDim = - gcpp::Activations::kModelDim; - static constexpr size_t kHeads = TConfig::kHeads; - static constexpr size_t kKVHeads = TConfig::kKVHeads; - static constexpr size_t kSeqLen = TConfig::kSeqLen; - static const float kQueryScale = - static_cast(1.0 / sqrt(static_cast(kQKVDim))); + using TActivations = Activations; + constexpr size_t kQKVDim = TActivations::kQKVDim; + constexpr size_t kQStride = TActivations::kQStride; + constexpr size_t kCachePosSize = TActivations::kCachePosSize; + constexpr size_t kCacheLayerSize = TActivations::kCacheLayerSize; + constexpr size_t kModelDim = TActivations::kModelDim; + constexpr size_t kHeads = TConfig::kHeads; + constexpr size_t kKVHeads = TConfig::kKVHeads; + constexpr size_t kSeqLen = TConfig::kSeqLen; + GEMMA_CONSTEXPR_SQRT const float kQueryScale = + 1.0f / Sqrt(static_cast(kQKVDim)); + constexpr bool kIsMHA = TActivations::kIsMHA; // Multi-Head Attention - auto Attn = [&](float* q, uint64_t head, size_t head_offset, size_t batch_idx, - size_t thread) HWY_ATTR { - const size_t pos = batch_start + batch_idx; - // Calculate scores - float* HWY_RESTRICT head_att = activations.att.data() + - head * kSeqLen + - batch_idx * kHeads * kSeqLen; - - Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); - MulByConst(kQueryScale, q, kQKVDim); - - // Compute Q dot K scores - const size_t start_pos = pos - std::min(kSeqLen - 1, pos); - for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { - const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); - const size_t kv_offset = cache_pos * kCachePosSize + - layer * kCacheLayerSize + head_offset; - const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset; - const float score = Dot(q, k2, kQKVDim); - head_att[pos2 % kSeqLen] = score; - } - Softmax(head_att, std::min(pos + 1, kSeqLen)); - - // Weighted summation - float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim + - batch_idx * kHeads * kQKVDim; - hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); - for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { - const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); - const size_t kv_offset = cache_pos * kCachePosSize + - layer * kCacheLayerSize + head_offset; - float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + kv_offset + kQKVDim; - MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim); - } - }; - - if constexpr (kHeads == kKVHeads) { - // Multi-Head Attention calculates qkv using q as scratch space. - static_assert(TConfig::kInterleaveQKV); - MatMul_4x4_Batch( - num_tokens, activations.pre_att_rms_out.data(), - layer_weights->qkv_einsum_w.data(), activations.q.data(), pool); - } else { - MatMul_4x4_Batch( - num_tokens, activations.pre_att_rms_out.data(), - layer_weights->qkv_einsum_w.data(), activations.q.data(), pool); - } + // If MHA, this also computes KV, which we copy to the KV cache below. + static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved + MatMul_4x4_Batch( + num_tokens, activations.pre_att_rms_out.data(), + layer_weights->qkv_einsum_w.data(), activations.q.data(), pool); for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { const float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim; // QKV projections: - if constexpr (kHeads != kKVHeads) { + if constexpr (!kIsMHA) { const size_t pos = batch_start + batch_idx; const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); const size_t kv_offset = @@ -447,37 +406,67 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer, } } - // Positional encodings for k: - const size_t num_kv_tasks = kKVHeads * num_tokens; - pool.Run(0, num_kv_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { - const size_t head = task % kKVHeads; - const size_t batch_idx = task / kKVHeads; - const size_t pos = batch_start + batch_idx; - const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); - const size_t kv_offset = cache_pos * kCachePosSize + - layer * kCacheLayerSize + head * kQKVDim * 2; - float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - if constexpr (kHeads == kKVHeads) { - // For MHA, copy kv into the KV cache from scratch space (see above). - const float* HWY_RESTRICT q = - activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3; - memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); - } - Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); - }); + // Positional encodings for kv: + pool.Run( + 0, kKVHeads * num_tokens, [&](uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kKVHeads; + const size_t batch_idx = task / kKVHeads; + const size_t pos = batch_start + batch_idx; + const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); + const size_t kv_offset = cache_pos * kCachePosSize + + layer * kCacheLayerSize + head * kQKVDim * 2; + float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; + if constexpr (kIsMHA) { + // For MHA, copy kv into the KV cache from scratch space (see above). + const float* HWY_RESTRICT q = + activations.q.data() + (batch_idx * kHeads + head) * kQStride; + // Skip past the Q part of `q`, and copy KV to `kv`. + memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); + } + Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + }); - static_assert((TConfig::kHeads % TConfig::kKVHeads) == 0, + static_assert((kHeads % kKVHeads) == 0, "query heads must be a multiple of key-value heads"); - static constexpr size_t kGroupHeads = TConfig::kHeads / TConfig::kKVHeads; - static constexpr size_t kQOffsetScale = (kHeads == kKVHeads) ? 3 : 1; - const size_t num_q_tasks = kHeads * num_tokens; - pool.Run(0, num_q_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { + static constexpr size_t kGroupHeads = kHeads / kKVHeads; + pool.Run(0, kHeads * num_tokens, [&](uint64_t task, size_t thread) HWY_ATTR { const size_t head = task % kHeads; const size_t batch_idx = task / kHeads; const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2; - float* HWY_RESTRICT q = activations.q.data() + (batch_idx * kHeads + head) * - kQKVDim * kQOffsetScale; - Attn(q, head, head_offset, batch_idx, thread); + float* HWY_RESTRICT q = + activations.q.data() + (batch_idx * kHeads + head) * kQStride; + + const size_t pos = batch_start + batch_idx; + // Calculate scores + float* HWY_RESTRICT head_att = + activations.att.data() + head * kSeqLen + batch_idx * kHeads * kSeqLen; + + Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + MulByConst(kQueryScale, q, kQKVDim); + + // Compute Q dot K scores + const size_t start_pos = pos - std::min(kSeqLen - 1, pos); + for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { + const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); + const size_t kv_offset = + cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; + const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset; + const float score = Dot(q, k2, kQKVDim); + head_att[pos2 % kSeqLen] = score; + } + Softmax(head_att, std::min(pos + 1, kSeqLen)); + + // Weighted summation + float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim + + batch_idx * kHeads * kQKVDim; + hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); + for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { + const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); + const size_t kv_offset = + cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; + float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + kv_offset + kQKVDim; + MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim); + } }); for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { @@ -1012,7 +1001,7 @@ std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, HWY_ASSERT(tokenizer.Encode(prompt, &tokens)); // Both pre-trained and instruction-tuned require BOS as first token. if (pos == 0) { - tokens.insert(tokens.begin(), gcpp::BOS_ID); + tokens.insert(tokens.begin(), BOS_ID); } return tokens; }