From 4974f248322af872d99102eda1a79964084eaaf8 Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Tue, 30 Sep 2025 02:17:18 -0700 Subject: [PATCH] Fixed bug with softcap in single flash attention PiperOrigin-RevId: 813164938 --- gemma/flash_attention.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index ba1de3e..c65c57f 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -126,6 +126,10 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, PROFILER_ZONE3(p, worker, zone); const size_t pos_mod = activations.div_seq_len.Remainder(start_pos); float m = Dot(q, k.Row(pos_mod), k.Cols()); + if (float cap = activations.config.att_cap; cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. + m = cap * std::tanh(m / cap); + } float d = 1.0f; // This is just a copy of the first token. MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), p, worker);