mirror of https://github.com/google/gemma.cpp.git
Avoid duplication of RMSNorm, support all activation/weight types
Add test for RMSNorm Rename VectorizedRopeAndMulBy -> RopeAndMulBy Move test_util to util/ PiperOrigin-RevId: 668332927
This commit is contained in:
parent
3c17911875
commit
4033ed9e78
12
BUILD.bazel
12
BUILD.bazel
|
|
@ -28,6 +28,16 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "test_util",
|
||||||
|
hdrs = ["util/test_util.h"],
|
||||||
|
deps = [
|
||||||
|
"@hwy//:hwy",
|
||||||
|
"@hwy//:hwy_test_util",
|
||||||
|
"@hwy//:stats",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "threading",
|
name = "threading",
|
||||||
hdrs = ["util/threading.h"],
|
hdrs = ["util/threading.h"],
|
||||||
|
|
@ -101,7 +111,9 @@ cc_test(
|
||||||
":common",
|
":common",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":ops",
|
":ops",
|
||||||
|
":test_util",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
"//compression:compress",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@hwy//:hwy_test_util",
|
||||||
"@hwy//:nanobenchmark", #buildcleaner: keep
|
"@hwy//:nanobenchmark", #buildcleaner: keep
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,6 @@ set(SOURCES
|
||||||
compression/nuq-inl.h
|
compression/nuq-inl.h
|
||||||
compression/sfp.h
|
compression/sfp.h
|
||||||
compression/sfp-inl.h
|
compression/sfp-inl.h
|
||||||
compression/test_util.h
|
|
||||||
compression/weights_raw.h
|
compression/weights_raw.h
|
||||||
backprop/activations.h
|
backprop/activations.h
|
||||||
backprop/backward.cc
|
backprop/backward.cc
|
||||||
|
|
@ -107,6 +106,7 @@ set(SOURCES
|
||||||
util/allocator.h
|
util/allocator.h
|
||||||
util/app.h
|
util/app.h
|
||||||
util/args.h
|
util/args.h
|
||||||
|
util/test_util.h
|
||||||
util/threading.h
|
util/threading.h
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -136,9 +136,8 @@ static HWY_NOINLINE void RMSNormVJP(
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
const size_t offset = pos * model_dim;
|
const size_t offset = pos * model_dim;
|
||||||
constexpr float eps = 1e-6f;
|
const float ss = detail::RMSNormMul(x + offset, model_dim);
|
||||||
float ss = SquaredL2(x + offset, model_dim);
|
|
||||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(model_dim) + eps);
|
|
||||||
for (size_t i = 0; i < model_dim; ++i) {
|
for (size_t i = 0; i < model_dim; ++i) {
|
||||||
grad_w[i] += v[offset + i] * x[offset + i] * ss;
|
grad_w[i] += v[offset + i] * x[offset + i] * ss;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -64,25 +64,14 @@ cc_test(
|
||||||
srcs = ["distortion_test.cc"],
|
srcs = ["distortion_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":distortion",
|
":distortion",
|
||||||
":test_util",
|
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
|
"//:test_util",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@hwy//:hwy_test_util",
|
||||||
"@hwy//:nanobenchmark", # Unpredictable1
|
"@hwy//:nanobenchmark", # Unpredictable1
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "test_util",
|
|
||||||
hdrs = ["test_util.h"],
|
|
||||||
deps = [
|
|
||||||
":distortion",
|
|
||||||
"@hwy//:hwy",
|
|
||||||
"@hwy//:hwy_test_util",
|
|
||||||
"@hwy//:stats",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "sfp",
|
name = "sfp",
|
||||||
hdrs = ["sfp.h"],
|
hdrs = ["sfp.h"],
|
||||||
|
|
@ -102,10 +91,11 @@ cc_test(
|
||||||
# for test_suite.
|
# for test_suite.
|
||||||
tags = ["hwy_ops_test"],
|
tags = ["hwy_ops_test"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":distortion",
|
||||||
":sfp",
|
":sfp",
|
||||||
":test_util",
|
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
"//:ops",
|
"//:ops",
|
||||||
|
"//:test_util",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@hwy//:hwy_test_util",
|
||||||
"@hwy//:nanobenchmark",
|
"@hwy//:nanobenchmark",
|
||||||
|
|
@ -133,10 +123,11 @@ cc_test(
|
||||||
# for test_suite.
|
# for test_suite.
|
||||||
tags = ["hwy_ops_test"],
|
tags = ["hwy_ops_test"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":distortion",
|
||||||
":nuq",
|
":nuq",
|
||||||
":sfp",
|
":sfp",
|
||||||
":test_util",
|
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
|
"//:test_util",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@hwy//:hwy_test_util",
|
||||||
"@hwy//:nanobenchmark",
|
"@hwy//:nanobenchmark",
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,25 @@ namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
// Adapters to store two f32 vectors to f32 or bf16; avoids duplicating
|
||||||
|
// RMSNorm and RMSNormInplace for the two output types.
|
||||||
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
|
void Store2(DF df, hn::Vec<DF> v0, hn::Vec<DF> v1, float* HWY_RESTRICT out) {
|
||||||
|
const size_t NF = hn::Lanes(df);
|
||||||
|
hn::StoreU(v0, df, out);
|
||||||
|
hn::StoreU(v1, df, out + NF);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
|
void Store2(DF df, hn::Vec<DF> v0, hn::Vec<DF> v1, BF16* HWY_RESTRICT out) {
|
||||||
|
const hn::Repartition<BF16, decltype(df)> dbf;
|
||||||
|
hn::StoreU(hn::OrderedDemote2To(dbf, v0, v1), dbf, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
// Enables generic code independent of compression type.
|
// Enables generic code independent of compression type.
|
||||||
template <typename T> // primary, must specialize
|
template <typename T> // primary, must specialize
|
||||||
struct CompressTraits {};
|
struct CompressTraits {};
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include "compression/test_util.h"
|
#include "util/test_util.h"
|
||||||
#include "hwy/nanobenchmark.h"
|
#include "hwy/nanobenchmark.h"
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
#include "hwy/tests/test_util-inl.h" // HWY_ASSERT_EQ
|
#include "hwy/tests/test_util-inl.h" // HWY_ASSERT_EQ
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,8 @@
|
||||||
#include <algorithm> // std::shuffle
|
#include <algorithm> // std::shuffle
|
||||||
#include <random>
|
#include <random>
|
||||||
|
|
||||||
#include "compression/test_util.h"
|
#include "compression/distortion.h"
|
||||||
|
#include "util/test_util.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/tests/test_util.h"
|
#include "hwy/tests/test_util.h"
|
||||||
|
|
|
||||||
|
|
@ -26,11 +26,11 @@
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
|
||||||
#include "compression/test_util.h"
|
#include "compression/distortion.h"
|
||||||
|
#include "util/test_util.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
#define HWY_TARGET_INCLUDE "compression/sfp_test.cc" // NOLINT
|
#define HWY_TARGET_INCLUDE "compression/sfp_test.cc" // NOLINT
|
||||||
|
|
|
||||||
|
|
@ -228,7 +228,7 @@ class GemmaAttention {
|
||||||
Rope(qk_out, kQKVDim / 2, inv_timescale, pos);
|
Rope(qk_out, kQKVDim / 2, inv_timescale, pos);
|
||||||
MulByConst(mul, qk_out, kQKVDim);
|
MulByConst(mul, qk_out, kQKVDim);
|
||||||
} else {
|
} else {
|
||||||
VectorizedRopeAndMulBy(mul, qk, kQKVDim, inv_timescale, pos, qk_out);
|
RopeAndMulBy(mul, qk, kQKVDim, inv_timescale, pos, qk_out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
257
ops/ops-inl.h
257
ops/ops-inl.h
|
|
@ -30,7 +30,6 @@
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/detect_targets.h"
|
#include "hwy/detect_targets.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_
|
||||||
|
|
||||||
// Include guard for (potentially) SIMD code.
|
// Include guard for (potentially) SIMD code.
|
||||||
|
|
@ -41,6 +40,7 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
|
#define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include "compression/compress-inl.h"
|
||||||
#include "hwy/contrib/algo/transform-inl.h"
|
#include "hwy/contrib/algo/transform-inl.h"
|
||||||
#include "hwy/contrib/dot/dot-inl.h"
|
#include "hwy/contrib/dot/dot-inl.h"
|
||||||
#include "hwy/contrib/math/math-inl.h"
|
#include "hwy/contrib/math/math-inl.h"
|
||||||
|
|
@ -62,7 +62,7 @@ StaticCast(From from) noexcept {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class D, HWY_IF_F32_D(D)>
|
template <class D, HWY_IF_F32_D(D)>
|
||||||
static HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
|
HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
|
||||||
const hn::Vec<D> kMul = hn::Set(d, 0.044715f);
|
const hn::Vec<D> kMul = hn::Set(d, 0.044715f);
|
||||||
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
|
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
|
||||||
const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
|
const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
|
||||||
|
|
@ -83,40 +83,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
|
||||||
[](D d, hn::Vec<D> v) HWY_ATTR { return Gelu(d, v); });
|
[](D d, hn::Vec<D> v) HWY_ATTR { return Gelu(d, v); });
|
||||||
}
|
}
|
||||||
|
|
||||||
// out[i] = BF(mul[i] * Gelu(gelu_in[i]))
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16(
|
|
||||||
const float* HWY_RESTRICT gelu_in, const float* HWY_RESTRICT mul,
|
|
||||||
hwy::bfloat16_t* HWY_RESTRICT out, size_t size) {
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
const hn::ScalableTag<float> df;
|
|
||||||
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf;
|
|
||||||
const size_t NF = hn::Lanes(df);
|
|
||||||
using VF = hn::Vec<decltype(df)>;
|
|
||||||
|
|
||||||
size_t i = 0;
|
|
||||||
if (size >= 2 * NF) {
|
|
||||||
for (; i <= size - 2 * NF; i += 2 * NF) {
|
|
||||||
const VF mul0 = hn::LoadU(df, mul + i);
|
|
||||||
const VF mul1 = hn::LoadU(df, mul + i + NF);
|
|
||||||
const VF g0 = hn::Mul(mul0, Gelu(df, hn::LoadU(df, gelu_in + i)));
|
|
||||||
const VF g1 = hn::Mul(mul1, Gelu(df, hn::LoadU(df, gelu_in + i + NF)));
|
|
||||||
const hn::Vec<decltype(dbf)> bf = hn::OrderedDemote2To(dbf, g0, g1);
|
|
||||||
hn::StoreU(bf, dbf, out + i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (i != size) {
|
|
||||||
const size_t remaining = size - i;
|
|
||||||
const VF mul0 = hn::LoadN(df, mul + i, remaining);
|
|
||||||
const VF g0 =
|
|
||||||
hn::Mul(mul0, Gelu(df, hn::LoadN(df, gelu_in + i, remaining)));
|
|
||||||
const hn::Half<decltype(dbf)> dbfh;
|
|
||||||
const hn::Vec<decltype(dbfh)> bfh = hn::DemoteTo(dbfh, g0);
|
|
||||||
hn::StoreN(bfh, dbfh, out + i, remaining);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class D, HWY_IF_F32_D(D)>
|
template <class D, HWY_IF_F32_D(D)>
|
||||||
static HWY_INLINE hn::Vec<D> Sigmoid(D d, hn::Vec<D> v) {
|
HWY_INLINE hn::Vec<D> Sigmoid(D d, hn::Vec<D> v) {
|
||||||
using VF = hn::Vec<D>;
|
using VF = hn::Vec<D>;
|
||||||
// Chebyshev polynomial coefficients for rational approximation
|
// Chebyshev polynomial coefficients for rational approximation
|
||||||
const VF c0 = hn::Set(d, 0.00949107017368078f);
|
const VF c0 = hn::Set(d, 0.00949107017368078f);
|
||||||
|
|
@ -180,173 +148,101 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
|
||||||
return hn::Dot::Compute<kAssumptions>(d, a, b, size);
|
return hn::Dot::Compute<kAssumptions>(d, a, b, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT.
|
// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT.
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
|
template <typename VecT>
|
||||||
const float* HWY_RESTRICT a, size_t size) {
|
float SquaredL2(const VecT* HWY_RESTRICT a, size_t size) {
|
||||||
PROFILER_ZONE("ops.SquaredL2");
|
using TraitsV = CompressTraits<VecT>;
|
||||||
|
|
||||||
const hn::ScalableTag<float> d;
|
const hn::ScalableTag<float> d;
|
||||||
using V = hn::Vec<decltype(d)>;
|
using V = hn::Vec<decltype(d)>;
|
||||||
const size_t N = hn::Lanes(d);
|
const size_t N = hn::Lanes(d);
|
||||||
HWY_DASSERT(size >= 2 * N);
|
HWY_DASSERT(size >= 2 * N);
|
||||||
HWY_DASSERT(size % (2 * N) == 0);
|
HWY_DASSERT(size % (2 * N) == 0);
|
||||||
|
|
||||||
|
// TODO: use more accurate Dot
|
||||||
V sum0 = hn::Zero(d);
|
V sum0 = hn::Zero(d);
|
||||||
V sum1 = hn::Zero(d);
|
V sum1 = hn::Zero(d);
|
||||||
for (size_t i = 0; i <= size - 2 * N; i += 2 * N) {
|
for (size_t i = 0; i <= size - 2 * N; i += 2 * N) {
|
||||||
const V a0 = hn::LoadU(d, a + i);
|
V a0, a1;
|
||||||
|
TraitsV::Decompress2(d, a, i, a0, a1);
|
||||||
sum0 = hn::MulAdd(a0, a0, sum0);
|
sum0 = hn::MulAdd(a0, a0, sum0);
|
||||||
const V a1 = hn::LoadU(d, a + i + N);
|
|
||||||
sum1 = hn::MulAdd(a1, a1, sum1);
|
sum1 = hn::MulAdd(a1, a1, sum1);
|
||||||
}
|
}
|
||||||
|
|
||||||
return hn::ReduceSum(d, hn::Add(sum0, sum1));
|
return hn::ReduceSum(d, hn::Add(sum0, sum1));
|
||||||
}
|
}
|
||||||
|
|
||||||
// float, float -> float; simple loop.
|
// Shared by RMSNorm and RMSNormInplace.
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
template <typename VecT>
|
||||||
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
|
float RMSNormMul(const VecT* HWY_RESTRICT x, size_t size) {
|
||||||
float* HWY_RESTRICT out, size_t size) {
|
const float l2 = SquaredL2(x, size);
|
||||||
PROFILER_ZONE("ops.RMSNormF");
|
constexpr float kEps = 1e-6f; // avoid divide by zero
|
||||||
constexpr float kEps = 1e-6f;
|
return 1.0f / sqrtf(l2 / StaticCast<float>(size) + kEps);
|
||||||
float ss = SquaredL2(x, size);
|
|
||||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
|
|
||||||
for (size_t j = 0; j < size; j++) {
|
|
||||||
// Note 1.0f centering here
|
|
||||||
out[j] = (1.0f + weight[j]) * (ss * x[j]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// x=f, w=bf16 -> out=f
|
} // namespace detail
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|
||||||
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
template <typename VecT, typename WeightT, typename OutT>
|
||||||
float* HWY_RESTRICT out, size_t size) {
|
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const VecT* HWY_RESTRICT x,
|
||||||
PROFILER_ZONE("ops.RMSNormBF16");
|
const WeightT* HWY_RESTRICT weight,
|
||||||
|
OutT* HWY_RESTRICT out,
|
||||||
|
const size_t size) {
|
||||||
|
PROFILER_FUNC;
|
||||||
|
|
||||||
|
using TraitsV = CompressTraits<VecT>;
|
||||||
|
using TraitsW = CompressTraits<WeightT>;
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
using VF = hn::Vec<decltype(df)>;
|
||||||
|
const size_t NF = hn::Lanes(df);
|
||||||
|
|
||||||
constexpr float kEps = 1e-6f;
|
const VF mul = hn::Set(df, detail::RMSNormMul(x, size));
|
||||||
constexpr size_t kUnrollSize = 2;
|
|
||||||
|
|
||||||
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
|
||||||
const hn::Repartition<float, decltype(dbf)> df32;
|
|
||||||
const size_t N32 = hn::Lanes(df32);
|
|
||||||
|
|
||||||
const float ss = SquaredL2(x, size);
|
|
||||||
const auto vss =
|
|
||||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
|
||||||
|
|
||||||
HWY_DASSERT(size % (kUnrollSize * MaxLanes(df32)) == 0);
|
|
||||||
for (size_t i = 0; i < size; i += kUnrollSize * N32) {
|
|
||||||
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
|
|
||||||
const auto w0 = hn::PromoteLowerTo(df32, w16);
|
|
||||||
const auto w1 = hn::PromoteUpperTo(df32, w16);
|
|
||||||
const auto m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
|
|
||||||
const auto m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
|
|
||||||
|
|
||||||
|
HWY_DASSERT(size % (2 * MaxLanes(df)) == 0);
|
||||||
|
for (size_t i = 0; i < size; i += 2 * NF) {
|
||||||
|
VF v0, v1, w0, w1;
|
||||||
|
TraitsV::Decompress2(df, x, i, v0, v1);
|
||||||
|
TraitsW::Decompress2(df, weight, i, w0, w1);
|
||||||
|
const VF m0 = hn::Mul(mul, v0);
|
||||||
|
const VF m1 = hn::Mul(mul, v1);
|
||||||
// (1+weight) * m = m + weight*m = one FMA.
|
// (1+weight) * m = m + weight*m = one FMA.
|
||||||
hn::StoreU(hn::MulAdd(m0, w0, m0), df32, out + i);
|
const VF out0 = hn::MulAdd(m0, w0, m0);
|
||||||
hn::StoreU(hn::MulAdd(m1, w1, m1), df32, out + i + N32);
|
const VF out1 = hn::MulAdd(m1, w1, m1);
|
||||||
|
detail::Store2(df, out0, out1, out + i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// float -> float; simple loop.
|
// Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer.
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
template <typename VecT, typename WeightT>
|
||||||
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
|
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||||
PROFILER_ZONE("ops.RMSNormInplaceF");
|
const WeightT* HWY_RESTRICT weight, VecT* HWY_RESTRICT inout,
|
||||||
constexpr float kEps = 1e-6f;
|
|
||||||
float ss = SquaredL2(inout, size);
|
|
||||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
|
|
||||||
for (size_t j = 0; j < size; j++) {
|
|
||||||
// Note 1.0f centering here
|
|
||||||
inout[j] = (1.0f + weight[j]) * (ss * inout[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// w=bf16 -> f
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
|
||||||
const hwy::bfloat16_t* HWY_RESTRICT weight, float* HWY_RESTRICT inout,
|
|
||||||
const size_t size) {
|
const size_t size) {
|
||||||
PROFILER_ZONE("ops.RMSNormInplaceBF");
|
PROFILER_FUNC;
|
||||||
|
|
||||||
|
using TraitsV = CompressTraits<VecT>;
|
||||||
|
using TraitsW = CompressTraits<WeightT>;
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
const hn::ScalableTag<float> df;
|
||||||
const hn::Repartition<float, decltype(dbf)> df32;
|
using VF = hn::Vec<decltype(df)>;
|
||||||
using VF = hn::Vec<decltype(df32)>;
|
const size_t NF = hn::Lanes(df);
|
||||||
const size_t N32 = hn::Lanes(df32);
|
|
||||||
|
|
||||||
constexpr float kEps = 1e-6f;
|
const VF mul = hn::Set(df, detail::RMSNormMul(inout, size));
|
||||||
const float ss = SquaredL2(inout, size);
|
|
||||||
const VF vss =
|
|
||||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
|
||||||
|
|
||||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
HWY_DASSERT(size % (2 * MaxLanes(df)) == 0);
|
||||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
for (size_t i = 0; i < size; i += 2 * NF) {
|
||||||
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
|
VF v0, v1, w0, w1;
|
||||||
const VF w0 = hn::PromoteLowerTo(df32, w16);
|
TraitsV::Decompress2(df, inout, i, v0, v1);
|
||||||
const VF w1 = hn::PromoteUpperTo(df32, w16);
|
TraitsW::Decompress2(df, weight, i, w0, w1);
|
||||||
const VF m0 = hn::Mul(vss, hn::LoadU(df32, inout + i));
|
const VF m0 = hn::Mul(mul, hn::LoadU(df, inout + i));
|
||||||
const VF m1 = hn::Mul(vss, hn::LoadU(df32, inout + i + N32));
|
const VF m1 = hn::Mul(mul, hn::LoadU(df, inout + i + NF));
|
||||||
// (1+weight) * m = m + weight*m = one FMA.
|
|
||||||
hn::StoreU(hn::MulAdd(m0, w0, m0), df32, inout + i);
|
|
||||||
hn::StoreU(hn::MulAdd(m1, w1, m1), df32, inout + i + N32);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// f, f -> bf
|
|
||||||
// TODO(janwas): consider generic function with adapter for loading bf16/f32
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|
||||||
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
|
|
||||||
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) {
|
|
||||||
PROFILER_ZONE("ops.RMSNormF F BF");
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
|
||||||
const hn::Repartition<float, decltype(dbf)> df32;
|
|
||||||
using VF = hn::Vec<decltype(df32)>;
|
|
||||||
const size_t N32 = hn::Lanes(df32);
|
|
||||||
|
|
||||||
constexpr float kEps = 1e-6f;
|
|
||||||
const float ss = SquaredL2(x, size);
|
|
||||||
const VF vss =
|
|
||||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
|
||||||
|
|
||||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
|
||||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
|
||||||
const VF w0 = hn::LoadU(df32, weight + i);
|
|
||||||
const VF w1 = hn::LoadU(df32, weight + i + N32);
|
|
||||||
const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
|
|
||||||
const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
|
|
||||||
// (1+weight) * m = m + weight*m = one FMA.
|
// (1+weight) * m = m + weight*m = one FMA.
|
||||||
const VF out0 = hn::MulAdd(m0, w0, m0);
|
const VF out0 = hn::MulAdd(m0, w0, m0);
|
||||||
const VF out1 = hn::MulAdd(m1, w1, m1);
|
const VF out1 = hn::MulAdd(m1, w1, m1);
|
||||||
hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i);
|
detail::Store2(df, out0, out1, inout + i);
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// x=f, w=bf16 -> bf16 to enable W16A16 MatVec.
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|
||||||
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
|
||||||
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) {
|
|
||||||
PROFILER_ZONE("ops.RMSNormF BF BF");
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
|
||||||
const hn::Repartition<float, decltype(dbf)> df32;
|
|
||||||
using VF = hn::Vec<decltype(df32)>;
|
|
||||||
const size_t N32 = hn::Lanes(df32);
|
|
||||||
|
|
||||||
constexpr float kEps = 1e-6f;
|
|
||||||
const float ss = SquaredL2(x, size);
|
|
||||||
const VF vss =
|
|
||||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
|
||||||
|
|
||||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
|
||||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
|
||||||
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
|
|
||||||
const VF w0 = hn::PromoteLowerTo(df32, w16);
|
|
||||||
const VF w1 = hn::PromoteUpperTo(df32, w16);
|
|
||||||
const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
|
|
||||||
const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
|
|
||||||
// (1+weight) * m = m + weight*m = one FMA.
|
|
||||||
const VF out0 = hn::MulAdd(m0, w0, m0);
|
|
||||||
const VF out1 = hn::MulAdd(m1, w1, m1);
|
|
||||||
hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -410,7 +306,6 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(janwas): vectorize
|
|
||||||
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate.
|
// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate.
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||||
const float mul, const float* HWY_RESTRICT x, size_t dim_qkv,
|
const float mul, const float* HWY_RESTRICT x, size_t dim_qkv,
|
||||||
|
|
@ -419,24 +314,6 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(
|
||||||
PROFILER_FUNC;
|
PROFILER_FUNC;
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
|
||||||
const float theta = StaticCast<float>(pos) * inv_timescale[dim];
|
|
||||||
const float cos_val = cosf(theta);
|
|
||||||
const float sin_val = sinf(theta);
|
|
||||||
const float x0 = x[dim];
|
|
||||||
const float x1 = x[dim + half_dim_qkv];
|
|
||||||
x_out[dim] = mul * (x0 * cos_val - x1 * sin_val);
|
|
||||||
x_out[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void VectorizedRopeAndMulBy(
|
|
||||||
const float mul, const float* HWY_RESTRICT x, size_t dim_qkv,
|
|
||||||
const float* HWY_RESTRICT inv_timescale, int pos,
|
|
||||||
float* HWY_RESTRICT x_out) {
|
|
||||||
PROFILER_FUNC;
|
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
|
||||||
|
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
using V = hn::Vec<D>;
|
using V = hn::Vec<D>;
|
||||||
|
|
@ -685,7 +562,7 @@ SampleArgmax(const float* probabilities, size_t vocab_size) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t k>
|
template <size_t k>
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int>
|
HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int>
|
||||||
create_distribution(std::array<float, k>& top_k, float temperature) {
|
create_distribution(std::array<float, k>& top_k, float temperature) {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
|
|
@ -702,7 +579,7 @@ create_distribution(std::array<float, k>& top_k, float temperature) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t k, typename TAcceptToken>
|
template <size_t k, typename TAcceptToken>
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
||||||
const float* HWY_RESTRICT probabilities, size_t vocab_size,
|
const float* HWY_RESTRICT probabilities, size_t vocab_size,
|
||||||
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
|
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
|
||||||
static_assert(k != 0, "");
|
static_assert(k != 0, "");
|
||||||
|
|
|
||||||
108
ops/ops_test.cc
108
ops/ops_test.cc
|
|
@ -27,8 +27,15 @@
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "compression/compress.h" // BF16
|
||||||
|
#include "gemma/activations.h"
|
||||||
|
#include "gemma/common.h"
|
||||||
|
#include "gemma/configs.h"
|
||||||
|
#include "util/allocator.h"
|
||||||
|
#include "util/test_util.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
|
@ -36,14 +43,9 @@
|
||||||
// clang-format on
|
// clang-format on
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/tests/test_util-inl.h"
|
|
||||||
// After highway.h
|
// After highway.h
|
||||||
#include "gemma/activations.h"
|
|
||||||
#include "gemma/common.h"
|
|
||||||
#include "gemma/configs.h"
|
|
||||||
#include "ops/ops-inl.h"
|
#include "ops/ops-inl.h"
|
||||||
#include "util/allocator.h"
|
#include "hwy/tests/test_util-inl.h"
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -366,6 +368,23 @@ void TestSigmoid() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
||||||
|
const float mul, const float* HWY_RESTRICT x, size_t dim_qkv,
|
||||||
|
const float* HWY_RESTRICT inv_timescale, int pos,
|
||||||
|
float* HWY_RESTRICT x_out) {
|
||||||
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
|
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||||
|
const float theta = StaticCast<float>(pos) * inv_timescale[dim];
|
||||||
|
const float cos_val = cosf(theta);
|
||||||
|
const float sin_val = sinf(theta);
|
||||||
|
const float x0 = x[dim];
|
||||||
|
const float x1 = x[dim + half_dim_qkv];
|
||||||
|
x_out[dim] = mul * (x0 * cos_val - x1 * sin_val);
|
||||||
|
x_out[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void TestRopeAndMulBy() {
|
void TestRopeAndMulBy() {
|
||||||
using Config = ConfigGemma2_9B<float>;
|
using Config = ConfigGemma2_9B<float>;
|
||||||
int dim_qkv = Config::kQKVDim;
|
int dim_qkv = Config::kQKVDim;
|
||||||
|
|
@ -392,10 +411,10 @@ void TestRopeAndMulBy() {
|
||||||
// Assert VectorizedRope computation is same as regular rope at different pos.
|
// Assert VectorizedRope computation is same as regular rope at different pos.
|
||||||
for (int pos = 1; pos < 500; pos++) {
|
for (int pos = 1; pos < 500; pos++) {
|
||||||
// Rope'd Q embeddings
|
// Rope'd Q embeddings
|
||||||
|
ScalarRopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
|
||||||
|
qexpected.data());
|
||||||
RopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
|
RopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
|
||||||
qexpected.data());
|
qactual.data());
|
||||||
VectorizedRopeAndMulBy(qmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
|
|
||||||
qactual.data());
|
|
||||||
|
|
||||||
for (int i = 0; i < dim_qkv; ++i) {
|
for (int i = 0; i < dim_qkv; ++i) {
|
||||||
EXPECT_NEAR(qactual[i], qexpected[i], 1e-4)
|
EXPECT_NEAR(qactual[i], qexpected[i], 1e-4)
|
||||||
|
|
@ -403,10 +422,10 @@ void TestRopeAndMulBy() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rope'd K embeddings
|
// Rope'd K embeddings
|
||||||
|
ScalarRopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
|
||||||
|
kexpected.data());
|
||||||
RopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
|
RopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
|
||||||
kexpected.data());
|
kactual.data());
|
||||||
VectorizedRopeAndMulBy(kmul, x.Const(), dim_qkv, inv_timescale.Const(), pos,
|
|
||||||
kactual.data());
|
|
||||||
|
|
||||||
for (int i = 0; i < dim_qkv; ++i) {
|
for (int i = 0; i < dim_qkv; ++i) {
|
||||||
EXPECT_NEAR(kactual[i], kexpected[i], 1e-4)
|
EXPECT_NEAR(kactual[i], kexpected[i], 1e-4)
|
||||||
|
|
@ -415,6 +434,70 @@ void TestRopeAndMulBy() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
HWY_NOINLINE float ScalarSquaredL2(const T* HWY_RESTRICT a, size_t size) {
|
||||||
|
double sum = 0.0;
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
const float f = hwy::ConvertScalarTo<float>(a[i]);
|
||||||
|
sum += f * f;
|
||||||
|
}
|
||||||
|
return static_cast<float>(sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Supports bf16 and f32 inputs/outputs, which can be in-place.
|
||||||
|
template <typename VecT, typename WeightT, typename OutT>
|
||||||
|
HWY_NOINLINE void ScalarRMSNorm(const VecT* x,
|
||||||
|
const WeightT* HWY_RESTRICT weight, OutT* out,
|
||||||
|
size_t size) {
|
||||||
|
constexpr float kEps = 1e-6f;
|
||||||
|
float ss = ScalarSquaredL2(x, size);
|
||||||
|
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
|
||||||
|
for (size_t j = 0; j < size; j++) {
|
||||||
|
const float v = hwy::ConvertScalarTo<float>(x[j]);
|
||||||
|
const float w = hwy::ConvertScalarTo<float>(weight[j]);
|
||||||
|
// Note 1.0f centering here
|
||||||
|
out[j] = hwy::ConvertScalarTo<OutT>((1.0f + w) * (ss * v));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename VecT, typename WeightT, typename OutT>
|
||||||
|
void TestRMSNorm(hwy::RandomState& rng) {
|
||||||
|
constexpr size_t kSize = 128;
|
||||||
|
VecT vec[kSize];
|
||||||
|
WeightT weight[kSize];
|
||||||
|
OutT expected[kSize];
|
||||||
|
OutT actual[kSize];
|
||||||
|
|
||||||
|
for (size_t i = 0; i < kSize; ++i) {
|
||||||
|
vec[i] = hwy::ConvertScalarTo<VecT>(RandomGaussian(rng));
|
||||||
|
weight[i] = hwy::ConvertScalarTo<WeightT>(RandomGaussian(rng));
|
||||||
|
}
|
||||||
|
|
||||||
|
ScalarRMSNorm(vec, weight, expected, kSize);
|
||||||
|
RMSNorm(vec, weight, actual, kSize);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < kSize; i++) {
|
||||||
|
const float e = hwy::ConvertScalarTo<float>(expected[i]);
|
||||||
|
const float a = hwy::ConvertScalarTo<float>(actual[i]);
|
||||||
|
if (!IsNear(e, a, 1e-5f)) {
|
||||||
|
HWY_ABORT("RMSNorm %s %s %s mismatch at %zu: %E %E\n", TypeName(VecT()),
|
||||||
|
TypeName(WeightT()), TypeName(OutT()), i, e, a);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestAllRMSNorm() {
|
||||||
|
hwy::RandomState rng;
|
||||||
|
TestRMSNorm<float, float, float>(rng);
|
||||||
|
TestRMSNorm<float, float, BF16>(rng);
|
||||||
|
TestRMSNorm<float, BF16, float>(rng);
|
||||||
|
TestRMSNorm<float, BF16, BF16>(rng);
|
||||||
|
TestRMSNorm<BF16, float, float>(rng);
|
||||||
|
TestRMSNorm<BF16, float, BF16>(rng);
|
||||||
|
TestRMSNorm<BF16, BF16, float>(rng);
|
||||||
|
TestRMSNorm<BF16, BF16, BF16>(rng);
|
||||||
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
@ -432,6 +515,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy);
|
||||||
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm);
|
||||||
HWY_AFTER_TEST();
|
HWY_AFTER_TEST();
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,8 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_TEST_UTIL_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_
|
#define THIRD_PARTY_GEMMA_CPP_UTIL_TEST_UTIL_H_
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
@ -24,7 +24,6 @@
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
#include "compression/distortion.h"
|
|
||||||
#include "hwy/stats.h"
|
#include "hwy/stats.h"
|
||||||
#include "hwy/tests/test_util.h" // RandomState
|
#include "hwy/tests/test_util.h" // RandomState
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
|
|
@ -73,4 +72,4 @@ HWY_INLINE void VerifyGaussian(hwy::Stats& stats) {
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_TEST_UTIL_H_
|
||||||
Loading…
Reference in New Issue