mirror of https://github.com/google/gemma.cpp.git
Change to FastExpMinusOrZero
PiperOrigin-RevId: 880761639
This commit is contained in:
parent
029cfd0b33
commit
968a5aa87f
|
|
@ -489,7 +489,7 @@ HWY_INLINE float SingleFlashAttentionRowVector(DF df, size_t start_pos,
|
|||
}
|
||||
float m = hn::ReduceMax(df, x);
|
||||
m = std::max(m, old_max);
|
||||
x = hn::Exp(df, hn::Sub(x, hn::Set(df, m)));
|
||||
x = hn::FastExpMinusOrZero(df, hn::Sub(x, hn::Set(df, m)));
|
||||
float scale = old_d * std::exp(old_max - m);
|
||||
old_d = hn::ReduceSum(df, x) + scale;
|
||||
old_max = m;
|
||||
|
|
@ -538,8 +538,8 @@ HWY_INLINE float DoubleFlashAttentionRowVector(DF df, size_t start_pos,
|
|||
float m = hn::ReduceMax(df, x_max);
|
||||
m = std::max(m, old_max);
|
||||
VF m_vec = hn::Set(df, m);
|
||||
x0 = hn::Exp(df, hn::Sub(x0, m_vec));
|
||||
x1 = hn::Exp(df, hn::Sub(x1, m_vec));
|
||||
x0 = hn::FastExpMinusOrZero(df, hn::Sub(x0, m_vec));
|
||||
x1 = hn::FastExpMinusOrZero(df, hn::Sub(x1, m_vec));
|
||||
float scale = old_d * std::exp(old_max - m);
|
||||
VF x_sum = hn::Add(x0, x1);
|
||||
old_d = hn::ReduceSum(df, x_sum) + scale;
|
||||
|
|
@ -672,7 +672,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
|
|||
x_sum = Reduce4(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum,
|
||||
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
|
||||
}
|
||||
VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max)));
|
||||
VF4 scale = hn::Mul(
|
||||
old_d_vf, hn::FastExpMinusOrZero(df4, hn::Sub(old_max_vf, new_max)));
|
||||
old_d_vf = hn::Add(scale, x_sum);
|
||||
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df4, 0.0f));
|
||||
const VF zero = hn::Zero(df);
|
||||
|
|
@ -810,7 +811,8 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
|
|||
x_6_sum, x_7_sum,
|
||||
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
|
||||
}
|
||||
VF8 scale = hn::Mul(old_d_vf, hn::Exp(df8, hn::Sub(old_max_vf, new_max)));
|
||||
VF8 scale = hn::Mul(
|
||||
old_d_vf, hn::FastExpMinusOrZero(df8, hn::Sub(old_max_vf, new_max)));
|
||||
old_d_vf = hn::Add(scale, x_sum);
|
||||
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df8, 0.0f));
|
||||
const VF zero = hn::Zero(df);
|
||||
|
|
|
|||
Loading…
Reference in New Issue