diff --git a/ops.h b/ops.h index f52e419..1305b3e 100644 --- a/ops.h +++ b/ops.h @@ -619,49 +619,29 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( MulByConstAndAdd(c, x, out, size, size); } -static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, size_t size, - size_t mask_pos) { +static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, + const size_t mask_pos) { HWY_DASSERT(size != 0); HWY_DASSERT(mask_pos <= size); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; const D d; - const size_t N = hn::Lanes(d); - // Find max so we can subtract it below. Avoid hn::Foreach because SVE vectors - // cannot be lambda-captured. - // TODO(janwas): could be replaced with an hn::Accumulate algo. - const hn::Vec vmin = hn::Set(d, hwy::LowestValue()); - hn::Vec vmax = vmin; - size_t idx = 0; - if (mask_pos >= N) { - for (; idx <= mask_pos - N; idx += N) { - vmax = hn::Max(vmax, LoadU(d, x + idx)); - } - } - vmax = hn::Max(vmax, LoadNOr(vmin, d, x + idx, mask_pos - idx)); - vmax = hn::MaxOfLanes(d, vmax); // broadcast + const auto vmin = hn::Set(d, hwy::LowestValue()); + auto vmax = vmin; + Foreach(d, x, mask_pos, vmin, + [&vmax](const auto d, const auto value) + HWY_ATTR { vmax = hn::Max(vmax, value); }); // Subtract max (avoid precision loss for large exponents) and exponentiate. - // Also avoid hn::Transform because the additional `sum` output vector cannot - // be captured by a lambda. - hn::Vec sum = hn::Zero(d); - idx = 0; - if (mask_pos >= N) { - for (; idx <= mask_pos - N; idx += N) { - const hn::Vec out = hn::Exp(d, hn::Sub(hn::LoadU(d, x + idx), vmax)); - sum = hn::Add(sum, out); - hn::StoreU(out, d, x + idx); - } - } - if (mask_pos > idx) { - const size_t remaining = mask_pos - idx; - const hn::Vec out = - hn::Exp(d, hn::Sub(hn::LoadN(d, x + idx, remaining), vmax)); - sum = hn::Add(sum, out); - hn::StoreN(out, d, x + idx, remaining); - } + auto sum = hn::Zero(d); + hn::Transform(d, x, mask_pos, + [&sum, &vmax](const auto d, const auto value) HWY_ATTR { + const auto out = hn::Exp(d, hn::Sub(value, vmax)); + sum = hn::Add(sum, out); + return out; + }); // Normalize to probability distribution const float mul = 1.0f / hn::ReduceSum(d, sum); @@ -669,29 +649,30 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, size_t size, } static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x, - size_t size) { + const size_t size) { Softmax(x, size, size); } static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, - size_t size, size_t max_pos) { + const size_t size, + const size_t max_pos) { HWY_DASSERT(max_pos <= size); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; const D d; - const float inv_cap = 1.0f / cap; + const auto vcap = hn::Set(d, cap); + const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap); - hn::Transform(d, x, size, [cap, inv_cap](D d, hn::Vec v) HWY_ATTR { - return hn::Mul(hn::Set(d, cap), - hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap)))); + hn::Transform(d, x, size, [&vcap, &vinv_cap](D d, hn::Vec v) HWY_ATTR { + return hn::Mul(vcap, hn::Tanh(d, hn::Mul(v, vinv_cap))); }); } static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, - size_t size) { + const size_t size) { LogitsSoftCap(cap, x, size, size); } @@ -716,8 +697,8 @@ create_distribution(std::array& top_k, float temperature) { using D = hn::ScalableTag; const D d; - const auto one = hn::Set(d, 1.0f); - const auto temperature_inv = hn::Div(one, hn::Set(d, temperature)); + const auto temperature_inv = + hn::Div(hn::Set(d, 1.0f), hn::Set(d, temperature)); hn::Transform(d, top_k.data(), top_k.size(), [&temperature_inv](D d, hn::Vec v) HWY_ATTR {