diff --git a/ops.h b/ops.h index db2ae4f..7619b44 100644 --- a/ops.h +++ b/ops.h @@ -214,7 +214,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x, size_t size) { namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; - hn::Transform(D(), x, size, [](D d, hn::Vec v) { return Gelu(d, v); }); + hn::Transform(D(), x, size, + [](D d, hn::Vec v) HWY_ATTR { return Gelu(d, v); }); } // out[i] = BF(mul[i] * Gelu(gelu_in[i])) @@ -567,22 +568,41 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, size_t size, namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; const D d; - using V = hn::Vec; + const size_t N = hn::Lanes(d); - // Find max so we can subtract it below. - const V vmin = hn::Set(d, hwy::LowestValue()); - V max = vmin; - hn::Foreach(d, x, mask_pos, vmin, - [&max](D d, V v) { max = hn::Max(max, v); }); - max = hn::MaxOfLanes(d, max); // broadcast + // Find max so we can subtract it below. Avoid hn::Foreach because SVE vectors + // cannot be lambda-captured. + // TODO(janwas): could be replaced with an hn::Accumulate algo. + const hn::Vec vmin = hn::Set(d, hwy::LowestValue()); + hn::Vec vmax = vmin; + size_t idx = 0; + if (mask_pos >= N) { + for (; idx <= mask_pos - N; idx += N) { + vmax = hn::Max(vmax, LoadU(d, x + idx)); + } + } + vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx)); + vmax = hn::MaxOfLanes(d, vmax); // broadcast // Subtract max (avoid precision loss for large exponents) and exponentiate. - V sum = hn::Zero(d); - hn::Transform(d, x, mask_pos, [&sum, max](D d, V v) { - const V out = hn::Exp(d, hn::Sub(v, max)); + // Also avoid hn::Transform because the additional `sum` output vector cannot + // be captured by a lambda. + hn::Vec sum = hn::Zero(d); + idx = 0; + if (mask_pos >= N) { + for (; idx <= mask_pos - N; idx += N) { + const hn::Vec out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax)); + sum = hn::Add(sum, out); + hn::StoreU(out, d, x + idx); + } + } + if (mask_pos > idx) { + const size_t remaining = mask_pos - idx; + const hn::Vec out = + hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax)); sum = hn::Add(sum, out); - return out; - }); + hn::StoreN(out, d, x + idx, remaining); + } // Normalize to probability distribution const float mul = 1.0f / hn::ReduceSum(d, sum); @@ -601,13 +621,12 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; const D d; - using V = hn::Vec; - const V inv_cap = hn::Set(d, 1.0f / cap); - const V vcap = hn::Set(d, cap); + const float inv_cap = 1.0f / cap; - hn::Transform(d, x, size, [vcap, inv_cap](D d, hn::Vec v) { - return hn::Mul(vcap, hn::Tanh(d, hn::Mul(inv_cap, v))); + hn::Transform(d, x, size, [cap, inv_cap](D d, hn::Vec v) HWY_ATTR { + return hn::Mul(hn::Set(d, cap), + hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap)))); }); }