mirror of https://github.com/google/gemma.cpp.git
Minor polishing: adding comments, renaming variables.
PiperOrigin-RevId: 655235006
This commit is contained in:
parent
33334ad454
commit
2346b5a434
|
|
@ -262,7 +262,7 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
||||||
// Apply positional encodings for K (and copy KV to cache if MHA).
|
// Apply positional encodings for K (and copy KV to cache if MHA).
|
||||||
pool.Run(
|
pool.Run(
|
||||||
0, kKVHeads * num_interleaved,
|
0, kKVHeads * num_interleaved,
|
||||||
[&](uint64_t task, size_t thread) HWY_ATTR {
|
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||||
const size_t head = task % kKVHeads;
|
const size_t head = task % kKVHeads;
|
||||||
const size_t interleaved_idx = task / kKVHeads;
|
const size_t interleaved_idx = task / kKVHeads;
|
||||||
const size_t query_idx = interleaved_idx % num_queries;
|
const size_t query_idx = interleaved_idx % num_queries;
|
||||||
|
|
@ -283,17 +283,20 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
||||||
PostQK<TConfig>(kv, pos, layer);
|
PostQK<TConfig>(kv, pos, layer);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// A "head group" in the context of GQA refers to a collection of query heads
|
||||||
|
// that share the same key and value heads.
|
||||||
static_assert((kHeads % kKVHeads) == 0,
|
static_assert((kHeads % kKVHeads) == 0,
|
||||||
"query heads must be a multiple of key-value heads");
|
"query heads must be a multiple of key-value heads");
|
||||||
constexpr size_t kGroupHeads = kHeads / kKVHeads;
|
constexpr size_t kHeadGroups = kHeads / kKVHeads;
|
||||||
// For each head (token, query), compute Q.K, softmax, and weighted V.
|
// For each head (token, query), compute Q.K, softmax, and weighted V.
|
||||||
pool.Run(
|
pool.Run(
|
||||||
0, kHeads * num_interleaved, [&](uint64_t task, size_t thread) HWY_ATTR {
|
0, kHeads * num_interleaved,
|
||||||
|
[&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||||
const size_t head = task % kHeads;
|
const size_t head = task % kHeads;
|
||||||
const size_t interleaved_idx = task / kHeads;
|
const size_t interleaved_idx = task / kHeads;
|
||||||
const size_t query_idx = interleaved_idx % num_queries;
|
const size_t query_idx = interleaved_idx % num_queries;
|
||||||
const size_t batch_idx = interleaved_idx / num_queries;
|
const size_t batch_idx = interleaved_idx / num_queries;
|
||||||
const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2;
|
const size_t head_offset = (head / kHeadGroups) * kQKVDim * 2;
|
||||||
KVCache& kv_cache = *kv_caches[query_idx];
|
KVCache& kv_cache = *kv_caches[query_idx];
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations.q.Batch(interleaved_idx) + head * kQStride;
|
activations.q.Batch(interleaved_idx) + head * kQStride;
|
||||||
|
|
@ -306,14 +309,16 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
||||||
// Compute Q.K scores, yielding "logits" (or scores) in head_att.
|
// Compute Q.K scores, yielding "logits" (or scores) in head_att.
|
||||||
float* HWY_RESTRICT head_att =
|
float* HWY_RESTRICT head_att =
|
||||||
activations.att.Batch(interleaved_idx) + head * kSeqLen;
|
activations.att.Batch(interleaved_idx) + head * kSeqLen;
|
||||||
|
// Usually start_pos is 0, unless pos is larger than the attention
|
||||||
|
// window size, then it is pos - window_size + 1.
|
||||||
const size_t start_pos =
|
const size_t start_pos =
|
||||||
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
|
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
|
||||||
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
|
||||||
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||||
const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset;
|
const float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset;
|
||||||
const float score = Dot(q, k2, kQKVDim);
|
const float score = Dot(q, k, kQKVDim);
|
||||||
head_att[pos2 % kSeqLen] = score;
|
head_att[pos2 % kSeqLen] = score;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -335,9 +340,9 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
||||||
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||||
float* HWY_RESTRICT v2 =
|
float* HWY_RESTRICT v =
|
||||||
kv_cache.kv_cache.get() + kv_offset + kQKVDim;
|
kv_cache.kv_cache.get() + kv_offset + kQKVDim;
|
||||||
MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim);
|
MulByConstAndAdd(head_att[pos2 % kSeqLen], v, att_out, kQKVDim);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue