diff --git a/ops.h b/ops.h index 481e1d7..f52e419 100644 --- a/ops.h +++ b/ops.h @@ -551,47 +551,66 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul, } static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom( - const float* HWY_RESTRICT other, float* HWY_RESTRICT x, size_t size) { - for (size_t i = 0; i < size; ++i) { - x[i] += other[i]; - } + const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + const D d; + + hn::Transform1(d, x, size, other, + [](const auto d, const auto x, const auto other) + HWY_ATTR { return hn::Add(x, other); }); } static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other, - float* HWY_RESTRICT x, size_t size, - size_t max_pos) { + float* HWY_RESTRICT x, const size_t size, + const size_t max_pos) { HWY_DASSERT(max_pos <= size); - for (size_t i = 0; i < max_pos; ++i) { - x[i] *= other[i]; - } + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + const D d; + + hn::Transform1(d, x, max_pos, other, + [](const auto d, const auto x, const auto other) + HWY_ATTR { return hn::Mul(x, other); }); } static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other, float* HWY_RESTRICT x, - size_t size) { + const size_t size) { return MulBy(other, x, size, size); } -static HWY_NOINLINE void MulByConst(float c, float* HWY_RESTRICT x, size_t size, - size_t max_pos) { +static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x, + const size_t size, const size_t max_pos) { HWY_DASSERT(max_pos <= size); - for (size_t i = 0; i < max_pos; ++i) { - x[i] *= c; - } + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + const D d; + const auto constant = hn::Set(d, c); + hn::Transform(d, x, max_pos, + [&constant](const auto d, const auto x) + HWY_ATTR { return hn::Mul(x, constant); }); } -static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(float c, +static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c, float* HWY_RESTRICT x, - size_t size) { + const size_t size) { MulByConst(c, x, size, size); } -static HWY_NOINLINE void MulByConstAndAdd(float c, const float* HWY_RESTRICT x, - float* HWY_RESTRICT out, size_t size, - size_t max_pos) { - for (size_t i = 0; i < max_pos; ++i) { - out[i] += x[i] * c; - } +static HWY_NOINLINE void MulByConstAndAdd(const float c, + const float* HWY_RESTRICT x, + float* HWY_RESTRICT out, + const size_t size, + const size_t max_pos) { + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + const D d; + const auto constant = hn::Set(d, c); + hn::Transform1( + d, out, max_pos, x, + [&constant](const auto d, const auto out_element, const auto x_element) + HWY_ATTR { return hn::MulAdd(x_element, constant, out_element); }); } static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( @@ -693,15 +712,18 @@ template static HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution create_distribution(std::array& top_k, float temperature) { // re-normalize distribution - for (size_t i = 0; i < k; ++i) { - top_k[i] = exp(log(top_k[i]) / temperature); - } - float denominator = 0.0f; - for (size_t i = 0; i < k; ++i) { - denominator += top_k[i]; - } - denominator = 1.0f / denominator; - MulByConst(denominator, top_k.data(), k); + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + const D d; + + const auto one = hn::Set(d, 1.0f); + const auto temperature_inv = hn::Div(one, hn::Set(d, temperature)); + + hn::Transform(d, top_k.data(), top_k.size(), + [&temperature_inv](D d, hn::Vec v) HWY_ATTR { + return hn::Mul(hn::Exp(d, hn::Log(d, v)), temperature_inv); + }); + return std::discrete_distribution(std::begin(top_k), std::end(top_k)); }