From 001c356b02a2f6bfdbbed6df1cae77112c216592 Mon Sep 17 00:00:00 2001 From: Phil Culliton Date: Tue, 16 Dec 2025 09:31:00 -0800 Subject: [PATCH] Use hn::Sub for vector subtraction in flash attention. PiperOrigin-RevId: 845310753 --- gemma/flash_attention.cc | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 671efb4..815d76f 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -335,15 +335,15 @@ void TileFlashAttention( } VF m = ElementwiseMaxOf8(df, x0, x1, x2, x3, x4, x5, x6, x7); m = hn::Max(old_m, m); - x0 = hn::Exp(df, x0 - m); - x1 = hn::Exp(df, x1 - m); - x2 = hn::Exp(df, x2 - m); - x3 = hn::Exp(df, x3 - m); - x4 = hn::Exp(df, x4 - m); - x5 = hn::Exp(df, x5 - m); - x6 = hn::Exp(df, x6 - m); - x7 = hn::Exp(df, x7 - m); - VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m)); + x0 = hn::Exp(df, hn::Sub(x0, m)); + x1 = hn::Exp(df, hn::Sub(x1, m)); + x2 = hn::Exp(df, hn::Sub(x2, m)); + x3 = hn::Exp(df, hn::Sub(x3, m)); + x4 = hn::Exp(df, hn::Sub(x4, m)); + x5 = hn::Exp(df, hn::Sub(x5, m)); + x6 = hn::Exp(df, hn::Sub(x6, m)); + x7 = hn::Exp(df, hn::Sub(x7, m)); + VF scale = hn::Mul(old_d, hn::Exp(df, hn::Sub(old_m, m))); old_d = ElementwiseSumOf8(df, x0, x1, x2, x3, x4, x5, x6, x7); old_d = hn::Add(scale, old_d); old_m = m; @@ -376,8 +376,8 @@ void TileFlashAttention( std::numeric_limits::max() / 2.0f); x0 = hn::Sub(x0, causal_offset); VF m = hn::Max(old_m, x0); - x0 = hn::Exp(df, x0 - m); - VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m)); + x0 = hn::Exp(df, hn::Sub(x0, m)); + VF scale = hn::Mul(old_d, hn::Exp(df, hn::Sub(old_m, m))); old_m = m; old_d = hn::Add(scale, x0); VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d); @@ -425,7 +425,7 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max, float& old_d) { float m = hn::ReduceMax(df, x); m = std::max(m, old_max); - x = hn::Exp(df, x - hn::Set(df, m)); + x = hn::Exp(df, hn::Sub(x, hn::Set(df, m))); float scale = old_d * std::exp(old_max - m); old_d = hn::ReduceSum(df, x) + scale; old_max = m;