diff --git a/gemma.cc b/gemma.cc index 1867fbf..76086d9 100644 --- a/gemma.cc +++ b/gemma.cc @@ -320,6 +320,16 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, const size_t batch_offset = batch_idx * kModelDim; + auto ProjKV = [&](size_t k_offset, size_t v_offset, size_t kv_offset) { + TwoOfsMatVecLoop( + c_layer->c_qkv_einsum_w, k_offset, v_offset, + activations.pre_att_rms_out.data() + batch_offset, + kv_cache.key_cache.get() + kv_offset, + kv_cache.value_cache.get() + kv_offset); + + Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); + }; + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { // linear projections to QKV constexpr const size_t head_offset = @@ -339,13 +349,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; - TwoOfsMatVecLoop( - c_layer->c_qkv_einsum_w, k_offset, v_offset, - activations.pre_att_rms_out.data() + batch_offset, - kv_cache.key_cache.get() + kv_offset, - kv_cache.value_cache.get() + kv_offset); - - Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); + ProjKV(k_offset, v_offset, kv_offset); } }); @@ -355,13 +359,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize; - TwoOfsMatVecLoop( - c_layer->c_qkv_einsum_w, k_offset, v_offset, - activations.pre_att_rms_out.data() + batch_offset, - kv_cache.key_cache.get() + kv_offset, - kv_cache.value_cache.get() + kv_offset); - - Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); + ProjKV(k_offset, v_offset, kv_offset); } pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { @@ -376,9 +374,10 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, MulByConst(kQueryScale, q, kQKVDim); // Compute Q dot K scores for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t cache_offset = kHeads == kKVHeads - ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim - : pos2 * kCachePosSize + layer * kCacheLayerSize; + const size_t cache_offset = + kHeads == kKVHeads + ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim + : pos2 * kCachePosSize + layer * kCacheLayerSize; const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset; const float score = Dot(q, k2, kQKVDim); head_att[pos2] = score; @@ -390,9 +389,10 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, batch_idx * kHeads * kQKVDim; hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t cache_offset = kHeads == kKVHeads - ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim - : pos2 * kCachePosSize + layer * kCacheLayerSize; + const size_t cache_offset = + kHeads == kKVHeads + ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim + : pos2 * kCachePosSize + layer * kCacheLayerSize; float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset; MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); }