mirror of https://github.com/google/gemma.cpp.git
Implement scalar version of LayerNorm
PiperOrigin-RevId: 675085495
This commit is contained in:
parent
1c8ddcdffe
commit
892f3bbcbe
|
|
@ -214,6 +214,57 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Computes mean mu and mean of squares mu2 of a vector. Used in LayerNorm.
|
||||||
|
template <typename T>
|
||||||
|
HWY_NOINLINE void ScalarMus(const T* HWY_RESTRICT a, size_t size, T& mu,
|
||||||
|
T& mu2) {
|
||||||
|
HWY_ASSERT(size > 0);
|
||||||
|
double sum = 0.0;
|
||||||
|
double sum2 = 0.0;
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
const float f = hwy::ConvertScalarTo<float>(a[i]);
|
||||||
|
sum += f;
|
||||||
|
sum2 += f * f;
|
||||||
|
}
|
||||||
|
mu = sum / size;
|
||||||
|
mu2 = sum2 / size;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare py/flax/linen/normalization.py.
|
||||||
|
// out = (x - mean) * scale * rsqrt(var + epsilon) + bias
|
||||||
|
template <typename VecT, typename WeightT, typename OutT>
|
||||||
|
HWY_NOINLINE void ScalarLayerNorm(const VecT* x,
|
||||||
|
const WeightT* HWY_RESTRICT scale,
|
||||||
|
const WeightT* HWY_RESTRICT bias,
|
||||||
|
OutT* out,
|
||||||
|
size_t size) {
|
||||||
|
constexpr float kEps = 1e-6f;
|
||||||
|
VecT mu, mu2;
|
||||||
|
ScalarMus(x, size, mu, mu2);
|
||||||
|
VecT var = mu2 - mu * mu;
|
||||||
|
VecT zero = 0.0f;
|
||||||
|
var = HWY_MAX(var, zero);
|
||||||
|
var = 1.0f / sqrtf(var + kEps);
|
||||||
|
for (size_t j = 0; j < size; j++) {
|
||||||
|
const float v = hwy::ConvertScalarTo<float>(x[j]);
|
||||||
|
const float s = hwy::ConvertScalarTo<float>(scale[j]);
|
||||||
|
const float b = hwy::ConvertScalarTo<float>(bias[j]);
|
||||||
|
out[j] = hwy::ConvertScalarTo<OutT>((v - mu) * s * var + b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename VecT, typename WeightT, typename OutT>
|
||||||
|
HWY_NOINLINE HWY_MAYBE_UNUSED void LayerNorm(const VecT* x,
|
||||||
|
const WeightT* HWY_RESTRICT weight,
|
||||||
|
const WeightT* HWY_RESTRICT bias,
|
||||||
|
OutT* out,
|
||||||
|
const size_t size) {
|
||||||
|
PROFILER_FUNC;
|
||||||
|
// For now we only delegate to the scalar version.
|
||||||
|
// TODO: implement vectorized version.
|
||||||
|
ScalarLayerNorm(x, weight, bias, out, size);
|
||||||
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
||||||
float* HWY_RESTRICT x, size_t dim_model, size_t pos) {
|
float* HWY_RESTRICT x, size_t dim_model, size_t pos) {
|
||||||
PROFILER_ZONE("ops.AddAbsolutePositionalEmbeddings");
|
PROFILER_ZONE("ops.AddAbsolutePositionalEmbeddings");
|
||||||
|
|
@ -377,6 +428,16 @@ void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename VecT, typename WeightT, typename OutT>
|
||||||
|
void LayerNormBatched(size_t num_tokens, const VecT* x,
|
||||||
|
const WeightT* HWY_RESTRICT weight,
|
||||||
|
const WeightT* HWY_RESTRICT bias, OutT* out,
|
||||||
|
const size_t size) {
|
||||||
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
LayerNorm(x + token_idx * size, weight, bias, out + token_idx * size, size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static HWY_INLINE void AddFromBatched(size_t num_tokens, const float* other,
|
static HWY_INLINE void AddFromBatched(size_t num_tokens, const float* other,
|
||||||
float* x, const size_t model_dim) {
|
float* x, const size_t model_dim) {
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
|
|
||||||
|
|
@ -498,6 +498,67 @@ void TestAllRMSNorm() {
|
||||||
TestRMSNorm<BF16, BF16, BF16>(rng);
|
TestRMSNorm<BF16, BF16, BF16>(rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestLayerNormSimple() {
|
||||||
|
const size_t kSize = 52;
|
||||||
|
std::vector<float> values(kSize);
|
||||||
|
// Alternating 1.0/-1.0, so mean=0.0, var=1.0, rsqrt(var+epsilon)=0.9999995
|
||||||
|
for (int i = 0; i < kSize; ++i) {
|
||||||
|
values[i] = (i % 2 == 0) ? 1.0f : -1.0f;
|
||||||
|
}
|
||||||
|
std::vector<float> scale(kSize, 1.2f);
|
||||||
|
std::vector<float> bias(kSize, 0.1f);
|
||||||
|
std::vector<float> result(kSize);
|
||||||
|
LayerNorm(values.data(), scale.data(), bias.data(), result.data(), kSize);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < kSize; i++) {
|
||||||
|
const float max_error = 1e-6f;
|
||||||
|
float value = values[i];
|
||||||
|
float res = result[i];
|
||||||
|
// out = (x - 0.0) * 1.2 * 0.9999995 + 0.1 = 1.2999994 / -1.0999994;
|
||||||
|
float expected = (i % 2 == 0) ? 1.2999994f : -1.0999994f;
|
||||||
|
EXPECT_NEAR(res, expected, max_error) << "Input: " << value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: there is no vectorized implementation of LayerNorm yet. So this test
|
||||||
|
// currently only checks that the scalar version can be called for the below
|
||||||
|
// combinations of float/BF16 inputs and outputs.
|
||||||
|
template <typename VecT, typename WeightT, typename OutT>
|
||||||
|
void TestLayerNorm(hwy::RandomState& rng) {
|
||||||
|
constexpr size_t kSize = 128;
|
||||||
|
VecT vec[kSize];
|
||||||
|
WeightT weight[kSize];
|
||||||
|
WeightT bias[kSize];
|
||||||
|
OutT expected[kSize];
|
||||||
|
OutT actual[kSize];
|
||||||
|
|
||||||
|
for (size_t i = 0; i < kSize; ++i) {
|
||||||
|
vec[i] = hwy::ConvertScalarTo<VecT>(RandomGaussian(rng));
|
||||||
|
weight[i] = hwy::ConvertScalarTo<WeightT>(RandomGaussian(rng));
|
||||||
|
bias[i] = hwy::ConvertScalarTo<WeightT>(RandomGaussian(rng));
|
||||||
|
}
|
||||||
|
|
||||||
|
ScalarLayerNorm(vec, weight, bias, expected, kSize);
|
||||||
|
LayerNorm(vec, weight, bias, actual, kSize);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < kSize; i++) {
|
||||||
|
const float e = hwy::ConvertScalarTo<float>(expected[i]);
|
||||||
|
const float a = hwy::ConvertScalarTo<float>(actual[i]);
|
||||||
|
if (!IsNear(e, a, 1e-5f)) {
|
||||||
|
HWY_ABORT("LayerNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<VecT>(),
|
||||||
|
TypeName<WeightT>(), TypeName<OutT>(), i, e, a);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestAllLayerNorm() {
|
||||||
|
hwy::RandomState rng;
|
||||||
|
TestLayerNorm<float, float, float>(rng);
|
||||||
|
TestLayerNorm<float, float, BF16>(rng);
|
||||||
|
TestLayerNorm<float, BF16, float>(rng);
|
||||||
|
TestLayerNorm<float, BF16, BF16>(rng);
|
||||||
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
@ -516,6 +577,8 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm);
|
||||||
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllLayerNorm);
|
||||||
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestLayerNormSimple);
|
||||||
HWY_AFTER_TEST();
|
HWY_AFTER_TEST();
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue