mirror of https://github.com/google/gemma.cpp.git
Merge 001c356b02 into baa69dfb78
This commit is contained in:
commit
72ff4b5b82
|
|
@ -335,15 +335,15 @@ void TileFlashAttention(
|
||||||
}
|
}
|
||||||
VF m = ElementwiseMaxOf8(df, x0, x1, x2, x3, x4, x5, x6, x7);
|
VF m = ElementwiseMaxOf8(df, x0, x1, x2, x3, x4, x5, x6, x7);
|
||||||
m = hn::Max(old_m, m);
|
m = hn::Max(old_m, m);
|
||||||
x0 = hn::Exp(df, x0 - m);
|
x0 = hn::Exp(df, hn::Sub(x0, m));
|
||||||
x1 = hn::Exp(df, x1 - m);
|
x1 = hn::Exp(df, hn::Sub(x1, m));
|
||||||
x2 = hn::Exp(df, x2 - m);
|
x2 = hn::Exp(df, hn::Sub(x2, m));
|
||||||
x3 = hn::Exp(df, x3 - m);
|
x3 = hn::Exp(df, hn::Sub(x3, m));
|
||||||
x4 = hn::Exp(df, x4 - m);
|
x4 = hn::Exp(df, hn::Sub(x4, m));
|
||||||
x5 = hn::Exp(df, x5 - m);
|
x5 = hn::Exp(df, hn::Sub(x5, m));
|
||||||
x6 = hn::Exp(df, x6 - m);
|
x6 = hn::Exp(df, hn::Sub(x6, m));
|
||||||
x7 = hn::Exp(df, x7 - m);
|
x7 = hn::Exp(df, hn::Sub(x7, m));
|
||||||
VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m));
|
VF scale = hn::Mul(old_d, hn::Exp(df, hn::Sub(old_m, m)));
|
||||||
old_d = ElementwiseSumOf8(df, x0, x1, x2, x3, x4, x5, x6, x7);
|
old_d = ElementwiseSumOf8(df, x0, x1, x2, x3, x4, x5, x6, x7);
|
||||||
old_d = hn::Add(scale, old_d);
|
old_d = hn::Add(scale, old_d);
|
||||||
old_m = m;
|
old_m = m;
|
||||||
|
|
@ -376,8 +376,8 @@ void TileFlashAttention(
|
||||||
std::numeric_limits<float>::max() / 2.0f);
|
std::numeric_limits<float>::max() / 2.0f);
|
||||||
x0 = hn::Sub(x0, causal_offset);
|
x0 = hn::Sub(x0, causal_offset);
|
||||||
VF m = hn::Max(old_m, x0);
|
VF m = hn::Max(old_m, x0);
|
||||||
x0 = hn::Exp(df, x0 - m);
|
x0 = hn::Exp(df, hn::Sub(x0, m));
|
||||||
VF scale = hn::Mul(old_d, hn::Exp(df, old_m - m));
|
VF scale = hn::Mul(old_d, hn::Exp(df, hn::Sub(old_m, m)));
|
||||||
old_m = m;
|
old_m = m;
|
||||||
old_d = hn::Add(scale, x0);
|
old_d = hn::Add(scale, x0);
|
||||||
VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d);
|
VF one_over_d = hn::Div(hn::Set(df, 1.0f), old_d);
|
||||||
|
|
@ -425,7 +425,7 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
|
||||||
float& old_d) {
|
float& old_d) {
|
||||||
float m = hn::ReduceMax(df, x);
|
float m = hn::ReduceMax(df, x);
|
||||||
m = std::max(m, old_max);
|
m = std::max(m, old_max);
|
||||||
x = hn::Exp(df, x - hn::Set(df, m));
|
x = hn::Exp(df, hn::Sub(x, hn::Set(df, m)));
|
||||||
float scale = old_d * std::exp(old_max - m);
|
float scale = old_d * std::exp(old_max - m);
|
||||||
old_d = hn::ReduceSum(df, x) + scale;
|
old_d = hn::ReduceSum(df, x) + scale;
|
||||||
old_max = m;
|
old_max = m;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue