use hwy/simd for SquaredL2 calculation

This commit is contained in:
enum-class 2024-03-05 17:37:09 +08:00
parent bb9b023502
commit 507d64e3e6
1 changed files with 14 additions and 4 deletions

18
ops.h
View File

@ -340,11 +340,21 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT. // = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT.
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2( static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
const float* HWY_RESTRICT a, size_t size) { const float* HWY_RESTRICT a, size_t size) {
float total = 0.f; const hn::ScalableTag<float> d;
for (size_t i = 0; i < size; ++i) { const size_t N = hn::Lanes(d);
total += a[i] * a[i]; HWY_DASSERT(size >= N);
HWY_DASSERT(size % (2 * N) == 0);
auto sum0 = hn::Zero(d);
auto sum1 = hn::Zero(d);
for (size_t i = 0; i + 2 * N <= size; i += 2 * N) {
const auto a0 = LoadU(d, a + i);
sum0 = MulAdd(a0, a0, sum0);
const auto a1 = LoadU(d, a + i + N);
sum1 = MulAdd(a1, a1, sum1);
} }
return total;
return ReduceSum(d, Add(sum0, sum1));
} }
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(