This commit is contained in:
copybara-service[bot] 2025-12-16 17:32:36 +00:00 committed by GitHub
commit 72ff4b5b82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 12 additions and 12 deletions

View File

@ -335,15 +335,15 @@ void TileFlashAttention(
} }
VF m = ElementwiseMaxOf8(df, x0, x1, x2, x3, x4, x5, x6, x7); VF m = ElementwiseMaxOf8(df, x0, x1, x2, x3, x4, x5, x6, x7);
m = hn::Max(old_m, m); m = hn::Max(old_m, m);
x0 = hn::Exp(df, x0 - m); x0 = hn::Exp(df, hn::Sub(x0, m));
x1 = hn::Exp(df, x1 - m); x1 = hn::Exp(df, hn::Sub(x1, m));
x2 = hn::Exp(df, x2 - m); x2 = hn::Exp(df, hn::Sub(x2, m));
x3 = hn::Exp(df, x3 - m); x3 = hn::Exp(df, hn::Sub(x3, m));
x4 = hn::Exp(df, x4 - m); x4 = hn::Exp(df, hn::Sub(x4, m));
x5 = hn::Exp(df, x5 - m); x5 = hn::Exp(df, hn::Sub(x5, m));
x6 = hn::Exp(df, x6 - m); x6 = hn::Exp(df, hn::Sub(x6, m));
x7 = hn::Exp(df, x7 - m); x7 = hn::Exp(df, hn::Sub(x7, m));
VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m)); VF scale = hn::Mul(old_d, hn::Exp(df, hn::Sub(old_m, m)));
old_d = ElementwiseSumOf8(df, x0, x1, x2, x3, x4, x5, x6, x7); old_d = ElementwiseSumOf8(df, x0, x1, x2, x3, x4, x5, x6, x7);
old_d = hn::Add(scale, old_d); old_d = hn::Add(scale, old_d);
old_m = m; old_m = m;
@ -376,8 +376,8 @@ void TileFlashAttention(
std::numeric_limits<float>::max() / 2.0f); std::numeric_limits<float>::max() / 2.0f);
x0 = hn::Sub(x0, causal_offset); x0 = hn::Sub(x0, causal_offset);
VF m = hn::Max(old_m, x0); VF m = hn::Max(old_m, x0);
x0 = hn::Exp(df, x0 - m); x0 = hn::Exp(df, hn::Sub(x0, m));
VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m)); VF scale = hn::Mul(old_d, hn::Exp(df, hn::Sub(old_m, m)));
old_m = m; old_m = m;
old_d = hn::Add(scale, x0); old_d = hn::Add(scale, x0);
VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d); VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d);
@ -425,7 +425,7 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
float& old_d) { float& old_d) {
float m = hn::ReduceMax(df, x); float m = hn::ReduceMax(df, x);
m = std::max(m, old_max); m = std::max(m, old_max);
x = hn::Exp(df, x - hn::Set(df, m)); x = hn::Exp(df, hn::Sub(x, hn::Set(df, m)));
float scale = old_d * std::exp(old_max - m); float scale = old_d * std::exp(old_max - m);
old_d = hn::ReduceSum(df, x) + scale; old_d = hn::ReduceSum(df, x) + scale;
old_max = m; old_max = m;