mirror of https://github.com/google/gemma.cpp.git
Fix kv offset computation for MHA config.
This commit is contained in:
parent
374fd7478a
commit
f8ccb8e37c
|
|
@ -787,11 +787,12 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
ProjQ(head, q_offset);
|
ProjQ(head, q_offset);
|
||||||
|
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
cache_pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
|
cache_pos * kCachePosSize + layer * kCacheLayerSize +
|
||||||
|
head * kQKVDim * 2;
|
||||||
|
|
||||||
ProjKV(k_offset, v_offset, kv_offset);
|
ProjKV(k_offset, v_offset, kv_offset);
|
||||||
|
|
||||||
Attn(head, head * kQKVDim);
|
Attn(head, head * kQKVDim * 2);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
// Multi-Query Attention
|
// Multi-Query Attention
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue