Change to FastExpMinusOrZero

PiperOrigin-RevId: 880761639
This commit is contained in:
Nikhil Dev Goyal 2026-03-09 03:25:28 -07:00 committed by Copybara-Service
parent 029cfd0b33
commit 968a5aa87f
1 changed files with 7 additions and 5 deletions

View File

@ -489,7 +489,7 @@ HWY_INLINE float SingleFlashAttentionRowVector(DF df, size_t start_pos,
}
float m = hn::ReduceMax(df, x);
m = std::max(m, old_max);
x = hn::Exp(df, hn::Sub(x, hn::Set(df, m)));
x = hn::FastExpMinusOrZero(df, hn::Sub(x, hn::Set(df, m)));
float scale = old_d * std::exp(old_max - m);
old_d = hn::ReduceSum(df, x) + scale;
old_max = m;
@ -538,8 +538,8 @@ HWY_INLINE float DoubleFlashAttentionRowVector(DF df, size_t start_pos,
float m = hn::ReduceMax(df, x_max);
m = std::max(m, old_max);
VF m_vec = hn::Set(df, m);
x0 = hn::Exp(df, hn::Sub(x0, m_vec));
x1 = hn::Exp(df, hn::Sub(x1, m_vec));
x0 = hn::FastExpMinusOrZero(df, hn::Sub(x0, m_vec));
x1 = hn::FastExpMinusOrZero(df, hn::Sub(x1, m_vec));
float scale = old_d * std::exp(old_max - m);
VF x_sum = hn::Add(x0, x1);
old_d = hn::ReduceSum(df, x_sum) + scale;
@ -672,7 +672,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
x_sum = Reduce4(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum,
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
}
VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max)));
VF4 scale = hn::Mul(
old_d_vf, hn::FastExpMinusOrZero(df4, hn::Sub(old_max_vf, new_max)));
old_d_vf = hn::Add(scale, x_sum);
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df4, 0.0f));
const VF zero = hn::Zero(df);
@ -810,7 +811,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
x_6_sum, x_7_sum,
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
}
VF8 scale = hn::Mul(old_d_vf, hn::Exp(df8, hn::Sub(old_max_vf, new_max)));
VF8 scale = hn::Mul(
old_d_vf, hn::FastExpMinusOrZero(df8, hn::Sub(old_max_vf, new_max)));
old_d_vf = hn::Add(scale, x_sum);
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df8, 0.0f));
const VF zero = hn::Zero(df);