Implement scalar version of LayerNorm

PiperOrigin-RevId: 675085495
This commit is contained in:
Daniel Keysers 2024-09-16 03:53:24 -07:00 committed by Copybara-Service
parent 1c8ddcdffe
commit 892f3bbcbe
2 changed files with 124 additions and 0 deletions

View File

@ -214,6 +214,57 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
}
}
// Computes mean mu and mean of squares mu2 of a vector. Used in LayerNorm.
template <typename T>
HWY_NOINLINE void ScalarMus(const T* HWY_RESTRICT a, size_t size, T& mu,
T& mu2) {
HWY_ASSERT(size > 0);
double sum = 0.0;
double sum2 = 0.0;
for (size_t i = 0; i < size; ++i) {
const float f = hwy::ConvertScalarTo<float>(a[i]);
sum += f;
sum2 += f * f;
}
mu = sum / size;
mu2 = sum2 / size;
}
// Compare py/flax/linen/normalization.py.
// out = (x - mean) * scale * rsqrt(var + epsilon) + bias
template <typename VecT, typename WeightT, typename OutT>
HWY_NOINLINE void ScalarLayerNorm(const VecT* x,
const WeightT* HWY_RESTRICT scale,
const WeightT* HWY_RESTRICT bias,
OutT* out,
size_t size) {
constexpr float kEps = 1e-6f;
VecT mu, mu2;
ScalarMus(x, size, mu, mu2);
VecT var = mu2 - mu * mu;
VecT zero = 0.0f;
var = HWY_MAX(var, zero);
var = 1.0f / sqrtf(var + kEps);
for (size_t j = 0; j < size; j++) {
const float v = hwy::ConvertScalarTo<float>(x[j]);
const float s = hwy::ConvertScalarTo<float>(scale[j]);
const float b = hwy::ConvertScalarTo<float>(bias[j]);
out[j] = hwy::ConvertScalarTo<OutT>((v - mu) * s * var + b);
}
}
template <typename VecT, typename WeightT, typename OutT>
HWY_NOINLINE HWY_MAYBE_UNUSED void LayerNorm(const VecT* x,
const WeightT* HWY_RESTRICT weight,
const WeightT* HWY_RESTRICT bias,
OutT* out,
const size_t size) {
PROFILER_FUNC;
// For now we only delegate to the scalar version.
// TODO: implement vectorized version.
ScalarLayerNorm(x, weight, bias, out, size);
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
float* HWY_RESTRICT x, size_t dim_model, size_t pos) {
PROFILER_ZONE("ops.AddAbsolutePositionalEmbeddings");
@ -377,6 +428,16 @@ void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights,
}
}
template <typename VecT, typename WeightT, typename OutT>
void LayerNormBatched(size_t num_tokens, const VecT* x,
const WeightT* HWY_RESTRICT weight,
const WeightT* HWY_RESTRICT bias, OutT* out,
const size_t size) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
LayerNorm(x + token_idx * size, weight, bias, out + token_idx * size, size);
}
}
static HWY_INLINE void AddFromBatched(size_t num_tokens, const float* other,
float* x, const size_t model_dim) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {

View File

@ -498,6 +498,67 @@ void TestAllRMSNorm() {
TestRMSNorm<BF16, BF16, BF16>(rng);
}
void TestLayerNormSimple() {
const size_t kSize = 52;
std::vector<float> values(kSize);
// Alternating 1.0/-1.0, so mean=0.0, var=1.0, rsqrt(var+epsilon)=0.9999995
for (int i = 0; i < kSize; ++i) {
values[i] = (i % 2 == 0) ? 1.0f : -1.0f;
}
std::vector<float> scale(kSize, 1.2f);
std::vector<float> bias(kSize, 0.1f);
std::vector<float> result(kSize);
LayerNorm(values.data(), scale.data(), bias.data(), result.data(), kSize);
for (size_t i = 0; i < kSize; i++) {
const float max_error = 1e-6f;
float value = values[i];
float res = result[i];
// out = (x - 0.0) * 1.2 * 0.9999995 + 0.1 = 1.2999994 / -1.0999994;
float expected = (i % 2 == 0) ? 1.2999994f : -1.0999994f;
EXPECT_NEAR(res, expected, max_error) << "Input: " << value;
}
}
// Note: there is no vectorized implementation of LayerNorm yet. So this test
// currently only checks that the scalar version can be called for the below
// combinations of float/BF16 inputs and outputs.
template <typename VecT, typename WeightT, typename OutT>
void TestLayerNorm(hwy::RandomState& rng) {
constexpr size_t kSize = 128;
VecT vec[kSize];
WeightT weight[kSize];
WeightT bias[kSize];
OutT expected[kSize];
OutT actual[kSize];
for (size_t i = 0; i < kSize; ++i) {
vec[i] = hwy::ConvertScalarTo<VecT>(RandomGaussian(rng));
weight[i] = hwy::ConvertScalarTo<WeightT>(RandomGaussian(rng));
bias[i] = hwy::ConvertScalarTo<WeightT>(RandomGaussian(rng));
}
ScalarLayerNorm(vec, weight, bias, expected, kSize);
LayerNorm(vec, weight, bias, actual, kSize);
for (size_t i = 0; i < kSize; i++) {
const float e = hwy::ConvertScalarTo<float>(expected[i]);
const float a = hwy::ConvertScalarTo<float>(actual[i]);
if (!IsNear(e, a, 1e-5f)) {
HWY_ABORT("LayerNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<VecT>(),
TypeName<WeightT>(), TypeName<OutT>(), i, e, a);
}
}
}
void TestAllLayerNorm() {
hwy::RandomState rng;
TestLayerNorm<float, float, float>(rng);
TestLayerNorm<float, float, BF16>(rng);
TestLayerNorm<float, BF16, float>(rng);
TestLayerNorm<float, BF16, BF16>(rng);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
@ -516,6 +577,8 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllLayerNorm);
HWY_EXPORT_AND_TEST_P(OpsTest, TestLayerNormSimple);
HWY_AFTER_TEST();
} // namespace gcpp