diff --git a/gemma.cc b/gemma.cc index edc5dfd..ae92713 100644 --- a/gemma.cc +++ b/gemma.cc @@ -405,18 +405,14 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, }); } else { // Multi-Query Attention - pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { - ProjQ(head, head * kQKVDim * kModelDim); - }); - constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim; constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim; constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize; - ProjKV(k_offset, v_offset, kv_offset); pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + ProjQ(head, head * kQKVDim * kModelDim); Attn(head, 0); }); }