Merge pull request #78 from enum-class:rmsnorm2

PiperOrigin-RevId: 614480854
This commit is contained in:
Copybara-Service 2024-03-10 16:14:44 -07:00
commit e577198fc0
2 changed files with 28 additions and 6 deletions

View File

@ -1,4 +1,5 @@
FormatStyle: file FormatStyle: file
WarningsAsErrors: "*"
Checks: "-*,\ Checks: "-*,\
abseil-*,\ abseil-*,\
-abseil-string-find-startswith,\ -abseil-string-find-startswith,\
@ -204,3 +205,6 @@ Checks: "-*,\
-readability-uppercase-literal-suffix,\ -readability-uppercase-literal-suffix,\
-readability-use-anyofallof -readability-use-anyofallof
" "
CheckOptions:
- { key: readability-identifier-naming.ConstexprVariableCase, value: CamelCase }
- { key: readability-identifier-naming.ConstexprVariablePrefix, value: k }

30
ops.h
View File

@ -362,12 +362,30 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight, const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
float* HWY_RESTRICT out, size_t size) { float* HWY_RESTRICT out, size_t size) {
constexpr float eps = 1e-6f; namespace hn = hwy::HWY_NAMESPACE;
float ss = SquaredL2(x, size);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps); constexpr float kEps = 1e-6f;
for (size_t j = 0; j < size; j++) { constexpr size_t kUnrollSize = 2;
// Note 1.0f centering here
out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]); const hn::ScalableTag<hwy::bfloat16_t> dbf;
const hn::Repartition<float, decltype(dbf)> df32;
const size_t N32 = hn::Lanes(df32);
const float ss = SquaredL2(x, size);
const auto vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
HWY_DASSERT(size % (kUnrollSize * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += kUnrollSize * N32) {
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
const auto w0 = hn::PromoteLowerTo(df32, w16);
const auto w1 = hn::PromoteUpperTo(df32, w16);
const auto m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
const auto m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
// (1+weight) * m = m + weight*m = one FMA.
hn::StoreU(hn::MulAdd(m0, w0, m0), df32, out + i);
hn::StoreU(hn::MulAdd(m1, w1, m1), df32, out + i + N32);
} }
} }