Use hn::Sub for vector subtraction in flash attention.

PiperOrigin-RevId: 845310753
This commit is contained in:
Phil Culliton 2025-12-16 09:31:00 -08:00 committed by Copybara-Service
parent baa69dfb78
commit 001c356b02
1 changed files with 12 additions and 12 deletions

View File

@ -335,15 +335,15 @@ void TileFlashAttention(
}
VF m = ElementwiseMaxOf8(df, x0, x1, x2, x3, x4, x5, x6, x7);
m = hn::Max(old_m, m);
x0 = hn::Exp(df, x0 - m);
x1 = hn::Exp(df, x1 - m);
x2 = hn::Exp(df, x2 - m);
x3 = hn::Exp(df, x3 - m);
x4 = hn::Exp(df, x4 - m);
x5 = hn::Exp(df, x5 - m);
x6 = hn::Exp(df, x6 - m);
x7 = hn::Exp(df, x7 - m);
VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m));
x0 = hn::Exp(df, hn::Sub(x0, m));
x1 = hn::Exp(df, hn::Sub(x1, m));
x2 = hn::Exp(df, hn::Sub(x2, m));
x3 = hn::Exp(df, hn::Sub(x3, m));
x4 = hn::Exp(df, hn::Sub(x4, m));
x5 = hn::Exp(df, hn::Sub(x5, m));
x6 = hn::Exp(df, hn::Sub(x6, m));
x7 = hn::Exp(df, hn::Sub(x7, m));
VF scale = hn::Mul(old_d, hn::Exp(df, hn::Sub(old_m, m)));
old_d = ElementwiseSumOf8(df, x0, x1, x2, x3, x4, x5, x6, x7);
old_d = hn::Add(scale, old_d);
old_m = m;
@ -376,8 +376,8 @@ void TileFlashAttention(
std::numeric_limits<float>::max() / 2.0f);
x0 = hn::Sub(x0, causal_offset);
VF m = hn::Max(old_m, x0);
x0 = hn::Exp(df, x0 - m);
VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m));
x0 = hn::Exp(df, hn::Sub(x0, m));
VF scale = hn::Mul(old_d, hn::Exp(df, hn::Sub(old_m, m)));
old_m = m;
old_d = hn::Add(scale, x0);
VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d);
@ -425,7 +425,7 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
float& old_d) {
float m = hn::ReduceMax(df, x);
m = std::max(m, old_max);
x = hn::Exp(df, x - hn::Set(df, m));
x = hn::Exp(df, hn::Sub(x, hn::Set(df, m)));
float scale = old_d * std::exp(old_max - m);
old_d = hn::ReduceSum(df, x) + scale;
old_max = m;