From 44e6274e996d5ee365aa34d11479b6cbb8fa9f93 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 4 Apr 2024 01:12:25 -0700 Subject: [PATCH] 1.07x speedup: merge MQA parallel sections as suggested by @veluca93 PiperOrigin-RevId: 621772392 --- gemma.cc | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/gemma.cc b/gemma.cc index edc5dfd..ae92713 100644 --- a/gemma.cc +++ b/gemma.cc @@ -405,18 +405,14 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, }); } else { // Multi-Query Attention - pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { - ProjQ(head, head * kQKVDim * kModelDim); - }); - constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim; constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim; constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize; - ProjKV(k_offset, v_offset, kv_offset); pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + ProjQ(head, head * kQKVDim * kModelDim); Attn(head, 0); }); }