From 85cac13fb106b33232c4d31fd2a972fe99c1e970 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 19 Jul 2024 11:21:16 -0700 Subject: [PATCH] Split up ops.h into ops/ops-inl and matmul-inl PiperOrigin-RevId: 654068303 --- BUILD.bazel | 24 +- CMakeLists.txt | 3 +- backprop/backward-inl.h | 3 +- backprop/backward_test.cc | 2 +- backprop/forward-inl.h | 3 +- gemma/gemma-inl.h | 3 +- gemma/ops.h => ops/matmul-inl.h | 633 +---------------------- gemma/ops_test.cc => ops/matmul_test.cc | 366 ++----------- ops/ops-inl.h | 652 ++++++++++++++++++++++++ ops/ops_test.cc | 384 ++++++++++++++ 10 files changed, 1105 insertions(+), 968 deletions(-) rename gemma/ops.h => ops/matmul-inl.h (57%) rename gemma/ops_test.cc => ops/matmul_test.cc (63%) create mode 100644 ops/ops-inl.h create mode 100644 ops/ops_test.cc diff --git a/BUILD.bazel b/BUILD.bazel index e652d72..6ee5ec6 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -22,7 +22,10 @@ exports_files(["LICENSE"]) cc_library( name = "ops", - hdrs = ["gemma/ops.h"], + textual_hdrs = [ + "ops/ops-inl.h", + "ops/matmul-inl.h", + ], deps = [ "//compression:compress", "//compression:sfp", @@ -40,7 +43,24 @@ cc_test( name = "ops_test", size = "small", timeout = "long", - srcs = ["gemma/ops_test.cc"], + srcs = ["ops/ops_test.cc"], + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = [ + ":ops", + "@googletest//:gtest_main", # buildcleaner: keep + "@hwy//:hwy", + "@hwy//:hwy_test_util", + "@hwy//:nanobenchmark", + ], +) + +cc_test( + name = "matmul_test", + size = "small", + timeout = "long", + srcs = ["ops/matmul_test.cc"], local_defines = ["HWY_IS_TEST"], # for test_suite. tags = ["hwy_ops_test"], diff --git a/CMakeLists.txt b/CMakeLists.txt index bc6dc14..b1347e0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -93,11 +93,12 @@ set(SOURCES gemma/instantiations/tiny_sfp.cc gemma/kv_cache.cc gemma/kv_cache.h - gemma/ops.h gemma/tokenizer.cc gemma/tokenizer.h gemma/weights.cc gemma/weights.h + ops/matmul-inl.h + ops/ops-inl.h util/app.h util/args.h ) diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index 492e4e7..62e2d13 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -41,7 +41,8 @@ #define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE #endif -#include "gemma/ops.h" +#include "ops/matmul-inl.h" +#include "ops/ops-inl.h" #include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc index 1e7dbd1..bf8cf5f 100644 --- a/backprop/backward_test.cc +++ b/backprop/backward_test.cc @@ -45,7 +45,7 @@ // After highway.h #include "backprop/backward-inl.h" #include "backprop/forward-inl.h" -#include "gemma/ops.h" +#include "ops/ops-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h index 636c23c..11ac050 100644 --- a/backprop/forward-inl.h +++ b/backprop/forward-inl.h @@ -39,7 +39,8 @@ #define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE #endif -#include "gemma/ops.h" +#include "ops/matmul-inl.h" +#include "ops/ops-inl.h" #include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 864a59f..2dd1d63 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -37,9 +37,10 @@ #include "gemma/common.h" #include "gemma/configs.h" #include "gemma/gemma.h" -#include "gemma/ops.h" #include "gemma/weights.h" // Placeholder for internal test4, do not remove +#include "ops/matmul-inl.h" +#include "ops/ops-inl.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/bit_set.h" diff --git a/gemma/ops.h b/ops/matmul-inl.h similarity index 57% rename from gemma/ops.h rename to ops/matmul-inl.h index 769deae..3a912f0 100644 --- a/gemma/ops.h +++ b/ops/matmul-inl.h @@ -14,8 +14,8 @@ // limitations under the License. // Include guard for non-SIMD code. -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_ +#ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_ +#define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_ #include #include @@ -24,7 +24,6 @@ #include #include -#include #include // std::enable_if_t #include "compression/compress.h" @@ -34,18 +33,17 @@ #include "hwy/detect_targets.h" #include "hwy/profiler.h" -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_ +#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_ // Include guard for (potentially) SIMD code. -#if defined(THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE) == defined(HWY_TARGET_TOGGLE) -#ifdef THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE -#undef THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE +#if defined(THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE) == defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE #else -#define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE +#define THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE #endif #include "compression/compress-inl.h" -#include "hwy/contrib/algo/transform-inl.h" #include "hwy/contrib/dot/dot-inl.h" #include "hwy/contrib/math/math-inl.h" #include "hwy/contrib/matvec/matvec-inl.h" @@ -60,40 +58,6 @@ HWY_INLINE constexpr size_t MaxCols() { return 2048; } -template -HWY_INLINE constexpr std::enable_if_t< - std::is_arithmetic_v && std::is_arithmetic_v, To> -StaticCast(From from) noexcept { - if constexpr (std::is_unsigned_v && std::is_floating_point_v) - return static_cast( - static_cast>(from)); - else - return static_cast(from); -} - -// For testing. -template -void AssertClose(const MatT* HWY_RESTRICT expected, - const MatT* HWY_RESTRICT actual, size_t num) { - for (size_t idx = 0; idx < num; idx++) { - const double expected_value = hwy::ConvertScalarTo(expected[idx]); - const double actual_value = hwy::ConvertScalarTo(actual[idx]); - - const double magnitude = std::abs(expected_value); - - const double tolerance = - 256.0 * hwy::ConvertScalarTo(hwy::Epsilon()) * - HWY_MAX(magnitude, 1.0); - - if (!(expected_value - tolerance <= actual_value && - actual_value <= expected_value + tolerance)) { - fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx, - expected_value, idx, actual_value); - HWY_ASSERT(0); - } - } -} - template HWY_INLINE constexpr size_t RowsPerStrip() { // Aim for 128 work items to reduce pool overhead. Must be at least one @@ -726,112 +690,6 @@ HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs, even_odd, out, pool); } -template -static HWY_INLINE hn::Vec Gelu(D d, hn::Vec v) { - const hn::Vec kMul = hn::Set(d, 0.044715f); - const hn::Vec kSqrt2OverPi = hn::Set(d, 0.797884560804236f); - const hn::Vec kHalf = hn::Set(d, 0.5f); - - // tanh approximation matches training. - const hn::Vec v3 = hn::Mul(hn::Mul(v, v), v); - const hn::Vec arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v)); - // 0.5 * (1 + tan) = MulAdd(0.5, tan, 0.5). - const hn::Vec cdf = hn::MulAdd(kHalf, hn::Tanh(d, arg), kHalf); - return hn::Mul(v, cdf); -} - -static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x, - size_t size) { - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - hn::Transform(D(), x, size, - [](D d, hn::Vec v) HWY_ATTR { return Gelu(d, v); }); -} - -// out[i] = BF(mul[i] * Gelu(gelu_in[i])) -static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16( - const float* HWY_RESTRICT gelu_in, const float* HWY_RESTRICT mul, - hwy::bfloat16_t* HWY_RESTRICT out, size_t size) { - namespace hn = hwy::HWY_NAMESPACE; - const hn::ScalableTag df; - const hn::Repartition dbf; - const size_t NF = hn::Lanes(df); - using VF = hn::Vec; - - size_t i = 0; - if (size >= 2 * NF) { - for (; i <= size - 2 * NF; i += 2 * NF) { - const VF mul0 = hn::LoadU(df, mul + i); - const VF mul1 = hn::LoadU(df, mul + i + NF); - const VF g0 = hn::Mul(mul0, Gelu(df, hn::LoadU(df, gelu_in + i))); - const VF g1 = hn::Mul(mul1, Gelu(df, hn::LoadU(df, gelu_in + i + NF))); - const hn::Vec bf = hn::OrderedDemote2To(dbf, g0, g1); - hn::StoreU(bf, dbf, out + i); - } - } - if (i != size) { - const size_t remaining = size - i; - const VF mul0 = hn::LoadN(df, mul + i, remaining); - const VF g0 = - hn::Mul(mul0, Gelu(df, hn::LoadN(df, gelu_in + i, remaining))); - const hn::Half dbfh; - const hn::Vec bfh = hn::DemoteTo(dbfh, g0); - hn::StoreN(bfh, dbfh, out + i, remaining); - } -} - -template -static HWY_INLINE hn::Vec Sigmoid(D d, hn::Vec v) { - using VF = hn::Vec; - // Chebyshev polynomial coefficients for rational approximation - const VF c0 = hn::Set(d, 0.00949107017368078f); - const VF c1 = hn::Set(d, 0.0654858946800232f); - const VF c2 = hn::Set(d, 0.231547489762306f - 0.00949107017368078f); - const VF c3 = hn::Set(d, 0.530778527259827f); - const VF c4 = hn::Set(d, 0.855334937572479f); - const VF c5 = hn::Set(d, 0.500000894069672f); - - const VF d0 = hn::Set(d, 0.130970627069473f); - const VF d1 = hn::Set(d, 3.99615288415589e-07f); - const VF d2 = hn::Set(d, 1.06155431270599f - 0.130970627069473f); - const VF d3 = hn::Set(d, 1.35144250634767e-06f); - const VF d4 = hn::Set(d, 1); - - // The approximation works in range -12..12, but the input value is clamped - // in -11.5..11.5 since the approximation slightly overshoots after that. - // The function is nearly 0 for input values below -11.5 and nearly 1 for - // input values above 11.5. - const VF invtwelve = hn::Set(d, 1.0f / 12.0f); - const VF lo = hn::Set(d, -11.5f); - const VF hi = hn::Set(d, 11.5f); - - VF f = hn::Clamp(v, lo, hi); - f = hn::Mul(f, invtwelve); - VF f2 = hn::Add(f, f); - - VF a1 = hn::MulAdd(f2, c0, c1); - VF a2 = hn::MulAdd(f2, a1, c2); - VF a3 = hn::Sub(hn::MulAdd(f2, a2, c3), a1); - VF a4 = hn::Sub(hn::MulAdd(f2, a3, c4), a2); - VF f0 = hn::Sub(hn::MulAdd(f, a4, c5), a3); - - VF b1 = hn::MulAdd(f2, d0, d1); - VF b2 = hn::MulAdd(f2, b1, d2); - VF b3 = hn::Sub(hn::MulAdd(f2, b2, d3), b1); - VF f1 = hn::Sub(hn::MulAdd(f, b3, d4), b2); - - return hn::Div(f0, f1); -} - -// Sigmoid using the logistic function 1 / (1 + exp(-x[i])) -static HWY_NOINLINE HWY_MAYBE_UNUSED void Sigmoid(float* HWY_RESTRICT x, - size_t size) { - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - hn::Transform(D(), x, size, - [](D d, hn::Vec v) HWY_ATTR { return Sigmoid(d, v); }); -} - // Two matrices, same vector template @@ -897,483 +755,6 @@ HWY_NOINLINE void TwoMatVec(const ArrayT& mat0, const ArrayT& mat1, out0, out1, pool); } -static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a, - const float* HWY_RESTRICT b, - size_t size) { - const hn::ScalableTag d; - HWY_DASSERT(size >= hn::Lanes(d)); - HWY_DASSERT(size % hn::Lanes(d) == 0); - constexpr int kAssumptions = - hn::Dot::kAtLeastOneVector | hn::Dot::kMultipleOfVector; - return hn::Dot::Compute(d, a, b, size); -} - -// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT. -static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2( - const float* HWY_RESTRICT a, size_t size) { - const hn::ScalableTag d; - using V = hn::Vec; - const size_t N = hn::Lanes(d); - HWY_DASSERT(size >= 2 * N); - HWY_DASSERT(size % (2 * N) == 0); - - V sum0 = hn::Zero(d); - V sum1 = hn::Zero(d); - for (size_t i = 0; i <= size - 2 * N; i += 2 * N) { - const V a0 = hn::LoadU(d, a + i); - sum0 = hn::MulAdd(a0, a0, sum0); - const V a1 = hn::LoadU(d, a + i + N); - sum1 = hn::MulAdd(a1, a1, sum1); - } - - return hn::ReduceSum(d, hn::Add(sum0, sum1)); -} - -// float, float -> float; simple loop. -static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( - const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight, - float* HWY_RESTRICT out, size_t size) { - constexpr float kEps = 1e-6f; - float ss = SquaredL2(x, size); - ss = 1.0f / sqrtf(ss / StaticCast(size) + kEps); - for (size_t j = 0; j < size; j++) { - // Note 1.0f centering here - out[j] = (1.0f + weight[j]) * (ss * x[j]); - } -} - -// x=f, w=bf16 -> out=f -static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( - const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight, - float* HWY_RESTRICT out, size_t size) { - namespace hn = hwy::HWY_NAMESPACE; - - constexpr float kEps = 1e-6f; - constexpr size_t kUnrollSize = 2; - - const hn::ScalableTag dbf; - const hn::Repartition df32; - const size_t N32 = hn::Lanes(df32); - - const float ss = SquaredL2(x, size); - const auto vss = - hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + kEps)); - - HWY_DASSERT(size % (kUnrollSize * MaxLanes(df32)) == 0); - for (size_t i = 0; i < size; i += kUnrollSize * N32) { - const hn::Vec w16 = hn::LoadU(dbf, weight + i); - const auto w0 = hn::PromoteLowerTo(df32, w16); - const auto w1 = hn::PromoteUpperTo(df32, w16); - const auto m0 = hn::Mul(vss, hn::LoadU(df32, x + i)); - const auto m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32)); - - // (1+weight) * m = m + weight*m = one FMA. - hn::StoreU(hn::MulAdd(m0, w0, m0), df32, out + i); - hn::StoreU(hn::MulAdd(m1, w1, m1), df32, out + i + N32); - } -} - -// float -> float; simple loop. -static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( - const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) { - constexpr float kEps = 1e-6f; - float ss = SquaredL2(inout, size); - ss = 1.0f / sqrtf(ss / StaticCast(size) + kEps); - for (size_t j = 0; j < size; j++) { - // Note 1.0f centering here - inout[j] = (1.0f + weight[j]) * (ss * inout[j]); - } -} - -// w=bf16 -> f -static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( - const hwy::bfloat16_t* HWY_RESTRICT weight, float* HWY_RESTRICT inout, - const size_t size) { - namespace hn = hwy::HWY_NAMESPACE; - const hn::ScalableTag dbf; - const hn::Repartition df32; - using VF = hn::Vec; - const size_t N32 = hn::Lanes(df32); - - constexpr float kEps = 1e-6f; - const float ss = SquaredL2(inout, size); - const VF vss = - hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + kEps)); - - HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); - for (size_t i = 0; i < size; i += 2 * N32) { - const hn::Vec w16 = hn::LoadU(dbf, weight + i); - const VF w0 = hn::PromoteLowerTo(df32, w16); - const VF w1 = hn::PromoteUpperTo(df32, w16); - const VF m0 = hn::Mul(vss, hn::LoadU(df32, inout + i)); - const VF m1 = hn::Mul(vss, hn::LoadU(df32, inout + i + N32)); - // (1+weight) * m = m + weight*m = one FMA. - hn::StoreU(hn::MulAdd(m0, w0, m0), df32, inout + i); - hn::StoreU(hn::MulAdd(m1, w1, m1), df32, inout + i + N32); - } -} - -// f, f -> bf -// TODO(janwas): consider generic function with adapter for loading bf16/f32 -static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( - const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight, - hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) { - namespace hn = hwy::HWY_NAMESPACE; - const hn::ScalableTag dbf; - const hn::Repartition df32; - using VF = hn::Vec; - const size_t N32 = hn::Lanes(df32); - - constexpr float kEps = 1e-6f; - const float ss = SquaredL2(x, size); - const VF vss = - hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + kEps)); - - HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); - for (size_t i = 0; i < size; i += 2 * N32) { - const VF w0 = hn::LoadU(df32, weight + i); - const VF w1 = hn::LoadU(df32, weight + i + N32); - const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i)); - const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32)); - // (1+weight) * m = m + weight*m = one FMA. - const VF out0 = hn::MulAdd(m0, w0, m0); - const VF out1 = hn::MulAdd(m1, w1, m1); - hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i); - } -} - -// x=f, w=bf16 -> bf16 to enable W16A16 MatVec. -static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( - const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight, - hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) { - namespace hn = hwy::HWY_NAMESPACE; - const hn::ScalableTag dbf; - const hn::Repartition df32; - using VF = hn::Vec; - const size_t N32 = hn::Lanes(df32); - - constexpr float kEps = 1e-6f; - const float ss = SquaredL2(x, size); - const VF vss = - hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + kEps)); - - HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); - for (size_t i = 0; i < size; i += 2 * N32) { - const hn::Vec w16 = hn::LoadU(dbf, weight + i); - const VF w0 = hn::PromoteLowerTo(df32, w16); - const VF w1 = hn::PromoteUpperTo(df32, w16); - const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i)); - const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32)); - // (1+weight) * m = m + weight*m = one FMA. - const VF out0 = hn::MulAdd(m0, w0, m0); - const VF out1 = hn::MulAdd(m1, w1, m1); - hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i); - } -} - -static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( - float* HWY_RESTRICT x, size_t dim_model, size_t pos) { - const size_t num_timescales = dim_model / 2; - const float log_timescale_increment = - logf(10000.0f) / - (num_timescales != 0 ? StaticCast(num_timescales - 1) : 1.0f); - for (size_t dim = 0; dim < num_timescales; ++dim) { - const float inv_timescale = - expf(StaticCast(dim) * -log_timescale_increment); - x[dim] += sinf(StaticCast(pos) * inv_timescale); - x[num_timescales + dim] += cosf(StaticCast(pos) * inv_timescale); - } -} - -/* RoPE as in Rotary Position Embeddings from the RoFormer paper - (https://arxiv.org/abs/2104.09864v5). The query and key vectors are rotated - as a function of their absolute position using the rotation matrix R before - the self-attention operation. R is a d x d matrix. - - R = cos(m*theta_1) -sin(m*theta_1) ... 0 0 - sin(m*theta_1) cos(m*theta_1) - 0 0 ... 0 0 - 0 0 ... 0 0 - ... - 0 0 ... cos(m*theta_{d/2}) sin(m*theta_{d/2}) - 0 0 ... sin(m*theta_{d/2}) cos(m*theta_{d/2}) - - Here theta_i = 10000^(-2(i-1)/d), where d is the dimension of the vector and - i is the ith index of the vector. - - Applying the rotation matrix R to a vector v is equivalent to rotating every - consecutive pair of dimensions of v i.e. v_{2i} and v_{2i+1} by an angle - m*theta_i. However in the Gemma implementation we choose to rotate - the pairs of dimensions v_{i} and v_{i + d//2} instead. - - pos parameter is deliberately an int because in the backward pass we - call this with negative values (for the VJP calculation we need the transpose - of this rotation matrix which is simply the same matrix with -pos parameter) -*/ - -static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x, - size_t dim_qkv, int pos) { - HWY_DASSERT(dim_qkv % 2 == 0); - const size_t half_dim_qkv = dim_qkv / 2; - for (size_t dim = 0; dim < half_dim_qkv; ++dim) { - const float freq_exponents = - StaticCast(2 * dim) / StaticCast(dim_qkv); - // Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably. - const float timescale = powf(10000.0f, freq_exponents); - const float theta = StaticCast(pos) / timescale; - const float cos_val = cosf(theta); - const float sin_val = sinf(theta); - const float x0 = x[dim]; - const float x1 = x[dim + half_dim_qkv]; - x[dim] = x0 * cos_val - x1 * sin_val; - x[dim + half_dim_qkv] = x0 * sin_val + x1 * cos_val; - } -} - -static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul, - float* HWY_RESTRICT x, - size_t dim_qkv, - int pos) { - HWY_DASSERT(dim_qkv % 2 == 0); - const size_t half_dim_qkv = dim_qkv / 2; - for (size_t dim = 0; dim < half_dim_qkv; ++dim) { - const float freq_exponents = - StaticCast(2 * dim) / StaticCast(dim_qkv); - // Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably. - const float timescale = powf(10000.0f, freq_exponents); - const float theta = StaticCast(pos) / timescale; - const float cos_val = cosf(theta); - const float sin_val = sinf(theta); - const float x0 = x[dim]; - const float x1 = x[dim + half_dim_qkv]; - x[dim] = mul * (x0 * cos_val - x1 * sin_val); - x[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val); - } -} - -static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom( - const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) { - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - using V = hn::Vec; - - hn::Transform1(D(), x, size, other, - [](const auto d, const V x, const V other) - HWY_ATTR { return hn::Add(x, other); }); -} - -// Simple loops unless/until batch sizes are large enough to parallelize. -template -void RMSNormBatched(size_t num_tokens, const float* activations, - const WeightT* weights, OutT* out, const size_t model_dim) { - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - RMSNorm(activations + token_idx * model_dim, weights, - out + token_idx * model_dim, model_dim); - } -} - -// TODO: pass RowVectorBatch argument. -template -void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights, - InOutT* inout, const size_t model_dim) { - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - RMSNormInplace(weights, inout + token_idx * model_dim, model_dim); - } -} - -static HWY_INLINE void AddFromBatched(size_t num_tokens, const float* other, - float* x, const size_t model_dim) { - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - AddFrom(other + token_idx * model_dim, x + token_idx * model_dim, - model_dim); - } -} - -static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other, - float* HWY_RESTRICT x, const size_t size, - const size_t max_pos) { - HWY_DASSERT(max_pos <= size); - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - using V = hn::Vec; - - hn::Transform1(D(), x, max_pos, other, - [](const auto d, const V x, const V other) - HWY_ATTR { return hn::Mul(x, other); }); -} - -static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other, - float* HWY_RESTRICT x, - const size_t size) { - return MulBy(other, x, size, size); -} - -static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x, - const size_t size, const size_t max_pos) { - HWY_DASSERT(max_pos <= size); - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - using V = hn::Vec; - hn::Transform(D(), x, max_pos, [c](const auto d, const V x) HWY_ATTR { - return hn::Mul(x, hn::Set(d, c)); - }); -} - -static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c, - float* HWY_RESTRICT x, - const size_t size) { - MulByConst(c, x, size, size); -} - -static HWY_NOINLINE void MulByConstAndAdd(const float c, - const float* HWY_RESTRICT x, - float* HWY_RESTRICT out, - const size_t size, - const size_t max_pos) { - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - using V = hn::Vec; - hn::Transform1(D(), out, max_pos, x, - [c](const auto d, const V v_out, const V v_x) HWY_ATTR { - return hn::MulAdd(v_x, hn::Set(d, c), v_out); - }); -} - -static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( - float c, const float* HWY_RESTRICT x, float* HWY_RESTRICT out, - size_t size) { - MulByConstAndAdd(c, x, out, size, size); -} - -static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, - const size_t mask_pos) { - HWY_DASSERT(size != 0); - HWY_DASSERT(mask_pos <= size); - - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - using V = hn::Vec; - const D d; - - const V vmin = hn::Set(d, hwy::LowestValue()); - V vmax = vmin; - V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly - Foreach(d, x, mask_pos, vmin, [pmax](const auto d, const V value) HWY_ATTR { - *pmax = hn::Max(*pmax, value); - }); - vmax = hn::MaxOfLanes(d, vmax); - - // Subtract max (avoid precision loss for large exponents) and exponentiate. - hn::Transform(d, x, mask_pos, [pmax](const auto d, const V value) HWY_ATTR { -#if HWY_TARGET & HWY_ALL_SVE - // Temporary workaround for buggy SVE codegen: avoid inlined - // Exp(). - return hn::CallExp(d, hn::Sub(value, *pmax)); -#else - return hn::Exp(d, hn::Sub(value, *pmax)); -#endif - }); - - V sum = hn::Zero(d); - V* psum = ∑ - Foreach(d, x, mask_pos, sum, [psum](const auto d, const V value) HWY_ATTR { - *psum = hn::Add(*psum, value); - }); - - // Normalize to probability distribution - const float mul = 1.0f / hn::ReduceSum(d, sum); - MulByConst(mul, x, size, mask_pos); -} - -static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x, - const size_t size) { - Softmax(x, size, size); -} - -static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, - const size_t size, - const size_t max_pos) { - HWY_DASSERT(max_pos <= size); - - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - using V = hn::Vec; - - const float inv_cap = 1.0f / cap; - - hn::Transform(D(), x, max_pos, [cap, inv_cap](D d, V v) HWY_ATTR { - return hn::Mul(hn::Set(d, cap), - hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap)))); - }); -} - -static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap, - float* HWY_RESTRICT x, - const size_t size) { - LogitsSoftCap(cap, x, size, size); -} - -static HWY_NOINLINE HWY_MAYBE_UNUSED size_t -SampleArgmax(const float* probabilities, size_t vocab_size) { - size_t max_index = 0; - float max_prob = probabilities[0]; - for (size_t i = 1; i < vocab_size; ++i) { - if (probabilities[i] > max_prob) { - max_index = i; - max_prob = probabilities[i]; - } - } - return max_index; -} - -template -static HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution -create_distribution(std::array& top_k, float temperature) { - namespace hn = hwy::HWY_NAMESPACE; - using D = hn::ScalableTag; - - // re-normalize distribution - const float temperature_inv = 1.0f / temperature; - hn::Transform(D(), top_k.data(), top_k.size(), - [temperature_inv](D d, hn::Vec v) HWY_ATTR { - return hn::Exp( - d, hn::Mul(hn::Log(d, v), hn::Set(d, temperature_inv))); - }); - - return std::discrete_distribution(std::begin(top_k), std::end(top_k)); -} - -template -static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( - const float* HWY_RESTRICT probabilities, size_t vocab_size, - std::mt19937& gen, float temperature, TAcceptToken& accept_token) { - static_assert(k != 0, ""); - // TODO: Optimize, potentially using new VQSort PartialSort. - std::array top_k{}; // sorted from highest [0], to lowest [k-1] - std::array indices{}; - for (size_t i = 0; i < vocab_size; ++i) { - if (probabilities[i] < top_k[k - 1] && - (!accept_token || accept_token(StaticCast(i), probabilities[i]))) { - continue; - } - for (size_t j = 0; j < k; ++j) { - if (probabilities[i] > top_k[j] && - (!accept_token || - accept_token(StaticCast(i), probabilities[i]))) { - // shift elements by 1, insert the new value, move on to next value - for (size_t idx = k - 1; idx > j; --idx) { - top_k[idx] = top_k[idx - 1]; - indices[idx] = indices[idx - 1]; - } - top_k[j] = probabilities[i]; - indices[j] = StaticCast(i); - break; - } - } - } - return indices[create_distribution(top_k, temperature)(gen)]; -} - // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/gemma/ops_test.cc b/ops/matmul_test.cc similarity index 63% rename from gemma/ops_test.cc rename to ops/matmul_test.cc index 04462ee..322cb89 100644 --- a/gemma/ops_test.cc +++ b/ops/matmul_test.cc @@ -13,8 +13,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS HWY_SCALAR #endif @@ -25,8 +23,7 @@ #include #include #include -#include -#include +#include #include "compression/compress.h" #include "hwy/aligned_allocator.h" @@ -36,13 +33,14 @@ // clang-format off #undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "gemma/ops_test.cc" // NOLINT +#define HWY_TARGET_INCLUDE "ops/matmul_test.cc" // NOLINT // clang-format on #include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/highway.h" #include "hwy/tests/test_util-inl.h" // After highway.h -#include "gemma/ops.h" +#include "ops/matmul-inl.h" +#include "ops/ops-inl.h" // MulByConst HWY_BEFORE_NAMESPACE(); namespace gcpp { @@ -50,304 +48,6 @@ namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -template -struct ForeachCountAndMisalign { - template - HWY_NOINLINE void operator()(T /*unused*/, D d) const { - hwy::RandomState rng; - const size_t N = Lanes(d); - const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; - - for (size_t count = 0; count < 2 * N; ++count) { - for (size_t ma : misalignments) { - for (size_t mb : misalignments) { - Test()(d, count, ma, mb, rng); - } - } - } - } -}; - -template -T Random(hwy::RandomState& rng) { - const int32_t bits = static_cast(Random32(&rng)) & 1023; - const double val = (bits - 512) / 64.0; - // Clamp negative to zero for unsigned types. - return hwy::ConvertScalarTo( - HWY_MAX(hwy::ConvertScalarTo(hwy::LowestValue()), val)); -} - -HWY_NOINLINE void SourceAddFrom(const float* HWY_RESTRICT other, - float* HWY_RESTRICT x, size_t size) { - for (size_t i = 0; i < size; ++i) { - x[i] += other[i]; - } -} - -HWY_NOINLINE void SourceMulBy(const float* HWY_RESTRICT other, - float* HWY_RESTRICT x, size_t size, - size_t max_pos) { - HWY_DASSERT(max_pos <= size); - for (size_t i = 0; i < max_pos; ++i) { - x[i] *= other[i]; - } -} - -HWY_NOINLINE void SourceMulByConst(float c, float* HWY_RESTRICT x, size_t size, - size_t max_pos) { - for (size_t i = 0; i < max_pos; ++i) { - x[i] *= c; - } -} - -HWY_NOINLINE void SourceMulByConstAndAdd(float c, const float* HWY_RESTRICT x, - float* HWY_RESTRICT out, size_t size, - size_t max_pos) { - for (size_t i = 0; i < max_pos; ++i) { - out[i] += x[i] * c; - } -} - -HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size, - size_t mask_pos) { - HWY_DASSERT(size != 0); - HWY_DASSERT(mask_pos <= size); - float sum = 0.0; - const float maxval = *std::max_element(x, x + mask_pos); - for (size_t i = 0; i < mask_pos; ++i) { - x[i] = std::exp(x[i] - maxval); - sum += x[i]; - } - const float scale = 1.0f / sum; - for (size_t i = 0; i < mask_pos; ++i) { - x[i] *= scale; - } -} - -template -HWY_NOINLINE std::discrete_distribution SourceCreateDistribution( - std::array& top_k, float temperature) { - // re-normalize distribution - for (size_t i = 0; i < k; ++i) { - top_k[i] = exp(log(top_k[i]) / temperature); - } - float denominator = 0.0f; - for (size_t i = 0; i < k; ++i) { - denominator += top_k[i]; - } - denominator = 1.0f / denominator; - MulByConst(denominator, top_k.data(), k); - return std::discrete_distribution(std::begin(top_k), std::end(top_k)); -} - -struct TestAddFrom { - template - void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, - hwy::RandomState& rng) { - using T = hn::TFromD; - - hwy::AlignedFreeUniquePtr px = - hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); - hwy::AlignedFreeUniquePtr pe = - hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); - hwy::AlignedFreeUniquePtr po = - hwy::AllocateAligned(HWY_MAX(1, misalign_b + count)); - HWY_ASSERT(px && pe && po); - - T* x = px.get() + misalign_a; - T* e = pe.get() + misalign_a; - T* o = po.get() + misalign_b; - - for (size_t i = 0; i < count; ++i) { - x[i] = Random(rng); - e[i] = x[i]; - o[i] = Random(rng); - } - - SourceAddFrom(o, e, count); - AddFrom(o, x, count); - - hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, - __LINE__); - } -}; - -struct TestMulBy { - template - void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, - hwy::RandomState& rng) { - using T = hn::TFromD; - - hwy::AlignedFreeUniquePtr px = - hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); - hwy::AlignedFreeUniquePtr pe = - hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); - hwy::AlignedFreeUniquePtr po = - hwy::AllocateAligned(HWY_MAX(1, misalign_b + count)); - HWY_ASSERT(px && pe && po); - - T* x = px.get() + misalign_a; - T* e = pe.get() + misalign_a; - T* o = po.get() + misalign_b; - - for (size_t i = 0; i < count; ++i) { - x[i] = Random(rng); - e[i] = x[i]; - o[i] = Random(rng); - } - - SourceMulBy(o, e, count, count); - MulBy(o, x, count, count); - - hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, - __LINE__); - } -}; - -struct TestMulByConstAndAdd { - template - void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, - hwy::RandomState& rng) { - using T = hn::TFromD; - - hwy::AlignedFreeUniquePtr px = - hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); - hwy::AlignedFreeUniquePtr pe = - hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); - hwy::AlignedFreeUniquePtr po = - hwy::AllocateAligned(HWY_MAX(1, misalign_b + count)); - HWY_ASSERT(px && pe && po); - - T* x = px.get() + misalign_a; - T* e = pe.get() + misalign_a; - T* o = po.get() + misalign_b; - - for (size_t i = 0; i < count; ++i) { - x[i] = Random(rng); - e[i] = x[i]; - o[i] = Random(rng); - } - T constant = Random(rng); - - SourceMulByConstAndAdd(constant, o, e, count, count); - MulByConstAndAdd(constant, o, x, count, count); - - hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, - __LINE__); - } -}; - -struct TestMulByConst { - template - void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, - hwy::RandomState& rng) { - if (misalign_b == 0) return; - using T = hn::TFromD; - - hwy::AlignedFreeUniquePtr px = - hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); - hwy::AlignedFreeUniquePtr pe = - hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); - HWY_ASSERT(px && pe); - - T* x = px.get() + misalign_a; - T* e = pe.get() + misalign_a; - - for (size_t i = 0; i < count; ++i) { - x[i] = Random(rng); - e[i] = x[i]; - } - T constant = Random(rng); - - SourceMulByConst(constant, e, count, count); - MulByConst(constant, x, count, count); - - hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, - __LINE__); - } -}; - -struct TestSoftmax { - template - void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, - hwy::RandomState& rng) { - if (count == 0) return; // *Softmax would assert - if (misalign_b == 0) return; - using T = hn::TFromD; - - hwy::AlignedFreeUniquePtr px = - hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); - hwy::AlignedFreeUniquePtr pe = - hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); - HWY_ASSERT(px && pe); - - T* x = px.get() + misalign_a; - T* e = pe.get() + misalign_a; - - for (size_t i = 0; i < count; ++i) { - x[i] = Random(rng); - e[i] = x[i]; - } - - SourceSoftmax(e, count, count); - Softmax(x, count, count); - - T sum = 0.0f; - for (size_t i = 0; i < count; ++i) { - sum += x[i]; - double rel = std::abs(x[i] - e[i]) / e[i]; - ASSERT_LT(rel, 1e-6) << "Mismatch on coordinate " << i << " out of " - << count; - } - ASSERT_NEAR(sum, 1.0, 1e-6); - } -}; - -template -struct TestCreateDistribution { - void operator()(hwy::RandomState& rng) { - std::array x; - std::array e; - - for (size_t i = 0; i < k; ++i) { - x[i] = Random(rng); - e[i] = x[i]; - } - const float constant = Random(rng); - auto expected = SourceCreateDistribution(e, constant); - auto output = create_distribution(x, constant); - - AssertEqual(expected, output, hwy::TargetName(HWY_TARGET), __FILE__, - __LINE__); - } -}; - -void TestAllAddFrom() { - hn::ForPartialVectors>()(float()); -} - -void TestAllMulBy() { - hn::ForPartialVectors>()(float()); -} - -void TestAllMulByConst() { - hn::ForPartialVectors>()(float()); -} - -void TestAllMulByConstAndAdd() { - hn::ForPartialVectors>()( - float()); -} - -void TestAllSoftmax() { - hn::ForPartialVectors>()(float()); -} - -void TestAllCreateDistribution() { - TestCreateDistribution<2048>(); - TestCreateDistribution<5000>(); -} - template CompressedArray GenerateMat(size_t offset, hwy::ThreadPool& pool) { @@ -500,6 +200,28 @@ hwy::AlignedFreeUniquePtr SimpleMatVecAdd( return out; } +template +void AssertClose(const MatT* HWY_RESTRICT expected, + const MatT* HWY_RESTRICT actual, size_t num) { + for (size_t idx = 0; idx < num; idx++) { + const double expected_value = hwy::ConvertScalarTo(expected[idx]); + const double actual_value = hwy::ConvertScalarTo(actual[idx]); + + const double magnitude = std::abs(expected_value); + + const double tolerance = + 256.0 * hwy::ConvertScalarTo(hwy::Epsilon()) * + HWY_MAX(magnitude, 1.0); + + if (!(expected_value - tolerance <= actual_value && + actual_value <= expected_value + tolerance)) { + fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx, + expected_value, idx, actual_value); + HWY_ASSERT(0); + } + } +} + template void AssertClose(const hwy::AlignedFreeUniquePtr& a, const hwy::AlignedFreeUniquePtr& b) { @@ -715,23 +437,6 @@ void TestTwoOfsMatVecAddLoop() { AssertClose(actual_out1, expected_out1); } -void TestSigmoid() { - std::vector values; - for (int i = -150; i <= 150; ++i) { - values.push_back(.1f * i); - } - std::vector result = values; - Sigmoid(result.data(), result.size()); - - for (size_t i = 0; i < values.size(); i++) { - const float max_error = 0.00007; - float value = values[i]; - float approx = result[i]; - float expected = (1 / (1 + std::exp(-values[i]))); - EXPECT_NEAR(approx, expected, max_error) << "Input: " << value; - } -} - // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp @@ -740,21 +445,12 @@ HWY_AFTER_NAMESPACE(); #if HWY_ONCE namespace gcpp { -HWY_BEFORE_TEST(OpsTest); -HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom); -HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy); -HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst); -HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); -HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); -HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); -HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledBatchMatMul); -HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd); -HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd); -HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop); -HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid); -#ifdef HWY_AFTER_TEST +HWY_BEFORE_TEST(MatmulTest); +HWY_EXPORT_AND_TEST_P(MatmulTest, TestAllTiledBatchMatMul); +HWY_EXPORT_AND_TEST_P(MatmulTest, TestMatVecAdd); +HWY_EXPORT_AND_TEST_P(MatmulTest, TestTwoMatVecAdd); +HWY_EXPORT_AND_TEST_P(MatmulTest, TestTwoOfsMatVecAddLoop); HWY_AFTER_TEST(); -#endif } // namespace gcpp diff --git a/ops/ops-inl.h b/ops/ops-inl.h new file mode 100644 index 0000000..eb5c15e --- /dev/null +++ b/ops/ops-inl.h @@ -0,0 +1,652 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard for non-SIMD code. +#ifndef THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_ +#define THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_ + +#include +#include +#include + +#include +#include +#include +#include // std::enable_if_t + +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/detect_targets.h" +#include "hwy/profiler.h" + +#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_ + +// Include guard for (potentially) SIMD code. +#if defined(THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE) == defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE +#endif + +#include "hwy/contrib/algo/transform-inl.h" +#include "hwy/contrib/dot/dot-inl.h" +#include "hwy/contrib/math/math-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +template +HWY_INLINE constexpr std::enable_if_t< + std::is_arithmetic_v && std::is_arithmetic_v, To> +StaticCast(From from) noexcept { + if constexpr (std::is_unsigned_v && std::is_floating_point_v) + return static_cast( + static_cast>(from)); + else + return static_cast(from); +} + +template +static HWY_INLINE hn::Vec Gelu(D d, hn::Vec v) { + const hn::Vec kMul = hn::Set(d, 0.044715f); + const hn::Vec kSqrt2OverPi = hn::Set(d, 0.797884560804236f); + const hn::Vec kHalf = hn::Set(d, 0.5f); + + // tanh approximation matches training. + const hn::Vec v3 = hn::Mul(hn::Mul(v, v), v); + const hn::Vec arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v)); + // 0.5 * (1 + tan) = MulAdd(0.5, tan, 0.5). + const hn::Vec cdf = hn::MulAdd(kHalf, hn::Tanh(d, arg), kHalf); + return hn::Mul(v, cdf); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x, + size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + hn::Transform(D(), x, size, + [](D d, hn::Vec v) HWY_ATTR { return Gelu(d, v); }); +} + +// out[i] = BF(mul[i] * Gelu(gelu_in[i])) +static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16( + const float* HWY_RESTRICT gelu_in, const float* HWY_RESTRICT mul, + hwy::bfloat16_t* HWY_RESTRICT out, size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + const hn::ScalableTag df; + const hn::Repartition dbf; + const size_t NF = hn::Lanes(df); + using VF = hn::Vec; + + size_t i = 0; + if (size >= 2 * NF) { + for (; i <= size - 2 * NF; i += 2 * NF) { + const VF mul0 = hn::LoadU(df, mul + i); + const VF mul1 = hn::LoadU(df, mul + i + NF); + const VF g0 = hn::Mul(mul0, Gelu(df, hn::LoadU(df, gelu_in + i))); + const VF g1 = hn::Mul(mul1, Gelu(df, hn::LoadU(df, gelu_in + i + NF))); + const hn::Vec bf = hn::OrderedDemote2To(dbf, g0, g1); + hn::StoreU(bf, dbf, out + i); + } + } + if (i != size) { + const size_t remaining = size - i; + const VF mul0 = hn::LoadN(df, mul + i, remaining); + const VF g0 = + hn::Mul(mul0, Gelu(df, hn::LoadN(df, gelu_in + i, remaining))); + const hn::Half dbfh; + const hn::Vec bfh = hn::DemoteTo(dbfh, g0); + hn::StoreN(bfh, dbfh, out + i, remaining); + } +} + +template +static HWY_INLINE hn::Vec Sigmoid(D d, hn::Vec v) { + using VF = hn::Vec; + // Chebyshev polynomial coefficients for rational approximation + const VF c0 = hn::Set(d, 0.00949107017368078f); + const VF c1 = hn::Set(d, 0.0654858946800232f); + const VF c2 = hn::Set(d, 0.231547489762306f - 0.00949107017368078f); + const VF c3 = hn::Set(d, 0.530778527259827f); + const VF c4 = hn::Set(d, 0.855334937572479f); + const VF c5 = hn::Set(d, 0.500000894069672f); + + const VF d0 = hn::Set(d, 0.130970627069473f); + const VF d1 = hn::Set(d, 3.99615288415589e-07f); + const VF d2 = hn::Set(d, 1.06155431270599f - 0.130970627069473f); + const VF d3 = hn::Set(d, 1.35144250634767e-06f); + const VF d4 = hn::Set(d, 1); + + // The approximation works in range -12..12, but the input value is clamped + // in -11.5..11.5 since the approximation slightly overshoots after that. + // The function is nearly 0 for input values below -11.5 and nearly 1 for + // input values above 11.5. + const VF invtwelve = hn::Set(d, 1.0f / 12.0f); + const VF lo = hn::Set(d, -11.5f); + const VF hi = hn::Set(d, 11.5f); + + VF f = hn::Clamp(v, lo, hi); + f = hn::Mul(f, invtwelve); + VF f2 = hn::Add(f, f); + + VF a1 = hn::MulAdd(f2, c0, c1); + VF a2 = hn::MulAdd(f2, a1, c2); + VF a3 = hn::Sub(hn::MulAdd(f2, a2, c3), a1); + VF a4 = hn::Sub(hn::MulAdd(f2, a3, c4), a2); + VF f0 = hn::Sub(hn::MulAdd(f, a4, c5), a3); + + VF b1 = hn::MulAdd(f2, d0, d1); + VF b2 = hn::MulAdd(f2, b1, d2); + VF b3 = hn::Sub(hn::MulAdd(f2, b2, d3), b1); + VF f1 = hn::Sub(hn::MulAdd(f, b3, d4), b2); + + return hn::Div(f0, f1); +} + +// Sigmoid using the logistic function 1 / (1 + exp(-x[i])) +static HWY_NOINLINE HWY_MAYBE_UNUSED void Sigmoid(float* HWY_RESTRICT x, + size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + hn::Transform(D(), x, size, + [](D d, hn::Vec v) HWY_ATTR { return Sigmoid(d, v); }); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a, + const float* HWY_RESTRICT b, + size_t size) { + const hn::ScalableTag d; + HWY_DASSERT(size >= hn::Lanes(d)); + HWY_DASSERT(size % hn::Lanes(d) == 0); + constexpr int kAssumptions = + hn::Dot::kAtLeastOneVector | hn::Dot::kMultipleOfVector; + return hn::Dot::Compute(d, a, b, size); +} + +// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT. +static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2( + const float* HWY_RESTRICT a, size_t size) { + const hn::ScalableTag d; + using V = hn::Vec; + const size_t N = hn::Lanes(d); + HWY_DASSERT(size >= 2 * N); + HWY_DASSERT(size % (2 * N) == 0); + + V sum0 = hn::Zero(d); + V sum1 = hn::Zero(d); + for (size_t i = 0; i <= size - 2 * N; i += 2 * N) { + const V a0 = hn::LoadU(d, a + i); + sum0 = hn::MulAdd(a0, a0, sum0); + const V a1 = hn::LoadU(d, a + i + N); + sum1 = hn::MulAdd(a1, a1, sum1); + } + + return hn::ReduceSum(d, hn::Add(sum0, sum1)); +} + +// float, float -> float; simple loop. +static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( + const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight, + float* HWY_RESTRICT out, size_t size) { + constexpr float kEps = 1e-6f; + float ss = SquaredL2(x, size); + ss = 1.0f / sqrtf(ss / StaticCast(size) + kEps); + for (size_t j = 0; j < size; j++) { + // Note 1.0f centering here + out[j] = (1.0f + weight[j]) * (ss * x[j]); + } +} + +// x=f, w=bf16 -> out=f +static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( + const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight, + float* HWY_RESTRICT out, size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + + constexpr float kEps = 1e-6f; + constexpr size_t kUnrollSize = 2; + + const hn::ScalableTag dbf; + const hn::Repartition df32; + const size_t N32 = hn::Lanes(df32); + + const float ss = SquaredL2(x, size); + const auto vss = + hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + kEps)); + + HWY_DASSERT(size % (kUnrollSize * MaxLanes(df32)) == 0); + for (size_t i = 0; i < size; i += kUnrollSize * N32) { + const hn::Vec w16 = hn::LoadU(dbf, weight + i); + const auto w0 = hn::PromoteLowerTo(df32, w16); + const auto w1 = hn::PromoteUpperTo(df32, w16); + const auto m0 = hn::Mul(vss, hn::LoadU(df32, x + i)); + const auto m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32)); + + // (1+weight) * m = m + weight*m = one FMA. + hn::StoreU(hn::MulAdd(m0, w0, m0), df32, out + i); + hn::StoreU(hn::MulAdd(m1, w1, m1), df32, out + i + N32); + } +} + +// float -> float; simple loop. +static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( + const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) { + constexpr float kEps = 1e-6f; + float ss = SquaredL2(inout, size); + ss = 1.0f / sqrtf(ss / StaticCast(size) + kEps); + for (size_t j = 0; j < size; j++) { + // Note 1.0f centering here + inout[j] = (1.0f + weight[j]) * (ss * inout[j]); + } +} + +// w=bf16 -> f +static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( + const hwy::bfloat16_t* HWY_RESTRICT weight, float* HWY_RESTRICT inout, + const size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + const hn::ScalableTag dbf; + const hn::Repartition df32; + using VF = hn::Vec; + const size_t N32 = hn::Lanes(df32); + + constexpr float kEps = 1e-6f; + const float ss = SquaredL2(inout, size); + const VF vss = + hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + kEps)); + + HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); + for (size_t i = 0; i < size; i += 2 * N32) { + const hn::Vec w16 = hn::LoadU(dbf, weight + i); + const VF w0 = hn::PromoteLowerTo(df32, w16); + const VF w1 = hn::PromoteUpperTo(df32, w16); + const VF m0 = hn::Mul(vss, hn::LoadU(df32, inout + i)); + const VF m1 = hn::Mul(vss, hn::LoadU(df32, inout + i + N32)); + // (1+weight) * m = m + weight*m = one FMA. + hn::StoreU(hn::MulAdd(m0, w0, m0), df32, inout + i); + hn::StoreU(hn::MulAdd(m1, w1, m1), df32, inout + i + N32); + } +} + +// f, f -> bf +// TODO(janwas): consider generic function with adapter for loading bf16/f32 +static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( + const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight, + hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + const hn::ScalableTag dbf; + const hn::Repartition df32; + using VF = hn::Vec; + const size_t N32 = hn::Lanes(df32); + + constexpr float kEps = 1e-6f; + const float ss = SquaredL2(x, size); + const VF vss = + hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + kEps)); + + HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); + for (size_t i = 0; i < size; i += 2 * N32) { + const VF w0 = hn::LoadU(df32, weight + i); + const VF w1 = hn::LoadU(df32, weight + i + N32); + const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i)); + const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32)); + // (1+weight) * m = m + weight*m = one FMA. + const VF out0 = hn::MulAdd(m0, w0, m0); + const VF out1 = hn::MulAdd(m1, w1, m1); + hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i); + } +} + +// x=f, w=bf16 -> bf16 to enable W16A16 MatVec. +static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( + const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight, + hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + const hn::ScalableTag dbf; + const hn::Repartition df32; + using VF = hn::Vec; + const size_t N32 = hn::Lanes(df32); + + constexpr float kEps = 1e-6f; + const float ss = SquaredL2(x, size); + const VF vss = + hn::Set(df32, 1.0f / sqrtf(ss / StaticCast(size) + kEps)); + + HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); + for (size_t i = 0; i < size; i += 2 * N32) { + const hn::Vec w16 = hn::LoadU(dbf, weight + i); + const VF w0 = hn::PromoteLowerTo(df32, w16); + const VF w1 = hn::PromoteUpperTo(df32, w16); + const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i)); + const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32)); + // (1+weight) * m = m + weight*m = one FMA. + const VF out0 = hn::MulAdd(m0, w0, m0); + const VF out1 = hn::MulAdd(m1, w1, m1); + hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i); + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( + float* HWY_RESTRICT x, size_t dim_model, size_t pos) { + const size_t num_timescales = dim_model / 2; + const float log_timescale_increment = + logf(10000.0f) / + (num_timescales != 0 ? StaticCast(num_timescales - 1) : 1.0f); + for (size_t dim = 0; dim < num_timescales; ++dim) { + const float inv_timescale = + expf(StaticCast(dim) * -log_timescale_increment); + x[dim] += sinf(StaticCast(pos) * inv_timescale); + x[num_timescales + dim] += cosf(StaticCast(pos) * inv_timescale); + } +} + +/* RoPE as in Rotary Position Embeddings from the RoFormer paper + (https://arxiv.org/abs/2104.09864v5). The query and key vectors are rotated + as a function of their absolute position using the rotation matrix R before + the self-attention operation. R is a d x d matrix. + + R = cos(m*theta_1) -sin(m*theta_1) ... 0 0 + sin(m*theta_1) cos(m*theta_1) + 0 0 ... 0 0 + 0 0 ... 0 0 + ... + 0 0 ... cos(m*theta_{d/2}) sin(m*theta_{d/2}) + 0 0 ... sin(m*theta_{d/2}) cos(m*theta_{d/2}) + + Here theta_i = 10000^(-2(i-1)/d), where d is the dimension of the vector and + i is the ith index of the vector. + + Applying the rotation matrix R to a vector v is equivalent to rotating every + consecutive pair of dimensions of v i.e. v_{2i} and v_{2i+1} by an angle + m*theta_i. However in the Gemma implementation we choose to rotate + the pairs of dimensions v_{i} and v_{i + d//2} instead. + + pos parameter is deliberately an int because in the backward pass we + call this with negative values (for the VJP calculation we need the transpose + of this rotation matrix which is simply the same matrix with -pos parameter) +*/ + +static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x, + size_t dim_qkv, int pos) { + HWY_DASSERT(dim_qkv % 2 == 0); + const size_t half_dim_qkv = dim_qkv / 2; + for (size_t dim = 0; dim < half_dim_qkv; ++dim) { + const float freq_exponents = + StaticCast(2 * dim) / StaticCast(dim_qkv); + // Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably. + const float timescale = powf(10000.0f, freq_exponents); + const float theta = StaticCast(pos) / timescale; + const float cos_val = cosf(theta); + const float sin_val = sinf(theta); + const float x0 = x[dim]; + const float x1 = x[dim + half_dim_qkv]; + x[dim] = x0 * cos_val - x1 * sin_val; + x[dim + half_dim_qkv] = x0 * sin_val + x1 * cos_val; + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul, + float* HWY_RESTRICT x, + size_t dim_qkv, + int pos) { + HWY_DASSERT(dim_qkv % 2 == 0); + const size_t half_dim_qkv = dim_qkv / 2; + for (size_t dim = 0; dim < half_dim_qkv; ++dim) { + const float freq_exponents = + StaticCast(2 * dim) / StaticCast(dim_qkv); + // Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably. + const float timescale = powf(10000.0f, freq_exponents); + const float theta = StaticCast(pos) / timescale; + const float cos_val = cosf(theta); + const float sin_val = sinf(theta); + const float x0 = x[dim]; + const float x1 = x[dim + half_dim_qkv]; + x[dim] = mul * (x0 * cos_val - x1 * sin_val); + x[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val); + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom( + const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + using V = hn::Vec; + + hn::Transform1(D(), x, size, other, + [](const auto d, const V x, const V other) + HWY_ATTR { return hn::Add(x, other); }); +} + +// Simple loops unless/until batch sizes are large enough to parallelize. +template +void RMSNormBatched(size_t num_tokens, const float* activations, + const WeightT* weights, OutT* out, const size_t model_dim) { + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + RMSNorm(activations + token_idx * model_dim, weights, + out + token_idx * model_dim, model_dim); + } +} + +// TODO: pass RowVectorBatch argument. +template +void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights, + InOutT* inout, const size_t model_dim) { + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + RMSNormInplace(weights, inout + token_idx * model_dim, model_dim); + } +} + +static HWY_INLINE void AddFromBatched(size_t num_tokens, const float* other, + float* x, const size_t model_dim) { + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + AddFrom(other + token_idx * model_dim, x + token_idx * model_dim, + model_dim); + } +} + +static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other, + float* HWY_RESTRICT x, const size_t size, + const size_t max_pos) { + HWY_DASSERT(max_pos <= size); + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + using V = hn::Vec; + + hn::Transform1(D(), x, max_pos, other, + [](const auto d, const V x, const V other) + HWY_ATTR { return hn::Mul(x, other); }); +} + +static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other, + float* HWY_RESTRICT x, + const size_t size) { + return MulBy(other, x, size, size); +} + +static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x, + const size_t size, const size_t max_pos) { + HWY_DASSERT(max_pos <= size); + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + using V = hn::Vec; + hn::Transform(D(), x, max_pos, [c](const auto d, const V x) HWY_ATTR { + return hn::Mul(x, hn::Set(d, c)); + }); +} + +static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c, + float* HWY_RESTRICT x, + const size_t size) { + MulByConst(c, x, size, size); +} + +static HWY_NOINLINE void MulByConstAndAdd(const float c, + const float* HWY_RESTRICT x, + float* HWY_RESTRICT out, + const size_t size, + const size_t max_pos) { + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + using V = hn::Vec; + hn::Transform1(D(), out, max_pos, x, + [c](const auto d, const V v_out, const V v_x) HWY_ATTR { + return hn::MulAdd(v_x, hn::Set(d, c), v_out); + }); +} + +static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd( + float c, const float* HWY_RESTRICT x, float* HWY_RESTRICT out, + size_t size) { + MulByConstAndAdd(c, x, out, size, size); +} + +static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size, + const size_t mask_pos) { + HWY_DASSERT(size != 0); + HWY_DASSERT(mask_pos <= size); + + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + using V = hn::Vec; + const D d; + + const V vmin = hn::Set(d, hwy::LowestValue()); + V vmax = vmin; + V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly + Foreach(d, x, mask_pos, vmin, [pmax](const auto d, const V value) HWY_ATTR { + *pmax = hn::Max(*pmax, value); + }); + vmax = hn::MaxOfLanes(d, vmax); + + // Subtract max (avoid precision loss for large exponents) and exponentiate. + hn::Transform(d, x, mask_pos, [pmax](const auto d, const V value) HWY_ATTR { +#if HWY_TARGET & HWY_ALL_SVE + // Temporary workaround for buggy SVE codegen: avoid inlined + // Exp(). + return hn::CallExp(d, hn::Sub(value, *pmax)); +#else + return hn::Exp(d, hn::Sub(value, *pmax)); +#endif + }); + + V sum = hn::Zero(d); + V* psum = ∑ + Foreach(d, x, mask_pos, sum, [psum](const auto d, const V value) HWY_ATTR { + *psum = hn::Add(*psum, value); + }); + + // Normalize to probability distribution + const float mul = 1.0f / hn::ReduceSum(d, sum); + MulByConst(mul, x, size, mask_pos); +} + +static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x, + const size_t size) { + Softmax(x, size, size); +} + +static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, + const size_t size, + const size_t max_pos) { + HWY_DASSERT(max_pos <= size); + + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + using V = hn::Vec; + + const float inv_cap = 1.0f / cap; + + hn::Transform(D(), x, max_pos, [cap, inv_cap](D d, V v) HWY_ATTR { + return hn::Mul(hn::Set(d, cap), + hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap)))); + }); +} + +static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap, + float* HWY_RESTRICT x, + const size_t size) { + LogitsSoftCap(cap, x, size, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED size_t +SampleArgmax(const float* probabilities, size_t vocab_size) { + size_t max_index = 0; + float max_prob = probabilities[0]; + for (size_t i = 1; i < vocab_size; ++i) { + if (probabilities[i] > max_prob) { + max_index = i; + max_prob = probabilities[i]; + } + } + return max_index; +} + +template +static HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution +create_distribution(std::array& top_k, float temperature) { + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + + // re-normalize distribution + const float temperature_inv = 1.0f / temperature; + hn::Transform(D(), top_k.data(), top_k.size(), + [temperature_inv](D d, hn::Vec v) HWY_ATTR { + return hn::Exp( + d, hn::Mul(hn::Log(d, v), hn::Set(d, temperature_inv))); + }); + + return std::discrete_distribution(std::begin(top_k), std::end(top_k)); +} + +template +static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( + const float* HWY_RESTRICT probabilities, size_t vocab_size, + std::mt19937& gen, float temperature, TAcceptToken& accept_token) { + static_assert(k != 0, ""); + // TODO: Optimize, potentially using new VQSort PartialSort. + std::array top_k{}; // sorted from highest [0], to lowest [k-1] + std::array indices{}; + for (size_t i = 0; i < vocab_size; ++i) { + if (probabilities[i] < top_k[k - 1] && + (!accept_token || accept_token(StaticCast(i), probabilities[i]))) { + continue; + } + for (size_t j = 0; j < k; ++j) { + if (probabilities[i] > top_k[j] && + (!accept_token || + accept_token(StaticCast(i), probabilities[i]))) { + // shift elements by 1, insert the new value, move on to next value + for (size_t idx = k - 1; idx > j; --idx) { + top_k[idx] = top_k[idx - 1]; + indices[idx] = indices[idx - 1]; + } + top_k[j] = probabilities[i]; + indices[j] = StaticCast(i); + break; + } + } + } + return indices[create_distribution(top_k, temperature)(gen)]; +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // NOLINT diff --git a/ops/ops_test.cc b/ops/ops_test.cc new file mode 100644 index 0000000..0a66f13 --- /dev/null +++ b/ops/ops_test.cc @@ -0,0 +1,384 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// OrderedDemote2To is not supported by HWY_SCALAR. +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS HWY_SCALAR +#endif + +#include +#include + +#include +#include +#include +#include +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "ops/ops_test.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" +// After highway.h +#include "ops/ops-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +namespace hn = hwy::HWY_NAMESPACE; + +template +struct ForeachCountAndMisalign { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) const { + hwy::RandomState rng; + const size_t N = Lanes(d); + const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; + + for (size_t count = 0; count < 2 * N; ++count) { + for (size_t ma : misalignments) { + for (size_t mb : misalignments) { + Test()(d, count, ma, mb, rng); + } + } + } + } +}; + +template +T Random(hwy::RandomState& rng) { + const int32_t bits = static_cast(Random32(&rng)) & 1023; + const double val = (bits - 512) / 64.0; + // Clamp negative to zero for unsigned types. + return hwy::ConvertScalarTo( + HWY_MAX(hwy::ConvertScalarTo(hwy::LowestValue()), val)); +} + +HWY_NOINLINE void SourceAddFrom(const float* HWY_RESTRICT other, + float* HWY_RESTRICT x, size_t size) { + for (size_t i = 0; i < size; ++i) { + x[i] += other[i]; + } +} + +HWY_NOINLINE void SourceMulBy(const float* HWY_RESTRICT other, + float* HWY_RESTRICT x, size_t size, + size_t max_pos) { + HWY_DASSERT(max_pos <= size); + for (size_t i = 0; i < max_pos; ++i) { + x[i] *= other[i]; + } +} + +HWY_NOINLINE void SourceMulByConst(float c, float* HWY_RESTRICT x, size_t size, + size_t max_pos) { + for (size_t i = 0; i < max_pos; ++i) { + x[i] *= c; + } +} + +HWY_NOINLINE void SourceMulByConstAndAdd(float c, const float* HWY_RESTRICT x, + float* HWY_RESTRICT out, size_t size, + size_t max_pos) { + for (size_t i = 0; i < max_pos; ++i) { + out[i] += x[i] * c; + } +} + +HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size, + size_t mask_pos) { + HWY_DASSERT(size != 0); + HWY_DASSERT(mask_pos <= size); + float sum = 0.0; + const float maxval = *std::max_element(x, x + mask_pos); + for (size_t i = 0; i < mask_pos; ++i) { + x[i] = std::exp(x[i] - maxval); + sum += x[i]; + } + const float scale = 1.0f / sum; + for (size_t i = 0; i < mask_pos; ++i) { + x[i] *= scale; + } +} + +template +HWY_NOINLINE std::discrete_distribution SourceCreateDistribution( + std::array& top_k, float temperature) { + // re-normalize distribution + for (size_t i = 0; i < k; ++i) { + top_k[i] = exp(log(top_k[i]) / temperature); + } + float denominator = 0.0f; + for (size_t i = 0; i < k; ++i) { + denominator += top_k[i]; + } + denominator = 1.0f / denominator; + MulByConst(denominator, top_k.data(), k); + return std::discrete_distribution(std::begin(top_k), std::end(top_k)); +} + +struct TestAddFrom { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + hwy::RandomState& rng) { + using T = hn::TFromD; + + hwy::AlignedFreeUniquePtr px = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr pe = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr po = + hwy::AllocateAligned(HWY_MAX(1, misalign_b + count)); + HWY_ASSERT(px && pe && po); + + T* x = px.get() + misalign_a; + T* e = pe.get() + misalign_a; + T* o = po.get() + misalign_b; + + for (size_t i = 0; i < count; ++i) { + x[i] = Random(rng); + e[i] = x[i]; + o[i] = Random(rng); + } + + SourceAddFrom(o, e, count); + AddFrom(o, x, count); + + hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, + __LINE__); + } +}; + +struct TestMulBy { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + hwy::RandomState& rng) { + using T = hn::TFromD; + + hwy::AlignedFreeUniquePtr px = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr pe = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr po = + hwy::AllocateAligned(HWY_MAX(1, misalign_b + count)); + HWY_ASSERT(px && pe && po); + + T* x = px.get() + misalign_a; + T* e = pe.get() + misalign_a; + T* o = po.get() + misalign_b; + + for (size_t i = 0; i < count; ++i) { + x[i] = Random(rng); + e[i] = x[i]; + o[i] = Random(rng); + } + + SourceMulBy(o, e, count, count); + MulBy(o, x, count, count); + + hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, + __LINE__); + } +}; + +struct TestMulByConstAndAdd { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + hwy::RandomState& rng) { + using T = hn::TFromD; + + hwy::AlignedFreeUniquePtr px = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr pe = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr po = + hwy::AllocateAligned(HWY_MAX(1, misalign_b + count)); + HWY_ASSERT(px && pe && po); + + T* x = px.get() + misalign_a; + T* e = pe.get() + misalign_a; + T* o = po.get() + misalign_b; + + for (size_t i = 0; i < count; ++i) { + x[i] = Random(rng); + e[i] = x[i]; + o[i] = Random(rng); + } + T constant = Random(rng); + + SourceMulByConstAndAdd(constant, o, e, count, count); + MulByConstAndAdd(constant, o, x, count, count); + + hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, + __LINE__); + } +}; + +struct TestMulByConst { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + hwy::RandomState& rng) { + if (misalign_b == 0) return; + using T = hn::TFromD; + + hwy::AlignedFreeUniquePtr px = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr pe = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + HWY_ASSERT(px && pe); + + T* x = px.get() + misalign_a; + T* e = pe.get() + misalign_a; + + for (size_t i = 0; i < count; ++i) { + x[i] = Random(rng); + e[i] = x[i]; + } + T constant = Random(rng); + + SourceMulByConst(constant, e, count, count); + MulByConst(constant, x, count, count); + + hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__, + __LINE__); + } +}; + +struct TestSoftmax { + template + void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b, + hwy::RandomState& rng) { + if (count == 0) return; // *Softmax would assert + if (misalign_b == 0) return; + using T = hn::TFromD; + + hwy::AlignedFreeUniquePtr px = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + hwy::AlignedFreeUniquePtr pe = + hwy::AllocateAligned(HWY_MAX(1, misalign_a + count)); + HWY_ASSERT(px && pe); + + T* x = px.get() + misalign_a; + T* e = pe.get() + misalign_a; + + for (size_t i = 0; i < count; ++i) { + x[i] = Random(rng); + e[i] = x[i]; + } + + SourceSoftmax(e, count, count); + Softmax(x, count, count); + + T sum = 0.0f; + for (size_t i = 0; i < count; ++i) { + sum += x[i]; + double rel = std::abs(x[i] - e[i]) / e[i]; + ASSERT_LT(rel, 1e-6) << "Mismatch on coordinate " << i << " out of " + << count; + } + ASSERT_NEAR(sum, 1.0, 1e-6); + } +}; + +template +struct TestCreateDistribution { + void operator()(hwy::RandomState& rng) { + std::array x; + std::array e; + + for (size_t i = 0; i < k; ++i) { + x[i] = Random(rng); + e[i] = x[i]; + } + const float constant = Random(rng); + auto expected = SourceCreateDistribution(e, constant); + auto output = create_distribution(x, constant); + + AssertEqual(expected, output, hwy::TargetName(HWY_TARGET), __FILE__, + __LINE__); + } +}; + +void TestAllAddFrom() { + hn::ForPartialVectors>()(float()); +} + +void TestAllMulBy() { + hn::ForPartialVectors>()(float()); +} + +void TestAllMulByConst() { + hn::ForPartialVectors>()(float()); +} + +void TestAllMulByConstAndAdd() { + hn::ForPartialVectors>()( + float()); +} + +void TestAllSoftmax() { + hn::ForPartialVectors>()(float()); +} + +void TestAllCreateDistribution() { + TestCreateDistribution<2048>(); + TestCreateDistribution<5000>(); +} + +void TestSigmoid() { + std::vector values; + for (int i = -150; i <= 150; ++i) { + values.push_back(.1f * i); + } + std::vector result = values; + Sigmoid(result.data(), result.size()); + + for (size_t i = 0; i < values.size(); i++) { + const float max_error = 0.00007; + float value = values[i]; + float approx = result[i]; + float expected = (1 / (1 + std::exp(-values[i]))); + EXPECT_NEAR(approx, expected, max_error) << "Input: " << value; + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace gcpp { +HWY_BEFORE_TEST(OpsTest); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); +HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid); +HWY_AFTER_TEST(); + +} // namespace gcpp + +#endif