Avoid duplication of RMSNorm, support all activation/weight types

Add test for RMSNorm
Rename VectorizedRopeAndMulBy -> RopeAndMulBy

Move test_util to util/

PiperOrigin-RevId: 668332927
This commit is contained in:
Jan Wassenberg 2024-08-28 01:25:52 -07:00 committed by Copybara-Service
parent 3c17911875
commit 4033ed9e78
12 changed files with 211 additions and 229 deletions

View File

@ -28,6 +28,16 @@ cc_library(
], ],
) )
cc_library(
name = "test_util",
hdrs = ["util/test_util.h"],
deps = [
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:stats",
],
)
cc_library( cc_library(
name = "threading", name = "threading",
hdrs = ["util/threading.h"], hdrs = ["util/threading.h"],
@ -101,7 +111,9 @@ cc_test(
":common", ":common",
":gemma_lib", ":gemma_lib",
":ops", ":ops",
":test_util",
"@googletest//:gtest_main", # buildcleaner: keep "@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:hwy_test_util", "@hwy//:hwy_test_util",
"@hwy//:nanobenchmark", #buildcleaner: keep "@hwy//:nanobenchmark", #buildcleaner: keep

View File

@ -48,7 +48,6 @@ set(SOURCES
compression/nuq-inl.h compression/nuq-inl.h
compression/sfp.h compression/sfp.h
compression/sfp-inl.h compression/sfp-inl.h
compression/test_util.h
compression/weights_raw.h compression/weights_raw.h
backprop/activations.h backprop/activations.h
backprop/backward.cc backprop/backward.cc
@ -107,6 +106,7 @@ set(SOURCES
util/allocator.h util/allocator.h
util/app.h util/app.h
util/args.h util/args.h
util/test_util.h
util/threading.h util/threading.h
) )

View File

@ -136,9 +136,8 @@ static HWY_NOINLINE void RMSNormVJP(
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = pos * model_dim; const size_t offset = pos * model_dim;
constexpr float eps = 1e-6f; const float ss = detail::RMSNormMul(x + offset, model_dim);
float ss = SquaredL2(x + offset, model_dim);
ss = 1.0f / sqrtf(ss / StaticCast<float>(model_dim) + eps);
for (size_t i = 0; i < model_dim; ++i) { for (size_t i = 0; i < model_dim; ++i) {
grad_w[i] += v[offset + i] * x[offset + i] * ss; grad_w[i] += v[offset + i] * x[offset + i] * ss;
} }

View File

@ -64,25 +64,14 @@ cc_test(
srcs = ["distortion_test.cc"], srcs = ["distortion_test.cc"],
deps = [ deps = [
":distortion", ":distortion",
":test_util",
"@googletest//:gtest_main", "@googletest//:gtest_main",
"//:test_util",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:hwy_test_util", "@hwy//:hwy_test_util",
"@hwy//:nanobenchmark", # Unpredictable1 "@hwy//:nanobenchmark", # Unpredictable1
], ],
) )
cc_library(
name = "test_util",
hdrs = ["test_util.h"],
deps = [
":distortion",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:stats",
],
)
cc_library( cc_library(
name = "sfp", name = "sfp",
hdrs = ["sfp.h"], hdrs = ["sfp.h"],
@ -102,10 +91,11 @@ cc_test(
# for test_suite. # for test_suite.
tags = ["hwy_ops_test"], tags = ["hwy_ops_test"],
deps = [ deps = [
":distortion",
":sfp", ":sfp",
":test_util",
"@googletest//:gtest_main", "@googletest//:gtest_main",
"//:ops", "//:ops",
"//:test_util",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:hwy_test_util", "@hwy//:hwy_test_util",
"@hwy//:nanobenchmark", "@hwy//:nanobenchmark",
@ -133,10 +123,11 @@ cc_test(
# for test_suite. # for test_suite.
tags = ["hwy_ops_test"], tags = ["hwy_ops_test"],
deps = [ deps = [
":distortion",
":nuq", ":nuq",
":sfp", ":sfp",
":test_util",
"@googletest//:gtest_main", "@googletest//:gtest_main",
"//:test_util",
"@hwy//:hwy", "@hwy//:hwy",
"@hwy//:hwy_test_util", "@hwy//:hwy_test_util",
"@hwy//:nanobenchmark", "@hwy//:nanobenchmark",

View File

@ -51,6 +51,25 @@ namespace gcpp {
namespace HWY_NAMESPACE { namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
namespace detail {
// Adapters to store two f32 vectors to f32 or bf16; avoids duplicating
// RMSNorm and RMSNormInplace for the two output types.
template <class DF, HWY_IF_F32_D(DF)>
void Store2(DF df, hn::Vec<DF> v0, hn::Vec<DF> v1, float* HWY_RESTRICT out) {
const size_t NF = hn::Lanes(df);
hn::StoreU(v0, df, out);
hn::StoreU(v1, df, out + NF);
}
template <class DF, HWY_IF_F32_D(DF)>
void Store2(DF df, hn::Vec<DF> v0, hn::Vec<DF> v1, BF16* HWY_RESTRICT out) {
const hn::Repartition<BF16, decltype(df)> dbf;
hn::StoreU(hn::OrderedDemote2To(dbf, v0, v1), dbf, out);
}
} // namespace detail
// Enables generic code independent of compression type. // Enables generic code independent of compression type.
template <typename T> // primary, must specialize template <typename T> // primary, must specialize
struct CompressTraits {}; struct CompressTraits {};

View File

@ -17,7 +17,7 @@
#include <stdio.h> #include <stdio.h>
#include "compression/test_util.h" #include "util/test_util.h"
#include "hwy/nanobenchmark.h" #include "hwy/nanobenchmark.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
#include "hwy/tests/test_util-inl.h" // HWY_ASSERT_EQ #include "hwy/tests/test_util-inl.h" // HWY_ASSERT_EQ

View File

@ -27,7 +27,8 @@
#include <algorithm> // std::shuffle #include <algorithm> // std::shuffle
#include <random> #include <random>
#include "compression/test_util.h" #include "compression/distortion.h"
#include "util/test_util.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/tests/test_util.h" #include "hwy/tests/test_util.h"

View File

@ -26,11 +26,11 @@
#include <set> #include <set>
#include "compression/test_util.h" #include "compression/distortion.h"
#include "util/test_util.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/timer.h" #include "hwy/timer.h"
// clang-format off // clang-format off
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "compression/sfp_test.cc" // NOLINT #define HWY_TARGET_INCLUDE "compression/sfp_test.cc" // NOLINT

View File

@ -228,7 +228,7 @@ class GemmaAttention {
Rope(qk_out, kQKVDim / 2, inv_timescale, pos); Rope(qk_out, kQKVDim / 2, inv_timescale, pos);
MulByConst(mul, qk_out, kQKVDim); MulByConst(mul, qk_out, kQKVDim);
} else { } else {
VectorizedRopeAndMulBy(mul, qk, kQKVDim, inv_timescale, pos, qk_out); RopeAndMulBy(mul, qk, kQKVDim, inv_timescale, pos, qk_out);
} }
} }

View File

@ -30,7 +30,6 @@
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_targets.h" #include "hwy/detect_targets.h"
#include "hwy/profiler.h" #include "hwy/profiler.h"
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_ #endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_
// Include guard for (potentially) SIMD code. // Include guard for (potentially) SIMD code.
@ -41,6 +40,7 @@
#define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE #define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
#endif #endif
#include "compression/compress-inl.h"
#include "hwy/contrib/algo/transform-inl.h" #include "hwy/contrib/algo/transform-inl.h"
#include "hwy/contrib/dot/dot-inl.h" #include "hwy/contrib/dot/dot-inl.h"
#include "hwy/contrib/math/math-inl.h" #include "hwy/contrib/math/math-inl.h"
@ -62,7 +62,7 @@ StaticCast(From from) noexcept {
} }
template <class D, HWY_IF_F32_D(D)> template <class D, HWY_IF_F32_D(D)>
static HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) { HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
const hn::Vec<D> kMul = hn::Set(d, 0.044715f); const hn::Vec<D> kMul = hn::Set(d, 0.044715f);
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f); const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
const hn::Vec<D> kHalf = hn::Set(d, 0.5f); const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
@ -83,40 +83,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
[](D d, hn::Vec<D> v) HWY_ATTR { return Gelu(d, v); }); [](D d, hn::Vec<D> v) HWY_ATTR { return Gelu(d, v); });
} }
// out[i] = BF(mul[i] * Gelu(gelu_in[i]))
static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16(
const float* HWY_RESTRICT gelu_in, const float* HWY_RESTRICT mul,
hwy::bfloat16_t* HWY_RESTRICT out, size_t size) {
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df;
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf;
const size_t NF = hn::Lanes(df);
using VF = hn::Vec<decltype(df)>;
size_t i = 0;
if (size >= 2 * NF) {
for (; i <= size - 2 * NF; i += 2 * NF) {
const VF mul0 = hn::LoadU(df, mul + i);
const VF mul1 = hn::LoadU(df, mul + i + NF);
const VF g0 = hn::Mul(mul0, Gelu(df, hn::LoadU(df, gelu_in + i)));
const VF g1 = hn::Mul(mul1, Gelu(df, hn::LoadU(df, gelu_in + i + NF)));
const hn::Vec<decltype(dbf)> bf = hn::OrderedDemote2To(dbf, g0, g1);
hn::StoreU(bf, dbf, out + i);
}
}
if (i != size) {
const size_t remaining = size - i;
const VF mul0 = hn::LoadN(df, mul + i, remaining);
const VF g0 =
hn::Mul(mul0, Gelu(df, hn::LoadN(df, gelu_in + i, remaining)));
const hn::Half<decltype(dbf)> dbfh;
const hn::Vec<decltype(dbfh)> bfh = hn::DemoteTo(dbfh, g0);
hn::StoreN(bfh, dbfh, out + i, remaining);
}
}
template <class D, HWY_IF_F32_D(D)> template <class D, HWY_IF_F32_D(D)>
static HWY_INLINE hn::Vec<D> Sigmoid(D d, hn::Vec<D> v) { HWY_INLINE hn::Vec<D> Sigmoid(D d, hn::Vec<D> v) {
using VF = hn::Vec<D>; using VF = hn::Vec<D>;
// Chebyshev polynomial coefficients for rational approximation // Chebyshev polynomial coefficients for rational approximation
const VF c0 = hn::Set(d, 0.00949107017368078f); const VF c0 = hn::Set(d, 0.00949107017368078f);
@ -180,173 +148,101 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
return hn::Dot::Compute<kAssumptions>(d, a, b, size); return hn::Dot::Compute<kAssumptions>(d, a, b, size);
} }
namespace detail {
// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT. // = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT.
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2( template <typename VecT>
const float* HWY_RESTRICT a, size_t size) { float SquaredL2(const VecT* HWY_RESTRICT a, size_t size) {
PROFILER_ZONE("ops.SquaredL2"); using TraitsV = CompressTraits<VecT>;
const hn::ScalableTag<float> d; const hn::ScalableTag<float> d;
using V = hn::Vec<decltype(d)>; using V = hn::Vec<decltype(d)>;
const size_t N = hn::Lanes(d); const size_t N = hn::Lanes(d);
HWY_DASSERT(size >= 2 * N); HWY_DASSERT(size >= 2 * N);
HWY_DASSERT(size % (2 * N) == 0); HWY_DASSERT(size % (2 * N) == 0);
// TODO: use more accurate Dot
V sum0 = hn::Zero(d); V sum0 = hn::Zero(d);
V sum1 = hn::Zero(d); V sum1 = hn::Zero(d);
for (size_t i = 0; i <= size - 2 * N; i += 2 * N) { for (size_t i = 0; i <= size - 2 * N; i += 2 * N) {
const V a0 = hn::LoadU(d, a + i); V a0, a1;
TraitsV::Decompress2(d, a, i, a0, a1);
sum0 = hn::MulAdd(a0, a0, sum0); sum0 = hn::MulAdd(a0, a0, sum0);
const V a1 = hn::LoadU(d, a + i + N);
sum1 = hn::MulAdd(a1, a1, sum1); sum1 = hn::MulAdd(a1, a1, sum1);
} }
return hn::ReduceSum(d, hn::Add(sum0, sum1)); return hn::ReduceSum(d, hn::Add(sum0, sum1));
} }
// float, float -> float; simple loop. // Shared by RMSNorm and RMSNormInplace.
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( template <typename VecT>
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight, float RMSNormMul(const VecT* HWY_RESTRICT x, size_t size) {
float* HWY_RESTRICT out, size_t size) { const float l2 = SquaredL2(x, size);
PROFILER_ZONE("ops.RMSNormF"); constexpr float kEps = 1e-6f; // avoid divide by zero
constexpr float kEps = 1e-6f; return 1.0f / sqrtf(l2 / StaticCast<float>(size) + kEps);
float ss = SquaredL2(x, size);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here
out[j] = (1.0f + weight[j]) * (ss * x[j]);
}
} }
// x=f, w=bf16 -> out=f } // namespace detail
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
float* HWY_RESTRICT out, size_t size) {
PROFILER_ZONE("ops.RMSNormBF16");
namespace hn = hwy::HWY_NAMESPACE;
constexpr float kEps = 1e-6f; template <typename VecT, typename WeightT, typename OutT>
constexpr size_t kUnrollSize = 2; HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const VecT* HWY_RESTRICT x,
const WeightT* HWY_RESTRICT weight,
const hn::ScalableTag<hwy::bfloat16_t> dbf; OutT* HWY_RESTRICT out,
const hn::Repartition<float, decltype(dbf)> df32;
const size_t N32 = hn::Lanes(df32);
const float ss = SquaredL2(x, size);
const auto vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
HWY_DASSERT(size % (kUnrollSize * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += kUnrollSize * N32) {
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
const auto w0 = hn::PromoteLowerTo(df32, w16);
const auto w1 = hn::PromoteUpperTo(df32, w16);
const auto m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
const auto m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
// (1+weight) * m = m + weight*m = one FMA.
hn::StoreU(hn::MulAdd(m0, w0, m0), df32, out + i);
hn::StoreU(hn::MulAdd(m1, w1, m1), df32, out + i + N32);
}
}
// float -> float; simple loop.
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
PROFILER_ZONE("ops.RMSNormInplaceF");
constexpr float kEps = 1e-6f;
float ss = SquaredL2(inout, size);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here
inout[j] = (1.0f + weight[j]) * (ss * inout[j]);
}
}
// w=bf16 -> f
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const hwy::bfloat16_t* HWY_RESTRICT weight, float* HWY_RESTRICT inout,
const size_t size) { const size_t size) {
PROFILER_ZONE("ops.RMSNormInplaceBF"); PROFILER_FUNC;
using TraitsV = CompressTraits<VecT>;
using TraitsW = CompressTraits<WeightT>;
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<hwy::bfloat16_t> dbf; const hn::ScalableTag<float> df;
const hn::Repartition<float, decltype(dbf)> df32; using VF = hn::Vec<decltype(df)>;
using VF = hn::Vec<decltype(df32)>; const size_t NF = hn::Lanes(df);
const size_t N32 = hn::Lanes(df32);
constexpr float kEps = 1e-6f; const VF mul = hn::Set(df, detail::RMSNormMul(x, size));
const float ss = SquaredL2(inout, size);
const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); HWY_DASSERT(size % (2 * MaxLanes(df)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) { for (size_t i = 0; i < size; i += 2 * NF) {
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i); VF v0, v1, w0, w1;
const VF w0 = hn::PromoteLowerTo(df32, w16); TraitsV::Decompress2(df, x, i, v0, v1);
const VF w1 = hn::PromoteUpperTo(df32, w16); TraitsW::Decompress2(df, weight, i, w0, w1);
const VF m0 = hn::Mul(vss, hn::LoadU(df32, inout + i)); const VF m0 = hn::Mul(mul, v0);
const VF m1 = hn::Mul(vss, hn::LoadU(df32, inout + i + N32)); const VF m1 = hn::Mul(mul, v1);
// (1+weight) * m = m + weight*m = one FMA.
hn::StoreU(hn::MulAdd(m0, w0, m0), df32, inout + i);
hn::StoreU(hn::MulAdd(m1, w1, m1), df32, inout + i + N32);
}
}
// f, f -> bf
// TODO(janwas): consider generic function with adapter for loading bf16/f32
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) {
PROFILER_ZONE("ops.RMSNormF F BF");
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<hwy::bfloat16_t> dbf;
const hn::Repartition<float, decltype(dbf)> df32;
using VF = hn::Vec<decltype(df32)>;
const size_t N32 = hn::Lanes(df32);
constexpr float kEps = 1e-6f;
const float ss = SquaredL2(x, size);
const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) {
const VF w0 = hn::LoadU(df32, weight + i);
const VF w1 = hn::LoadU(df32, weight + i + N32);
const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
// (1+weight) * m = m + weight*m = one FMA. // (1+weight) * m = m + weight*m = one FMA.
const VF out0 = hn::MulAdd(m0, w0, m0); const VF out0 = hn::MulAdd(m0, w0, m0);
const VF out1 = hn::MulAdd(m1, w1, m1); const VF out1 = hn::MulAdd(m1, w1, m1);
hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i); detail::Store2(df, out0, out1, out + i);
} }
} }
// x=f, w=bf16 -> bf16 to enable W16A16 MatVec. // Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer.
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( template <typename VecT, typename WeightT>
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight, HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) { const WeightT* HWY_RESTRICT weight, VecT* HWY_RESTRICT inout,
PROFILER_ZONE("ops.RMSNormF BF BF"); const size_t size) {
PROFILER_FUNC;
using TraitsV = CompressTraits<VecT>;
using TraitsW = CompressTraits<WeightT>;
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<hwy::bfloat16_t> dbf; const hn::ScalableTag<float> df;
const hn::Repartition<float, decltype(dbf)> df32; using VF = hn::Vec<decltype(df)>;
using VF = hn::Vec<decltype(df32)>; const size_t NF = hn::Lanes(df);
const size_t N32 = hn::Lanes(df32);
constexpr float kEps = 1e-6f; const VF mul = hn::Set(df, detail::RMSNormMul(inout, size));
const float ss = SquaredL2(x, size);
const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); HWY_DASSERT(size % (2 * MaxLanes(df)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) { for (size_t i = 0; i < size; i += 2 * NF) {
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i); VF v0, v1, w0, w1;
const VF w0 = hn::PromoteLowerTo(df32, w16); TraitsV::Decompress2(df, inout, i, v0, v1);
const VF w1 = hn::PromoteUpperTo(df32, w16); TraitsW::Decompress2(df, weight, i, w0, w1);
const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i)); const VF m0 = hn::Mul(mul, hn::LoadU(df, inout + i));
const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32)); const VF m1 = hn::Mul(mul, hn::LoadU(df, inout + i + NF));
// (1+weight) * m = m + weight*m = one FMA. // (1+weight) * m = m + weight*m = one FMA.
const VF out0 = hn::MulAdd(m0, w0, m0); const VF out0 = hn::MulAdd(m0, w0, m0);
const VF out1 = hn::MulAdd(m1, w1, m1); const VF out1 = hn::MulAdd(m1, w1, m1);
hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i); detail::Store2(df, out0, out1, inout + i);
} }
} }
@ -410,7 +306,6 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
} }
} }
// TODO(janwas): vectorize
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate. // `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate.
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
const float mul, const float* HWY_RESTRICT x, size_t dim_qkv, const float mul, const float* HWY_RESTRICT x, size_t dim_qkv,
@ -419,24 +314,6 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
PROFILER_FUNC; PROFILER_FUNC;
HWY_DASSERT(dim_qkv % 2 == 0); HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2; const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float theta = StaticCast<float>(pos) * inv_timescale[dim];
const float cos_val = cosf(theta);
const float sin_val = sinf(theta);
const float x0 = x[dim];
const float x1 = x[dim + half_dim_qkv];
x_out[dim] = mul * (x0 * cos_val - x1 * sin_val);
x_out[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val);
}
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy(
const float mul, const float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, int pos,
float* HWY_RESTRICT x_out) {
PROFILER_FUNC;
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
using D = hn::ScalableTag<float>; using D = hn::ScalableTag<float>;
using V = hn::Vec<D>; using V = hn::Vec<D>;
@ -685,7 +562,7 @@ SampleArgmax(const float* probabilities, size_t vocab_size) {
} }
template <size_t k> template <size_t k>
static HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int> HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int>
create_distribution(std::array<float, k>& top_k, float temperature) { create_distribution(std::array<float, k>& top_k, float temperature) {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>; using D = hn::ScalableTag<float>;
@ -702,7 +579,7 @@ create_distribution(std::array<float, k>& top_k, float temperature) {
} }
template <size_t k, typename TAcceptToken> template <size_t k, typename TAcceptToken>
static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
const float* HWY_RESTRICT probabilities, size_t vocab_size, const float* HWY_RESTRICT probabilities, size_t vocab_size,
std::mt19937& gen, float temperature, TAcceptToken& accept_token) { std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
static_assert(k != 0, ""); static_assert(k != 0, "");

View File

@ -27,8 +27,15 @@
#include <random> #include <random>
#include <vector> #include <vector>
#include "compression/compress.h" // BF16
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "util/allocator.h"
#include "util/test_util.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h"
// clang-format off // clang-format off
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
@ -36,14 +43,9 @@
// clang-format on // clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h" #include "hwy/highway.h"
#include "hwy/tests/test_util-inl.h"
// After highway.h // After highway.h
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "ops/ops-inl.h" #include "ops/ops-inl.h"
#include "util/allocator.h" #include "hwy/tests/test_util-inl.h"
#include "hwy/tests/hwy_gtest.h"
HWY_BEFORE_NAMESPACE(); HWY_BEFORE_NAMESPACE();
namespace gcpp { namespace gcpp {
@ -366,6 +368,23 @@ void TestSigmoid() {
} }
} }
static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
const float mul, const float* HWY_RESTRICT x, size_t dim_qkv,
const float* HWY_RESTRICT inv_timescale, int pos,
float* HWY_RESTRICT x_out) {
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float theta = StaticCast<float>(pos) * inv_timescale[dim];
const float cos_val = cosf(theta);
const float sin_val = sinf(theta);
const float x0 = x[dim];
const float x1 = x[dim + half_dim_qkv];
x_out[dim] = mul * (x0 * cos_val - x1 * sin_val);
x_out[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val);
}
}
void TestRopeAndMulBy() { void TestRopeAndMulBy() {
using Config = ConfigGemma2_9B<float>; using Config = ConfigGemma2_9B<float>;
int dim_qkv = Config::kQKVDim; int dim_qkv = Config::kQKVDim;
@ -392,9 +411,9 @@ void TestRopeAndMulBy() {
// Assert VectorizedRope computation is same as regular rope at different pos. // Assert VectorizedRope computation is same as regular rope at different pos.
for (int pos = 1; pos < 500; pos++) { for (int pos = 1; pos < 500; pos++) {
// Rope'd Q embeddings // Rope'd Q embeddings
RopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, ScalarRopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
qexpected.data()); qexpected.data());
VectorizedRopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, RopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
qactual.data()); qactual.data());
for (int i = 0; i < dim_qkv; ++i) { for (int i = 0; i < dim_qkv; ++i) {
@ -403,9 +422,9 @@ void TestRopeAndMulBy() {
} }
// Rope'd K embeddings // Rope'd K embeddings
RopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, ScalarRopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
kexpected.data()); kexpected.data());
VectorizedRopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos, RopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
kactual.data()); kactual.data());
for (int i = 0; i < dim_qkv; ++i) { for (int i = 0; i < dim_qkv; ++i) {
@ -415,6 +434,70 @@ void TestRopeAndMulBy() {
} }
} }
template <typename T>
HWY_NOINLINE float ScalarSquaredL2(const T* HWY_RESTRICT a, size_t size) {
double sum = 0.0;
for (size_t i = 0; i < size; ++i) {
const float f = hwy::ConvertScalarTo<float>(a[i]);
sum += f * f;
}
return static_cast<float>(sum);
}
// Supports bf16 and f32 inputs/outputs, which can be in-place.
template <typename VecT, typename WeightT, typename OutT>
HWY_NOINLINE void ScalarRMSNorm(const VecT* x,
const WeightT* HWY_RESTRICT weight, OutT* out,
size_t size) {
constexpr float kEps = 1e-6f;
float ss = ScalarSquaredL2(x, size);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
for (size_t j = 0; j < size; j++) {
const float v = hwy::ConvertScalarTo<float>(x[j]);
const float w = hwy::ConvertScalarTo<float>(weight[j]);
// Note 1.0f centering here
out[j] = hwy::ConvertScalarTo<OutT>((1.0f + w) * (ss * v));
}
}
template <typename VecT, typename WeightT, typename OutT>
void TestRMSNorm(hwy::RandomState& rng) {
constexpr size_t kSize = 128;
VecT vec[kSize];
WeightT weight[kSize];
OutT expected[kSize];
OutT actual[kSize];
for (size_t i = 0; i < kSize; ++i) {
vec[i] = hwy::ConvertScalarTo<VecT>(RandomGaussian(rng));
weight[i] = hwy::ConvertScalarTo<WeightT>(RandomGaussian(rng));
}
ScalarRMSNorm(vec, weight, expected, kSize);
RMSNorm(vec, weight, actual, kSize);
for (size_t i = 0; i < kSize; i++) {
const float e = hwy::ConvertScalarTo<float>(expected[i]);
const float a = hwy::ConvertScalarTo<float>(actual[i]);
if (!IsNear(e, a, 1e-5f)) {
HWY_ABORT("RMSNorm %s %s %s mismatch at %zu: %E %E\n", TypeName(VecT()),
TypeName(WeightT()), TypeName(OutT()), i, e, a);
}
}
}
void TestAllRMSNorm() {
hwy::RandomState rng;
TestRMSNorm<float, float, float>(rng);
TestRMSNorm<float, float, BF16>(rng);
TestRMSNorm<float, BF16, float>(rng);
TestRMSNorm<float, BF16, BF16>(rng);
TestRMSNorm<BF16, float, float>(rng);
TestRMSNorm<BF16, float, BF16>(rng);
TestRMSNorm<BF16, BF16, float>(rng);
TestRMSNorm<BF16, BF16, BF16>(rng);
}
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE
} // namespace gcpp } // namespace gcpp
@ -432,6 +515,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid); HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy); HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy);
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm);
HWY_AFTER_TEST(); HWY_AFTER_TEST();
} // namespace gcpp } // namespace gcpp

View File

@ -13,8 +13,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_ #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_TEST_UTIL_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_TEST_UTIL_H_
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
@ -24,7 +24,6 @@
#include "hwy/base.h" #include "hwy/base.h"
// IWYU pragma: begin_exports // IWYU pragma: begin_exports
#include "compression/distortion.h"
#include "hwy/stats.h" #include "hwy/stats.h"
#include "hwy/tests/test_util.h" // RandomState #include "hwy/tests/test_util.h" // RandomState
// IWYU pragma: end_exports // IWYU pragma: end_exports
@ -73,4 +72,4 @@ HWY_INLINE void VerifyGaussian(hwy::Stats& stats) {
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_ #endif // THIRD_PARTY_GEMMA_CPP_UTIL_TEST_UTIL_H_