mirror of https://github.com/google/gemma.cpp.git
Fixed bug with softcap in single flash attention
PiperOrigin-RevId: 813164938
This commit is contained in:
parent
16536996d1
commit
4974f24832
|
|
@ -126,6 +126,10 @@ void SingleFlashAttention(const size_t start_pos, const size_t last_pos,
|
|||
PROFILER_ZONE3(p, worker, zone);
|
||||
const size_t pos_mod = activations.div_seq_len.Remainder(start_pos);
|
||||
float m = Dot(q, k.Row(pos_mod), k.Cols());
|
||||
if (float cap = activations.config.att_cap; cap > 0.0f) {
|
||||
// Compute tanh(x / cap) * cap, being LogitsSoftCap on the scalar x.
|
||||
m = cap * std::tanh(m / cap);
|
||||
}
|
||||
float d = 1.0f;
|
||||
// This is just a copy of the first token.
|
||||
MulByConstTo(d, v.Row(pos_mod), att_out, v.Cols(), p, worker);
|
||||
|
|
|
|||
Loading…
Reference in New Issue