Fix kv offset computation for MHA config.

This commit is contained in:
Zoltan Szabadka 2024-04-30 16:19:14 +00:00
parent 374fd7478a
commit f8ccb8e37c
1 changed files with 3 additions and 2 deletions

View File

@ -787,11 +787,12 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
ProjQ(head, q_offset);
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
cache_pos * kCachePosSize + layer * kCacheLayerSize +
head * kQKVDim * 2;
ProjKV(k_offset, v_offset, kv_offset);
Attn(head, head * kQKVDim);
Attn(head, head * kQKVDim * 2);
});
} else {
// Multi-Query Attention