mirror of https://github.com/google/gemma.cpp.git
parent
539d9bb8e7
commit
6721dddf38
|
|
@ -63,6 +63,162 @@ HWY_INLINE hn::Vec<D> FastGelu(D d, hn::Vec<D> v) {
|
|||
return hn::Mul(v, cdf);
|
||||
}
|
||||
|
||||
// Fast approximation of sigmoid(x) = 1 / (1 + exp(-x))
|
||||
// Derived from FastTanh by substituting x/2.
|
||||
template <class D, HWY_IF_F32_D(D)>
|
||||
HWY_INLINE hn::Vec<D> FastSigmoid(D d, hn::Vec<D> val) {
|
||||
using T = hn::TFromD<D>;
|
||||
|
||||
// Abs(val) and preserve sign for later for symmetric rational approximation
|
||||
auto y = hn::Abs(val);
|
||||
|
||||
constexpr size_t kLanes = HWY_MAX_LANES_D(D);
|
||||
hn::Vec<D> a, c, d_coef;
|
||||
|
||||
if constexpr ((kLanes >= 4 && !HWY_HAVE_SCALABLE) ||
|
||||
(HWY_HAVE_SCALABLE && sizeof(T) == 4)) {
|
||||
// Coefficients for P(y/2) ~ index using CF algo
|
||||
const auto k0 = hn::Set(d, static_cast<T>(-0.1145426548151546));
|
||||
const auto k1 = hn::Set(d, static_cast<T>(3.4556654973457404));
|
||||
const auto k2 = hn::Set(d, static_cast<T>(-0.6278480784875462));
|
||||
const auto k3 = hn::Set(d, static_cast<T>(0.04331384030062471));
|
||||
|
||||
// Index calculation: idx = P(y/2)
|
||||
// Estrin's scheme
|
||||
// k0 + y * k1 + y^2 * (k2 + y * k3)
|
||||
const auto y2 = hn::Mul(y, y);
|
||||
const auto p01 = hn::MulAdd(k1, y, k0);
|
||||
const auto p23 = hn::MulAdd(k3, y, k2);
|
||||
auto idx_poly = hn::MulAdd(y2, p23, p01);
|
||||
|
||||
// Convert to integer index
|
||||
using DI = hn::RebindToSigned<D>;
|
||||
auto idx_i = hn::ConvertTo(DI(), idx_poly);
|
||||
|
||||
// Clamp index to 7
|
||||
idx_i = hn::Min(idx_i, hn::Set(DI(), 7));
|
||||
|
||||
HWY_ALIGN static constexpr T arr_a[] = {
|
||||
static_cast<T>(-1435.326650329326),
|
||||
static_cast<T>(-96.9456723845743),
|
||||
static_cast<T>(-18.628915468855695),
|
||||
static_cast<T>(-5.90191111348809),
|
||||
static_cast<T>(-2.356433838423728),
|
||||
static_cast<T>(-1.0464246812594584),
|
||||
static_cast<T>(-0.4801959711368016),
|
||||
static_cast<T>(-0.2132727031175401)};
|
||||
HWY_ALIGN static constexpr T arr_c[] = {static_cast<T>(-316.5640994591445),
|
||||
static_cast<T>(-49.14374182730444),
|
||||
static_cast<T>(-15.69264419046708),
|
||||
static_cast<T>(-6.949871926785674),
|
||||
static_cast<T>(-3.513259738716989),
|
||||
static_cast<T>(-1.839177585570145),
|
||||
static_cast<T>(-0.9298342163526662),
|
||||
static_cast<T>(-0.426230503963466)};
|
||||
HWY_ALIGN static constexpr T arr_d[] = {
|
||||
static_cast<T>(-5676.517069241468), static_cast<T>(-363.0662559912978),
|
||||
static_cast<T>(-60.61589604370584), static_cast<T>(-14.306713103378062),
|
||||
static_cast<T>(-2.725237489187118), static_cast<T>(0.7890752292798894),
|
||||
static_cast<T>(1.8089988725725492), static_cast<T>(1.9956027601801545)};
|
||||
|
||||
if constexpr (kLanes >= 8 && !HWY_HAVE_SCALABLE) {
|
||||
auto idx = hn::IndicesFromVec(d, idx_i);
|
||||
a = hn::TableLookupLanes(hn::Load(d, arr_a), idx);
|
||||
c = hn::TableLookupLanes(hn::Load(d, arr_c), idx);
|
||||
d_coef = hn::TableLookupLanes(hn::Load(d, arr_d), idx);
|
||||
} else {
|
||||
auto idx = hn::IndicesFromVec(d, idx_i);
|
||||
hn::FixedTag<T, 4> d4;
|
||||
a = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_a),
|
||||
hn::Load(d4, arr_a + 4), idx);
|
||||
c = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_c),
|
||||
hn::Load(d4, arr_c + 4), idx);
|
||||
d_coef = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_d),
|
||||
hn::Load(d4, arr_d + 4), idx);
|
||||
}
|
||||
} else {
|
||||
// --- FALLBACK PATH: Blend Chain ---
|
||||
// Thresholds for intervals
|
||||
const auto t0 = hn::Set(d, static_cast<T>(0.3434497447432422));
|
||||
const auto t1 = hn::Set(d, static_cast<T>(0.6955976007186494));
|
||||
const auto t2 = hn::Set(d, static_cast<T>(1.1068914127668934));
|
||||
const auto t3 = hn::Set(d, static_cast<T>(1.608648163822941));
|
||||
const auto t4 = hn::Set(d, static_cast<T>(2.269039121646492));
|
||||
const auto t5 = hn::Set(d, static_cast<T>(3.288402547357102));
|
||||
const auto t6 = hn::Set(d, static_cast<T>(5.271780018997146));
|
||||
|
||||
// Start with highest index (7)
|
||||
a = hn::Set(d, static_cast<T>(-0.2132727031175401));
|
||||
c = hn::Set(d, static_cast<T>(-0.426230503963466));
|
||||
d_coef = hn::Set(d, static_cast<T>(1.9956027601801545));
|
||||
|
||||
// If y < t6 (idx 6)
|
||||
auto mask = hn::Lt(y, t6);
|
||||
a = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.4801959711368016)),
|
||||
a);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-0.9298342163526662)),
|
||||
c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(1.8089988725725492)), d_coef);
|
||||
|
||||
// If y < t5 (idx 5)
|
||||
mask = hn::Lt(y, t5);
|
||||
a = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-1.0464246812594584)),
|
||||
a);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-1.839177585570145)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(0.7890752292798894)), d_coef);
|
||||
|
||||
// If y < t4 (idx 4)
|
||||
mask = hn::Lt(y, t4);
|
||||
a = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-2.356433838423728)), a);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-3.513259738716989)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-2.725237489187118)), d_coef);
|
||||
|
||||
// If y < t3 (idx 3)
|
||||
mask = hn::Lt(y, t3);
|
||||
a = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-5.90191111348809)), a);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-6.949871926785674)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-14.306713103378062)), d_coef);
|
||||
|
||||
// If y < t2 (idx 2)
|
||||
mask = hn::Lt(y, t2);
|
||||
a = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-18.628915468855695)),
|
||||
a);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-15.69264419046708)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-60.61589604370584)), d_coef);
|
||||
|
||||
// If y < t1 (idx 1)
|
||||
mask = hn::Lt(y, t1);
|
||||
a = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-96.9456723845743)), a);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-49.14374182730444)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-363.0662559912978)), d_coef);
|
||||
|
||||
// If y < t0 (idx 0)
|
||||
mask = hn::Lt(y, t0);
|
||||
a = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-1435.326650329326)), a);
|
||||
c = hn::IfThenElse(mask, hn::Set(d, static_cast<T>(-316.5640994591445)), c);
|
||||
d_coef = hn::IfThenElse(
|
||||
mask, hn::Set(d, static_cast<T>(-5676.517069241468)), d_coef);
|
||||
}
|
||||
|
||||
// Math: 0.5 * tanh(y/2) = (ay + 1.0)/(cy + d_coef)
|
||||
auto num = hn::MulAdd(a, y, hn::Set(d, static_cast<T>(1.0)));
|
||||
auto den = hn::MulAdd(c, y, d_coef);
|
||||
|
||||
auto approx = hn::Div(num, den);
|
||||
|
||||
const auto half = hn::Set(d, static_cast<T>(0.5));
|
||||
// Clamp the approx value to 0.5
|
||||
approx = hn::Min(approx, half);
|
||||
// sigmoid(x) = 0.5 + sign(x) * (0.5 * tanh(|x|/2))
|
||||
return hn::Add(half, hn::CopySign(approx, val));
|
||||
}
|
||||
|
||||
// Activation already has a profiler zone.
|
||||
template <typename T>
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void FastGelu(T* HWY_RESTRICT x,
|
||||
|
|
@ -74,6 +230,17 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void FastGelu(T* HWY_RESTRICT x,
|
|||
DF(), x, size, [](DF d, VF v) HWY_ATTR -> VF { return FastGelu(d, v); });
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void FastSigmoid(T* HWY_RESTRICT x,
|
||||
size_t size) {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
using VF = hn::Vec<DF>;
|
||||
DecompressAndCompressInplace(DF(), x, size, [](DF d, VF v) HWY_ATTR -> VF {
|
||||
return FastSigmoid(d, v);
|
||||
});
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -441,6 +441,31 @@ static HWY_NOINLINE void TestAllSigmoid() {
|
|||
ForeachActivationType1<TestSigmoid>(hn::ScalableTag<float>());
|
||||
}
|
||||
|
||||
struct TestFastSigmoid {
|
||||
template <typename T, class D>
|
||||
void operator()(T, D) const {
|
||||
std::vector<T> values;
|
||||
for (int i = -150; i <= 150; ++i) {
|
||||
values.push_back(hwy::ConvertScalarTo<T>(.1f * i));
|
||||
}
|
||||
std::vector<T> result = values;
|
||||
gcpp::HWY_NAMESPACE::FastSigmoid(result.data(), result.size());
|
||||
|
||||
for (size_t i = 0; i < values.size(); i++) {
|
||||
const float max_error = IsBF16<T>() ? 0.003f : 0.0004f;
|
||||
const float value = hwy::ConvertScalarTo<float>(values[i]);
|
||||
const float actual = hwy::ConvertScalarTo<float>(result[i]);
|
||||
const float expected = (1 / (1 + std::exp(-value)));
|
||||
EXPECT_NEAR(expected, actual, max_error)
|
||||
<< (IsBF16<T>() ? "bf16" : "float");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static HWY_NOINLINE void TestAllFastSigmoid() {
|
||||
ForeachActivationType1<TestFastSigmoid>(hn::ScalableTag<float>());
|
||||
}
|
||||
|
||||
struct TestGelu {
|
||||
template <typename T, class D>
|
||||
void operator()(T, D) const {
|
||||
|
|
@ -844,6 +869,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
|||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmaxState);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllFastSigmoid);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllFastGelu);
|
||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy);
|
||||
|
|
|
|||
Loading…
Reference in New Issue