diff --git a/ops.h b/ops.h index 86dc54c..481e1d7 100644 --- a/ops.h +++ b/ops.h @@ -340,11 +340,21 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a, // = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT. static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2( const float* HWY_RESTRICT a, size_t size) { - float total = 0.f; - for (size_t i = 0; i < size; ++i) { - total += a[i] * a[i]; + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + HWY_DASSERT(size >= 2 * N); + HWY_DASSERT(size % (2 * N) == 0); + + auto sum0 = hn::Zero(d); + auto sum1 = hn::Zero(d); + for (size_t i = 0; i <= size - 2 * N; i += 2 * N) { + const auto a0 = LoadU(d, a + i); + sum0 = MulAdd(a0, a0, sum0); + const auto a1 = LoadU(d, a + i + N); + sum1 = MulAdd(a1, a1, sum1); } - return total; + + return ReduceSum(d, Add(sum0, sum1)); } static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(