diff --git a/gemma/gemma.cc b/gemma/gemma.cc index a494b3b..05d1aaf 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -787,11 +787,12 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, ProjQ(head, q_offset); const size_t kv_offset = - cache_pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; + cache_pos * kCachePosSize + layer * kCacheLayerSize + + head * kQKVDim * 2; ProjKV(k_offset, v_offset, kv_offset); - Attn(head, head * kQKVDim); + Attn(head, head * kQKVDim * 2); }); } else { // Multi-Query Attention