diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 7f32804..504e1e0 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -762,69 +762,67 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer, } }; - if constexpr (kHeads == kKVHeads) { - // Multi-Head Attention - static_assert(TConfig::kInterleaveQKV); - - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim; + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + const float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim; + // QKV projections: + if constexpr (kHeads == kKVHeads) { + // Multi-Head Attention calculates qkv using q as scratch space. + static_assert(TConfig::kInterleaveQKV); float* HWY_RESTRICT qkv = activations.q.data() + batch_idx * kHeads * kQKVDim * 3; - MatVec( - layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), qkv, - pool); - } - const size_t num_tasks = kHeads * num_tokens; - pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { - const size_t head = task % kHeads; - const size_t batch_idx = task / kHeads; + MatVec(layer_weights->qkv_einsum_w, 0, x, + activations.even_odd.data(), qkv, + pool); + } else { const size_t pos = batch_start + batch_idx; - float* HWY_RESTRICT q = - activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3; - const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); - const size_t kv_offset = cache_pos * kCachePosSize + - layer * kCacheLayerSize + head * kQKVDim * 2; - float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); - Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); - }); - pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { - const size_t head = task % kHeads; - const size_t batch_idx = task / kHeads; - float* HWY_RESTRICT q = - activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3; - Attn(q, head, head * kQKVDim * 2, batch_idx, thread); - }); - } else { - // Multi-Query Attention - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const size_t pos = batch_start + batch_idx; - float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim; - float* HWY_RESTRICT q = activations.q.data() + batch_idx * kHeads * kQKVDim; MatVec(layer_weights->qkv_einsum_w, 0, x, activations.even_odd.data(), q, pool); const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); - const size_t kv_offset = cache_pos * kCachePosSize + - layer * kCacheLayerSize; + const size_t kv_offset = + cache_pos * kCachePosSize + layer * kCacheLayerSize; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - MatVec(layer_weights->qkv_einsum_w, - kHeads * kQKVDim * kModelDim, x, - activations.even_odd.data(), kv, pool); - Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + MatVec( + layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, + activations.even_odd.data(), kv, pool); } - const size_t num_tasks = kHeads * num_tokens; - pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { - const size_t head = task % kHeads; - const size_t batch_idx = task / kHeads; - float* HWY_RESTRICT q = - activations.q.data() + batch_idx * kHeads * kQKVDim; - Attn(q + head * kQKVDim, head, 0, batch_idx, thread); - }); } + // Positional encodings for k: + const size_t num_kv_tasks = kKVHeads * num_tokens; + pool.Run(0, num_kv_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kKVHeads; + const size_t batch_idx = task / kKVHeads; + const size_t pos = batch_start + batch_idx; + const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); + const size_t kv_offset = cache_pos * kCachePosSize + + layer * kCacheLayerSize + head * kQKVDim * 2; + float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; + if constexpr (kHeads == kKVHeads) { + // For MHA, copy kv into the KV cache from scratch space (see above). + const float* HWY_RESTRICT q = + activations.q.data() + (batch_idx * kHeads + head) * kQKVDim * 3; + memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); + } + Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); + }); + + static_assert((TConfig::kHeads % TConfig::kKVHeads) == 0, + "query heads must be a multiple of key-value heads"); + static constexpr size_t kGroupHeads = TConfig::kHeads / TConfig::kKVHeads; + static constexpr size_t kQOffsetScale = (kHeads == kKVHeads) ? 3 : 1; + const size_t num_q_tasks = kHeads * num_tokens; + pool.Run(0, num_q_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kHeads; + const size_t batch_idx = task / kHeads; + const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2; + float* HWY_RESTRICT q = activations.q.data() + (batch_idx * kHeads + head) * + kQKVDim * kQOffsetScale; + Attn(q, head, head_offset, batch_idx, thread); + }); + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { // TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after // rearranging the weights.