From 130e1f678fbc5cfb6c17a7f15e7c882ff21a4d46 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Tue, 19 Mar 2024 22:00:52 +0800 Subject: [PATCH 1/6] Adjust vocab size to be the same as gemma_pytorch --- configs.h | 4 ++-- util/convert_weights.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/configs.h b/configs.h index 7b420b5..58c053f 100644 --- a/configs.h +++ b/configs.h @@ -37,7 +37,7 @@ static constexpr size_t kTopK = GEMMA_TOPK; struct ConfigGemma7B { static constexpr int kSeqLen = gcpp::kSeqLen; - static constexpr int kVocabSize = 256128; + static constexpr int kVocabSize = 256000; static constexpr int kLayers = 28; static constexpr int kModelDim = 3072; static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 @@ -49,7 +49,7 @@ struct ConfigGemma7B { struct ConfigGemma2B { static constexpr int kSeqLen = gcpp::kSeqLen; - static constexpr int kVocabSize = 256128; + static constexpr int kVocabSize = 256000; static constexpr int kLayers = 18; static constexpr int kModelDim = 2048; static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 diff --git a/util/convert_weights.py b/util/convert_weights.py index bd6750a..6552d89 100644 --- a/util/convert_weights.py +++ b/util/convert_weights.py @@ -90,7 +90,7 @@ TRANSFORMATIONS = { "2b":defaultdict( lambda: lambda x: x, { - "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0), + "embedder.weight": lambda x: x, "self_attn.qkv_proj.weight": expand_qkv, "self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]), "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], @@ -101,7 +101,7 @@ TRANSFORMATIONS = { "7b":defaultdict( lambda: lambda x: x, { - "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0), + "embedder.weight": lambda x: x, "self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]), "self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]), "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], @@ -113,7 +113,7 @@ TRANSFORMATIONS = { VALIDATIONS = { "2b": { - "embedder.weight": lambda x: x.shape == (256128, 2048), + "embedder.weight": lambda x: x.shape == (256000, 2048), "model.norm.weight": lambda x: x.shape == (2048,), "self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048), "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), @@ -124,7 +124,7 @@ VALIDATIONS = { "post_attention_layernorm.weight": lambda x: x.shape == (2048,), }, "7b": { - "embedder.weight": lambda x: x.shape == (256128, 3072), + "embedder.weight": lambda x: x.shape == (256000, 3072), "model.norm.weight": lambda x: x.shape == (3072,), "self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072), "self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256), From 6923aec853f2d8df5038855e1a52ab56ee99d06f Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Wed, 20 Mar 2024 18:14:09 +0800 Subject: [PATCH 2/6] Add MQA support --- configs.h | 2 +- gemma.cc | 56 ++++++++++++++++++++++++++++++----------- util/convert_weights.py | 18 ++----------- 3 files changed, 44 insertions(+), 32 deletions(-) diff --git a/configs.h b/configs.h index 58c053f..e704664 100644 --- a/configs.h +++ b/configs.h @@ -54,7 +54,7 @@ struct ConfigGemma2B { static constexpr int kModelDim = 2048; static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 static constexpr int kHeads = 8; - static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support + static constexpr int kKVHeads = 1; static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; }; diff --git a/gemma.cc b/gemma.cc index 7c9d187..1867fbf 100644 --- a/gemma.cc +++ b/gemma.cc @@ -70,12 +70,13 @@ template struct Layer { Layer() = default; static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kKVHeads = TConfig::kKVHeads; static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kQKVDim = TConfig::kQKVDim; static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim; - // 3x for (query, key, value) - static constexpr size_t kQKVEinsumWSize = 3 * kHeads * kQKVDim * kModelDim; + static constexpr size_t kQKVEinsumWSize = + (kHeads + 2 * kKVHeads) * kQKVDim * kModelDim; // 2x for (gelu gating vector, gated vector) static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; @@ -313,28 +314,46 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, static constexpr size_t kModelDim = gcpp::Activations::kModelDim; static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kKVHeads = TConfig::kKVHeads; static const float kQueryScale = static_cast(1.0 / sqrt(static_cast(kQKVDim))); + const size_t batch_offset = batch_idx * kModelDim; + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { // linear projections to QKV - const size_t head_offset = - 3 * kQKVDim * kModelDim; // 3x for QKV dimensions + constexpr const size_t head_offset = + kHeads == kKVHeads ? 3 * kQKVDim * kModelDim : kQKVDim * kModelDim; const size_t q_offset = head * head_offset + 0 * kQKVDim * kModelDim; - const size_t k_offset = head * head_offset + 1 * kQKVDim * kModelDim; - const size_t v_offset = head * head_offset + 2 * kQKVDim * kModelDim; float* HWY_RESTRICT q = activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; - const size_t batch_offset = batch_idx * kModelDim; - MatVecLoop( c_layer->c_qkv_einsum_w, q_offset, activations.pre_att_rms_out.data() + batch_offset, q); - const size_t kv_offset = - pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; + if constexpr (kHeads == kKVHeads) { + const size_t k_offset = head * head_offset + 1 * kQKVDim * kModelDim; + const size_t v_offset = head * head_offset + 2 * kQKVDim * kModelDim; + const size_t kv_offset = + pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; + + TwoOfsMatVecLoop( + c_layer->c_qkv_einsum_w, k_offset, v_offset, + activations.pre_att_rms_out.data() + batch_offset, + kv_cache.key_cache.get() + kv_offset, + kv_cache.value_cache.get() + kv_offset); + + Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); + } + }); + + if constexpr (kHeads != kKVHeads) { + constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim; + constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim; + constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; + const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize; TwoOfsMatVecLoop( c_layer->c_qkv_einsum_w, k_offset, v_offset, @@ -342,18 +361,24 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, kv_cache.key_cache.get() + kv_offset, kv_cache.value_cache.get() + kv_offset); + Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); + } + + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { // Calculate scores + float* HWY_RESTRICT q = + activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; float* HWY_RESTRICT head_att = activations.att.data() + head * TConfig::kSeqLen + batch_idx * kHeads * kQKVDim; Rope(q, kQKVDim, pos); - Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); MulByConst(kQueryScale, q, kQKVDim); // Compute Q dot K scores for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t cache_offset = - pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; + const size_t cache_offset = kHeads == kKVHeads + ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim + : pos2 * kCachePosSize + layer * kCacheLayerSize; const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset; const float score = Dot(q, k2, kQKVDim); head_att[pos2] = score; @@ -365,8 +390,9 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, batch_idx * kHeads * kQKVDim; hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t cache_offset = - pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; + const size_t cache_offset = kHeads == kKVHeads + ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim + : pos2 * kCachePosSize + layer * kCacheLayerSize; float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset; MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); } diff --git a/util/convert_weights.py b/util/convert_weights.py index 6552d89..0211c01 100644 --- a/util/convert_weights.py +++ b/util/convert_weights.py @@ -72,26 +72,12 @@ parser.add_argument( args = parser.parse_args() -def expand_qkv(qkv_proj: np.array) -> np.array: - """This won't be needed anymore when MQA is implemented""" - assert qkv_proj.shape == (2560, 2048) - qkv = qkv_proj.reshape((10, 256, 2048)) - - q_proj = qkv[:8].reshape((1,8,256,2048)) - kv_proj = qkv[8:] - kv_proj = kv_proj[:, np.newaxis, :, :] - kv_proj = np.repeat(kv_proj, 8, axis=1) - - qkv = np.concatenate([q_proj, kv_proj]) - qkv = np.transpose(qkv, axes=[1,0,2,3]) - return qkv - TRANSFORMATIONS = { "2b":defaultdict( lambda: lambda x: x, { "embedder.weight": lambda x: x, - "self_attn.qkv_proj.weight": expand_qkv, + "self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)), "self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]), "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], @@ -115,7 +101,7 @@ VALIDATIONS = { "2b": { "embedder.weight": lambda x: x.shape == (256000, 2048), "model.norm.weight": lambda x: x.shape == (2048,), - "self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048), + "self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048), "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), "mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048), "mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048), From ce32f4db81f9ac91ac18fff42516e3c1a3f12b24 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Wed, 20 Mar 2024 22:39:31 +0800 Subject: [PATCH 3/6] Streamline the implementation --- gemma.cc | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/gemma.cc b/gemma.cc index 1867fbf..76086d9 100644 --- a/gemma.cc +++ b/gemma.cc @@ -320,6 +320,16 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, const size_t batch_offset = batch_idx * kModelDim; + auto ProjKV = [&](size_t k_offset, size_t v_offset, size_t kv_offset) { + TwoOfsMatVecLoop( + c_layer->c_qkv_einsum_w, k_offset, v_offset, + activations.pre_att_rms_out.data() + batch_offset, + kv_cache.key_cache.get() + kv_offset, + kv_cache.value_cache.get() + kv_offset); + + Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); + }; + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { // linear projections to QKV constexpr const size_t head_offset = @@ -339,13 +349,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; - TwoOfsMatVecLoop( - c_layer->c_qkv_einsum_w, k_offset, v_offset, - activations.pre_att_rms_out.data() + batch_offset, - kv_cache.key_cache.get() + kv_offset, - kv_cache.value_cache.get() + kv_offset); - - Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); + ProjKV(k_offset, v_offset, kv_offset); } }); @@ -355,13 +359,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize; - TwoOfsMatVecLoop( - c_layer->c_qkv_einsum_w, k_offset, v_offset, - activations.pre_att_rms_out.data() + batch_offset, - kv_cache.key_cache.get() + kv_offset, - kv_cache.value_cache.get() + kv_offset); - - Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); + ProjKV(k_offset, v_offset, kv_offset); } pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { @@ -376,9 +374,10 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, MulByConst(kQueryScale, q, kQKVDim); // Compute Q dot K scores for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t cache_offset = kHeads == kKVHeads - ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim - : pos2 * kCachePosSize + layer * kCacheLayerSize; + const size_t cache_offset = + kHeads == kKVHeads + ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim + : pos2 * kCachePosSize + layer * kCacheLayerSize; const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset; const float score = Dot(q, k2, kQKVDim); head_att[pos2] = score; @@ -390,9 +389,10 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, batch_idx * kHeads * kQKVDim; hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t cache_offset = kHeads == kKVHeads - ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim - : pos2 * kCachePosSize + layer * kCacheLayerSize; + const size_t cache_offset = + kHeads == kKVHeads + ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim + : pos2 * kCachePosSize + layer * kCacheLayerSize; float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset; MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); } From c75d2eb63549fe844c61ee6a80f968f3af34f995 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Wed, 20 Mar 2024 23:21:43 +0800 Subject: [PATCH 4/6] Add the missing `HWY_ATTR` of `ProjKV` --- gemma.cc | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/gemma.cc b/gemma.cc index 76086d9..877a3dc 100644 --- a/gemma.cc +++ b/gemma.cc @@ -320,15 +320,16 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, const size_t batch_offset = batch_idx * kModelDim; - auto ProjKV = [&](size_t k_offset, size_t v_offset, size_t kv_offset) { - TwoOfsMatVecLoop( - c_layer->c_qkv_einsum_w, k_offset, v_offset, - activations.pre_att_rms_out.data() + batch_offset, - kv_cache.key_cache.get() + kv_offset, - kv_cache.value_cache.get() + kv_offset); + auto ProjKV = + [&](size_t k_offset, size_t v_offset, size_t kv_offset) HWY_ATTR { + TwoOfsMatVecLoop( + c_layer->c_qkv_einsum_w, k_offset, v_offset, + activations.pre_att_rms_out.data() + batch_offset, + kv_cache.key_cache.get() + kv_offset, + kv_cache.value_cache.get() + kv_offset); - Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); - }; + Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); + }; pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { // linear projections to QKV From 8fc6959950df9a2e9b5fa0d95096d1b6d511e7b5 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Wed, 20 Mar 2024 23:50:14 +0800 Subject: [PATCH 5/6] Move conditional branch out of `pos2` loop --- gemma.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/gemma.cc b/gemma.cc index 877a3dc..9baccd7 100644 --- a/gemma.cc +++ b/gemma.cc @@ -373,12 +373,13 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, Rope(q, kQKVDim, pos); MulByConst(kQueryScale, q, kQKVDim); + + const size_t head_offset = kHeads == kKVHeads ? head * kQKVDim : 0; + // Compute Q dot K scores for (size_t pos2 = 0; pos2 <= pos; ++pos2) { const size_t cache_offset = - kHeads == kKVHeads - ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim - : pos2 * kCachePosSize + layer * kCacheLayerSize; + pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset; const float score = Dot(q, k2, kQKVDim); head_att[pos2] = score; @@ -391,9 +392,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); for (size_t pos2 = 0; pos2 <= pos; ++pos2) { const size_t cache_offset = - kHeads == kKVHeads - ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim - : pos2 * kCachePosSize + layer * kCacheLayerSize; + pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset; MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); } From 90b0e9fd7ac3dbe73d63abd3ed7eeb14ad6d013b Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Thu, 21 Mar 2024 14:40:56 +0800 Subject: [PATCH 6/6] Refactor the implementation of `Attention` --- gemma.cc | 81 +++++++++++++++++++++++++++++++------------------------- 1 file changed, 45 insertions(+), 36 deletions(-) diff --git a/gemma.cc b/gemma.cc index 9baccd7..533854c 100644 --- a/gemma.cc +++ b/gemma.cc @@ -320,6 +320,15 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, const size_t batch_offset = batch_idx * kModelDim; + auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR { + float* HWY_RESTRICT q = + activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; + + MatVecLoop( + c_layer->c_qkv_einsum_w, head_offset + 0 * kQKVDim * kModelDim, + activations.pre_att_rms_out.data() + batch_offset, q); + }; + auto ProjKV = [&](size_t k_offset, size_t v_offset, size_t kv_offset) HWY_ATTR { TwoOfsMatVecLoop( @@ -331,39 +340,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); }; - pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { - // linear projections to QKV - constexpr const size_t head_offset = - kHeads == kKVHeads ? 3 * kQKVDim * kModelDim : kQKVDim * kModelDim; - const size_t q_offset = head * head_offset + 0 * kQKVDim * kModelDim; - - float* HWY_RESTRICT q = - activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; - - MatVecLoop( - c_layer->c_qkv_einsum_w, q_offset, - activations.pre_att_rms_out.data() + batch_offset, q); - - if constexpr (kHeads == kKVHeads) { - const size_t k_offset = head * head_offset + 1 * kQKVDim * kModelDim; - const size_t v_offset = head * head_offset + 2 * kQKVDim * kModelDim; - const size_t kv_offset = - pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; - - ProjKV(k_offset, v_offset, kv_offset); - } - }); - - if constexpr (kHeads != kKVHeads) { - constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim; - constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim; - constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; - const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize; - - ProjKV(k_offset, v_offset, kv_offset); - } - - pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR { // Calculate scores float* HWY_RESTRICT q = activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; @@ -374,8 +351,6 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, Rope(q, kQKVDim, pos); MulByConst(kQueryScale, q, kQKVDim); - const size_t head_offset = kHeads == kKVHeads ? head * kQKVDim : 0; - // Compute Q dot K scores for (size_t pos2 = 0; pos2 <= pos; ++pos2) { const size_t cache_offset = @@ -405,7 +380,41 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, MatVecLoop(c_layer->c_attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out, head_out); - }); + }; + + if constexpr (kHeads == kKVHeads) { + // Multi-Head Attention + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + const size_t head_offset = head * 3 * kQKVDim * kModelDim; + + ProjQ(head, head_offset); + + const size_t k_offset = head_offset + 1 * kQKVDim * kModelDim; + const size_t v_offset = head_offset + 2 * kQKVDim * kModelDim; + const size_t kv_offset = + pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; + + ProjKV(k_offset, v_offset, kv_offset); + + Attn(head, head * kQKVDim); + }); + } else { + // Multi-Query Attention + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + ProjQ(head, head * kQKVDim * kModelDim); + }); + + constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim; + constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim; + constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; + const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize; + + ProjKV(k_offset, v_offset, kv_offset); + + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + Attn(head, 0); + }); + } // accumulate output across all heads into att_post2. head 0 already wrote // directly to att_post2.