mirror of https://github.com/google/gemma.cpp.git
Refactor the implementation of `Attention`
This commit is contained in:
parent
8fc6959950
commit
90b0e9fd7a
81
gemma.cc
81
gemma.cc
|
|
@ -320,6 +320,15 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
|
|
||||||
const size_t batch_offset = batch_idx * kModelDim;
|
const size_t batch_offset = batch_idx * kModelDim;
|
||||||
|
|
||||||
|
auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR {
|
||||||
|
float* HWY_RESTRICT q =
|
||||||
|
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
||||||
|
|
||||||
|
MatVecLoop<kQKVDim, kModelDim>(
|
||||||
|
c_layer->c_qkv_einsum_w, head_offset + 0 * kQKVDim * kModelDim,
|
||||||
|
activations.pre_att_rms_out.data() + batch_offset, q);
|
||||||
|
};
|
||||||
|
|
||||||
auto ProjKV =
|
auto ProjKV =
|
||||||
[&](size_t k_offset, size_t v_offset, size_t kv_offset) HWY_ATTR {
|
[&](size_t k_offset, size_t v_offset, size_t kv_offset) HWY_ATTR {
|
||||||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
|
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
|
||||||
|
|
@ -331,39 +340,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
|
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
|
||||||
};
|
};
|
||||||
|
|
||||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR {
|
||||||
// linear projections to QKV
|
|
||||||
constexpr const size_t head_offset =
|
|
||||||
kHeads == kKVHeads ? 3 * kQKVDim * kModelDim : kQKVDim * kModelDim;
|
|
||||||
const size_t q_offset = head * head_offset + 0 * kQKVDim * kModelDim;
|
|
||||||
|
|
||||||
float* HWY_RESTRICT q =
|
|
||||||
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
|
||||||
|
|
||||||
MatVecLoop<kQKVDim, kModelDim>(
|
|
||||||
c_layer->c_qkv_einsum_w, q_offset,
|
|
||||||
activations.pre_att_rms_out.data() + batch_offset, q);
|
|
||||||
|
|
||||||
if constexpr (kHeads == kKVHeads) {
|
|
||||||
const size_t k_offset = head * head_offset + 1 * kQKVDim * kModelDim;
|
|
||||||
const size_t v_offset = head * head_offset + 2 * kQKVDim * kModelDim;
|
|
||||||
const size_t kv_offset =
|
|
||||||
pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
|
|
||||||
|
|
||||||
ProjKV(k_offset, v_offset, kv_offset);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
if constexpr (kHeads != kKVHeads) {
|
|
||||||
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
|
|
||||||
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
|
|
||||||
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
|
|
||||||
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;
|
|
||||||
|
|
||||||
ProjKV(k_offset, v_offset, kv_offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
|
||||||
// Calculate scores
|
// Calculate scores
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim;
|
||||||
|
|
@ -374,8 +351,6 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
Rope(q, kQKVDim, pos);
|
Rope(q, kQKVDim, pos);
|
||||||
MulByConst(kQueryScale, q, kQKVDim);
|
MulByConst(kQueryScale, q, kQKVDim);
|
||||||
|
|
||||||
const size_t head_offset = kHeads == kKVHeads ? head * kQKVDim : 0;
|
|
||||||
|
|
||||||
// Compute Q dot K scores
|
// Compute Q dot K scores
|
||||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||||
const size_t cache_offset =
|
const size_t cache_offset =
|
||||||
|
|
@ -405,7 +380,41 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
MatVecLoop<kModelDim, kQKVDim>(c_layer->c_attn_vec_einsum_w,
|
MatVecLoop<kModelDim, kQKVDim>(c_layer->c_attn_vec_einsum_w,
|
||||||
head * kModelDim * kQKVDim, att_out,
|
head * kModelDim * kQKVDim, att_out,
|
||||||
head_out);
|
head_out);
|
||||||
});
|
};
|
||||||
|
|
||||||
|
if constexpr (kHeads == kKVHeads) {
|
||||||
|
// Multi-Head Attention
|
||||||
|
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||||
|
const size_t head_offset = head * 3 * kQKVDim * kModelDim;
|
||||||
|
|
||||||
|
ProjQ(head, head_offset);
|
||||||
|
|
||||||
|
const size_t k_offset = head_offset + 1 * kQKVDim * kModelDim;
|
||||||
|
const size_t v_offset = head_offset + 2 * kQKVDim * kModelDim;
|
||||||
|
const size_t kv_offset =
|
||||||
|
pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
|
||||||
|
|
||||||
|
ProjKV(k_offset, v_offset, kv_offset);
|
||||||
|
|
||||||
|
Attn(head, head * kQKVDim);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// Multi-Query Attention
|
||||||
|
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||||
|
ProjQ(head, head * kQKVDim * kModelDim);
|
||||||
|
});
|
||||||
|
|
||||||
|
constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
|
||||||
|
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
|
||||||
|
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
|
||||||
|
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;
|
||||||
|
|
||||||
|
ProjKV(k_offset, v_offset, kv_offset);
|
||||||
|
|
||||||
|
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||||
|
Attn(head, 0);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// accumulate output across all heads into att_post2. head 0 already wrote
|
// accumulate output across all heads into att_post2. head 0 already wrote
|
||||||
// directly to att_post2.
|
// directly to att_post2.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue