Minor polishing: adding comments, renaming variables.

PiperOrigin-RevId: 655235006
This commit is contained in:
Daniel Keysers 2024-07-23 11:17:06 -07:00 committed by Copybara-Service
parent 33334ad454
commit 2346b5a434
1 changed files with 13 additions and 8 deletions

View File

@ -262,7 +262,7 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
// Apply positional encodings for K (and copy KV to cache if MHA).
pool.Run(
0, kKVHeads * num_interleaved,
[&](uint64_t task, size_t thread) HWY_ATTR {
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % kKVHeads;
const size_t interleaved_idx = task / kKVHeads;
const size_t query_idx = interleaved_idx % num_queries;
@ -283,17 +283,20 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
PostQK<TConfig>(kv, pos, layer);
});
// A "head group" in the context of GQA refers to a collection of query heads
// that share the same key and value heads.
static_assert((kHeads % kKVHeads) == 0,
"query heads must be a multiple of key-value heads");
constexpr size_t kGroupHeads = kHeads / kKVHeads;
constexpr size_t kHeadGroups = kHeads / kKVHeads;
// For each head (token, query), compute Q.K, softmax, and weighted V.
pool.Run(
0, kHeads * num_interleaved, [&](uint64_t task, size_t thread) HWY_ATTR {
0, kHeads * num_interleaved,
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t head = task % kHeads;
const size_t interleaved_idx = task / kHeads;
const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = interleaved_idx / num_queries;
const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2;
const size_t head_offset = (head / kHeadGroups) * kQKVDim * 2;
KVCache& kv_cache = *kv_caches[query_idx];
float* HWY_RESTRICT q =
activations.q.Batch(interleaved_idx) + head * kQStride;
@ -306,14 +309,16 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
// Compute Q.K scores, yielding "logits" (or scores) in head_att.
float* HWY_RESTRICT head_att =
activations.att.Batch(interleaved_idx) + head * kSeqLen;
// Usually start_pos is 0, unless pos is larger than the attention
// window size, then it is pos - window_size + 1.
const size_t start_pos =
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset;
const float score = Dot(q, k2, kQKVDim);
const float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset;
const float score = Dot(q, k, kQKVDim);
head_att[pos2 % kSeqLen] = score;
}
@ -335,9 +340,9 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
const size_t kv_offset =
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
float* HWY_RESTRICT v2 =
float* HWY_RESTRICT v =
kv_cache.kv_cache.get() + kv_offset + kQKVDim;
MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim);
MulByConstAndAdd(head_att[pos2 % kSeqLen], v, att_out, kQKVDim);
}
});