diff --git a/gemma.cc b/gemma.cc index 9baccd7..533854c 100644 --- a/gemma.cc +++ b/gemma.cc @@ -320,6 +320,15 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, const size_t batch_offset = batch_idx * kModelDim; + auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR { + float* HWY_RESTRICT q = + activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; + + MatVecLoop( + c_layer->c_qkv_einsum_w, head_offset + 0 * kQKVDim * kModelDim, + activations.pre_att_rms_out.data() + batch_offset, q); + }; + auto ProjKV = [&](size_t k_offset, size_t v_offset, size_t kv_offset) HWY_ATTR { TwoOfsMatVecLoop( @@ -331,39 +340,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); }; - pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { - // linear projections to QKV - constexpr const size_t head_offset = - kHeads == kKVHeads ? 3 * kQKVDim * kModelDim : kQKVDim * kModelDim; - const size_t q_offset = head * head_offset + 0 * kQKVDim * kModelDim; - - float* HWY_RESTRICT q = - activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; - - MatVecLoop( - c_layer->c_qkv_einsum_w, q_offset, - activations.pre_att_rms_out.data() + batch_offset, q); - - if constexpr (kHeads == kKVHeads) { - const size_t k_offset = head * head_offset + 1 * kQKVDim * kModelDim; - const size_t v_offset = head * head_offset + 2 * kQKVDim * kModelDim; - const size_t kv_offset = - pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; - - ProjKV(k_offset, v_offset, kv_offset); - } - }); - - if constexpr (kHeads != kKVHeads) { - constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim; - constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim; - constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; - const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize; - - ProjKV(k_offset, v_offset, kv_offset); - } - - pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR { // Calculate scores float* HWY_RESTRICT q = activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; @@ -374,8 +351,6 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, Rope(q, kQKVDim, pos); MulByConst(kQueryScale, q, kQKVDim); - const size_t head_offset = kHeads == kKVHeads ? head * kQKVDim : 0; - // Compute Q dot K scores for (size_t pos2 = 0; pos2 <= pos; ++pos2) { const size_t cache_offset = @@ -405,7 +380,41 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, MatVecLoop(c_layer->c_attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out, head_out); - }); + }; + + if constexpr (kHeads == kKVHeads) { + // Multi-Head Attention + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + const size_t head_offset = head * 3 * kQKVDim * kModelDim; + + ProjQ(head, head_offset); + + const size_t k_offset = head_offset + 1 * kQKVDim * kModelDim; + const size_t v_offset = head_offset + 2 * kQKVDim * kModelDim; + const size_t kv_offset = + pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; + + ProjKV(k_offset, v_offset, kv_offset); + + Attn(head, head * kQKVDim); + }); + } else { + // Multi-Query Attention + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + ProjQ(head, head * kQKVDim * kModelDim); + }); + + constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim; + constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim; + constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; + const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize; + + ProjKV(k_offset, v_offset, kv_offset); + + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + Attn(head, 0); + }); + } // accumulate output across all heads into att_post2. head 0 already wrote // directly to att_post2.