diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 129e2f5..4183572 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -214,6 +214,57 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( } } +// Computes mean mu and mean of squares mu2 of a vector. Used in LayerNorm. +template +HWY_NOINLINE void ScalarMus(const T* HWY_RESTRICT a, size_t size, T& mu, + T& mu2) { + HWY_ASSERT(size > 0); + double sum = 0.0; + double sum2 = 0.0; + for (size_t i = 0; i < size; ++i) { + const float f = hwy::ConvertScalarTo(a[i]); + sum += f; + sum2 += f * f; + } + mu = sum / size; + mu2 = sum2 / size; +} + +// Compare py/flax/linen/normalization.py. +// out = (x - mean) * scale * rsqrt(var + epsilon) + bias +template +HWY_NOINLINE void ScalarLayerNorm(const VecT* x, + const WeightT* HWY_RESTRICT scale, + const WeightT* HWY_RESTRICT bias, + OutT* out, + size_t size) { + constexpr float kEps = 1e-6f; + VecT mu, mu2; + ScalarMus(x, size, mu, mu2); + VecT var = mu2 - mu * mu; + VecT zero = 0.0f; + var = HWY_MAX(var, zero); + var = 1.0f / sqrtf(var + kEps); + for (size_t j = 0; j < size; j++) { + const float v = hwy::ConvertScalarTo(x[j]); + const float s = hwy::ConvertScalarTo(scale[j]); + const float b = hwy::ConvertScalarTo(bias[j]); + out[j] = hwy::ConvertScalarTo((v - mu) * s * var + b); + } +} + +template +HWY_NOINLINE HWY_MAYBE_UNUSED void LayerNorm(const VecT* x, + const WeightT* HWY_RESTRICT weight, + const WeightT* HWY_RESTRICT bias, + OutT* out, + const size_t size) { + PROFILER_FUNC; + // For now we only delegate to the scalar version. + // TODO: implement vectorized version. + ScalarLayerNorm(x, weight, bias, out, size); +} + static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( float* HWY_RESTRICT x, size_t dim_model, size_t pos) { PROFILER_ZONE("ops.AddAbsolutePositionalEmbeddings"); @@ -377,6 +428,16 @@ void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights, } } +template +void LayerNormBatched(size_t num_tokens, const VecT* x, + const WeightT* HWY_RESTRICT weight, + const WeightT* HWY_RESTRICT bias, OutT* out, + const size_t size) { + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + LayerNorm(x + token_idx * size, weight, bias, out + token_idx * size, size); + } +} + static HWY_INLINE void AddFromBatched(size_t num_tokens, const float* other, float* x, const size_t model_dim) { for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { diff --git a/ops/ops_test.cc b/ops/ops_test.cc index aebbab1..a993883 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -498,6 +498,67 @@ void TestAllRMSNorm() { TestRMSNorm(rng); } +void TestLayerNormSimple() { + const size_t kSize = 52; + std::vector values(kSize); + // Alternating 1.0/-1.0, so mean=0.0, var=1.0, rsqrt(var+epsilon)=0.9999995 + for (int i = 0; i < kSize; ++i) { + values[i] = (i % 2 == 0) ? 1.0f : -1.0f; + } + std::vector scale(kSize, 1.2f); + std::vector bias(kSize, 0.1f); + std::vector result(kSize); + LayerNorm(values.data(), scale.data(), bias.data(), result.data(), kSize); + + for (size_t i = 0; i < kSize; i++) { + const float max_error = 1e-6f; + float value = values[i]; + float res = result[i]; + // out = (x - 0.0) * 1.2 * 0.9999995 + 0.1 = 1.2999994 / -1.0999994; + float expected = (i % 2 == 0) ? 1.2999994f : -1.0999994f; + EXPECT_NEAR(res, expected, max_error) << "Input: " << value; + } +} + +// Note: there is no vectorized implementation of LayerNorm yet. So this test +// currently only checks that the scalar version can be called for the below +// combinations of float/BF16 inputs and outputs. +template +void TestLayerNorm(hwy::RandomState& rng) { + constexpr size_t kSize = 128; + VecT vec[kSize]; + WeightT weight[kSize]; + WeightT bias[kSize]; + OutT expected[kSize]; + OutT actual[kSize]; + + for (size_t i = 0; i < kSize; ++i) { + vec[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); + weight[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); + bias[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); + } + + ScalarLayerNorm(vec, weight, bias, expected, kSize); + LayerNorm(vec, weight, bias, actual, kSize); + + for (size_t i = 0; i < kSize; i++) { + const float e = hwy::ConvertScalarTo(expected[i]); + const float a = hwy::ConvertScalarTo(actual[i]); + if (!IsNear(e, a, 1e-5f)) { + HWY_ABORT("LayerNorm %s %s %s mismatch at %zu: %E %E\n", TypeName(), + TypeName(), TypeName(), i, e, a); + } + } +} + +void TestAllLayerNorm() { + hwy::RandomState rng; + TestLayerNorm(rng); + TestLayerNorm(rng); + TestLayerNorm(rng); + TestLayerNorm(rng); +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp @@ -516,6 +577,8 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid); HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllLayerNorm); +HWY_EXPORT_AND_TEST_P(OpsTest, TestLayerNormSimple); HWY_AFTER_TEST(); } // namespace gcpp