diff --git a/gemma/gemma.cc b/gemma/gemma.cc index a92d835..6d80741 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -741,22 +741,6 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + cache_offset + kQKVDim; MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); } - // linear projection from kQKVDim back to kModelDim, sum projections - // across heads - float* HWY_RESTRICT head_out = - head == 0 - ? activations.att_post2.data() + batch_idx * kModelDim - : activations.att_post1.data() + head * kBatchSize * kModelDim; - float* even_odd = activations.even_odd.data() + thread * kQKVDim; - if (head == 0) { - MatVecAddLoop( - layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out, - layer_weights->attention_output_biases.data(), even_odd, head_out); - } else { - MatVecLoop(layer_weights->attn_vec_einsum_w, - head * kModelDim * kQKVDim, att_out, - even_odd, head_out); - } }; if constexpr (kHeads == kKVHeads) { @@ -810,11 +794,24 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, }); } - // accumulate output across all heads into att_post2. head 0 already wrote - // directly to att_post2. - for (size_t head = 1; head < kHeads; ++head) { - AddFrom(activations.att_post1.data() + head * kBatchSize * kModelDim, - activations.att_post2.data() + batch_idx * kModelDim, kModelDim); + // TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after + // rearranging the weights. + float* HWY_RESTRICT att_out = + activations.att_out.data() + batch_idx * kHeads * kQKVDim; + float* HWY_RESTRICT layer_out = + activations.att_post2.data() + batch_idx * kModelDim; + MatVecAdd( + layer_weights->attn_vec_einsum_w, 0, att_out, + layer_weights->attention_output_biases.data(), + activations.even_odd.data(), layer_out, pool); + for (size_t head = 1; head < kHeads; ++head) { + float* HWY_RESTRICT head_out = + activations.att_post1.data() + head * kBatchSize * kModelDim; + MatVec( + layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, + att_out + head * kQKVDim, + activations.even_odd.data(), head_out, pool); + AddFrom(head_out, layer_out, kModelDim); } }