diff --git a/BUILD.bazel b/BUILD.bazel index 3749297..d672a38 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -28,6 +28,16 @@ cc_library( ], ) +cc_library( + name = "test_util", + hdrs = ["util/test_util.h"], + deps = [ + "@hwy//:hwy", + "@hwy//:hwy_test_util", + "@hwy//:stats", + ], +) + cc_library( name = "threading", hdrs = ["util/threading.h"], @@ -101,7 +111,9 @@ cc_test( ":common", ":gemma_lib", ":ops", + ":test_util", "@googletest//:gtest_main", # buildcleaner: keep + "//compression:compress", "@hwy//:hwy", "@hwy//:hwy_test_util", "@hwy//:nanobenchmark", #buildcleaner: keep diff --git a/CMakeLists.txt b/CMakeLists.txt index 3bc4097..3ed8397 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -48,7 +48,6 @@ set(SOURCES compression/nuq-inl.h compression/sfp.h compression/sfp-inl.h - compression/test_util.h compression/weights_raw.h backprop/activations.h backprop/backward.cc @@ -107,6 +106,7 @@ set(SOURCES util/allocator.h util/app.h util/args.h + util/test_util.h util/threading.h ) diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index 41dc21c..39b03f7 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -136,9 +136,8 @@ static HWY_NOINLINE void RMSNormVJP( hwy::ThreadPool& pool) { for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t offset = pos * model_dim; - constexpr float eps = 1e-6f; - float ss = SquaredL2(x + offset, model_dim); - ss = 1.0f / sqrtf(ss / StaticCast(model_dim) + eps); + const float ss = detail::RMSNormMul(x + offset, model_dim); + for (size_t i = 0; i < model_dim; ++i) { grad_w[i] += v[offset + i] * x[offset + i] * ss; } diff --git a/compression/BUILD b/compression/BUILD index ca1dd2e..565dbba 100644 --- a/compression/BUILD +++ b/compression/BUILD @@ -64,25 +64,14 @@ cc_test( srcs = ["distortion_test.cc"], deps = [ ":distortion", - ":test_util", "@googletest//:gtest_main", + "//:test_util", "@hwy//:hwy", "@hwy//:hwy_test_util", "@hwy//:nanobenchmark", # Unpredictable1 ], ) -cc_library( - name = "test_util", - hdrs = ["test_util.h"], - deps = [ - ":distortion", - "@hwy//:hwy", - "@hwy//:hwy_test_util", - "@hwy//:stats", - ], -) - cc_library( name = "sfp", hdrs = ["sfp.h"], @@ -102,10 +91,11 @@ cc_test( # for test_suite. tags = ["hwy_ops_test"], deps = [ + ":distortion", ":sfp", - ":test_util", "@googletest//:gtest_main", "//:ops", + "//:test_util", "@hwy//:hwy", "@hwy//:hwy_test_util", "@hwy//:nanobenchmark", @@ -133,10 +123,11 @@ cc_test( # for test_suite. tags = ["hwy_ops_test"], deps = [ + ":distortion", ":nuq", ":sfp", - ":test_util", "@googletest//:gtest_main", + "//:test_util", "@hwy//:hwy", "@hwy//:hwy_test_util", "@hwy//:nanobenchmark", diff --git a/compression/compress-inl.h b/compression/compress-inl.h index a246dfa..9579e18 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -51,6 +51,25 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +namespace detail { + +// Adapters to store two f32 vectors to f32 or bf16; avoids duplicating +// RMSNorm and RMSNormInplace for the two output types. +template +void Store2(DF df, hn::Vec v0, hn::Vec v1, float* HWY_RESTRICT out) { + const size_t NF = hn::Lanes(df); + hn::StoreU(v0, df, out); + hn::StoreU(v1, df, out + NF); +} + +template +void Store2(DF df, hn::Vec v0, hn::Vec v1, BF16* HWY_RESTRICT out) { + const hn::Repartition dbf; + hn::StoreU(hn::OrderedDemote2To(dbf, v0, v1), dbf, out); +} + +} // namespace detail + // Enables generic code independent of compression type. template // primary, must specialize struct CompressTraits {}; diff --git a/compression/distortion_test.cc b/compression/distortion_test.cc index 0889c41..7ee9b0a 100644 --- a/compression/distortion_test.cc +++ b/compression/distortion_test.cc @@ -17,7 +17,7 @@ #include -#include "compression/test_util.h" +#include "util/test_util.h" #include "hwy/nanobenchmark.h" #include "hwy/tests/hwy_gtest.h" #include "hwy/tests/test_util-inl.h" // HWY_ASSERT_EQ diff --git a/compression/nuq_test.cc b/compression/nuq_test.cc index 2823a50..dbb0980 100644 --- a/compression/nuq_test.cc +++ b/compression/nuq_test.cc @@ -27,7 +27,8 @@ #include // std::shuffle #include -#include "compression/test_util.h" +#include "compression/distortion.h" +#include "util/test_util.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/tests/test_util.h" diff --git a/compression/sfp_test.cc b/compression/sfp_test.cc index 41983d3..6e1f5ce 100644 --- a/compression/sfp_test.cc +++ b/compression/sfp_test.cc @@ -26,11 +26,11 @@ #include -#include "compression/test_util.h" +#include "compression/distortion.h" +#include "util/test_util.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/timer.h" - // clang-format off #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "compression/sfp_test.cc" // NOLINT diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 24ac462..6dc122b 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -228,7 +228,7 @@ class GemmaAttention { Rope(qk_out, kQKVDim / 2, inv_timescale, pos); MulByConst(mul, qk_out, kQKVDim); } else { - VectorizedRopeAndMulBy(mul, qk, kQKVDim, inv_timescale, pos, qk_out); + RopeAndMulBy(mul, qk, kQKVDim, inv_timescale, pos, qk_out); } } diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 342c4a4..a7f4bbf 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -30,7 +30,6 @@ #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/detect_targets.h" #include "hwy/profiler.h" - #endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_ // Include guard for (potentially) SIMD code. @@ -41,6 +40,7 @@ #define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE #endif +#include "compression/compress-inl.h" #include "hwy/contrib/algo/transform-inl.h" #include "hwy/contrib/dot/dot-inl.h" #include "hwy/contrib/math/math-inl.h" @@ -62,7 +62,7 @@ StaticCast(From from) noexcept { } template -static HWY_INLINE hn::Vec Gelu(D d, hn::Vec v) { +HWY_INLINE hn::Vec Gelu(D d, hn::Vec v) { const hn::Vec kMul = hn::Set(d, 0.044715f); const hn::Vec kSqrt2OverPi = hn::Set(d, 0.797884560804236f); const hn::Vec kHalf = hn::Set(d, 0.5f); @@ -83,40 +83,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x, [](D d, hn::Vec v) HWY_ATTR { return Gelu(d, v); }); } -// out[i] = BF(mul[i] * Gelu(gelu_in[i])) -static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16( - const float* HWY_RESTRICT gelu_in, const float* HWY_RESTRICT mul, - hwy::bfloat16_t* HWY_RESTRICT out, size_t size) { - namespace hn = hwy::HWY_NAMESPACE; - const hn::ScalableTag df; - const hn::Repartition dbf; - const size_t NF = hn::Lanes(df); - using VF = hn::Vec; - - size_t i = 0; - if (size >= 2 * NF) { - for (; i <= size - 2 * NF; i += 2 * NF) { - const VF mul0 = hn::LoadU(df, mul + i); - const VF mul1 = hn::LoadU(df, mul + i + NF); - const VF g0 = hn::Mul(mul0, Gelu(df, hn::LoadU(df, gelu_in + i))); - const VF g1 = hn::Mul(mul1, Gelu(df, hn::LoadU(df, gelu_in + i + NF))); - const hn::Vec bf = hn::OrderedDemote2To(dbf, g0, g1); - hn::StoreU(bf, dbf, out + i); - } - } - if (i != size) { - const size_t remaining = size - i; - const VF mul0 = hn::LoadN(df, mul + i, remaining); - const VF g0 = - hn::Mul(mul0, Gelu(df, hn::LoadN(df, gelu_in + i, remaining))); - const hn::Half dbfh; - const hn::Vec bfh = hn::DemoteTo(dbfh, g0); - hn::StoreN(bfh, dbfh, out + i, remaining); - } -} - template -static HWY_INLINE hn::Vec Sigmoid(D d, hn::Vec v) { +HWY_INLINE hn::Vec Sigmoid(D d, hn::Vec v) { using VF = hn::Vec; // Chebyshev polynomial coefficients for rational approximation const VF c0 = hn::Set(d, 0.00949107017368078f); @@ -180,173 +148,101 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a, return hn::Dot::Compute(d, a, b, size); } +namespace detail { + // = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT. -static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2( - const float* HWY_RESTRICT a, size_t size) { - PROFILER_ZONE("ops.SquaredL2"); +template +float SquaredL2(const VecT* HWY_RESTRICT a, size_t size) { + using TraitsV = CompressTraits; + const hn::ScalableTag d; using V = hn::Vec; const size_t N = hn::Lanes(d); HWY_DASSERT(size >= 2 * N); HWY_DASSERT(size % (2 * N) == 0); + // TODO: use more accurate Dot V sum0 = hn::Zero(d); V sum1 = hn::Zero(d); for (size_t i = 0; i <= size - 2 * N; i += 2 * N) { - const V a0 = hn::LoadU(d, a + i); + V a0, a1; + TraitsV::Decompress2(d, a, i, a0, a1); sum0 = hn::MulAdd(a0, a0, sum0); - const V a1 = hn::LoadU(d, a + i + N); sum1 = hn::MulAdd(a1, a1, sum1); } return hn::ReduceSum(d, hn::Add(sum0, sum1)); } -// float, float -> float; simple loop. -static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( - const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight, - float* HWY_RESTRICT out, size_t size) { - PROFILER_ZONE("ops.RMSNormF"); - constexpr float kEps = 1e-6f; - float ss = SquaredL2(x, size); - ss = 1.0f / sqrtf(ss / StaticCast(size) + kEps); - for (size_t j = 0; j < size; j++) { - // Note 1.0f centering here - out[j] = (1.0f + weight[j]) * (ss * x[j]); - } +// Shared by RMSNorm and RMSNormInplace. +template +float RMSNormMul(const VecT* HWY_RESTRICT x, size_t size) { + const float l2 = SquaredL2(x, size); + constexpr float kEps = 1e-6f; // avoid divide by zero + return 1.0f / sqrtf(l2 / StaticCast(size) + kEps); } -// x=f, w=bf16 -> out=f -static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( - const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight, - float* HWY_RESTRICT out, size_t size) { - PROFILER_ZONE("ops.RMSNormBF16"); +} // namespace detail + +template +HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const VecT* HWY_RESTRICT x, + const WeightT* HWY_RESTRICT weight, + OutT* HWY_RESTRICT out, + const size_t size) { + PROFILER_FUNC; + + using TraitsV = CompressTraits; + using TraitsW = CompressTraits; + namespace hn = hwy::HWY_NAMESPACE; + const hn::ScalableTag df; + using VF = hn::Vec; + const size_t NF = hn::Lanes(df); - constexpr float kEps = 1e-6f; - constexpr size_t kUnrollSize = 2; - - const hn::ScalableTag dbf; - const hn::Repartition df32; - const size_t N32 = hn::Lanes(df32); - - const float ss = SquaredL2(x, size); - const auto vss = - hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + kEps)); - - HWY_DASSERT(size % (kUnrollSize * MaxLanes(df32)) == 0); - for (size_t i = 0; i < size; i += kUnrollSize * N32) { - const hn::Vec w16 = hn::LoadU(dbf, weight + i); - const auto w0 = hn::PromoteLowerTo(df32, w16); - const auto w1 = hn::PromoteUpperTo(df32, w16); - const auto m0 = hn::Mul(vss, hn::LoadU(df32, x + i)); - const auto m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32)); + const VF mul = hn::Set(df, detail::RMSNormMul(x, size)); + HWY_DASSERT(size % (2 * MaxLanes(df)) == 0); + for (size_t i = 0; i < size; i += 2 * NF) { + VF v0, v1, w0, w1; + TraitsV::Decompress2(df, x, i, v0, v1); + TraitsW::Decompress2(df, weight, i, w0, w1); + const VF m0 = hn::Mul(mul, v0); + const VF m1 = hn::Mul(mul, v1); // (1+weight) * m = m + weight*m = one FMA. - hn::StoreU(hn::MulAdd(m0, w0, m0), df32, out + i); - hn::StoreU(hn::MulAdd(m1, w1, m1), df32, out + i + N32); + const VF out0 = hn::MulAdd(m0, w0, m0); + const VF out1 = hn::MulAdd(m1, w1, m1); + detail::Store2(df, out0, out1, out + i); } } -// float -> float; simple loop. -static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( - const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) { - PROFILER_ZONE("ops.RMSNormInplaceF"); - constexpr float kEps = 1e-6f; - float ss = SquaredL2(inout, size); - ss = 1.0f / sqrtf(ss / StaticCast(size) + kEps); - for (size_t j = 0; j < size; j++) { - // Note 1.0f centering here - inout[j] = (1.0f + weight[j]) * (ss * inout[j]); - } -} - -// w=bf16 -> f -static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( - const hwy::bfloat16_t* HWY_RESTRICT weight, float* HWY_RESTRICT inout, +// Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer. +template +HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( + const WeightT* HWY_RESTRICT weight, VecT* HWY_RESTRICT inout, const size_t size) { - PROFILER_ZONE("ops.RMSNormInplaceBF"); + PROFILER_FUNC; + + using TraitsV = CompressTraits; + using TraitsW = CompressTraits; + namespace hn = hwy::HWY_NAMESPACE; - const hn::ScalableTag dbf; - const hn::Repartition df32; - using VF = hn::Vec; - const size_t N32 = hn::Lanes(df32); + const hn::ScalableTag df; + using VF = hn::Vec; + const size_t NF = hn::Lanes(df); - constexpr float kEps = 1e-6f; - const float ss = SquaredL2(inout, size); - const VF vss = - hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + kEps)); + const VF mul = hn::Set(df, detail::RMSNormMul(inout, size)); - HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); - for (size_t i = 0; i < size; i += 2 * N32) { - const hn::Vec w16 = hn::LoadU(dbf, weight + i); - const VF w0 = hn::PromoteLowerTo(df32, w16); - const VF w1 = hn::PromoteUpperTo(df32, w16); - const VF m0 = hn::Mul(vss, hn::LoadU(df32, inout + i)); - const VF m1 = hn::Mul(vss, hn::LoadU(df32, inout + i + N32)); - // (1+weight) * m = m + weight*m = one FMA. - hn::StoreU(hn::MulAdd(m0, w0, m0), df32, inout + i); - hn::StoreU(hn::MulAdd(m1, w1, m1), df32, inout + i + N32); - } -} - -// f, f -> bf -// TODO(janwas): consider generic function with adapter for loading bf16/f32 -static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( - const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight, - hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) { - PROFILER_ZONE("ops.RMSNormF F BF"); - namespace hn = hwy::HWY_NAMESPACE; - const hn::ScalableTag dbf; - const hn::Repartition df32; - using VF = hn::Vec; - const size_t N32 = hn::Lanes(df32); - - constexpr float kEps = 1e-6f; - const float ss = SquaredL2(x, size); - const VF vss = - hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + kEps)); - - HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); - for (size_t i = 0; i < size; i += 2 * N32) { - const VF w0 = hn::LoadU(df32, weight + i); - const VF w1 = hn::LoadU(df32, weight + i + N32); - const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i)); - const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32)); + HWY_DASSERT(size % (2 * MaxLanes(df)) == 0); + for (size_t i = 0; i < size; i += 2 * NF) { + VF v0, v1, w0, w1; + TraitsV::Decompress2(df, inout, i, v0, v1); + TraitsW::Decompress2(df, weight, i, w0, w1); + const VF m0 = hn::Mul(mul, hn::LoadU(df, inout + i)); + const VF m1 = hn::Mul(mul, hn::LoadU(df, inout + i + NF)); // (1+weight) * m = m + weight*m = one FMA. const VF out0 = hn::MulAdd(m0, w0, m0); const VF out1 = hn::MulAdd(m1, w1, m1); - hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i); - } -} - -// x=f, w=bf16 -> bf16 to enable W16A16 MatVec. -static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( - const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight, - hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) { - PROFILER_ZONE("ops.RMSNormF BF BF"); - namespace hn = hwy::HWY_NAMESPACE; - const hn::ScalableTag dbf; - const hn::Repartition df32; - using VF = hn::Vec; - const size_t N32 = hn::Lanes(df32); - - constexpr float kEps = 1e-6f; - const float ss = SquaredL2(x, size); - const VF vss = - hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + kEps)); - - HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); - for (size_t i = 0; i < size; i += 2 * N32) { - const hn::Vec w16 = hn::LoadU(dbf, weight + i); - const VF w0 = hn::PromoteLowerTo(df32, w16); - const VF w1 = hn::PromoteUpperTo(df32, w16); - const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i)); - const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32)); - // (1+weight) * m = m + weight*m = one FMA. - const VF out0 = hn::MulAdd(m0, w0, m0); - const VF out1 = hn::MulAdd(m1, w1, m1); - hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i); + detail::Store2(df, out0, out1, inout + i); } } @@ -410,7 +306,6 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( } } -// TODO(janwas): vectorize // `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate. static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( const float mul, const float* HWY_RESTRICT x, size_t dim_qkv, @@ -419,24 +314,6 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( PROFILER_FUNC; HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; - for (size_t dim = 0; dim < half_dim_qkv; ++dim) { - const float theta = StaticCast(pos) * inv_timescale[dim]; - const float cos_val = cosf(theta); - const float sin_val = sinf(theta); - const float x0 = x[dim]; - const float x1 = x[dim + half_dim_qkv]; - x_out[dim] = mul * (x0 * cos_val - x1 * sin_val); - x_out[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val); - } -} - -static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy( - const float mul, const float* HWY_RESTRICT x, size_t dim_qkv, - const float* HWY_RESTRICT inv_timescale, int pos, - float* HWY_RESTRICT x_out) { - PROFILER_FUNC; - HWY_DASSERT(dim_qkv % 2 == 0); - const size_t half_dim_qkv = dim_qkv / 2; using D = hn::ScalableTag; using V = hn::Vec; @@ -685,7 +562,7 @@ SampleArgmax(const float* probabilities, size_t vocab_size) { } template -static HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution +HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution create_distribution(std::array& top_k, float temperature) { namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; @@ -702,7 +579,7 @@ create_distribution(std::array& top_k, float temperature) { } template -static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( +HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( const float* HWY_RESTRICT probabilities, size_t vocab_size, std::mt19937& gen, float temperature, TAcceptToken& accept_token) { static_assert(k != 0, ""); diff --git a/ops/ops_test.cc b/ops/ops_test.cc index a8d2fee..790a637 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -27,8 +27,15 @@ #include #include +#include "compression/compress.h" // BF16 +#include "gemma/activations.h" +#include "gemma/common.h" +#include "gemma/configs.h" +#include "util/allocator.h" +#include "util/test_util.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" +#include "hwy/tests/hwy_gtest.h" // clang-format off #undef HWY_TARGET_INCLUDE @@ -36,14 +43,9 @@ // clang-format on #include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/highway.h" -#include "hwy/tests/test_util-inl.h" // After highway.h -#include "gemma/activations.h" -#include "gemma/common.h" -#include "gemma/configs.h" #include "ops/ops-inl.h" -#include "util/allocator.h" -#include "hwy/tests/hwy_gtest.h" +#include "hwy/tests/test_util-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { @@ -366,6 +368,23 @@ void TestSigmoid() { } } +static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy( + const float mul, const float* HWY_RESTRICT x, size_t dim_qkv, + const float* HWY_RESTRICT inv_timescale, int pos, + float* HWY_RESTRICT x_out) { + HWY_DASSERT(dim_qkv % 2 == 0); + const size_t half_dim_qkv = dim_qkv / 2; + for (size_t dim = 0; dim < half_dim_qkv; ++dim) { + const float theta = StaticCast(pos) * inv_timescale[dim]; + const float cos_val = cosf(theta); + const float sin_val = sinf(theta); + const float x0 = x[dim]; + const float x1 = x[dim + half_dim_qkv]; + x_out[dim] = mul * (x0 * cos_val - x1 * sin_val); + x_out[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val); + } +} + void TestRopeAndMulBy() { using Config = ConfigGemma2_9B; int dim_qkv = Config::kQKVDim; @@ -392,10 +411,10 @@ void TestRopeAndMulBy() { // Assert VectorizedRope computation is same as regular rope at different pos. for (int pos = 1; pos < 500; pos++) { // Rope'd Q embeddings + ScalarRopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, + qexpected.data()); RopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, - qexpected.data()); - VectorizedRopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, - qactual.data()); + qactual.data()); for (int i = 0; i < dim_qkv; ++i) { EXPECT_NEAR(qactual[i], qexpected[i], 1e-4) @@ -403,10 +422,10 @@ void TestRopeAndMulBy() { } // Rope'd K embeddings + ScalarRopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, + kexpected.data()); RopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, - kexpected.data()); - VectorizedRopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, - kactual.data()); + kactual.data()); for (int i = 0; i < dim_qkv; ++i) { EXPECT_NEAR(kactual[i], kexpected[i], 1e-4) @@ -415,6 +434,70 @@ void TestRopeAndMulBy() { } } +template +HWY_NOINLINE float ScalarSquaredL2(const T* HWY_RESTRICT a, size_t size) { + double sum = 0.0; + for (size_t i = 0; i < size; ++i) { + const float f = hwy::ConvertScalarTo(a[i]); + sum += f * f; + } + return static_cast(sum); +} + +// Supports bf16 and f32 inputs/outputs, which can be in-place. +template +HWY_NOINLINE void ScalarRMSNorm(const VecT* x, + const WeightT* HWY_RESTRICT weight, OutT* out, + size_t size) { + constexpr float kEps = 1e-6f; + float ss = ScalarSquaredL2(x, size); + ss = 1.0f / sqrtf(ss / StaticCast(size) + kEps); + for (size_t j = 0; j < size; j++) { + const float v = hwy::ConvertScalarTo(x[j]); + const float w = hwy::ConvertScalarTo(weight[j]); + // Note 1.0f centering here + out[j] = hwy::ConvertScalarTo((1.0f + w) * (ss * v)); + } +} + +template +void TestRMSNorm(hwy::RandomState& rng) { + constexpr size_t kSize = 128; + VecT vec[kSize]; + WeightT weight[kSize]; + OutT expected[kSize]; + OutT actual[kSize]; + + for (size_t i = 0; i < kSize; ++i) { + vec[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); + weight[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); + } + + ScalarRMSNorm(vec, weight, expected, kSize); + RMSNorm(vec, weight, actual, kSize); + + for (size_t i = 0; i < kSize; i++) { + const float e = hwy::ConvertScalarTo(expected[i]); + const float a = hwy::ConvertScalarTo(actual[i]); + if (!IsNear(e, a, 1e-5f)) { + HWY_ABORT("RMSNorm %s %s %s mismatch at %zu: %E %E\n", TypeName(VecT()), + TypeName(WeightT()), TypeName(OutT()), i, e, a); + } + } +} + +void TestAllRMSNorm() { + hwy::RandomState rng; + TestRMSNorm(rng); + TestRMSNorm(rng); + TestRMSNorm(rng); + TestRMSNorm(rng); + TestRMSNorm(rng); + TestRMSNorm(rng); + TestRMSNorm(rng); + TestRMSNorm(rng); +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp @@ -432,6 +515,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid); HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm); HWY_AFTER_TEST(); } // namespace gcpp diff --git a/compression/test_util.h b/util/test_util.h similarity index 92% rename from compression/test_util.h rename to util/test_util.h index 745b00f..ead1874 100644 --- a/compression/test_util.h +++ b/util/test_util.h @@ -13,8 +13,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_ -#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_ +#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_TEST_UTIL_H_ +#define THIRD_PARTY_GEMMA_CPP_UTIL_TEST_UTIL_H_ #include #include @@ -24,7 +24,6 @@ #include "hwy/base.h" // IWYU pragma: begin_exports -#include "compression/distortion.h" #include "hwy/stats.h" #include "hwy/tests/test_util.h" // RandomState // IWYU pragma: end_exports @@ -73,4 +72,4 @@ HWY_INLINE void VerifyGaussian(hwy::Stats& stats) { } // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_ +#endif // THIRD_PARTY_GEMMA_CPP_UTIL_TEST_UTIL_H_