Refactor the implementation of `Attention`

This commit is contained in:
RangerUFO 2024-03-21 14:40:56 +08:00
parent 8fc6959950
commit 90b0e9fd7a
1 changed files with 45 additions and 36 deletions

View File

@ -320,6 +320,15 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
const size_t batch_offset = batch_idx * kModelDim;
auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR {
float* HWY_RESTRICT q =
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
MatVecLoop<kQKVDim, kModelDim>(
c_layer->c_qkv_einsum_w, head_offset + 0 * kQKVDim * kModelDim,
activations.pre_att_rms_out.data() + batch_offset, q);
};
auto ProjKV =
[&](size_t k_offset, size_t v_offset, size_t kv_offset) HWY_ATTR {
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
@ -331,39 +340,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
};
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
// linear projections to QKV
constexpr const size_t head_offset =
kHeads == kKVHeads ? 3 * kQKVDim * kModelDim : kQKVDim * kModelDim;
const size_t q_offset = head * head_offset + 0 * kQKVDim * kModelDim;
float* HWY_RESTRICT q =
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
MatVecLoop<kQKVDim, kModelDim>(
c_layer->c_qkv_einsum_w, q_offset,
activations.pre_att_rms_out.data() + batch_offset, q);
if constexpr (kHeads == kKVHeads) {
const size_t k_offset = head * head_offset + 1 * kQKVDim * kModelDim;
const size_t v_offset = head * head_offset + 2 * kQKVDim * kModelDim;
const size_t kv_offset =
pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
ProjKV(k_offset, v_offset, kv_offset);
}
});
if constexpr (kHeads != kKVHeads) {
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;
ProjKV(k_offset, v_offset, kv_offset);
}
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR {
// Calculate scores
float* HWY_RESTRICT q =
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
@ -374,8 +351,6 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
Rope(q, kQKVDim, pos);
MulByConst(kQueryScale, q, kQKVDim);
const size_t head_offset = kHeads == kKVHeads ? head * kQKVDim : 0;
// Compute Q dot K scores
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t cache_offset =
@ -405,7 +380,41 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
MatVecLoop<kModelDim, kQKVDim>(c_layer->c_attn_vec_einsum_w,
head * kModelDim * kQKVDim, att_out,
head_out);
});
};
if constexpr (kHeads == kKVHeads) {
// Multi-Head Attention
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
const size_t head_offset = head * 3 * kQKVDim * kModelDim;
ProjQ(head, head_offset);
const size_t k_offset = head_offset + 1 * kQKVDim * kModelDim;
const size_t v_offset = head_offset + 2 * kQKVDim * kModelDim;
const size_t kv_offset =
pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
ProjKV(k_offset, v_offset, kv_offset);
Attn(head, head * kQKVDim);
});
} else {
// Multi-Query Attention
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
ProjQ(head, head * kQKVDim * kModelDim);
});
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;
ProjKV(k_offset, v_offset, kv_offset);
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
Attn(head, 0);
});
}
// accumulate output across all heads into att_post2. head 0 already wrote
// directly to att_post2.