Reduce number of operations in Gelu() by one Mul.

About 5% faster Gen.Activation.

PiperOrigin-RevId: 684035719
This commit is contained in:
Daniel Keysers 2024-10-09 07:50:02 -07:00 committed by Copybara-Service
parent 2c28b18eb0
commit a570e3f662
1 changed files with 8 additions and 5 deletions

View File

@ -65,16 +65,19 @@ StaticCast(From from) noexcept {
} }
} }
// We use the tanh approximation for gelu (also used in training).
// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
// = 0.5 * x * (1 + tanh(x * (sqrt(2/π) + sqrt(2/π) * 0.044715 * x^2)))
// = 0.5 * x * (1 + tanh(x * (0.79788 + 0.035677 * x^2)))
// = x * (0.5 + 0.5 * tanh(x * (0.79788 + 0.035677 * x^2))))
template <class D, HWY_IF_F32_D(D)> template <class D, HWY_IF_F32_D(D)>
HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) { HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
const hn::Vec<D> kMul = hn::Set(d, 0.044715f); const hn::Vec<D> kMul = hn::Set(d, 0.03567740813636141f);
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f); const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
const hn::Vec<D> kHalf = hn::Set(d, 0.5f); const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
// tanh approximation matches training. const hn::Vec<D> v2 = hn::Mul(v, v);
const hn::Vec<D> v3 = hn::Mul(hn::Mul(v, v), v); const hn::Vec<D> arg = hn::Mul(v, hn::MulAdd(kMul, v2, kSqrt2OverPi));
const hn::Vec<D> arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v));
// 0.5 * (1 + tan) = MulAdd(0.5, tan, 0.5).
const hn::Vec<D> cdf = hn::MulAdd(kHalf, hn::Tanh(d, arg), kHalf); const hn::Vec<D> cdf = hn::MulAdd(kHalf, hn::Tanh(d, arg), kHalf);
return hn::Mul(v, cdf); return hn::Mul(v, cdf);
} }