diff --git a/gemma.cc b/gemma.cc index 76086d9..877a3dc 100644 --- a/gemma.cc +++ b/gemma.cc @@ -320,15 +320,16 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, const size_t batch_offset = batch_idx * kModelDim; - auto ProjKV = [&](size_t k_offset, size_t v_offset, size_t kv_offset) { - TwoOfsMatVecLoop( - c_layer->c_qkv_einsum_w, k_offset, v_offset, - activations.pre_att_rms_out.data() + batch_offset, - kv_cache.key_cache.get() + kv_offset, - kv_cache.value_cache.get() + kv_offset); + auto ProjKV = + [&](size_t k_offset, size_t v_offset, size_t kv_offset) HWY_ATTR { + TwoOfsMatVecLoop( + c_layer->c_qkv_einsum_w, k_offset, v_offset, + activations.pre_att_rms_out.data() + batch_offset, + kv_cache.key_cache.get() + kv_offset, + kv_cache.value_cache.get() + kv_offset); - Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); - }; + Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); + }; pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { // linear projections to QKV