1.07x speedup: merge MQA parallel sections as suggested by @veluca93

PiperOrigin-RevId: 621772392
This commit is contained in:
Jan Wassenberg 2024-04-04 01:12:25 -07:00 committed by Copybara-Service
parent ede337f876
commit 44e6274e99
1 changed files with 1 additions and 5 deletions

View File

@ -405,18 +405,14 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
});
} else {
// Multi-Query Attention
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
ProjQ(head, head * kQKVDim * kModelDim);
});
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;
ProjKV(k_offset, v_offset, kv_offset);
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
ProjQ(head, head * kQKVDim * kModelDim);
Attn(head, 0);
});
}