diff --git a/gemma.cc b/gemma.cc index 877a3dc..9baccd7 100644 --- a/gemma.cc +++ b/gemma.cc @@ -373,12 +373,13 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, Rope(q, kQKVDim, pos); MulByConst(kQueryScale, q, kQKVDim); + + const size_t head_offset = kHeads == kKVHeads ? head * kQKVDim : 0; + // Compute Q dot K scores for (size_t pos2 = 0; pos2 <= pos; ++pos2) { const size_t cache_offset = - kHeads == kKVHeads - ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim - : pos2 * kCachePosSize + layer * kCacheLayerSize; + pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset; const float score = Dot(q, k2, kQKVDim); head_att[pos2] = score; @@ -391,9 +392,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); for (size_t pos2 = 0; pos2 <= pos; ++pos2) { const size_t cache_offset = - kHeads == kKVHeads - ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim - : pos2 * kCachePosSize + layer * kCacheLayerSize; + pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset; MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); }