mirror of https://github.com/google/gemma.cpp.git
Reduce number of operations in Gelu() by one Mul.
About 5% faster Gen.Activation. PiperOrigin-RevId: 684035719
This commit is contained in:
parent
2c28b18eb0
commit
a570e3f662
|
|
@ -65,16 +65,19 @@ StaticCast(From from) noexcept {
|
|||
}
|
||||
}
|
||||
|
||||
// We use the tanh approximation for gelu (also used in training).
|
||||
// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
|
||||
// = 0.5 * x * (1 + tanh(x * (sqrt(2/π) + sqrt(2/π) * 0.044715 * x^2)))
|
||||
// = 0.5 * x * (1 + tanh(x * (0.79788 + 0.035677 * x^2)))
|
||||
// = x * (0.5 + 0.5 * tanh(x * (0.79788 + 0.035677 * x^2))))
|
||||
template <class D, HWY_IF_F32_D(D)>
|
||||
HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
|
||||
const hn::Vec<D> kMul = hn::Set(d, 0.044715f);
|
||||
const hn::Vec<D> kMul = hn::Set(d, 0.03567740813636141f);
|
||||
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
|
||||
const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
|
||||
|
||||
// tanh approximation matches training.
|
||||
const hn::Vec<D> v3 = hn::Mul(hn::Mul(v, v), v);
|
||||
const hn::Vec<D> arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v));
|
||||
// 0.5 * (1 + tan) = MulAdd(0.5, tan, 0.5).
|
||||
const hn::Vec<D> v2 = hn::Mul(v, v);
|
||||
const hn::Vec<D> arg = hn::Mul(v, hn::MulAdd(kMul, v2, kSqrt2OverPi));
|
||||
const hn::Vec<D> cdf = hn::MulAdd(kHalf, hn::Tanh(d, arg), kHalf);
|
||||
return hn::Mul(v, cdf);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue