mirror of https://github.com/google/gemma.cpp.git
1.07x speedup: merge MQA parallel sections as suggested by @veluca93
PiperOrigin-RevId: 621772392
This commit is contained in:
parent
ede337f876
commit
44e6274e99
6
gemma.cc
6
gemma.cc
|
|
@ -405,18 +405,14 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
// Multi-Query Attention
|
// Multi-Query Attention
|
||||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
|
||||||
ProjQ(head, head * kQKVDim * kModelDim);
|
|
||||||
});
|
|
||||||
|
|
||||||
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
|
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
|
||||||
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
|
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
|
||||||
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
|
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
|
||||||
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;
|
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;
|
||||||
|
|
||||||
ProjKV(k_offset, v_offset, kv_offset);
|
ProjKV(k_offset, v_offset, kv_offset);
|
||||||
|
|
||||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||||
|
ProjQ(head, head * kQKVDim * kModelDim);
|
||||||
Attn(head, 0);
|
Attn(head, 0);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue