From 14a9ecf21da9970efa8b6c79491fe44bfdbcabf5 Mon Sep 17 00:00:00 2001 From: Martin Stolle Date: Tue, 9 Dec 2025 02:22:46 -0800 Subject: [PATCH] Factor out SumHeads PiperOrigin-RevId: 842138081 --- gemma/attention.cc | 20 +------------------- gemma/gemma-inl.h | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/gemma/attention.cc b/gemma/attention.cc index 5880b50..deb85c0 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -43,6 +43,7 @@ // After highway.h #include "compression/compress-inl.h" #include "gemma/flash_attention.h" +#include "gemma/gemma-inl.h" #include "ops/ops-inl.h" HWY_BEFORE_NAMESPACE(); @@ -319,25 +320,6 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, }); } -// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and -// head_dim (`qkv_dim`) into output (`layer_out`). -static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, - AttentionActivationsPtrs& activations, - MatMulEnv& env) { - GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads); - const LayerConfig& layer_config = layer.layer_config; - (void)layer_config; // For HWY_DASSERT - // att_weights and att_out are concatenated heads, each of length - // layer_config.qkv_dim. Thus the [num_interleaved, - // layer_config.model_dim] matmul output is the sum over heads. Compare - // gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', - // encoded) - HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 && - layer_config.qkv_dim != 0); - CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env, - activations.att_sums); -} - void GemmaAttention(size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, AttentionActivationsPtrs& activations, QBatch& qbatch, diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 93f8928..2c55ef4 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -183,6 +183,25 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer, CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out); } +// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and +// head_dim (`qkv_dim`) into output (`layer_out`). +static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, + AttentionActivationsPtrs& activations, + MatMulEnv& env) { + GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads); + const LayerConfig& layer_config = layer.layer_config; + (void)layer_config; // For HWY_DASSERT + // att_weights and att_out are concatenated heads, each of length + // layer_config.qkv_dim. Thus the [num_interleaved, + // layer_config.model_dim] matmul output is the sum over heads. Compare + // gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', + // encoded) + HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 && + layer_config.qkv_dim != 0); + CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env, + activations.att_sums); +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp