From 50144738f19cd6cd2c4a340901c5ca3f3a62c971 Mon Sep 17 00:00:00 2001 From: Nikhil Dev Goyal Date: Thu, 19 Mar 2026 07:36:28 -0700 Subject: [PATCH] Change calculation from (ax+b)/(cx+d) to (x + b')/(c'x+ d') this replaces a MulAdd with Add reducing port contention on modern cpus and thus increasing throughput. Also reduces the need for 1 register to hold b as 1.0 here PiperOrigin-RevId: 886170146 --- ops/fast_ops-inl.h | 112 ++++++++++++++++++++++----------------------- 1 file changed, 56 insertions(+), 56 deletions(-) diff --git a/ops/fast_ops-inl.h b/ops/fast_ops-inl.h index cc69321..a3c2051 100644 --- a/ops/fast_ops-inl.h +++ b/ops/fast_ops-inl.h @@ -73,7 +73,7 @@ HWY_INLINE hn::Vec FastSigmoid(D d, hn::Vec val) { auto y = hn::Abs(val); constexpr size_t kLanes = HWY_MAX_LANES_D(D); - hn::Vec a, c, d_coef; + hn::Vec b, c, d_coef; if constexpr ((kLanes >= 4 && !HWY_HAVE_SCALABLE) || (HWY_HAVE_SCALABLE && sizeof(T) == 4)) { @@ -98,41 +98,38 @@ HWY_INLINE hn::Vec FastSigmoid(D d, hn::Vec val) { // Clamp index to 7 idx_i = hn::Min(idx_i, hn::Set(DI(), 7)); - HWY_ALIGN static constexpr T arr_a[8] = { - static_cast(-1435.326650329326), - static_cast(-96.9456723845743), - static_cast(-18.628915468855695), - static_cast(-5.90191111348809), - static_cast(-2.356433838423728), - static_cast(-1.0464246812594584), - static_cast(-0.4801959711368016), - static_cast(-0.2132727031175401)}; - HWY_ALIGN static constexpr T arr_c[8] = {static_cast(-316.5640994591445), - static_cast(-49.14374182730444), - static_cast(-15.69264419046708), - static_cast(-6.949871926785674), - static_cast(-3.513259738716989), - static_cast(-1.839177585570145), - static_cast(-0.9298342163526662), - static_cast(-0.426230503963466)}; + HWY_ALIGN static constexpr T arr_b[8] = { + static_cast(-0.0006967055197996615), + static_cast(-0.010315055591476996), + static_cast(-0.05367999021047822), + static_cast(-0.16943664192343108), + static_cast(-0.42437007298661206), + static_cast(-0.9556349519550872), + static_cast(-2.0824831112860647), + static_cast(-4.688832585616333)}; + HWY_ALIGN static constexpr T arr_c[8] = { + static_cast(0.220551955463595), static_cast(0.5069204289218385), + static_cast(0.8423809865207907), static_cast(1.1775629610724903), + static_cast(1.4909222917402543), static_cast(1.757582383623199), + static_cast(1.9363640518503402), static_cast(1.9985234759675707)}; HWY_ALIGN static constexpr T arr_d[8] = { - static_cast(-5676.517069241468), static_cast(-363.0662559912978), - static_cast(-60.61589604370584), static_cast(-14.306713103378062), - static_cast(-2.725237489187118), static_cast(0.7890752292798894), - static_cast(1.8089988725725492), static_cast(1.9956027601801545)}; + static_cast(3.9548607753775276), static_cast(3.7450486139396544), + static_cast(3.253860706225495), static_cast(2.4240814251983283), + static_cast(1.1565092321921886), static_cast(-0.7540678688218365), + static_cast(-3.767209600467866), static_cast(-9.357047249878605)}; if constexpr (kLanes >= 8 && !HWY_HAVE_SCALABLE) { auto idx = hn::IndicesFromVec(d, idx_i); hn::CappedTag d8; - a = hn::TableLookupLanes(hn::ResizeBitCast(d, hn::Load(d8, arr_a)), idx); + b = hn::TableLookupLanes(hn::ResizeBitCast(d, hn::Load(d8, arr_b)), idx); c = hn::TableLookupLanes(hn::ResizeBitCast(d, hn::Load(d8, arr_c)), idx); d_coef = hn::TableLookupLanes(hn::ResizeBitCast(d, hn::Load(d8, arr_d)), idx); } else { auto idx = hn::IndicesFromVec(d, idx_i); hn::FixedTag d4; - a = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_a), - hn::Load(d4, arr_a + 4), idx); + b = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_b), + hn::Load(d4, arr_b + 4), idx); c = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_c), hn::Load(d4, arr_c + 4), idx); d_coef = hn::TwoTablesLookupLanes(d, hn::Load(d4, arr_d), @@ -150,66 +147,69 @@ HWY_INLINE hn::Vec FastSigmoid(D d, hn::Vec val) { const auto t6 = hn::Set(d, static_cast(5.271780018997146)); // Start with highest index (7) - a = hn::Set(d, static_cast(-0.2132727031175401)); - c = hn::Set(d, static_cast(-0.426230503963466)); - d_coef = hn::Set(d, static_cast(1.9956027601801545)); + b = hn::Set(d, static_cast(-4.688832585616333)); + c = hn::Set(d, static_cast(1.9985234759675707)); + d_coef = hn::Set(d, static_cast(-9.357047249878605)); // If y < t6 (idx 6) auto mask = hn::Lt(y, t6); - a = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.4801959711368016)), - a); - c = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.9298342163526662)), - c); + b = hn::IfThenElse(mask, hn::Set(d, static_cast(-2.0824831112860647)), + b); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(1.9363640518503402)), c); d_coef = hn::IfThenElse( - mask, hn::Set(d, static_cast(1.8089988725725492)), d_coef); + mask, hn::Set(d, static_cast(-3.767209600467866)), d_coef); // If y < t5 (idx 5) mask = hn::Lt(y, t5); - a = hn::IfThenElse(mask, hn::Set(d, static_cast(-1.0464246812594584)), - a); - c = hn::IfThenElse(mask, hn::Set(d, static_cast(-1.839177585570145)), c); + b = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.9556349519550872)), + b); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(1.757582383623199)), c); d_coef = hn::IfThenElse( - mask, hn::Set(d, static_cast(0.7890752292798894)), d_coef); + mask, hn::Set(d, static_cast(-0.7540678688218365)), d_coef); // If y < t4 (idx 4) mask = hn::Lt(y, t4); - a = hn::IfThenElse(mask, hn::Set(d, static_cast(-2.356433838423728)), a); - c = hn::IfThenElse(mask, hn::Set(d, static_cast(-3.513259738716989)), c); + b = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.42437007298661206)), + b); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(1.4909222917402543)), c); d_coef = hn::IfThenElse( - mask, hn::Set(d, static_cast(-2.725237489187118)), d_coef); + mask, hn::Set(d, static_cast(1.1565092321921886)), d_coef); // If y < t3 (idx 3) mask = hn::Lt(y, t3); - a = hn::IfThenElse(mask, hn::Set(d, static_cast(-5.90191111348809)), a); - c = hn::IfThenElse(mask, hn::Set(d, static_cast(-6.949871926785674)), c); + b = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.16943664192343108)), + b); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(1.1775629610724903)), c); d_coef = hn::IfThenElse( - mask, hn::Set(d, static_cast(-14.306713103378062)), d_coef); + mask, hn::Set(d, static_cast(2.4240814251983283)), d_coef); // If y < t2 (idx 2) mask = hn::Lt(y, t2); - a = hn::IfThenElse(mask, hn::Set(d, static_cast(-18.628915468855695)), - a); - c = hn::IfThenElse(mask, hn::Set(d, static_cast(-15.69264419046708)), c); - d_coef = hn::IfThenElse( - mask, hn::Set(d, static_cast(-60.61589604370584)), d_coef); + b = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.05367999021047822)), + b); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(0.8423809865207907)), c); + d_coef = hn::IfThenElse(mask, hn::Set(d, static_cast(3.253860706225495)), + d_coef); // If y < t1 (idx 1) mask = hn::Lt(y, t1); - a = hn::IfThenElse(mask, hn::Set(d, static_cast(-96.9456723845743)), a); - c = hn::IfThenElse(mask, hn::Set(d, static_cast(-49.14374182730444)), c); + b = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.010315055591476996)), + b); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(0.5069204289218385)), c); d_coef = hn::IfThenElse( - mask, hn::Set(d, static_cast(-363.0662559912978)), d_coef); + mask, hn::Set(d, static_cast(3.7450486139396544)), d_coef); // If y < t0 (idx 0) mask = hn::Lt(y, t0); - a = hn::IfThenElse(mask, hn::Set(d, static_cast(-1435.326650329326)), a); - c = hn::IfThenElse(mask, hn::Set(d, static_cast(-316.5640994591445)), c); + b = hn::IfThenElse(mask, hn::Set(d, static_cast(-0.0006967055197996615)), + b); + c = hn::IfThenElse(mask, hn::Set(d, static_cast(0.220551955463595)), c); d_coef = hn::IfThenElse( - mask, hn::Set(d, static_cast(-5676.517069241468)), d_coef); + mask, hn::Set(d, static_cast(3.9548607753775276)), d_coef); } - // Math: 0.5 * tanh(y/2) = (ay + 1.0)/(cy + d_coef) - auto num = hn::MulAdd(a, y, hn::Set(d, static_cast(1.0))); + // Math: 0.5 * tanh(y/2) = (y + b)/(cy + d_coef) + auto num = hn::Add(y, b); auto den = hn::MulAdd(c, y, d_coef); auto approx = hn::Div(num, den);