mirror of https://github.com/google/gemma.cpp.git
Generic MHA/MQA/GQA implementation
PiperOrigin-RevId: 636937885
This commit is contained in:
parent
93c0088646
commit
419dc34ed5
|
|
@ -762,69 +762,67 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if constexpr (kHeads == kKVHeads) {
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
// Multi-Head Attention
|
const float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
||||||
static_assert(TConfig::kInterleaveQKV);
|
// QKV projections:
|
||||||
|
if constexpr (kHeads == kKVHeads) {
|
||||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
// Multi-Head Attention calculates qkv using q as scratch space.
|
||||||
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
static_assert(TConfig::kInterleaveQKV);
|
||||||
float* HWY_RESTRICT qkv =
|
float* HWY_RESTRICT qkv =
|
||||||
activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
|
activations.q.data() + batch_idx * kHeads * kQKVDim * 3;
|
||||||
MatVec<kHeads * kQKVDim * 3, kModelDim>(
|
MatVec<kHeads * kQKVDim * 3, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
|
||||||
layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv,
|
activations.even_odd.data(), qkv,
|
||||||
pool);
|
pool);
|
||||||
}
|
} else {
|
||||||
const size_t num_tasks = kHeads * num_tokens;
|
|
||||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
|
||||||
const size_t head = task % kHeads;
|
|
||||||
const size_t batch_idx = task / kHeads;
|
|
||||||
const size_t pos = batch_start + batch_idx;
|
const size_t pos = batch_start + batch_idx;
|
||||||
float* HWY_RESTRICT q =
|
|
||||||
activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3;
|
|
||||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
|
||||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
|
||||||
layer * kCacheLayerSize + head * kQKVDim * 2;
|
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
|
||||||
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
|
||||||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
|
||||||
});
|
|
||||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
|
||||||
const size_t head = task % kHeads;
|
|
||||||
const size_t batch_idx = task / kHeads;
|
|
||||||
float* HWY_RESTRICT q =
|
|
||||||
activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3;
|
|
||||||
Attn(q, head, head * kQKVDim * 2, batch_idx, thread);
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
// Multi-Query Attention
|
|
||||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
|
||||||
const size_t pos = batch_start + batch_idx;
|
|
||||||
float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim;
|
|
||||||
|
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations.q.data() + batch_idx * kHeads * kQKVDim;
|
activations.q.data() + batch_idx * kHeads * kQKVDim;
|
||||||
MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
|
MatVec<kHeads * kQKVDim, kModelDim>(layer_weights->qkv_einsum_w, 0, x,
|
||||||
activations.even_odd.data(), q, pool);
|
activations.even_odd.data(), q, pool);
|
||||||
|
|
||||||
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||||
const size_t kv_offset = cache_pos * kCachePosSize +
|
const size_t kv_offset =
|
||||||
layer * kCacheLayerSize;
|
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
MatVec<kQKVDim * 2, kModelDim>(layer_weights->qkv_einsum_w,
|
MatVec<kKVHeads * kQKVDim * 2, kModelDim>(
|
||||||
kHeads * kQKVDim * kModelDim, x,
|
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
|
||||||
activations.even_odd.data(), kv, pool);
|
activations.even_odd.data(), kv, pool);
|
||||||
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
|
||||||
}
|
}
|
||||||
const size_t num_tasks = kHeads * num_tokens;
|
|
||||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
|
||||||
const size_t head = task % kHeads;
|
|
||||||
const size_t batch_idx = task / kHeads;
|
|
||||||
float* HWY_RESTRICT q =
|
|
||||||
activations.q.data() + batch_idx * kHeads * kQKVDim;
|
|
||||||
Attn(q + head * kQKVDim, head, 0, batch_idx, thread);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Positional encodings for k:
|
||||||
|
const size_t num_kv_tasks = kKVHeads * num_tokens;
|
||||||
|
pool.Run(0, num_kv_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||||
|
const size_t head = task % kKVHeads;
|
||||||
|
const size_t batch_idx = task / kKVHeads;
|
||||||
|
const size_t pos = batch_start + batch_idx;
|
||||||
|
const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize);
|
||||||
|
const size_t kv_offset = cache_pos * kCachePosSize +
|
||||||
|
layer * kCacheLayerSize + head * kQKVDim * 2;
|
||||||
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
|
if constexpr (kHeads == kKVHeads) {
|
||||||
|
// For MHA, copy kv into the KV cache from scratch space (see above).
|
||||||
|
const float* HWY_RESTRICT q =
|
||||||
|
activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3;
|
||||||
|
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
|
||||||
|
}
|
||||||
|
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
|
||||||
|
});
|
||||||
|
|
||||||
|
static_assert((TConfig::kHeads % TConfig::kKVHeads) == 0,
|
||||||
|
"query heads must be a multiple of key-value heads");
|
||||||
|
static constexpr size_t kGroupHeads = TConfig::kHeads / TConfig::kKVHeads;
|
||||||
|
static constexpr size_t kQOffsetScale = (kHeads == kKVHeads) ? 3 : 1;
|
||||||
|
const size_t num_q_tasks = kHeads * num_tokens;
|
||||||
|
pool.Run(0, num_q_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||||
|
const size_t head = task % kHeads;
|
||||||
|
const size_t batch_idx = task / kHeads;
|
||||||
|
const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2;
|
||||||
|
float* HWY_RESTRICT q = activations.q.data() + (batch_idx * kHeads + head) *
|
||||||
|
kQKVDim * kQOffsetScale;
|
||||||
|
Attn(q, head, head_offset, batch_idx, thread);
|
||||||
|
});
|
||||||
|
|
||||||
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
|
||||||
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
||||||
// rearranging the weights.
|
// rearranging the weights.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue