diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index ba1de3e..c65c57f 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -126,6 +126,10 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos, PROFILER_ZONE3(p, worker, zone); const size_t pos_mod = activations.div_seq_len.Remainder(start_pos); float m = Dot(q, k.Row(pos_mod), k.Cols()); + if (float cap = activations.config.att_cap; cap > 0.0f) { + // Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x. + m = cap * std::tanh(m / cap); + } float d = 1.0f; // This is just a copy of the first token. MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), p, worker);