Fixed bug with softcap in single flash attention

PiperOrigin-RevId: 813164938
This commit is contained in:
Ray Smith 2025-09-30 02:17:18 -07:00 committed by Copybara-Service
parent 16536996d1
commit 4974f24832
1 changed files with 4 additions and 0 deletions

View File

@ -126,6 +126,10 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
PROFILER_ZONE3(p, worker, zone);
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
float m = Dot(q, k.Row(pos_mod), k.Cols());
if (float cap = activations.config.att_cap; cap > 0.0f) {
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
m = cap * std::tanh(m / cap);
}
float d = 1.0f;
// This is just a copy of the first token.
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), p, worker);