diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 347a4cd..8b470e6 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -419,28 +419,28 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer, } }; + if constexpr (kHeads == kKVHeads) { + // Multi-Head Attention calculates qkv using q as scratch space. + static_assert(TConfig::kInterleaveQKV); + MatMul_4x4_Batch( + num_tokens, activations.pre_att_rms_out.data(), + layer_weights->qkv_einsum_w.data(), activations.q.data(), pool); + } else { + MatMul_4x4_Batch( + num_tokens, activations.pre_att_rms_out.data(), + layer_weights->qkv_einsum_w.data(), activations.q.data(), pool); + } + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { const float* x = activations.pre_att_rms_out.data() + batch_idx * kModelDim; // QKV projections: - if constexpr (kHeads == kKVHeads) { - // Multi-Head Attention calculates qkv using q as scratch space. - static_assert(TConfig::kInterleaveQKV); - float* HWY_RESTRICT qkv = - activations.q.data() + batch_idx * kHeads * kQKVDim * 3; - MatVec(layer_weights->qkv_einsum_w, 0, x, - activations.even_odd.data(), qkv, - pool); - } else { + if constexpr (kHeads != kKVHeads) { const size_t pos = batch_start + batch_idx; - float* HWY_RESTRICT q = - activations.q.data() + batch_idx * kHeads * kQKVDim; - MatVec(layer_weights->qkv_einsum_w, 0, x, - activations.even_odd.data(), q, pool); - const size_t cache_pos = pos % (kSeqLen + kPrefillBatchSize); const size_t kv_offset = cache_pos * kCachePosSize + layer * kCacheLayerSize; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; + // TODO: requires MatMul support for offsets. MatVec( layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, activations.even_odd.data(), kv, pool); @@ -494,6 +494,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer, for (size_t head = 1; head < kHeads; ++head) { float* HWY_RESTRICT head_out = activations.att_post1.data() + head * kBatchSize * kModelDim; + // TODO: requires MatMul support for offsets. MatVec( layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out + head * kQKVDim, diff --git a/gemma/ops.h b/gemma/ops.h index 318af52..18b638b 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -881,6 +881,7 @@ template