mirror of https://github.com/google/gemma.cpp.git
Add FastGelu activation function in a newly created created fast_ops-inl.h files.
This replaces the Tanh call with FastTanh call in the Gelu function written in math-inl.h. PiperOrigin-RevId: 876339830
This commit is contained in:
parent
bdba3bfa63
commit
dd268ddbe8
|
|
@ -387,6 +387,7 @@ cc_library(
|
|||
"ops/sum-inl.h",
|
||||
"ops/fp_arith-inl.h",
|
||||
"ops/ops-inl.h",
|
||||
"ops/fast_ops-inl.h",
|
||||
],
|
||||
deps = [
|
||||
":allocator",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,82 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Include guard for non-SIMD code.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_INL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_INL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "ops/ops.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "util/zones.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_INL_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_TOGGLE) == \
|
||||
defined(HWY_TARGET_TOGGLE)
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "compression/compress-inl.h"
|
||||
#include "hwy/contrib/math/fast_math-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
// We use the tanh approximation for gelu (also used in training).
|
||||
// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
|
||||
// = 0.5 * x * (1 + tanh(x * (sqrt(2/π) + sqrt(2/π) * 0.044715 * x^2)))
|
||||
// = 0.5 * x * (1 + tanh(x * (0.79788 + 0.035677 * x^2)))
|
||||
// = x * (0.5 + 0.5 * tanh(x * (0.79788 + 0.035677 * x^2))))
|
||||
//
|
||||
// This uses hn::FastTanh from
|
||||
// third_party/highway/hwy/contrib/math/fast_math-inl.h
|
||||
template <class D, HWY_IF_F32_D(D)>
|
||||
HWY_INLINE hn::Vec<D> FastGelu(D d, hn::Vec<D> v) {
|
||||
const hn::Vec<D> kMul = hn::Set(d, 0.03567740813636141f);
|
||||
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
|
||||
const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
|
||||
|
||||
const hn::Vec<D> v2 = hn::Mul(v, v);
|
||||
const hn::Vec<D> arg = hn::Mul(v, hn::MulAdd(kMul, v2, kSqrt2OverPi));
|
||||
const hn::Vec<D> cdf = hn::MulAdd(kHalf, hn::FastTanh(d, arg), kHalf);
|
||||
return hn::Mul(v, cdf);
|
||||
}
|
||||
|
||||
// Activation already has a profiler zone.
|
||||
template <typename T>
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void FastGelu(T* HWY_RESTRICT x,
|
||||
size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
using VF = hn::Vec<DF>;
|
||||
DecompressAndCompressInplace(
|
||||
DF(), x, size, [](DF d, VF v) HWY_ATTR -> VF { return FastGelu(d, v); });
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // NOLINT
|
||||
|
|
@ -48,6 +48,7 @@
|
|||
// After highway.h
|
||||
#include "compression/test_util-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
#include "ops/fast_ops-inl.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
|
|
@ -466,6 +467,32 @@ static HWY_NOINLINE void TestAllGelu() {
|
|||
ForeachActivationType1<TestGelu>(hn::ScalableTag<float>());
|
||||
}
|
||||
|
||||
struct TestFastGelu {
|
||||
template <typename T, class D>
|
||||
void operator()(T, D) const {
|
||||
std::vector<T> values;
|
||||
for (int i = -150; i <= 150; ++i) {
|
||||
values.push_back(hwy::ConvertScalarTo<T>(.1f * i));
|
||||
}
|
||||
std::vector<T> result = values;
|
||||
gcpp::HWY_NAMESPACE::FastGelu(result.data(), result.size());
|
||||
|
||||
for (size_t i = 0; i < values.size(); i++) {
|
||||
const float max_error = IsBF16<T>() ? 0.02f : 0.002f;
|
||||
const float x = hwy::ConvertScalarTo<float>(values[i]);
|
||||
const float actual = hwy::ConvertScalarTo<float>(result[i]);
|
||||
const float expected =
|
||||
x * (0.5f + 0.5f * tanh(x * (0.79788f + 0.035677f * x * x)));
|
||||
EXPECT_NEAR(expected, actual, max_error)
|
||||
<< (IsBF16<T>() ? "bf16" : "float");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static HWY_NOINLINE void TestAllFastGelu() {
|
||||
ForeachActivationType1<TestFastGelu>(hn::ScalableTag<float>());
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
||||
const float mul, float* HWY_RESTRICT x, const size_t dim_qkv,
|
||||
const float* HWY_RESTRICT inv_timescale, const int pos) {
|
||||
|
|
@ -818,6 +845,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmaxState);
|
|||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllFastGelu);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNormInplace);
|
||||
|
|
|
|||
Loading…
Reference in New Issue