From 5f016fb433ff9ee9fbdc9834697de9256a05b84c Mon Sep 17 00:00:00 2001 From: enum-class Date: Tue, 5 Mar 2024 17:53:52 +0800 Subject: [PATCH 1/2] use hwy/simd for RMSNorm(f, bf, f) calculation --- ops.h | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/ops.h b/ops.h index 8f92d82..1919ac9 100644 --- a/ops.h +++ b/ops.h @@ -362,12 +362,30 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight, float* HWY_RESTRICT out, size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + constexpr float eps = 1e-6f; - float ss = SquaredL2(x, size); - ss = 1.0f / sqrtf(ss / StaticCast(size) + eps); - for (size_t j = 0; j < size; j++) { - // Note 1.0f centering here - out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]); + constexpr size_t unroll_size = 2; + + const hn::ScalableTag dbf; + const hn::Repartition df32; + const size_t N32 = hn::Lanes(df32); + + const float ss = SquaredL2(x, size); + const auto vss = + hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + eps)); + + HWY_DASSERT(size % (unroll_size * MaxLanes(df32)) == 0); + for (size_t i = 0; i < size; i += unroll_size * N32) { + const hn::Vec w16 = hn::LoadU(dbf, weight + i); + const auto w0 = hn::PromoteLowerTo(df32, w16); + const auto w1 = hn::PromoteUpperTo(df32, w16); + const auto m0 = hn::Mul(vss, hn::LoadU(df32, x + i)); + const auto m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32)); + + // (1+weight) * m = m + weight*m = one FMA. + hn::StoreU(hn::MulAdd(m0, w0, m0), df32, out + i); + hn::StoreU(hn::MulAdd(m1, w1, m1), df32, out + i + N32); } } From bc845515b7261284fd001fd0f481a1593e29ff97 Mon Sep 17 00:00:00 2001 From: enum-class Date: Tue, 5 Mar 2024 20:45:30 +0800 Subject: [PATCH 2/2] fix style, add kCamelCase style for constexpr in clang-tidy --- .clang-tidy | 4 ++++ ops.h | 10 +++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/.clang-tidy b/.clang-tidy index abcd9d7..497c2e3 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,4 +1,5 @@ FormatStyle: file +WarningsAsErrors: "*" Checks: "-*,\ abseil-*,\ -abseil-string-find-startswith,\ @@ -204,3 +205,6 @@ Checks: "-*,\ -readability-uppercase-literal-suffix,\ -readability-use-anyofallof " +CheckOptions: + - { key: readability-identifier-naming.ConstexprVariableCase, value: CamelCase } + - { key: readability-identifier-naming.ConstexprVariablePrefix, value: k } diff --git a/ops.h b/ops.h index 1919ac9..5d34b3a 100644 --- a/ops.h +++ b/ops.h @@ -364,8 +364,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( float* HWY_RESTRICT out, size_t size) { namespace hn = hwy::HWY_NAMESPACE; - constexpr float eps = 1e-6f; - constexpr size_t unroll_size = 2; + constexpr float kEps = 1e-6f; + constexpr size_t kUnrollSize = 2; const hn::ScalableTag dbf; const hn::Repartition df32; @@ -373,10 +373,10 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( const float ss = SquaredL2(x, size); const auto vss = - hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + eps)); + hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + kEps)); - HWY_DASSERT(size % (unroll_size * MaxLanes(df32)) == 0); - for (size_t i = 0; i < size; i += unroll_size * N32) { + HWY_DASSERT(size % (kUnrollSize * MaxLanes(df32)) == 0); + for (size_t i = 0; i < size; i += kUnrollSize * N32) { const hn::Vec w16 = hn::LoadU(dbf, weight + i); const auto w0 = hn::PromoteLowerTo(df32, w16); const auto w1 = hn::PromoteUpperTo(df32, w16);