// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Include guard for non-SIMD code. #ifndef THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_INL_H_ #define THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_INL_H_ #include #include "ops/ops.h" #include "util/threading_context.h" #include "util/zones.h" #include "hwy/base.h" #endif // THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_INL_H_ // Include guard for (potentially) SIMD code. #if defined(THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_TOGGLE) == \ defined(HWY_TARGET_TOGGLE) #ifdef THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_TOGGLE #undef THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_TOGGLE #else #define THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_TOGGLE #endif #include "compression/compress-inl.h" #include "hwy/contrib/math/fast_math-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; // We use the tanh approximation for gelu (also used in training). // gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3))) // = 0.5 * x * (1 + tanh(x * (sqrt(2/π) + sqrt(2/π) * 0.044715 * x^2))) // = 0.5 * x * (1 + tanh(x * (0.79788 + 0.035677 * x^2))) // = x * (0.5 + 0.5 * tanh(x * (0.79788 + 0.035677 * x^2)))) // // This uses hn::FastTanh from // third_party/highway/hwy/contrib/math/fast_math-inl.h template HWY_INLINE hn::Vec FastGelu(D d, hn::Vec v) { const hn::Vec kMul = hn::Set(d, 0.03567740813636141f); const hn::Vec kSqrt2OverPi = hn::Set(d, 0.797884560804236f); const hn::Vec kHalf = hn::Set(d, 0.5f); const hn::Vec v2 = hn::Mul(v, v); const hn::Vec arg = hn::Mul(v, hn::MulAdd(kMul, v2, kSqrt2OverPi)); const hn::Vec cdf = hn::MulAdd(kHalf, hn::FastTanh(d, arg), kHalf); return hn::Mul(v, cdf); } // Activation already has a profiler zone. template static HWY_NOINLINE HWY_MAYBE_UNUSED void FastGelu(T* HWY_RESTRICT x, size_t size) { namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; DecompressAndCompressInplace( DF(), x, size, [](DF d, VF v) HWY_ATTR -> VF { return FastGelu(d, v); }); } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); #endif // NOLINT