From afaca4efa81c87e8b0b19c0f19fe1c3862bc1634 Mon Sep 17 00:00:00 2001 From: Zoltan Szabadka Date: Tue, 30 Apr 2024 13:10:14 +0000 Subject: [PATCH] Use more parallelism in the QKV projections in MQA mode. Instead of MatVecLoop, we use MatVec and we combine k and v into one 2 * kQKVDim long vector so that K and V projections can be combined into one MatVec operation. Benchmark results (summarization with 1600 tokens for prefill and essay writing with 500 tokens for generation): ``` Prefill speed Generation speed Num threads BEFORE AFTER BEFORE AFTER 4 9.81 t/s 9.96 t/s 8.39 t/s 8.46 t/s 18 31.50 t/s 36.67 t/s 23.10 t/s 25.83 t/s 32 45.36 t/s 58.91 t/s 27.60 t/s 31.25 t/s 64 57.72 t/s 80.64 t/s 35.40 t/s 39.76 t/s ``` --- gemma/gemma.cc | 34 ++++++++++++++++++---------------- gemma/gemma.h | 4 +--- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 434576d..a494b3b 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -399,9 +399,9 @@ struct Activations { static constexpr size_t kQKVDim = TConfig::kQKVDim; static constexpr size_t kHeads = TConfig::kHeads; static constexpr size_t kKVHeads = TConfig::kKVHeads; + static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim * 2; static constexpr size_t kCachePosSize = - TConfig::kGemmaLayers * kKVHeads * kQKVDim; - static constexpr size_t kCacheLayerSize = kKVHeads * kQKVDim; + TConfig::kGemmaLayers * kCacheLayerSize; std::array x; // input std::array pre_att_rms_out; @@ -714,8 +714,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, auto ProjKV = [&](size_t k_offset, size_t v_offset, size_t kv_offset) HWY_ATTR { - float* HWY_RESTRICT k = kv_cache.key_cache.get() + kv_offset; - float* HWY_RESTRICT v = kv_cache.value_cache.get() + kv_offset; + float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset; + float* HWY_RESTRICT v = k + kQKVDim; TwoOfsMatVecLoop(layer_weights->qkv_einsum_w, k_offset, v_offset, x, k, v); @@ -738,7 +738,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, for (size_t pos2 = 0; pos2 < cache_num; ++pos2) { const size_t cache_offset = pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; - const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset; + const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + cache_offset; const float score = Dot(q, k2, kQKVDim); head_att[pos2] = score; } @@ -751,7 +751,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, for (size_t pos2 = 0; pos2 < cache_num; ++pos2) { const size_t cache_offset = pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; - float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset; + float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + cache_offset + kQKVDim; MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); } // linear projection from kQKVDim back to kModelDim, sum projections @@ -795,16 +795,19 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, }); } else { // Multi-Query Attention - constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim; - constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim; - constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; - const size_t kv_offset = - cache_pos * kCachePosSize + layer * kCacheLayerSize; + float* HWY_RESTRICT q = activations.q.data() + batch_idx * kHeads * kQKVDim; + MatVec(layer_weights->qkv_einsum_w, 0, x, q, + pool); - ProjKV(k_offset, v_offset, kv_offset); + float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + + cache_pos * kCachePosSize + + layer * kCacheLayerSize; + MatVec(layer_weights->qkv_einsum_w, + kHeads * kQKVDim * kModelDim, x, kv, pool); + + Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { - ProjQ(head, head * kQKVDim * kModelDim); Attn(head, 0); }); } @@ -1465,9 +1468,8 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len, size_t conv1d_cache_size, size_t rglru_cache_size) { KVCache kv_cache = {}; if (size_cache_pos != 0) { - kv_cache.key_cache = hwy::AllocateAligned(seq_len * size_cache_pos); - kv_cache.value_cache = - hwy::AllocateAligned(seq_len * size_cache_pos); + kv_cache.kv_cache = + hwy::AllocateAligned(seq_len * size_cache_pos * 2); } if (conv1d_cache_size != 0) { kv_cache.conv1d_cache = hwy::AllocateAligned(conv1d_cache_size); diff --git a/gemma/gemma.h b/gemma/gemma.h index e1689d9..822f75b 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -44,9 +44,7 @@ constexpr bool kSystemPrompt = false; struct KVCache { hwy::AlignedFreeUniquePtr - key_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim - hwy::AlignedFreeUniquePtr - value_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim + kv_cache; // kSeqLen * kGemmaLayers * kKVHeads * kQKVDim * 2 hwy::AlignedFreeUniquePtr conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers hwy::AlignedFreeUniquePtr