diff --git a/ops.h b/ops.h index 481e1d7..7aa7b62 100644 --- a/ops.h +++ b/ops.h @@ -341,20 +341,21 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a, static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2( const float* HWY_RESTRICT a, size_t size) { const hn::ScalableTag d; + using V = hn::Vec; const size_t N = hn::Lanes(d); HWY_DASSERT(size >= 2 * N); HWY_DASSERT(size % (2 * N) == 0); - auto sum0 = hn::Zero(d); - auto sum1 = hn::Zero(d); + V sum0 = hn::Zero(d); + V sum1 = hn::Zero(d); for (size_t i = 0; i <= size - 2 * N; i += 2 * N) { - const auto a0 = LoadU(d, a + i); - sum0 = MulAdd(a0, a0, sum0); - const auto a1 = LoadU(d, a + i + N); - sum1 = MulAdd(a1, a1, sum1); + const V a0 = hn::LoadU(d, a + i); + sum0 = hn::MulAdd(a0, a0, sum0); + const V a1 = hn::LoadU(d, a + i + N); + sum1 = hn::MulAdd(a1, a1, sum1); } - return ReduceSum(d, Add(sum0, sum1)); + return hn::ReduceSum(d, hn::Add(sum0, sum1)); } static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(