From 6721dddf3862d46d76945c394572b5eff7bf1489 Mon Sep 17 00:00:00 2001 From: Nikhil Dev Goyal Date: Wed, 4 Mar 2026 06:12:14 -0800 Subject: [PATCH] Implement FastSigmoid. PiperOrigin-RevId: 878453196 --- ops/fast_ops-inl.h | 167 +++++++++++++++++++++++++++++++++++++++++++++ ops/ops_test.cc | 26 +++++++ 2 files changed, 193 insertions(+) diff --git a/ops/fast_ops-inl.h b/ops/fast_ops-inl.h index 1120d9a..bf1475d 100644 --- a/ops/fast_ops-inl.h +++ b/ops/fast_ops-inl.h @@ -63,6 +63,162 @@ HWY_INLINE hn::Vec FastGelu(D d, hn::Vec v) { return hn::Mul(v, cdf); } +// Fast approximation of sigmoid(x) = 1 / (1 + exp(-x)) +// Derived from FastTanh by substituting x/2. +template +HWY_INLINE hn::Vec FastSigmoid(D d, hn::Vec val) { + using T = hn::TFromD; + + // Abs(val) and preserve sign for later for symmetric rational approximation + auto y = hn::Abs(val); + + constexpr size_t kLanes = HWY_MAX_LANES_D(D); + hn::Vec a, c, d_coef; + + if constexpr ((kLanes >= 4 && !HWY_HAVE_SCALABLE) || + (HWY_HAVE_SCALABLE && sizeof(T) == 4)) { + // Coefficients for P(y/2) ~ index using CF algo + const auto k0 = hn::Set(d, static_cast(-0.1145426548151546)); + const auto k1 = hn::Set(d, static_cast(3.4556654973457404)); + const auto k2 = hn::Set(d, static_cast(-0.6278480784875462)); + const auto k3 = hn::Set(d, static_cast(0.04331384030062471)); + + // Index calculation: idx = P(y/2) + // Estrin's scheme + // k0 + y * k1 + y^2 * (k2 + y * k3) + const auto y2 = hn::Mul(y, y); + const auto p01 = hn::MulAdd(k1, y, k0); + const auto p23 = hn::MulAdd(k3, y, k2); + auto idx_poly = hn::MulAdd(y2, p23, p01); + + // Convert to integer index + using DI = hn::RebindToSigned; + auto idx_i = hn::ConvertTo(DI(), idx_poly); + + // Clamp index to 7 + idx_i = hn::Min(idx_i, hn::Set(DI(), 7)); + + HWY_ALIGN static constexpr T arr_a[] = { + static_cast(-1435.326650329326), + static_cast(-96.9456723845743), + static_cast(-18.628915468855695), + static_cast(-5.90191111348809), + static_cast(-2.356433838423728), + static_cast(-1.0464246812594584), + static_cast(-0.4801959711368016), + static_cast(-0.2132727031175401)}; + HWY_ALIGN static constexpr T arr_c[] = {static_cast(-316.5640994591445), + static_cast(-49.14374182730444), + static_cast(-15.69264419046708), + static_cast(-6.949871926785674), + static_cast(-3.513259738716989), + static_cast(-1.839177585570145), + static_cast(-0.9298342163526662), + static_cast(-0.426230503963466)}; + HWY_ALIGN static constexpr T arr_d[] = { + static_cast(-5676.517069241468), static_cast(-363.0662559912978), + static_cast(-60.61589604370584), static_cast(-14.306713103378062), + static_cast(-2.725237489187118), static_cast(0.7890752292798894), + static_cast(1.8089988725725492), static_cast(1.9956027601801545)}; + + if constexpr (kLanes >= 8 && !HWY_HAVE_SCALABLE) { + auto idx = hn::IndicesFromVec(d, idx_i); + a = hn::TableLookupLanes(hn::Load(d, arr_a), idx); + c = hn::TableLookupLanes(hn::Load(d, arr_c), idx); + d_coef = hn::TableLookupLanes(hn::Load(d, arr_d), idx); + } else { + auto idx = hn::IndicesFromVec(d, idx_i); + hn::FixedTag d4; + a = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_a), + hn::Load(d4, arr_a + 4), idx); + c = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_c), + hn::Load(d4, arr_c + 4), idx); + d_coef = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_d), + hn::Load(d4, arr_d + 4), idx); + } + } else { + // --- FALLBACK PATH: Blend Chain --- + // Thresholds for intervals + const auto t0 = hn::Set(d, static_cast(0.3434497447432422)); + const auto t1 = hn::Set(d, static_cast(0.6955976007186494)); + const auto t2 = hn::Set(d, static_cast(1.1068914127668934)); + const auto t3 = hn::Set(d, static_cast(1.608648163822941)); + const auto t4 = hn::Set(d, static_cast(2.269039121646492)); + const auto t5 = hn::Set(d, static_cast(3.288402547357102)); + const auto t6 = hn::Set(d, static_cast(5.271780018997146)); + + // Start with highest index (7) + a = hn::Set(d, static_cast(-0.2132727031175401)); + c = hn::Set(d, static_cast(-0.426230503963466)); + d_coef = hn::Set(d, static_cast(1.9956027601801545)); + + // If y < t6 (idx 6) + auto mask = hn::Lt(y, t6); + a = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.4801959711368016)), + a); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.9298342163526662)), + c); + d_coef = hn::IfThenElse( + mask, hn::Set(d, static_cast(1.8089988725725492)), d_coef); + + // If y < t5 (idx 5) + mask = hn::Lt(y, t5); + a = hn::IfThenElse(mask, hn::Set(d, static_cast(-1.0464246812594584)), + a); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(-1.839177585570145)), c); + d_coef = hn::IfThenElse( + mask, hn::Set(d, static_cast(0.7890752292798894)), d_coef); + + // If y < t4 (idx 4) + mask = hn::Lt(y, t4); + a = hn::IfThenElse(mask, hn::Set(d, static_cast(-2.356433838423728)), a); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(-3.513259738716989)), c); + d_coef = hn::IfThenElse( + mask, hn::Set(d, static_cast(-2.725237489187118)), d_coef); + + // If y < t3 (idx 3) + mask = hn::Lt(y, t3); + a = hn::IfThenElse(mask, hn::Set(d, static_cast(-5.90191111348809)), a); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(-6.949871926785674)), c); + d_coef = hn::IfThenElse( + mask, hn::Set(d, static_cast(-14.306713103378062)), d_coef); + + // If y < t2 (idx 2) + mask = hn::Lt(y, t2); + a = hn::IfThenElse(mask, hn::Set(d, static_cast(-18.628915468855695)), + a); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(-15.69264419046708)), c); + d_coef = hn::IfThenElse( + mask, hn::Set(d, static_cast(-60.61589604370584)), d_coef); + + // If y < t1 (idx 1) + mask = hn::Lt(y, t1); + a = hn::IfThenElse(mask, hn::Set(d, static_cast(-96.9456723845743)), a); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(-49.14374182730444)), c); + d_coef = hn::IfThenElse( + mask, hn::Set(d, static_cast(-363.0662559912978)), d_coef); + + // If y < t0 (idx 0) + mask = hn::Lt(y, t0); + a = hn::IfThenElse(mask, hn::Set(d, static_cast(-1435.326650329326)), a); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(-316.5640994591445)), c); + d_coef = hn::IfThenElse( + mask, hn::Set(d, static_cast(-5676.517069241468)), d_coef); + } + + // Math: 0.5 * tanh(y/2) = (ay + 1.0)/(cy + d_coef) + auto num = hn::MulAdd(a, y, hn::Set(d, static_cast(1.0))); + auto den = hn::MulAdd(c, y, d_coef); + + auto approx = hn::Div(num, den); + + const auto half = hn::Set(d, static_cast(0.5)); + // Clamp the approx value to 0.5 + approx = hn::Min(approx, half); + // sigmoid(x) = 0.5 + sign(x) * (0.5 * tanh(|x|/2)) + return hn::Add(half, hn::CopySign(approx, val)); +} + // Activation already has a profiler zone. template static HWY_NOINLINE HWY_MAYBE_UNUSED void FastGelu(T* HWY_RESTRICT x, @@ -74,6 +230,17 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void FastGelu(T* HWY_RESTRICT x, DF(), x, size, [](DF d, VF v) HWY_ATTR -> VF { return FastGelu(d, v); }); } +template +static HWY_NOINLINE HWY_MAYBE_UNUSED void FastSigmoid(T* HWY_RESTRICT x, + size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + using DF = hn::ScalableTag; + using VF = hn::Vec; + DecompressAndCompressInplace(DF(), x, size, [](DF d, VF v) HWY_ATTR -> VF { + return FastSigmoid(d, v); + }); +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 3abb5b8..8d17726 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -441,6 +441,31 @@ static HWY_NOINLINE void TestAllSigmoid() { ForeachActivationType1(hn::ScalableTag()); } +struct TestFastSigmoid { + template + void operator()(T, D) const { + std::vector values; + for (int i = -150; i <= 150; ++i) { + values.push_back(hwy::ConvertScalarTo(.1f * i)); + } + std::vector result = values; + gcpp::HWY_NAMESPACE::FastSigmoid(result.data(), result.size()); + + for (size_t i = 0; i < values.size(); i++) { + const float max_error = IsBF16() ? 0.003f : 0.0004f; + const float value = hwy::ConvertScalarTo(values[i]); + const float actual = hwy::ConvertScalarTo(result[i]); + const float expected = (1 / (1 + std::exp(-value))); + EXPECT_NEAR(expected, actual, max_error) + << (IsBF16() ? "bf16" : "float"); + } + } +}; + +static HWY_NOINLINE void TestAllFastSigmoid() { + ForeachActivationType1(hn::ScalableTag()); +} + struct TestGelu { template void operator()(T, D) const { @@ -844,6 +869,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmaxState); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllFastSigmoid); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllFastGelu); HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy);