From 2346b5a4343693d1bd6c8a3b1c9daafb09192f9d Mon Sep 17 00:00:00 2001 From: Daniel Keysers Date: Tue, 23 Jul 2024 11:17:06 -0700 Subject: [PATCH] Minor polishing: adding comments, renaming variables. PiperOrigin-RevId: 655235006 --- gemma/gemma-inl.h | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index f36730f..62ebcb0 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -262,7 +262,7 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, // Apply positional encodings for K (and copy KV to cache if MHA). pool.Run( 0, kKVHeads * num_interleaved, - [&](uint64_t task, size_t thread) HWY_ATTR { + [&](uint64_t task, size_t /*thread*/) HWY_ATTR { const size_t head = task % kKVHeads; const size_t interleaved_idx = task / kKVHeads; const size_t query_idx = interleaved_idx % num_queries; @@ -283,17 +283,20 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, PostQK(kv, pos, layer); }); + // A "head group" in the context of GQA refers to a collection of query heads + // that share the same key and value heads. static_assert((kHeads % kKVHeads) == 0, "query heads must be a multiple of key-value heads"); - constexpr size_t kGroupHeads = kHeads / kKVHeads; + constexpr size_t kHeadGroups = kHeads / kKVHeads; // For each head (token, query), compute Q.K, softmax, and weighted V. pool.Run( - 0, kHeads * num_interleaved, [&](uint64_t task, size_t thread) HWY_ATTR { + 0, kHeads * num_interleaved, + [&](uint64_t task, size_t /*thread*/) HWY_ATTR { const size_t head = task % kHeads; const size_t interleaved_idx = task / kHeads; const size_t query_idx = interleaved_idx % num_queries; const size_t batch_idx = interleaved_idx / num_queries; - const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2; + const size_t head_offset = (head / kHeadGroups) * kQKVDim * 2; KVCache& kv_cache = *kv_caches[query_idx]; float* HWY_RESTRICT q = activations.q.Batch(interleaved_idx) + head * kQStride; @@ -306,14 +309,16 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, // Compute Q.K scores, yielding "logits" (or scores) in head_att. float* HWY_RESTRICT head_att = activations.att.Batch(interleaved_idx) + head * kSeqLen; + // Usually start_pos is 0, unless pos is larger than the attention + // window size, then it is pos - window_size + 1. const size_t start_pos = pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos); for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); const size_t kv_offset = cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; - const float* HWY_RESTRICT k2 = kv_cache.kv_cache.get() + kv_offset; - const float score = Dot(q, k2, kQKVDim); + const float* HWY_RESTRICT k = kv_cache.kv_cache.get() + kv_offset; + const float score = Dot(q, k, kQKVDim); head_att[pos2 % kSeqLen] = score; } @@ -335,9 +340,9 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); const size_t kv_offset = cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; - float* HWY_RESTRICT v2 = + float* HWY_RESTRICT v = kv_cache.kv_cache.get() + kv_offset + kQKVDim; - MulByConstAndAdd(head_att[pos2 % kSeqLen], v2, att_out, kQKVDim); + MulByConstAndAdd(head_att[pos2 % kSeqLen], v, att_out, kQKVDim); } });