mirror of https://github.com/google/gemma.cpp.git
Split up ops.h into ops/ops-inl and matmul-inl
PiperOrigin-RevId: 654068303
This commit is contained in:
parent
5844e6a1e5
commit
85cac13fb1
24
BUILD.bazel
24
BUILD.bazel
|
|
@ -22,7 +22,10 @@ exports_files(["LICENSE"])
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "ops",
|
name = "ops",
|
||||||
hdrs = ["gemma/ops.h"],
|
textual_hdrs = [
|
||||||
|
"ops/ops-inl.h",
|
||||||
|
"ops/matmul-inl.h",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//compression:sfp",
|
"//compression:sfp",
|
||||||
|
|
@ -40,7 +43,24 @@ cc_test(
|
||||||
name = "ops_test",
|
name = "ops_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
timeout = "long",
|
timeout = "long",
|
||||||
srcs = ["gemma/ops_test.cc"],
|
srcs = ["ops/ops_test.cc"],
|
||||||
|
local_defines = ["HWY_IS_TEST"],
|
||||||
|
# for test_suite.
|
||||||
|
tags = ["hwy_ops_test"],
|
||||||
|
deps = [
|
||||||
|
":ops",
|
||||||
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
"@hwy//:hwy",
|
||||||
|
"@hwy//:hwy_test_util",
|
||||||
|
"@hwy//:nanobenchmark",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "matmul_test",
|
||||||
|
size = "small",
|
||||||
|
timeout = "long",
|
||||||
|
srcs = ["ops/matmul_test.cc"],
|
||||||
local_defines = ["HWY_IS_TEST"],
|
local_defines = ["HWY_IS_TEST"],
|
||||||
# for test_suite.
|
# for test_suite.
|
||||||
tags = ["hwy_ops_test"],
|
tags = ["hwy_ops_test"],
|
||||||
|
|
|
||||||
|
|
@ -93,11 +93,12 @@ set(SOURCES
|
||||||
gemma/instantiations/tiny_sfp.cc
|
gemma/instantiations/tiny_sfp.cc
|
||||||
gemma/kv_cache.cc
|
gemma/kv_cache.cc
|
||||||
gemma/kv_cache.h
|
gemma/kv_cache.h
|
||||||
gemma/ops.h
|
|
||||||
gemma/tokenizer.cc
|
gemma/tokenizer.cc
|
||||||
gemma/tokenizer.h
|
gemma/tokenizer.h
|
||||||
gemma/weights.cc
|
gemma/weights.cc
|
||||||
gemma/weights.h
|
gemma/weights.h
|
||||||
|
ops/matmul-inl.h
|
||||||
|
ops/ops-inl.h
|
||||||
util/app.h
|
util/app.h
|
||||||
util/args.h
|
util/args.h
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,8 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
|
#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "gemma/ops.h"
|
#include "ops/matmul-inl.h"
|
||||||
|
#include "ops/ops-inl.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@
|
||||||
// After highway.h
|
// After highway.h
|
||||||
#include "backprop/backward-inl.h"
|
#include "backprop/backward-inl.h"
|
||||||
#include "backprop/forward-inl.h"
|
#include "backprop/forward-inl.h"
|
||||||
#include "gemma/ops.h"
|
#include "ops/ops-inl.h"
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,8 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
|
#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "gemma/ops.h"
|
#include "ops/matmul-inl.h"
|
||||||
|
#include "ops/ops-inl.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
|
|
||||||
|
|
@ -37,9 +37,10 @@
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/ops.h"
|
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
// Placeholder for internal test4, do not remove
|
// Placeholder for internal test4, do not remove
|
||||||
|
#include "ops/matmul-inl.h"
|
||||||
|
#include "ops/ops-inl.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/bit_set.h"
|
#include "hwy/bit_set.h"
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,8 @@
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
// Include guard for non-SIMD code.
|
// Include guard for non-SIMD code.
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
|
#define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_
|
||||||
|
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
@ -24,7 +24,6 @@
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <random>
|
|
||||||
#include <type_traits> // std::enable_if_t
|
#include <type_traits> // std::enable_if_t
|
||||||
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
|
|
@ -34,18 +33,17 @@
|
||||||
#include "hwy/detect_targets.h"
|
#include "hwy/detect_targets.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_
|
||||||
|
|
||||||
// Include guard for (potentially) SIMD code.
|
// Include guard for (potentially) SIMD code.
|
||||||
#if defined(THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
#if defined(THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||||
#ifdef THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
|
#ifdef THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE
|
||||||
#undef THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
|
#undef THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE
|
||||||
#else
|
#else
|
||||||
#define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
|
#define THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
#include "hwy/contrib/algo/transform-inl.h"
|
|
||||||
#include "hwy/contrib/dot/dot-inl.h"
|
#include "hwy/contrib/dot/dot-inl.h"
|
||||||
#include "hwy/contrib/math/math-inl.h"
|
#include "hwy/contrib/math/math-inl.h"
|
||||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||||
|
|
@ -60,40 +58,6 @@ HWY_INLINE constexpr size_t MaxCols() {
|
||||||
return 2048;
|
return 2048;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename To, typename From>
|
|
||||||
HWY_INLINE constexpr std::enable_if_t<
|
|
||||||
std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To>
|
|
||||||
StaticCast(From from) noexcept {
|
|
||||||
if constexpr (std::is_unsigned_v<From> && std::is_floating_point_v<To>)
|
|
||||||
return static_cast<To>(
|
|
||||||
static_cast<hwy::SignedFromSize<sizeof(From)>>(from));
|
|
||||||
else
|
|
||||||
return static_cast<To>(from);
|
|
||||||
}
|
|
||||||
|
|
||||||
// For testing.
|
|
||||||
template <typename MatT>
|
|
||||||
void AssertClose(const MatT* HWY_RESTRICT expected,
|
|
||||||
const MatT* HWY_RESTRICT actual, size_t num) {
|
|
||||||
for (size_t idx = 0; idx < num; idx++) {
|
|
||||||
const double expected_value = hwy::ConvertScalarTo<double>(expected[idx]);
|
|
||||||
const double actual_value = hwy::ConvertScalarTo<double>(actual[idx]);
|
|
||||||
|
|
||||||
const double magnitude = std::abs(expected_value);
|
|
||||||
|
|
||||||
const double tolerance =
|
|
||||||
256.0 * hwy::ConvertScalarTo<double>(hwy::Epsilon<MatT>()) *
|
|
||||||
HWY_MAX(magnitude, 1.0);
|
|
||||||
|
|
||||||
if (!(expected_value - tolerance <= actual_value &&
|
|
||||||
actual_value <= expected_value + tolerance)) {
|
|
||||||
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx,
|
|
||||||
expected_value, idx, actual_value);
|
|
||||||
HWY_ASSERT(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <size_t kOuter>
|
template <size_t kOuter>
|
||||||
HWY_INLINE constexpr size_t RowsPerStrip() {
|
HWY_INLINE constexpr size_t RowsPerStrip() {
|
||||||
// Aim for 128 work items to reduce pool overhead. Must be at least one
|
// Aim for 128 work items to reduce pool overhead. Must be at least one
|
||||||
|
|
@ -726,112 +690,6 @@ HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs,
|
||||||
even_odd, out, pool);
|
even_odd, out, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class D, HWY_IF_F32_D(D)>
|
|
||||||
static HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
|
|
||||||
const hn::Vec<D> kMul = hn::Set(d, 0.044715f);
|
|
||||||
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
|
|
||||||
const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
|
|
||||||
|
|
||||||
// tanh approximation matches training.
|
|
||||||
const hn::Vec<D> v3 = hn::Mul(hn::Mul(v, v), v);
|
|
||||||
const hn::Vec<D> arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v));
|
|
||||||
// 0.5 * (1 + tan) = MulAdd(0.5, tan, 0.5).
|
|
||||||
const hn::Vec<D> cdf = hn::MulAdd(kHalf, hn::Tanh(d, arg), kHalf);
|
|
||||||
return hn::Mul(v, cdf);
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
|
|
||||||
size_t size) {
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
using D = hn::ScalableTag<float>;
|
|
||||||
hn::Transform(D(), x, size,
|
|
||||||
[](D d, hn::Vec<D> v) HWY_ATTR { return Gelu(d, v); });
|
|
||||||
}
|
|
||||||
|
|
||||||
// out[i] = BF(mul[i] * Gelu(gelu_in[i]))
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16(
|
|
||||||
const float* HWY_RESTRICT gelu_in, const float* HWY_RESTRICT mul,
|
|
||||||
hwy::bfloat16_t* HWY_RESTRICT out, size_t size) {
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
const hn::ScalableTag<float> df;
|
|
||||||
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf;
|
|
||||||
const size_t NF = hn::Lanes(df);
|
|
||||||
using VF = hn::Vec<decltype(df)>;
|
|
||||||
|
|
||||||
size_t i = 0;
|
|
||||||
if (size >= 2 * NF) {
|
|
||||||
for (; i <= size - 2 * NF; i += 2 * NF) {
|
|
||||||
const VF mul0 = hn::LoadU(df, mul + i);
|
|
||||||
const VF mul1 = hn::LoadU(df, mul + i + NF);
|
|
||||||
const VF g0 = hn::Mul(mul0, Gelu(df, hn::LoadU(df, gelu_in + i)));
|
|
||||||
const VF g1 = hn::Mul(mul1, Gelu(df, hn::LoadU(df, gelu_in + i + NF)));
|
|
||||||
const hn::Vec<decltype(dbf)> bf = hn::OrderedDemote2To(dbf, g0, g1);
|
|
||||||
hn::StoreU(bf, dbf, out + i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (i != size) {
|
|
||||||
const size_t remaining = size - i;
|
|
||||||
const VF mul0 = hn::LoadN(df, mul + i, remaining);
|
|
||||||
const VF g0 =
|
|
||||||
hn::Mul(mul0, Gelu(df, hn::LoadN(df, gelu_in + i, remaining)));
|
|
||||||
const hn::Half<decltype(dbf)> dbfh;
|
|
||||||
const hn::Vec<decltype(dbfh)> bfh = hn::DemoteTo(dbfh, g0);
|
|
||||||
hn::StoreN(bfh, dbfh, out + i, remaining);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class D, HWY_IF_F32_D(D)>
|
|
||||||
static HWY_INLINE hn::Vec<D> Sigmoid(D d, hn::Vec<D> v) {
|
|
||||||
using VF = hn::Vec<D>;
|
|
||||||
// Chebyshev polynomial coefficients for rational approximation
|
|
||||||
const VF c0 = hn::Set(d, 0.00949107017368078f);
|
|
||||||
const VF c1 = hn::Set(d, 0.0654858946800232f);
|
|
||||||
const VF c2 = hn::Set(d, 0.231547489762306f - 0.00949107017368078f);
|
|
||||||
const VF c3 = hn::Set(d, 0.530778527259827f);
|
|
||||||
const VF c4 = hn::Set(d, 0.855334937572479f);
|
|
||||||
const VF c5 = hn::Set(d, 0.500000894069672f);
|
|
||||||
|
|
||||||
const VF d0 = hn::Set(d, 0.130970627069473f);
|
|
||||||
const VF d1 = hn::Set(d, 3.99615288415589e-07f);
|
|
||||||
const VF d2 = hn::Set(d, 1.06155431270599f - 0.130970627069473f);
|
|
||||||
const VF d3 = hn::Set(d, 1.35144250634767e-06f);
|
|
||||||
const VF d4 = hn::Set(d, 1);
|
|
||||||
|
|
||||||
// The approximation works in range -12..12, but the input value is clamped
|
|
||||||
// in -11.5..11.5 since the approximation slightly overshoots after that.
|
|
||||||
// The function is nearly 0 for input values below -11.5 and nearly 1 for
|
|
||||||
// input values above 11.5.
|
|
||||||
const VF invtwelve = hn::Set(d, 1.0f / 12.0f);
|
|
||||||
const VF lo = hn::Set(d, -11.5f);
|
|
||||||
const VF hi = hn::Set(d, 11.5f);
|
|
||||||
|
|
||||||
VF f = hn::Clamp(v, lo, hi);
|
|
||||||
f = hn::Mul(f, invtwelve);
|
|
||||||
VF f2 = hn::Add(f, f);
|
|
||||||
|
|
||||||
VF a1 = hn::MulAdd(f2, c0, c1);
|
|
||||||
VF a2 = hn::MulAdd(f2, a1, c2);
|
|
||||||
VF a3 = hn::Sub(hn::MulAdd(f2, a2, c3), a1);
|
|
||||||
VF a4 = hn::Sub(hn::MulAdd(f2, a3, c4), a2);
|
|
||||||
VF f0 = hn::Sub(hn::MulAdd(f, a4, c5), a3);
|
|
||||||
|
|
||||||
VF b1 = hn::MulAdd(f2, d0, d1);
|
|
||||||
VF b2 = hn::MulAdd(f2, b1, d2);
|
|
||||||
VF b3 = hn::Sub(hn::MulAdd(f2, b2, d3), b1);
|
|
||||||
VF f1 = hn::Sub(hn::MulAdd(f, b3, d4), b2);
|
|
||||||
|
|
||||||
return hn::Div(f0, f1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sigmoid using the logistic function 1 / (1 + exp(-x[i]))
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Sigmoid(float* HWY_RESTRICT x,
|
|
||||||
size_t size) {
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
using D = hn::ScalableTag<float>;
|
|
||||||
hn::Transform(D(), x, size,
|
|
||||||
[](D d, hn::Vec<D> v) HWY_ATTR { return Sigmoid(d, v); });
|
|
||||||
}
|
|
||||||
|
|
||||||
// Two matrices, same vector
|
// Two matrices, same vector
|
||||||
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
template <bool kAdd, size_t kOuter, size_t kInner, typename ArrayT,
|
||||||
typename VecT, typename AddT>
|
typename VecT, typename AddT>
|
||||||
|
|
@ -897,483 +755,6 @@ HWY_NOINLINE void TwoMatVec(const ArrayT& mat0, const ArrayT& mat1,
|
||||||
out0, out1, pool);
|
out0, out1, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
|
|
||||||
const float* HWY_RESTRICT b,
|
|
||||||
size_t size) {
|
|
||||||
const hn::ScalableTag<float> d;
|
|
||||||
HWY_DASSERT(size >= hn::Lanes(d));
|
|
||||||
HWY_DASSERT(size % hn::Lanes(d) == 0);
|
|
||||||
constexpr int kAssumptions =
|
|
||||||
hn::Dot::kAtLeastOneVector | hn::Dot::kMultipleOfVector;
|
|
||||||
return hn::Dot::Compute<kAssumptions>(d, a, b, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT.
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
|
|
||||||
const float* HWY_RESTRICT a, size_t size) {
|
|
||||||
const hn::ScalableTag<float> d;
|
|
||||||
using V = hn::Vec<decltype(d)>;
|
|
||||||
const size_t N = hn::Lanes(d);
|
|
||||||
HWY_DASSERT(size >= 2 * N);
|
|
||||||
HWY_DASSERT(size % (2 * N) == 0);
|
|
||||||
|
|
||||||
V sum0 = hn::Zero(d);
|
|
||||||
V sum1 = hn::Zero(d);
|
|
||||||
for (size_t i = 0; i <= size - 2 * N; i += 2 * N) {
|
|
||||||
const V a0 = hn::LoadU(d, a + i);
|
|
||||||
sum0 = hn::MulAdd(a0, a0, sum0);
|
|
||||||
const V a1 = hn::LoadU(d, a + i + N);
|
|
||||||
sum1 = hn::MulAdd(a1, a1, sum1);
|
|
||||||
}
|
|
||||||
|
|
||||||
return hn::ReduceSum(d, hn::Add(sum0, sum1));
|
|
||||||
}
|
|
||||||
|
|
||||||
// float, float -> float; simple loop.
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|
||||||
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
|
|
||||||
float* HWY_RESTRICT out, size_t size) {
|
|
||||||
constexpr float kEps = 1e-6f;
|
|
||||||
float ss = SquaredL2(x, size);
|
|
||||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
|
|
||||||
for (size_t j = 0; j < size; j++) {
|
|
||||||
// Note 1.0f centering here
|
|
||||||
out[j] = (1.0f + weight[j]) * (ss * x[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// x=f, w=bf16 -> out=f
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|
||||||
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
|
||||||
float* HWY_RESTRICT out, size_t size) {
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
|
|
||||||
constexpr float kEps = 1e-6f;
|
|
||||||
constexpr size_t kUnrollSize = 2;
|
|
||||||
|
|
||||||
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
|
||||||
const hn::Repartition<float, decltype(dbf)> df32;
|
|
||||||
const size_t N32 = hn::Lanes(df32);
|
|
||||||
|
|
||||||
const float ss = SquaredL2(x, size);
|
|
||||||
const auto vss =
|
|
||||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
|
||||||
|
|
||||||
HWY_DASSERT(size % (kUnrollSize * MaxLanes(df32)) == 0);
|
|
||||||
for (size_t i = 0; i < size; i += kUnrollSize * N32) {
|
|
||||||
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
|
|
||||||
const auto w0 = hn::PromoteLowerTo(df32, w16);
|
|
||||||
const auto w1 = hn::PromoteUpperTo(df32, w16);
|
|
||||||
const auto m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
|
|
||||||
const auto m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
|
|
||||||
|
|
||||||
// (1+weight) * m = m + weight*m = one FMA.
|
|
||||||
hn::StoreU(hn::MulAdd(m0, w0, m0), df32, out + i);
|
|
||||||
hn::StoreU(hn::MulAdd(m1, w1, m1), df32, out + i + N32);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// float -> float; simple loop.
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
|
||||||
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
|
|
||||||
constexpr float kEps = 1e-6f;
|
|
||||||
float ss = SquaredL2(inout, size);
|
|
||||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
|
|
||||||
for (size_t j = 0; j < size; j++) {
|
|
||||||
// Note 1.0f centering here
|
|
||||||
inout[j] = (1.0f + weight[j]) * (ss * inout[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// w=bf16 -> f
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
|
||||||
const hwy::bfloat16_t* HWY_RESTRICT weight, float* HWY_RESTRICT inout,
|
|
||||||
const size_t size) {
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
|
||||||
const hn::Repartition<float, decltype(dbf)> df32;
|
|
||||||
using VF = hn::Vec<decltype(df32)>;
|
|
||||||
const size_t N32 = hn::Lanes(df32);
|
|
||||||
|
|
||||||
constexpr float kEps = 1e-6f;
|
|
||||||
const float ss = SquaredL2(inout, size);
|
|
||||||
const VF vss =
|
|
||||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
|
||||||
|
|
||||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
|
||||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
|
||||||
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
|
|
||||||
const VF w0 = hn::PromoteLowerTo(df32, w16);
|
|
||||||
const VF w1 = hn::PromoteUpperTo(df32, w16);
|
|
||||||
const VF m0 = hn::Mul(vss, hn::LoadU(df32, inout + i));
|
|
||||||
const VF m1 = hn::Mul(vss, hn::LoadU(df32, inout + i + N32));
|
|
||||||
// (1+weight) * m = m + weight*m = one FMA.
|
|
||||||
hn::StoreU(hn::MulAdd(m0, w0, m0), df32, inout + i);
|
|
||||||
hn::StoreU(hn::MulAdd(m1, w1, m1), df32, inout + i + N32);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// f, f -> bf
|
|
||||||
// TODO(janwas): consider generic function with adapter for loading bf16/f32
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|
||||||
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
|
|
||||||
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) {
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
|
||||||
const hn::Repartition<float, decltype(dbf)> df32;
|
|
||||||
using VF = hn::Vec<decltype(df32)>;
|
|
||||||
const size_t N32 = hn::Lanes(df32);
|
|
||||||
|
|
||||||
constexpr float kEps = 1e-6f;
|
|
||||||
const float ss = SquaredL2(x, size);
|
|
||||||
const VF vss =
|
|
||||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
|
||||||
|
|
||||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
|
||||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
|
||||||
const VF w0 = hn::LoadU(df32, weight + i);
|
|
||||||
const VF w1 = hn::LoadU(df32, weight + i + N32);
|
|
||||||
const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
|
|
||||||
const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
|
|
||||||
// (1+weight) * m = m + weight*m = one FMA.
|
|
||||||
const VF out0 = hn::MulAdd(m0, w0, m0);
|
|
||||||
const VF out1 = hn::MulAdd(m1, w1, m1);
|
|
||||||
hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// x=f, w=bf16 -> bf16 to enable W16A16 MatVec.
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|
||||||
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
|
||||||
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) {
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
|
||||||
const hn::Repartition<float, decltype(dbf)> df32;
|
|
||||||
using VF = hn::Vec<decltype(df32)>;
|
|
||||||
const size_t N32 = hn::Lanes(df32);
|
|
||||||
|
|
||||||
constexpr float kEps = 1e-6f;
|
|
||||||
const float ss = SquaredL2(x, size);
|
|
||||||
const VF vss =
|
|
||||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
|
||||||
|
|
||||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
|
||||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
|
||||||
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
|
|
||||||
const VF w0 = hn::PromoteLowerTo(df32, w16);
|
|
||||||
const VF w1 = hn::PromoteUpperTo(df32, w16);
|
|
||||||
const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
|
|
||||||
const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
|
|
||||||
// (1+weight) * m = m + weight*m = one FMA.
|
|
||||||
const VF out0 = hn::MulAdd(m0, w0, m0);
|
|
||||||
const VF out1 = hn::MulAdd(m1, w1, m1);
|
|
||||||
hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
|
||||||
float* HWY_RESTRICT x, size_t dim_model, size_t pos) {
|
|
||||||
const size_t num_timescales = dim_model / 2;
|
|
||||||
const float log_timescale_increment =
|
|
||||||
logf(10000.0f) /
|
|
||||||
(num_timescales != 0 ? StaticCast<float>(num_timescales - 1) : 1.0f);
|
|
||||||
for (size_t dim = 0; dim < num_timescales; ++dim) {
|
|
||||||
const float inv_timescale =
|
|
||||||
expf(StaticCast<float>(dim) * -log_timescale_increment);
|
|
||||||
x[dim] += sinf(StaticCast<float>(pos) * inv_timescale);
|
|
||||||
x[num_timescales + dim] += cosf(StaticCast<float>(pos) * inv_timescale);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* RoPE as in Rotary Position Embeddings from the RoFormer paper
|
|
||||||
(https://arxiv.org/abs/2104.09864v5). The query and key vectors are rotated
|
|
||||||
as a function of their absolute position using the rotation matrix R before
|
|
||||||
the self-attention operation. R is a d x d matrix.
|
|
||||||
|
|
||||||
R = cos(m*theta_1) -sin(m*theta_1) ... 0 0
|
|
||||||
sin(m*theta_1) cos(m*theta_1)
|
|
||||||
0 0 ... 0 0
|
|
||||||
0 0 ... 0 0
|
|
||||||
...
|
|
||||||
0 0 ... cos(m*theta_{d/2}) sin(m*theta_{d/2})
|
|
||||||
0 0 ... sin(m*theta_{d/2}) cos(m*theta_{d/2})
|
|
||||||
|
|
||||||
Here theta_i = 10000^(-2(i-1)/d), where d is the dimension of the vector and
|
|
||||||
i is the ith index of the vector.
|
|
||||||
|
|
||||||
Applying the rotation matrix R to a vector v is equivalent to rotating every
|
|
||||||
consecutive pair of dimensions of v i.e. v_{2i} and v_{2i+1} by an angle
|
|
||||||
m*theta_i. However in the Gemma implementation we choose to rotate
|
|
||||||
the pairs of dimensions v_{i} and v_{i + d//2} instead.
|
|
||||||
|
|
||||||
pos parameter is deliberately an int because in the backward pass we
|
|
||||||
call this with negative values (for the VJP calculation we need the transpose
|
|
||||||
of this rotation matrix which is simply the same matrix with -pos parameter)
|
|
||||||
*/
|
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
|
|
||||||
size_t dim_qkv, int pos) {
|
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
|
||||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
|
||||||
const float freq_exponents =
|
|
||||||
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
|
|
||||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
|
|
||||||
const float timescale = powf(10000.0f, freq_exponents);
|
|
||||||
const float theta = StaticCast<float>(pos) / timescale;
|
|
||||||
const float cos_val = cosf(theta);
|
|
||||||
const float sin_val = sinf(theta);
|
|
||||||
const float x0 = x[dim];
|
|
||||||
const float x1 = x[dim + half_dim_qkv];
|
|
||||||
x[dim] = x0 * cos_val - x1 * sin_val;
|
|
||||||
x[dim + half_dim_qkv] = x0 * sin_val + x1 * cos_val;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
|
|
||||||
float* HWY_RESTRICT x,
|
|
||||||
size_t dim_qkv,
|
|
||||||
int pos) {
|
|
||||||
HWY_DASSERT(dim_qkv % 2 == 0);
|
|
||||||
const size_t half_dim_qkv = dim_qkv / 2;
|
|
||||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
|
||||||
const float freq_exponents =
|
|
||||||
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
|
|
||||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
|
|
||||||
const float timescale = powf(10000.0f, freq_exponents);
|
|
||||||
const float theta = StaticCast<float>(pos) / timescale;
|
|
||||||
const float cos_val = cosf(theta);
|
|
||||||
const float sin_val = sinf(theta);
|
|
||||||
const float x0 = x[dim];
|
|
||||||
const float x1 = x[dim + half_dim_qkv];
|
|
||||||
x[dim] = mul * (x0 * cos_val - x1 * sin_val);
|
|
||||||
x[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
|
||||||
const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) {
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
using D = hn::ScalableTag<float>;
|
|
||||||
using V = hn::Vec<D>;
|
|
||||||
|
|
||||||
hn::Transform1(D(), x, size, other,
|
|
||||||
[](const auto d, const V x, const V other)
|
|
||||||
HWY_ATTR { return hn::Add(x, other); });
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simple loops unless/until batch sizes are large enough to parallelize.
|
|
||||||
template <typename WeightT, typename OutT>
|
|
||||||
void RMSNormBatched(size_t num_tokens, const float* activations,
|
|
||||||
const WeightT* weights, OutT* out, const size_t model_dim) {
|
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
|
||||||
RMSNorm(activations + token_idx * model_dim, weights,
|
|
||||||
out + token_idx * model_dim, model_dim);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: pass RowVectorBatch argument.
|
|
||||||
template <typename WeightT, typename InOutT>
|
|
||||||
void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights,
|
|
||||||
InOutT* inout, const size_t model_dim) {
|
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
|
||||||
RMSNormInplace(weights, inout + token_idx * model_dim, model_dim);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_INLINE void AddFromBatched(size_t num_tokens, const float* other,
|
|
||||||
float* x, const size_t model_dim) {
|
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
|
||||||
AddFrom(other + token_idx * model_dim, x + token_idx * model_dim,
|
|
||||||
model_dim);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
|
|
||||||
float* HWY_RESTRICT x, const size_t size,
|
|
||||||
const size_t max_pos) {
|
|
||||||
HWY_DASSERT(max_pos <= size);
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
using D = hn::ScalableTag<float>;
|
|
||||||
using V = hn::Vec<D>;
|
|
||||||
|
|
||||||
hn::Transform1(D(), x, max_pos, other,
|
|
||||||
[](const auto d, const V x, const V other)
|
|
||||||
HWY_ATTR { return hn::Mul(x, other); });
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other,
|
|
||||||
float* HWY_RESTRICT x,
|
|
||||||
const size_t size) {
|
|
||||||
return MulBy(other, x, size, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x,
|
|
||||||
const size_t size, const size_t max_pos) {
|
|
||||||
HWY_DASSERT(max_pos <= size);
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
using D = hn::ScalableTag<float>;
|
|
||||||
using V = hn::Vec<D>;
|
|
||||||
hn::Transform(D(), x, max_pos, [c](const auto d, const V x) HWY_ATTR {
|
|
||||||
return hn::Mul(x, hn::Set(d, c));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c,
|
|
||||||
float* HWY_RESTRICT x,
|
|
||||||
const size_t size) {
|
|
||||||
MulByConst(c, x, size, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_NOINLINE void MulByConstAndAdd(const float c,
|
|
||||||
const float* HWY_RESTRICT x,
|
|
||||||
float* HWY_RESTRICT out,
|
|
||||||
const size_t size,
|
|
||||||
const size_t max_pos) {
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
using D = hn::ScalableTag<float>;
|
|
||||||
using V = hn::Vec<D>;
|
|
||||||
hn::Transform1(D(), out, max_pos, x,
|
|
||||||
[c](const auto d, const V v_out, const V v_x) HWY_ATTR {
|
|
||||||
return hn::MulAdd(v_x, hn::Set(d, c), v_out);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
|
||||||
float c, const float* HWY_RESTRICT x, float* HWY_RESTRICT out,
|
|
||||||
size_t size) {
|
|
||||||
MulByConstAndAdd(c, x, out, size, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
|
||||||
const size_t mask_pos) {
|
|
||||||
HWY_DASSERT(size != 0);
|
|
||||||
HWY_DASSERT(mask_pos <= size);
|
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
using D = hn::ScalableTag<float>;
|
|
||||||
using V = hn::Vec<D>;
|
|
||||||
const D d;
|
|
||||||
|
|
||||||
const V vmin = hn::Set(d, hwy::LowestValue<float>());
|
|
||||||
V vmax = vmin;
|
|
||||||
V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly
|
|
||||||
Foreach(d, x, mask_pos, vmin, [pmax](const auto d, const V value) HWY_ATTR {
|
|
||||||
*pmax = hn::Max(*pmax, value);
|
|
||||||
});
|
|
||||||
vmax = hn::MaxOfLanes(d, vmax);
|
|
||||||
|
|
||||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
|
||||||
hn::Transform(d, x, mask_pos, [pmax](const auto d, const V value) HWY_ATTR {
|
|
||||||
#if HWY_TARGET & HWY_ALL_SVE
|
|
||||||
// Temporary workaround for buggy SVE codegen: avoid inlined
|
|
||||||
// Exp().
|
|
||||||
return hn::CallExp(d, hn::Sub(value, *pmax));
|
|
||||||
#else
|
|
||||||
return hn::Exp(d, hn::Sub(value, *pmax));
|
|
||||||
#endif
|
|
||||||
});
|
|
||||||
|
|
||||||
V sum = hn::Zero(d);
|
|
||||||
V* psum = ∑
|
|
||||||
Foreach(d, x, mask_pos, sum, [psum](const auto d, const V value) HWY_ATTR {
|
|
||||||
*psum = hn::Add(*psum, value);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Normalize to probability distribution
|
|
||||||
const float mul = 1.0f / hn::ReduceSum(d, sum);
|
|
||||||
MulByConst(mul, x, size, mask_pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
|
|
||||||
const size_t size) {
|
|
||||||
Softmax(x, size, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
|
||||||
const size_t size,
|
|
||||||
const size_t max_pos) {
|
|
||||||
HWY_DASSERT(max_pos <= size);
|
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
using D = hn::ScalableTag<float>;
|
|
||||||
using V = hn::Vec<D>;
|
|
||||||
|
|
||||||
const float inv_cap = 1.0f / cap;
|
|
||||||
|
|
||||||
hn::Transform(D(), x, max_pos, [cap, inv_cap](D d, V v) HWY_ATTR {
|
|
||||||
return hn::Mul(hn::Set(d, cap),
|
|
||||||
hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap))));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap,
|
|
||||||
float* HWY_RESTRICT x,
|
|
||||||
const size_t size) {
|
|
||||||
LogitsSoftCap(cap, x, size, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t
|
|
||||||
SampleArgmax(const float* probabilities, size_t vocab_size) {
|
|
||||||
size_t max_index = 0;
|
|
||||||
float max_prob = probabilities[0];
|
|
||||||
for (size_t i = 1; i < vocab_size; ++i) {
|
|
||||||
if (probabilities[i] > max_prob) {
|
|
||||||
max_index = i;
|
|
||||||
max_prob = probabilities[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return max_index;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <size_t k>
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int>
|
|
||||||
create_distribution(std::array<float, k>& top_k, float temperature) {
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
|
||||||
using D = hn::ScalableTag<float>;
|
|
||||||
|
|
||||||
// re-normalize distribution
|
|
||||||
const float temperature_inv = 1.0f / temperature;
|
|
||||||
hn::Transform(D(), top_k.data(), top_k.size(),
|
|
||||||
[temperature_inv](D d, hn::Vec<D> v) HWY_ATTR {
|
|
||||||
return hn::Exp(
|
|
||||||
d, hn::Mul(hn::Log(d, v), hn::Set(d, temperature_inv)));
|
|
||||||
});
|
|
||||||
|
|
||||||
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <size_t k, typename TAcceptToken>
|
|
||||||
static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
|
||||||
const float* HWY_RESTRICT probabilities, size_t vocab_size,
|
|
||||||
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
|
|
||||||
static_assert(k != 0, "");
|
|
||||||
// TODO: Optimize, potentially using new VQSort PartialSort.
|
|
||||||
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1]
|
|
||||||
std::array<int, k> indices{};
|
|
||||||
for (size_t i = 0; i < vocab_size; ++i) {
|
|
||||||
if (probabilities[i] < top_k[k - 1] &&
|
|
||||||
(!accept_token || accept_token(StaticCast<int>(i), probabilities[i]))) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (size_t j = 0; j < k; ++j) {
|
|
||||||
if (probabilities[i] > top_k[j] &&
|
|
||||||
(!accept_token ||
|
|
||||||
accept_token(StaticCast<int>(i), probabilities[i]))) {
|
|
||||||
// shift elements by 1, insert the new value, move on to next value
|
|
||||||
for (size_t idx = k - 1; idx > j; --idx) {
|
|
||||||
top_k[idx] = top_k[idx - 1];
|
|
||||||
indices[idx] = indices[idx - 1];
|
|
||||||
}
|
|
||||||
top_k[j] = probabilities[i];
|
|
||||||
indices[j] = StaticCast<int>(i);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return indices[create_distribution<k>(top_k, temperature)(gen)];
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
@ -13,8 +13,6 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#ifndef HWY_DISABLED_TARGETS
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -25,8 +23,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <random>
|
#include <memory>
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
|
|
@ -36,13 +33,14 @@
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
#define HWY_TARGET_INCLUDE "gemma/ops_test.cc" // NOLINT
|
#define HWY_TARGET_INCLUDE "ops/matmul_test.cc" // NOLINT
|
||||||
// clang-format on
|
// clang-format on
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/tests/test_util-inl.h"
|
#include "hwy/tests/test_util-inl.h"
|
||||||
// After highway.h
|
// After highway.h
|
||||||
#include "gemma/ops.h"
|
#include "ops/matmul-inl.h"
|
||||||
|
#include "ops/ops-inl.h" // MulByConst
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -50,304 +48,6 @@ namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
template <class Test>
|
|
||||||
struct ForeachCountAndMisalign {
|
|
||||||
template <typename T, class D>
|
|
||||||
HWY_NOINLINE void operator()(T /*unused*/, D d) const {
|
|
||||||
hwy::RandomState rng;
|
|
||||||
const size_t N = Lanes(d);
|
|
||||||
const size_t misalignments[3] = {0, N / 4, 3 * N / 5};
|
|
||||||
|
|
||||||
for (size_t count = 0; count < 2 * N; ++count) {
|
|
||||||
for (size_t ma : misalignments) {
|
|
||||||
for (size_t mb : misalignments) {
|
|
||||||
Test()(d, count, ma, mb, rng);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T Random(hwy::RandomState& rng) {
|
|
||||||
const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023;
|
|
||||||
const double val = (bits - 512) / 64.0;
|
|
||||||
// Clamp negative to zero for unsigned types.
|
|
||||||
return hwy::ConvertScalarTo<T>(
|
|
||||||
HWY_MAX(hwy::ConvertScalarTo<double>(hwy::LowestValue<T>()), val));
|
|
||||||
}
|
|
||||||
|
|
||||||
HWY_NOINLINE void SourceAddFrom(const float* HWY_RESTRICT other,
|
|
||||||
float* HWY_RESTRICT x, size_t size) {
|
|
||||||
for (size_t i = 0; i < size; ++i) {
|
|
||||||
x[i] += other[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
HWY_NOINLINE void SourceMulBy(const float* HWY_RESTRICT other,
|
|
||||||
float* HWY_RESTRICT x, size_t size,
|
|
||||||
size_t max_pos) {
|
|
||||||
HWY_DASSERT(max_pos <= size);
|
|
||||||
for (size_t i = 0; i < max_pos; ++i) {
|
|
||||||
x[i] *= other[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
HWY_NOINLINE void SourceMulByConst(float c, float* HWY_RESTRICT x, size_t size,
|
|
||||||
size_t max_pos) {
|
|
||||||
for (size_t i = 0; i < max_pos; ++i) {
|
|
||||||
x[i] *= c;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
HWY_NOINLINE void SourceMulByConstAndAdd(float c, const float* HWY_RESTRICT x,
|
|
||||||
float* HWY_RESTRICT out, size_t size,
|
|
||||||
size_t max_pos) {
|
|
||||||
for (size_t i = 0; i < max_pos; ++i) {
|
|
||||||
out[i] += x[i] * c;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size,
|
|
||||||
size_t mask_pos) {
|
|
||||||
HWY_DASSERT(size != 0);
|
|
||||||
HWY_DASSERT(mask_pos <= size);
|
|
||||||
float sum = 0.0;
|
|
||||||
const float maxval = *std::max_element(x, x + mask_pos);
|
|
||||||
for (size_t i = 0; i < mask_pos; ++i) {
|
|
||||||
x[i] = std::exp(x[i] - maxval);
|
|
||||||
sum += x[i];
|
|
||||||
}
|
|
||||||
const float scale = 1.0f / sum;
|
|
||||||
for (size_t i = 0; i < mask_pos; ++i) {
|
|
||||||
x[i] *= scale;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <size_t k>
|
|
||||||
HWY_NOINLINE std::discrete_distribution<int> SourceCreateDistribution(
|
|
||||||
std::array<float, k>& top_k, float temperature) {
|
|
||||||
// re-normalize distribution
|
|
||||||
for (size_t i = 0; i < k; ++i) {
|
|
||||||
top_k[i] = exp(log(top_k[i]) / temperature);
|
|
||||||
}
|
|
||||||
float denominator = 0.0f;
|
|
||||||
for (size_t i = 0; i < k; ++i) {
|
|
||||||
denominator += top_k[i];
|
|
||||||
}
|
|
||||||
denominator = 1.0f / denominator;
|
|
||||||
MulByConst(denominator, top_k.data(), k);
|
|
||||||
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
|
|
||||||
}
|
|
||||||
|
|
||||||
struct TestAddFrom {
|
|
||||||
template <class D>
|
|
||||||
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
|
||||||
hwy::RandomState& rng) {
|
|
||||||
using T = hn::TFromD<D>;
|
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> px =
|
|
||||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> pe =
|
|
||||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> po =
|
|
||||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
|
|
||||||
HWY_ASSERT(px && pe && po);
|
|
||||||
|
|
||||||
T* x = px.get() + misalign_a;
|
|
||||||
T* e = pe.get() + misalign_a;
|
|
||||||
T* o = po.get() + misalign_b;
|
|
||||||
|
|
||||||
for (size_t i = 0; i < count; ++i) {
|
|
||||||
x[i] = Random<T>(rng);
|
|
||||||
e[i] = x[i];
|
|
||||||
o[i] = Random<T>(rng);
|
|
||||||
}
|
|
||||||
|
|
||||||
SourceAddFrom(o, e, count);
|
|
||||||
AddFrom(o, x, count);
|
|
||||||
|
|
||||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
|
||||||
__LINE__);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TestMulBy {
|
|
||||||
template <class D>
|
|
||||||
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
|
||||||
hwy::RandomState& rng) {
|
|
||||||
using T = hn::TFromD<D>;
|
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> px =
|
|
||||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> pe =
|
|
||||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> po =
|
|
||||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
|
|
||||||
HWY_ASSERT(px && pe && po);
|
|
||||||
|
|
||||||
T* x = px.get() + misalign_a;
|
|
||||||
T* e = pe.get() + misalign_a;
|
|
||||||
T* o = po.get() + misalign_b;
|
|
||||||
|
|
||||||
for (size_t i = 0; i < count; ++i) {
|
|
||||||
x[i] = Random<T>(rng);
|
|
||||||
e[i] = x[i];
|
|
||||||
o[i] = Random<T>(rng);
|
|
||||||
}
|
|
||||||
|
|
||||||
SourceMulBy(o, e, count, count);
|
|
||||||
MulBy(o, x, count, count);
|
|
||||||
|
|
||||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
|
||||||
__LINE__);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TestMulByConstAndAdd {
|
|
||||||
template <class D>
|
|
||||||
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
|
||||||
hwy::RandomState& rng) {
|
|
||||||
using T = hn::TFromD<D>;
|
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> px =
|
|
||||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> pe =
|
|
||||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> po =
|
|
||||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
|
|
||||||
HWY_ASSERT(px && pe && po);
|
|
||||||
|
|
||||||
T* x = px.get() + misalign_a;
|
|
||||||
T* e = pe.get() + misalign_a;
|
|
||||||
T* o = po.get() + misalign_b;
|
|
||||||
|
|
||||||
for (size_t i = 0; i < count; ++i) {
|
|
||||||
x[i] = Random<T>(rng);
|
|
||||||
e[i] = x[i];
|
|
||||||
o[i] = Random<T>(rng);
|
|
||||||
}
|
|
||||||
T constant = Random<T>(rng);
|
|
||||||
|
|
||||||
SourceMulByConstAndAdd(constant, o, e, count, count);
|
|
||||||
MulByConstAndAdd(constant, o, x, count, count);
|
|
||||||
|
|
||||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
|
||||||
__LINE__);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TestMulByConst {
|
|
||||||
template <class D>
|
|
||||||
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
|
||||||
hwy::RandomState& rng) {
|
|
||||||
if (misalign_b == 0) return;
|
|
||||||
using T = hn::TFromD<D>;
|
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> px =
|
|
||||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> pe =
|
|
||||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
||||||
HWY_ASSERT(px && pe);
|
|
||||||
|
|
||||||
T* x = px.get() + misalign_a;
|
|
||||||
T* e = pe.get() + misalign_a;
|
|
||||||
|
|
||||||
for (size_t i = 0; i < count; ++i) {
|
|
||||||
x[i] = Random<T>(rng);
|
|
||||||
e[i] = x[i];
|
|
||||||
}
|
|
||||||
T constant = Random<T>(rng);
|
|
||||||
|
|
||||||
SourceMulByConst(constant, e, count, count);
|
|
||||||
MulByConst(constant, x, count, count);
|
|
||||||
|
|
||||||
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
|
||||||
__LINE__);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TestSoftmax {
|
|
||||||
template <class D>
|
|
||||||
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
|
||||||
hwy::RandomState& rng) {
|
|
||||||
if (count == 0) return; // *Softmax would assert
|
|
||||||
if (misalign_b == 0) return;
|
|
||||||
using T = hn::TFromD<D>;
|
|
||||||
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> px =
|
|
||||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
||||||
hwy::AlignedFreeUniquePtr<T[]> pe =
|
|
||||||
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
|
||||||
HWY_ASSERT(px && pe);
|
|
||||||
|
|
||||||
T* x = px.get() + misalign_a;
|
|
||||||
T* e = pe.get() + misalign_a;
|
|
||||||
|
|
||||||
for (size_t i = 0; i < count; ++i) {
|
|
||||||
x[i] = Random<T>(rng);
|
|
||||||
e[i] = x[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
SourceSoftmax(e, count, count);
|
|
||||||
Softmax(x, count, count);
|
|
||||||
|
|
||||||
T sum = 0.0f;
|
|
||||||
for (size_t i = 0; i < count; ++i) {
|
|
||||||
sum += x[i];
|
|
||||||
double rel = std::abs(x[i] - e[i]) / e[i];
|
|
||||||
ASSERT_LT(rel, 1e-6) << "Mismatch on coordinate " << i << " out of "
|
|
||||||
<< count;
|
|
||||||
}
|
|
||||||
ASSERT_NEAR(sum, 1.0, 1e-6);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <size_t k>
|
|
||||||
struct TestCreateDistribution {
|
|
||||||
void operator()(hwy::RandomState& rng) {
|
|
||||||
std::array<float, k> x;
|
|
||||||
std::array<float, k> e;
|
|
||||||
|
|
||||||
for (size_t i = 0; i < k; ++i) {
|
|
||||||
x[i] = Random<float>(rng);
|
|
||||||
e[i] = x[i];
|
|
||||||
}
|
|
||||||
const float constant = Random<float>(rng);
|
|
||||||
auto expected = SourceCreateDistribution(e, constant);
|
|
||||||
auto output = create_distribution(x, constant);
|
|
||||||
|
|
||||||
AssertEqual(expected, output, hwy::TargetName(HWY_TARGET), __FILE__,
|
|
||||||
__LINE__);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
void TestAllAddFrom() {
|
|
||||||
hn::ForPartialVectors<ForeachCountAndMisalign<TestAddFrom>>()(float());
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestAllMulBy() {
|
|
||||||
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulBy>>()(float());
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestAllMulByConst() {
|
|
||||||
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulByConst>>()(float());
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestAllMulByConstAndAdd() {
|
|
||||||
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulByConstAndAdd>>()(
|
|
||||||
float());
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestAllSoftmax() {
|
|
||||||
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmax>>()(float());
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestAllCreateDistribution() {
|
|
||||||
TestCreateDistribution<2048>();
|
|
||||||
TestCreateDistribution<5000>();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename MatT, size_t kOuter, size_t kInner>
|
template <typename MatT, size_t kOuter, size_t kInner>
|
||||||
CompressedArray<MatT, kOuter * kInner> GenerateMat(size_t offset,
|
CompressedArray<MatT, kOuter * kInner> GenerateMat(size_t offset,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
|
|
@ -500,6 +200,28 @@ hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename MatT>
|
||||||
|
void AssertClose(const MatT* HWY_RESTRICT expected,
|
||||||
|
const MatT* HWY_RESTRICT actual, size_t num) {
|
||||||
|
for (size_t idx = 0; idx < num; idx++) {
|
||||||
|
const double expected_value = hwy::ConvertScalarTo<double>(expected[idx]);
|
||||||
|
const double actual_value = hwy::ConvertScalarTo<double>(actual[idx]);
|
||||||
|
|
||||||
|
const double magnitude = std::abs(expected_value);
|
||||||
|
|
||||||
|
const double tolerance =
|
||||||
|
256.0 * hwy::ConvertScalarTo<double>(hwy::Epsilon<MatT>()) *
|
||||||
|
HWY_MAX(magnitude, 1.0);
|
||||||
|
|
||||||
|
if (!(expected_value - tolerance <= actual_value &&
|
||||||
|
actual_value <= expected_value + tolerance)) {
|
||||||
|
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx,
|
||||||
|
expected_value, idx, actual_value);
|
||||||
|
HWY_ASSERT(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <size_t length>
|
template <size_t length>
|
||||||
void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
|
void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
|
||||||
const hwy::AlignedFreeUniquePtr<float[]>& b) {
|
const hwy::AlignedFreeUniquePtr<float[]>& b) {
|
||||||
|
|
@ -715,23 +437,6 @@ void TestTwoOfsMatVecAddLoop() {
|
||||||
AssertClose<kOuter>(actual_out1, expected_out1);
|
AssertClose<kOuter>(actual_out1, expected_out1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestSigmoid() {
|
|
||||||
std::vector<float> values;
|
|
||||||
for (int i = -150; i <= 150; ++i) {
|
|
||||||
values.push_back(.1f * i);
|
|
||||||
}
|
|
||||||
std::vector<float> result = values;
|
|
||||||
Sigmoid(result.data(), result.size());
|
|
||||||
|
|
||||||
for (size_t i = 0; i < values.size(); i++) {
|
|
||||||
const float max_error = 0.00007;
|
|
||||||
float value = values[i];
|
|
||||||
float approx = result[i];
|
|
||||||
float expected = (1 / (1 + std::exp(-values[i])));
|
|
||||||
EXPECT_NEAR(approx, expected, max_error) << "Input: " << value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
@ -740,21 +445,12 @@ HWY_AFTER_NAMESPACE();
|
||||||
#if HWY_ONCE
|
#if HWY_ONCE
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
HWY_BEFORE_TEST(OpsTest);
|
HWY_BEFORE_TEST(MatmulTest);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom);
|
HWY_EXPORT_AND_TEST_P(MatmulTest, TestAllTiledBatchMatMul);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy);
|
HWY_EXPORT_AND_TEST_P(MatmulTest, TestMatVecAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
|
HWY_EXPORT_AND_TEST_P(MatmulTest, TestTwoMatVecAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
HWY_EXPORT_AND_TEST_P(MatmulTest, TestTwoOfsMatVecAddLoop);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllTiledBatchMatMul);
|
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
|
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
|
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);
|
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
|
|
||||||
#ifdef HWY_AFTER_TEST
|
|
||||||
HWY_AFTER_TEST();
|
HWY_AFTER_TEST();
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
@ -0,0 +1,652 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Include guard for non-SIMD code.
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
#include <cmath>
|
||||||
|
#include <random>
|
||||||
|
#include <type_traits> // std::enable_if_t
|
||||||
|
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
#include "hwy/detect_targets.h"
|
||||||
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_
|
||||||
|
|
||||||
|
// Include guard for (potentially) SIMD code.
|
||||||
|
#if defined(THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||||
|
#ifdef THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
|
||||||
|
#undef THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
|
||||||
|
#else
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_OPS_TOGGLE
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "hwy/contrib/algo/transform-inl.h"
|
||||||
|
#include "hwy/contrib/dot/dot-inl.h"
|
||||||
|
#include "hwy/contrib/math/math-inl.h"
|
||||||
|
|
||||||
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
namespace gcpp {
|
||||||
|
namespace HWY_NAMESPACE {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
|
template <typename To, typename From>
|
||||||
|
HWY_INLINE constexpr std::enable_if_t<
|
||||||
|
std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To>
|
||||||
|
StaticCast(From from) noexcept {
|
||||||
|
if constexpr (std::is_unsigned_v<From> && std::is_floating_point_v<To>)
|
||||||
|
return static_cast<To>(
|
||||||
|
static_cast<hwy::SignedFromSize<sizeof(From)>>(from));
|
||||||
|
else
|
||||||
|
return static_cast<To>(from);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class D, HWY_IF_F32_D(D)>
|
||||||
|
static HWY_INLINE hn::Vec<D> Gelu(D d, hn::Vec<D> v) {
|
||||||
|
const hn::Vec<D> kMul = hn::Set(d, 0.044715f);
|
||||||
|
const hn::Vec<D> kSqrt2OverPi = hn::Set(d, 0.797884560804236f);
|
||||||
|
const hn::Vec<D> kHalf = hn::Set(d, 0.5f);
|
||||||
|
|
||||||
|
// tanh approximation matches training.
|
||||||
|
const hn::Vec<D> v3 = hn::Mul(hn::Mul(v, v), v);
|
||||||
|
const hn::Vec<D> arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v));
|
||||||
|
// 0.5 * (1 + tan) = MulAdd(0.5, tan, 0.5).
|
||||||
|
const hn::Vec<D> cdf = hn::MulAdd(kHalf, hn::Tanh(d, arg), kHalf);
|
||||||
|
return hn::Mul(v, cdf);
|
||||||
|
}
|
||||||
|
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void Gelu(float* HWY_RESTRICT x,
|
||||||
|
size_t size) {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
using D = hn::ScalableTag<float>;
|
||||||
|
hn::Transform(D(), x, size,
|
||||||
|
[](D d, hn::Vec<D> v) HWY_ATTR { return Gelu(d, v); });
|
||||||
|
}
|
||||||
|
|
||||||
|
// out[i] = BF(mul[i] * Gelu(gelu_in[i]))
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16(
|
||||||
|
const float* HWY_RESTRICT gelu_in, const float* HWY_RESTRICT mul,
|
||||||
|
hwy::bfloat16_t* HWY_RESTRICT out, size_t size) {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf;
|
||||||
|
const size_t NF = hn::Lanes(df);
|
||||||
|
using VF = hn::Vec<decltype(df)>;
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
|
if (size >= 2 * NF) {
|
||||||
|
for (; i <= size - 2 * NF; i += 2 * NF) {
|
||||||
|
const VF mul0 = hn::LoadU(df, mul + i);
|
||||||
|
const VF mul1 = hn::LoadU(df, mul + i + NF);
|
||||||
|
const VF g0 = hn::Mul(mul0, Gelu(df, hn::LoadU(df, gelu_in + i)));
|
||||||
|
const VF g1 = hn::Mul(mul1, Gelu(df, hn::LoadU(df, gelu_in + i + NF)));
|
||||||
|
const hn::Vec<decltype(dbf)> bf = hn::OrderedDemote2To(dbf, g0, g1);
|
||||||
|
hn::StoreU(bf, dbf, out + i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (i != size) {
|
||||||
|
const size_t remaining = size - i;
|
||||||
|
const VF mul0 = hn::LoadN(df, mul + i, remaining);
|
||||||
|
const VF g0 =
|
||||||
|
hn::Mul(mul0, Gelu(df, hn::LoadN(df, gelu_in + i, remaining)));
|
||||||
|
const hn::Half<decltype(dbf)> dbfh;
|
||||||
|
const hn::Vec<decltype(dbfh)> bfh = hn::DemoteTo(dbfh, g0);
|
||||||
|
hn::StoreN(bfh, dbfh, out + i, remaining);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class D, HWY_IF_F32_D(D)>
|
||||||
|
static HWY_INLINE hn::Vec<D> Sigmoid(D d, hn::Vec<D> v) {
|
||||||
|
using VF = hn::Vec<D>;
|
||||||
|
// Chebyshev polynomial coefficients for rational approximation
|
||||||
|
const VF c0 = hn::Set(d, 0.00949107017368078f);
|
||||||
|
const VF c1 = hn::Set(d, 0.0654858946800232f);
|
||||||
|
const VF c2 = hn::Set(d, 0.231547489762306f - 0.00949107017368078f);
|
||||||
|
const VF c3 = hn::Set(d, 0.530778527259827f);
|
||||||
|
const VF c4 = hn::Set(d, 0.855334937572479f);
|
||||||
|
const VF c5 = hn::Set(d, 0.500000894069672f);
|
||||||
|
|
||||||
|
const VF d0 = hn::Set(d, 0.130970627069473f);
|
||||||
|
const VF d1 = hn::Set(d, 3.99615288415589e-07f);
|
||||||
|
const VF d2 = hn::Set(d, 1.06155431270599f - 0.130970627069473f);
|
||||||
|
const VF d3 = hn::Set(d, 1.35144250634767e-06f);
|
||||||
|
const VF d4 = hn::Set(d, 1);
|
||||||
|
|
||||||
|
// The approximation works in range -12..12, but the input value is clamped
|
||||||
|
// in -11.5..11.5 since the approximation slightly overshoots after that.
|
||||||
|
// The function is nearly 0 for input values below -11.5 and nearly 1 for
|
||||||
|
// input values above 11.5.
|
||||||
|
const VF invtwelve = hn::Set(d, 1.0f / 12.0f);
|
||||||
|
const VF lo = hn::Set(d, -11.5f);
|
||||||
|
const VF hi = hn::Set(d, 11.5f);
|
||||||
|
|
||||||
|
VF f = hn::Clamp(v, lo, hi);
|
||||||
|
f = hn::Mul(f, invtwelve);
|
||||||
|
VF f2 = hn::Add(f, f);
|
||||||
|
|
||||||
|
VF a1 = hn::MulAdd(f2, c0, c1);
|
||||||
|
VF a2 = hn::MulAdd(f2, a1, c2);
|
||||||
|
VF a3 = hn::Sub(hn::MulAdd(f2, a2, c3), a1);
|
||||||
|
VF a4 = hn::Sub(hn::MulAdd(f2, a3, c4), a2);
|
||||||
|
VF f0 = hn::Sub(hn::MulAdd(f, a4, c5), a3);
|
||||||
|
|
||||||
|
VF b1 = hn::MulAdd(f2, d0, d1);
|
||||||
|
VF b2 = hn::MulAdd(f2, b1, d2);
|
||||||
|
VF b3 = hn::Sub(hn::MulAdd(f2, b2, d3), b1);
|
||||||
|
VF f1 = hn::Sub(hn::MulAdd(f, b3, d4), b2);
|
||||||
|
|
||||||
|
return hn::Div(f0, f1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sigmoid using the logistic function 1 / (1 + exp(-x[i]))
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void Sigmoid(float* HWY_RESTRICT x,
|
||||||
|
size_t size) {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
using D = hn::ScalableTag<float>;
|
||||||
|
hn::Transform(D(), x, size,
|
||||||
|
[](D d, hn::Vec<D> v) HWY_ATTR { return Sigmoid(d, v); });
|
||||||
|
}
|
||||||
|
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
|
||||||
|
const float* HWY_RESTRICT b,
|
||||||
|
size_t size) {
|
||||||
|
const hn::ScalableTag<float> d;
|
||||||
|
HWY_DASSERT(size >= hn::Lanes(d));
|
||||||
|
HWY_DASSERT(size % hn::Lanes(d) == 0);
|
||||||
|
constexpr int kAssumptions =
|
||||||
|
hn::Dot::kAtLeastOneVector | hn::Dot::kMultipleOfVector;
|
||||||
|
return hn::Dot::Compute<kAssumptions>(d, a, b, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT.
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2(
|
||||||
|
const float* HWY_RESTRICT a, size_t size) {
|
||||||
|
const hn::ScalableTag<float> d;
|
||||||
|
using V = hn::Vec<decltype(d)>;
|
||||||
|
const size_t N = hn::Lanes(d);
|
||||||
|
HWY_DASSERT(size >= 2 * N);
|
||||||
|
HWY_DASSERT(size % (2 * N) == 0);
|
||||||
|
|
||||||
|
V sum0 = hn::Zero(d);
|
||||||
|
V sum1 = hn::Zero(d);
|
||||||
|
for (size_t i = 0; i <= size - 2 * N; i += 2 * N) {
|
||||||
|
const V a0 = hn::LoadU(d, a + i);
|
||||||
|
sum0 = hn::MulAdd(a0, a0, sum0);
|
||||||
|
const V a1 = hn::LoadU(d, a + i + N);
|
||||||
|
sum1 = hn::MulAdd(a1, a1, sum1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return hn::ReduceSum(d, hn::Add(sum0, sum1));
|
||||||
|
}
|
||||||
|
|
||||||
|
// float, float -> float; simple loop.
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
|
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
|
||||||
|
float* HWY_RESTRICT out, size_t size) {
|
||||||
|
constexpr float kEps = 1e-6f;
|
||||||
|
float ss = SquaredL2(x, size);
|
||||||
|
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
|
||||||
|
for (size_t j = 0; j < size; j++) {
|
||||||
|
// Note 1.0f centering here
|
||||||
|
out[j] = (1.0f + weight[j]) * (ss * x[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// x=f, w=bf16 -> out=f
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
|
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
||||||
|
float* HWY_RESTRICT out, size_t size) {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
|
constexpr float kEps = 1e-6f;
|
||||||
|
constexpr size_t kUnrollSize = 2;
|
||||||
|
|
||||||
|
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
||||||
|
const hn::Repartition<float, decltype(dbf)> df32;
|
||||||
|
const size_t N32 = hn::Lanes(df32);
|
||||||
|
|
||||||
|
const float ss = SquaredL2(x, size);
|
||||||
|
const auto vss =
|
||||||
|
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
||||||
|
|
||||||
|
HWY_DASSERT(size % (kUnrollSize * MaxLanes(df32)) == 0);
|
||||||
|
for (size_t i = 0; i < size; i += kUnrollSize * N32) {
|
||||||
|
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
|
||||||
|
const auto w0 = hn::PromoteLowerTo(df32, w16);
|
||||||
|
const auto w1 = hn::PromoteUpperTo(df32, w16);
|
||||||
|
const auto m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
|
||||||
|
const auto m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
|
||||||
|
|
||||||
|
// (1+weight) * m = m + weight*m = one FMA.
|
||||||
|
hn::StoreU(hn::MulAdd(m0, w0, m0), df32, out + i);
|
||||||
|
hn::StoreU(hn::MulAdd(m1, w1, m1), df32, out + i + N32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// float -> float; simple loop.
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||||
|
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
|
||||||
|
constexpr float kEps = 1e-6f;
|
||||||
|
float ss = SquaredL2(inout, size);
|
||||||
|
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
|
||||||
|
for (size_t j = 0; j < size; j++) {
|
||||||
|
// Note 1.0f centering here
|
||||||
|
inout[j] = (1.0f + weight[j]) * (ss * inout[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// w=bf16 -> f
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
||||||
|
const hwy::bfloat16_t* HWY_RESTRICT weight, float* HWY_RESTRICT inout,
|
||||||
|
const size_t size) {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
||||||
|
const hn::Repartition<float, decltype(dbf)> df32;
|
||||||
|
using VF = hn::Vec<decltype(df32)>;
|
||||||
|
const size_t N32 = hn::Lanes(df32);
|
||||||
|
|
||||||
|
constexpr float kEps = 1e-6f;
|
||||||
|
const float ss = SquaredL2(inout, size);
|
||||||
|
const VF vss =
|
||||||
|
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
||||||
|
|
||||||
|
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||||
|
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||||
|
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
|
||||||
|
const VF w0 = hn::PromoteLowerTo(df32, w16);
|
||||||
|
const VF w1 = hn::PromoteUpperTo(df32, w16);
|
||||||
|
const VF m0 = hn::Mul(vss, hn::LoadU(df32, inout + i));
|
||||||
|
const VF m1 = hn::Mul(vss, hn::LoadU(df32, inout + i + N32));
|
||||||
|
// (1+weight) * m = m + weight*m = one FMA.
|
||||||
|
hn::StoreU(hn::MulAdd(m0, w0, m0), df32, inout + i);
|
||||||
|
hn::StoreU(hn::MulAdd(m1, w1, m1), df32, inout + i + N32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// f, f -> bf
|
||||||
|
// TODO(janwas): consider generic function with adapter for loading bf16/f32
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
|
const float* HWY_RESTRICT x, const float* HWY_RESTRICT weight,
|
||||||
|
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
||||||
|
const hn::Repartition<float, decltype(dbf)> df32;
|
||||||
|
using VF = hn::Vec<decltype(df32)>;
|
||||||
|
const size_t N32 = hn::Lanes(df32);
|
||||||
|
|
||||||
|
constexpr float kEps = 1e-6f;
|
||||||
|
const float ss = SquaredL2(x, size);
|
||||||
|
const VF vss =
|
||||||
|
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
||||||
|
|
||||||
|
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||||
|
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||||
|
const VF w0 = hn::LoadU(df32, weight + i);
|
||||||
|
const VF w1 = hn::LoadU(df32, weight + i + N32);
|
||||||
|
const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
|
||||||
|
const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
|
||||||
|
// (1+weight) * m = m + weight*m = one FMA.
|
||||||
|
const VF out0 = hn::MulAdd(m0, w0, m0);
|
||||||
|
const VF out1 = hn::MulAdd(m1, w1, m1);
|
||||||
|
hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// x=f, w=bf16 -> bf16 to enable W16A16 MatVec.
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
||||||
|
const float* HWY_RESTRICT x, const hwy::bfloat16_t* HWY_RESTRICT weight,
|
||||||
|
hwy::bfloat16_t* HWY_RESTRICT out, const size_t size) {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
const hn::ScalableTag<hwy::bfloat16_t> dbf;
|
||||||
|
const hn::Repartition<float, decltype(dbf)> df32;
|
||||||
|
using VF = hn::Vec<decltype(df32)>;
|
||||||
|
const size_t N32 = hn::Lanes(df32);
|
||||||
|
|
||||||
|
constexpr float kEps = 1e-6f;
|
||||||
|
const float ss = SquaredL2(x, size);
|
||||||
|
const VF vss =
|
||||||
|
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps));
|
||||||
|
|
||||||
|
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||||
|
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||||
|
const hn::Vec<decltype(dbf)> w16 = hn::LoadU(dbf, weight + i);
|
||||||
|
const VF w0 = hn::PromoteLowerTo(df32, w16);
|
||||||
|
const VF w1 = hn::PromoteUpperTo(df32, w16);
|
||||||
|
const VF m0 = hn::Mul(vss, hn::LoadU(df32, x + i));
|
||||||
|
const VF m1 = hn::Mul(vss, hn::LoadU(df32, x + i + N32));
|
||||||
|
// (1+weight) * m = m + weight*m = one FMA.
|
||||||
|
const VF out0 = hn::MulAdd(m0, w0, m0);
|
||||||
|
const VF out1 = hn::MulAdd(m1, w1, m1);
|
||||||
|
hn::StoreU(hn::OrderedDemote2To(dbf, out0, out1), dbf, out + i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
||||||
|
float* HWY_RESTRICT x, size_t dim_model, size_t pos) {
|
||||||
|
const size_t num_timescales = dim_model / 2;
|
||||||
|
const float log_timescale_increment =
|
||||||
|
logf(10000.0f) /
|
||||||
|
(num_timescales != 0 ? StaticCast<float>(num_timescales - 1) : 1.0f);
|
||||||
|
for (size_t dim = 0; dim < num_timescales; ++dim) {
|
||||||
|
const float inv_timescale =
|
||||||
|
expf(StaticCast<float>(dim) * -log_timescale_increment);
|
||||||
|
x[dim] += sinf(StaticCast<float>(pos) * inv_timescale);
|
||||||
|
x[num_timescales + dim] += cosf(StaticCast<float>(pos) * inv_timescale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* RoPE as in Rotary Position Embeddings from the RoFormer paper
|
||||||
|
(https://arxiv.org/abs/2104.09864v5). The query and key vectors are rotated
|
||||||
|
as a function of their absolute position using the rotation matrix R before
|
||||||
|
the self-attention operation. R is a d x d matrix.
|
||||||
|
|
||||||
|
R = cos(m*theta_1) -sin(m*theta_1) ... 0 0
|
||||||
|
sin(m*theta_1) cos(m*theta_1)
|
||||||
|
0 0 ... 0 0
|
||||||
|
0 0 ... 0 0
|
||||||
|
...
|
||||||
|
0 0 ... cos(m*theta_{d/2}) sin(m*theta_{d/2})
|
||||||
|
0 0 ... sin(m*theta_{d/2}) cos(m*theta_{d/2})
|
||||||
|
|
||||||
|
Here theta_i = 10000^(-2(i-1)/d), where d is the dimension of the vector and
|
||||||
|
i is the ith index of the vector.
|
||||||
|
|
||||||
|
Applying the rotation matrix R to a vector v is equivalent to rotating every
|
||||||
|
consecutive pair of dimensions of v i.e. v_{2i} and v_{2i+1} by an angle
|
||||||
|
m*theta_i. However in the Gemma implementation we choose to rotate
|
||||||
|
the pairs of dimensions v_{i} and v_{i + d//2} instead.
|
||||||
|
|
||||||
|
pos parameter is deliberately an int because in the backward pass we
|
||||||
|
call this with negative values (for the VJP calculation we need the transpose
|
||||||
|
of this rotation matrix which is simply the same matrix with -pos parameter)
|
||||||
|
*/
|
||||||
|
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
|
||||||
|
size_t dim_qkv, int pos) {
|
||||||
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
|
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||||
|
const float freq_exponents =
|
||||||
|
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
|
||||||
|
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
|
||||||
|
const float timescale = powf(10000.0f, freq_exponents);
|
||||||
|
const float theta = StaticCast<float>(pos) / timescale;
|
||||||
|
const float cos_val = cosf(theta);
|
||||||
|
const float sin_val = sinf(theta);
|
||||||
|
const float x0 = x[dim];
|
||||||
|
const float x1 = x[dim + half_dim_qkv];
|
||||||
|
x[dim] = x0 * cos_val - x1 * sin_val;
|
||||||
|
x[dim + half_dim_qkv] = x0 * sin_val + x1 * cos_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
|
||||||
|
float* HWY_RESTRICT x,
|
||||||
|
size_t dim_qkv,
|
||||||
|
int pos) {
|
||||||
|
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||||
|
const size_t half_dim_qkv = dim_qkv / 2;
|
||||||
|
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||||
|
const float freq_exponents =
|
||||||
|
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
|
||||||
|
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
|
||||||
|
const float timescale = powf(10000.0f, freq_exponents);
|
||||||
|
const float theta = StaticCast<float>(pos) / timescale;
|
||||||
|
const float cos_val = cosf(theta);
|
||||||
|
const float sin_val = sinf(theta);
|
||||||
|
const float x0 = x[dim];
|
||||||
|
const float x1 = x[dim + half_dim_qkv];
|
||||||
|
x[dim] = mul * (x0 * cos_val - x1 * sin_val);
|
||||||
|
x[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
|
||||||
|
const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size) {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
using D = hn::ScalableTag<float>;
|
||||||
|
using V = hn::Vec<D>;
|
||||||
|
|
||||||
|
hn::Transform1(D(), x, size, other,
|
||||||
|
[](const auto d, const V x, const V other)
|
||||||
|
HWY_ATTR { return hn::Add(x, other); });
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple loops unless/until batch sizes are large enough to parallelize.
|
||||||
|
template <typename WeightT, typename OutT>
|
||||||
|
void RMSNormBatched(size_t num_tokens, const float* activations,
|
||||||
|
const WeightT* weights, OutT* out, const size_t model_dim) {
|
||||||
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
RMSNorm(activations + token_idx * model_dim, weights,
|
||||||
|
out + token_idx * model_dim, model_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: pass RowVectorBatch argument.
|
||||||
|
template <typename WeightT, typename InOutT>
|
||||||
|
void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights,
|
||||||
|
InOutT* inout, const size_t model_dim) {
|
||||||
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
RMSNormInplace(weights, inout + token_idx * model_dim, model_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static HWY_INLINE void AddFromBatched(size_t num_tokens, const float* other,
|
||||||
|
float* x, const size_t model_dim) {
|
||||||
|
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
AddFrom(other + token_idx * model_dim, x + token_idx * model_dim,
|
||||||
|
model_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
|
||||||
|
float* HWY_RESTRICT x, const size_t size,
|
||||||
|
const size_t max_pos) {
|
||||||
|
HWY_DASSERT(max_pos <= size);
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
using D = hn::ScalableTag<float>;
|
||||||
|
using V = hn::Vec<D>;
|
||||||
|
|
||||||
|
hn::Transform1(D(), x, max_pos, other,
|
||||||
|
[](const auto d, const V x, const V other)
|
||||||
|
HWY_ATTR { return hn::Mul(x, other); });
|
||||||
|
}
|
||||||
|
|
||||||
|
static HWY_INLINE HWY_MAYBE_UNUSED void MulBy(const float* HWY_RESTRICT other,
|
||||||
|
float* HWY_RESTRICT x,
|
||||||
|
const size_t size) {
|
||||||
|
return MulBy(other, x, size, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
static HWY_NOINLINE void MulByConst(const float c, float* HWY_RESTRICT x,
|
||||||
|
const size_t size, const size_t max_pos) {
|
||||||
|
HWY_DASSERT(max_pos <= size);
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
using D = hn::ScalableTag<float>;
|
||||||
|
using V = hn::Vec<D>;
|
||||||
|
hn::Transform(D(), x, max_pos, [c](const auto d, const V x) HWY_ATTR {
|
||||||
|
return hn::Mul(x, hn::Set(d, c));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConst(const float c,
|
||||||
|
float* HWY_RESTRICT x,
|
||||||
|
const size_t size) {
|
||||||
|
MulByConst(c, x, size, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
static HWY_NOINLINE void MulByConstAndAdd(const float c,
|
||||||
|
const float* HWY_RESTRICT x,
|
||||||
|
float* HWY_RESTRICT out,
|
||||||
|
const size_t size,
|
||||||
|
const size_t max_pos) {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
using D = hn::ScalableTag<float>;
|
||||||
|
using V = hn::Vec<D>;
|
||||||
|
hn::Transform1(D(), out, max_pos, x,
|
||||||
|
[c](const auto d, const V v_out, const V v_x) HWY_ATTR {
|
||||||
|
return hn::MulAdd(v_x, hn::Set(d, c), v_out);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||||
|
float c, const float* HWY_RESTRICT x, float* HWY_RESTRICT out,
|
||||||
|
size_t size) {
|
||||||
|
MulByConstAndAdd(c, x, out, size, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
|
const size_t mask_pos) {
|
||||||
|
HWY_DASSERT(size != 0);
|
||||||
|
HWY_DASSERT(mask_pos <= size);
|
||||||
|
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
using D = hn::ScalableTag<float>;
|
||||||
|
using V = hn::Vec<D>;
|
||||||
|
const D d;
|
||||||
|
|
||||||
|
const V vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||||
|
V vmax = vmin;
|
||||||
|
V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly
|
||||||
|
Foreach(d, x, mask_pos, vmin, [pmax](const auto d, const V value) HWY_ATTR {
|
||||||
|
*pmax = hn::Max(*pmax, value);
|
||||||
|
});
|
||||||
|
vmax = hn::MaxOfLanes(d, vmax);
|
||||||
|
|
||||||
|
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||||
|
hn::Transform(d, x, mask_pos, [pmax](const auto d, const V value) HWY_ATTR {
|
||||||
|
#if HWY_TARGET & HWY_ALL_SVE
|
||||||
|
// Temporary workaround for buggy SVE codegen: avoid inlined
|
||||||
|
// Exp().
|
||||||
|
return hn::CallExp(d, hn::Sub(value, *pmax));
|
||||||
|
#else
|
||||||
|
return hn::Exp(d, hn::Sub(value, *pmax));
|
||||||
|
#endif
|
||||||
|
});
|
||||||
|
|
||||||
|
V sum = hn::Zero(d);
|
||||||
|
V* psum = ∑
|
||||||
|
Foreach(d, x, mask_pos, sum, [psum](const auto d, const V value) HWY_ATTR {
|
||||||
|
*psum = hn::Add(*psum, value);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Normalize to probability distribution
|
||||||
|
const float mul = 1.0f / hn::ReduceSum(d, sum);
|
||||||
|
MulByConst(mul, x, size, mask_pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
static HWY_INLINE HWY_MAYBE_UNUSED void Softmax(float* HWY_RESTRICT x,
|
||||||
|
const size_t size) {
|
||||||
|
Softmax(x, size, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x,
|
||||||
|
const size_t size,
|
||||||
|
const size_t max_pos) {
|
||||||
|
HWY_DASSERT(max_pos <= size);
|
||||||
|
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
using D = hn::ScalableTag<float>;
|
||||||
|
using V = hn::Vec<D>;
|
||||||
|
|
||||||
|
const float inv_cap = 1.0f / cap;
|
||||||
|
|
||||||
|
hn::Transform(D(), x, max_pos, [cap, inv_cap](D d, V v) HWY_ATTR {
|
||||||
|
return hn::Mul(hn::Set(d, cap),
|
||||||
|
hn::Tanh(d, hn::Mul(v, hn::Set(d, inv_cap))));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap,
|
||||||
|
float* HWY_RESTRICT x,
|
||||||
|
const size_t size) {
|
||||||
|
LogitsSoftCap(cap, x, size, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED size_t
|
||||||
|
SampleArgmax(const float* probabilities, size_t vocab_size) {
|
||||||
|
size_t max_index = 0;
|
||||||
|
float max_prob = probabilities[0];
|
||||||
|
for (size_t i = 1; i < vocab_size; ++i) {
|
||||||
|
if (probabilities[i] > max_prob) {
|
||||||
|
max_index = i;
|
||||||
|
max_prob = probabilities[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return max_index;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t k>
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED std::discrete_distribution<int>
|
||||||
|
create_distribution(std::array<float, k>& top_k, float temperature) {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
using D = hn::ScalableTag<float>;
|
||||||
|
|
||||||
|
// re-normalize distribution
|
||||||
|
const float temperature_inv = 1.0f / temperature;
|
||||||
|
hn::Transform(D(), top_k.data(), top_k.size(),
|
||||||
|
[temperature_inv](D d, hn::Vec<D> v) HWY_ATTR {
|
||||||
|
return hn::Exp(
|
||||||
|
d, hn::Mul(hn::Log(d, v), hn::Set(d, temperature_inv)));
|
||||||
|
});
|
||||||
|
|
||||||
|
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t k, typename TAcceptToken>
|
||||||
|
static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
||||||
|
const float* HWY_RESTRICT probabilities, size_t vocab_size,
|
||||||
|
std::mt19937& gen, float temperature, TAcceptToken& accept_token) {
|
||||||
|
static_assert(k != 0, "");
|
||||||
|
// TODO: Optimize, potentially using new VQSort PartialSort.
|
||||||
|
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1]
|
||||||
|
std::array<int, k> indices{};
|
||||||
|
for (size_t i = 0; i < vocab_size; ++i) {
|
||||||
|
if (probabilities[i] < top_k[k - 1] &&
|
||||||
|
(!accept_token || accept_token(StaticCast<int>(i), probabilities[i]))) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (size_t j = 0; j < k; ++j) {
|
||||||
|
if (probabilities[i] > top_k[j] &&
|
||||||
|
(!accept_token ||
|
||||||
|
accept_token(StaticCast<int>(i), probabilities[i]))) {
|
||||||
|
// shift elements by 1, insert the new value, move on to next value
|
||||||
|
for (size_t idx = k - 1; idx > j; --idx) {
|
||||||
|
top_k[idx] = top_k[idx - 1];
|
||||||
|
indices[idx] = indices[idx - 1];
|
||||||
|
}
|
||||||
|
top_k[j] = probabilities[i];
|
||||||
|
indices[j] = StaticCast<int>(i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return indices[create_distribution<k>(top_k, temperature)(gen)];
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
} // namespace HWY_NAMESPACE
|
||||||
|
} // namespace gcpp
|
||||||
|
HWY_AFTER_NAMESPACE();
|
||||||
|
|
||||||
|
#endif // NOLINT
|
||||||
|
|
@ -0,0 +1,384 @@
|
||||||
|
// Copyright 2023 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// OrderedDemote2To is not supported by HWY_SCALAR.
|
||||||
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
|
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <array>
|
||||||
|
#include <cmath>
|
||||||
|
#include <random>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "hwy/aligned_allocator.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
#define HWY_TARGET_INCLUDE "ops/ops_test.cc" // NOLINT
|
||||||
|
// clang-format on
|
||||||
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
|
#include "hwy/highway.h"
|
||||||
|
#include "hwy/tests/test_util-inl.h"
|
||||||
|
// After highway.h
|
||||||
|
#include "ops/ops-inl.h"
|
||||||
|
|
||||||
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
namespace gcpp {
|
||||||
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
|
template <class Test>
|
||||||
|
struct ForeachCountAndMisalign {
|
||||||
|
template <typename T, class D>
|
||||||
|
HWY_NOINLINE void operator()(T /*unused*/, D d) const {
|
||||||
|
hwy::RandomState rng;
|
||||||
|
const size_t N = Lanes(d);
|
||||||
|
const size_t misalignments[3] = {0, N / 4, 3 * N / 5};
|
||||||
|
|
||||||
|
for (size_t count = 0; count < 2 * N; ++count) {
|
||||||
|
for (size_t ma : misalignments) {
|
||||||
|
for (size_t mb : misalignments) {
|
||||||
|
Test()(d, count, ma, mb, rng);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T Random(hwy::RandomState& rng) {
|
||||||
|
const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023;
|
||||||
|
const double val = (bits - 512) / 64.0;
|
||||||
|
// Clamp negative to zero for unsigned types.
|
||||||
|
return hwy::ConvertScalarTo<T>(
|
||||||
|
HWY_MAX(hwy::ConvertScalarTo<double>(hwy::LowestValue<T>()), val));
|
||||||
|
}
|
||||||
|
|
||||||
|
HWY_NOINLINE void SourceAddFrom(const float* HWY_RESTRICT other,
|
||||||
|
float* HWY_RESTRICT x, size_t size) {
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
x[i] += other[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
HWY_NOINLINE void SourceMulBy(const float* HWY_RESTRICT other,
|
||||||
|
float* HWY_RESTRICT x, size_t size,
|
||||||
|
size_t max_pos) {
|
||||||
|
HWY_DASSERT(max_pos <= size);
|
||||||
|
for (size_t i = 0; i < max_pos; ++i) {
|
||||||
|
x[i] *= other[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
HWY_NOINLINE void SourceMulByConst(float c, float* HWY_RESTRICT x, size_t size,
|
||||||
|
size_t max_pos) {
|
||||||
|
for (size_t i = 0; i < max_pos; ++i) {
|
||||||
|
x[i] *= c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
HWY_NOINLINE void SourceMulByConstAndAdd(float c, const float* HWY_RESTRICT x,
|
||||||
|
float* HWY_RESTRICT out, size_t size,
|
||||||
|
size_t max_pos) {
|
||||||
|
for (size_t i = 0; i < max_pos; ++i) {
|
||||||
|
out[i] += x[i] * c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
HWY_NOINLINE void SourceSoftmax(float* HWY_RESTRICT x, size_t size,
|
||||||
|
size_t mask_pos) {
|
||||||
|
HWY_DASSERT(size != 0);
|
||||||
|
HWY_DASSERT(mask_pos <= size);
|
||||||
|
float sum = 0.0;
|
||||||
|
const float maxval = *std::max_element(x, x + mask_pos);
|
||||||
|
for (size_t i = 0; i < mask_pos; ++i) {
|
||||||
|
x[i] = std::exp(x[i] - maxval);
|
||||||
|
sum += x[i];
|
||||||
|
}
|
||||||
|
const float scale = 1.0f / sum;
|
||||||
|
for (size_t i = 0; i < mask_pos; ++i) {
|
||||||
|
x[i] *= scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t k>
|
||||||
|
HWY_NOINLINE std::discrete_distribution<int> SourceCreateDistribution(
|
||||||
|
std::array<float, k>& top_k, float temperature) {
|
||||||
|
// re-normalize distribution
|
||||||
|
for (size_t i = 0; i < k; ++i) {
|
||||||
|
top_k[i] = exp(log(top_k[i]) / temperature);
|
||||||
|
}
|
||||||
|
float denominator = 0.0f;
|
||||||
|
for (size_t i = 0; i < k; ++i) {
|
||||||
|
denominator += top_k[i];
|
||||||
|
}
|
||||||
|
denominator = 1.0f / denominator;
|
||||||
|
MulByConst(denominator, top_k.data(), k);
|
||||||
|
return std::discrete_distribution<int>(std::begin(top_k), std::end(top_k));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TestAddFrom {
|
||||||
|
template <class D>
|
||||||
|
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
||||||
|
hwy::RandomState& rng) {
|
||||||
|
using T = hn::TFromD<D>;
|
||||||
|
|
||||||
|
hwy::AlignedFreeUniquePtr<T[]> px =
|
||||||
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
||||||
|
hwy::AlignedFreeUniquePtr<T[]> pe =
|
||||||
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
||||||
|
hwy::AlignedFreeUniquePtr<T[]> po =
|
||||||
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
|
||||||
|
HWY_ASSERT(px && pe && po);
|
||||||
|
|
||||||
|
T* x = px.get() + misalign_a;
|
||||||
|
T* e = pe.get() + misalign_a;
|
||||||
|
T* o = po.get() + misalign_b;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
x[i] = Random<T>(rng);
|
||||||
|
e[i] = x[i];
|
||||||
|
o[i] = Random<T>(rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
SourceAddFrom(o, e, count);
|
||||||
|
AddFrom(o, x, count);
|
||||||
|
|
||||||
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||||
|
__LINE__);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TestMulBy {
|
||||||
|
template <class D>
|
||||||
|
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
||||||
|
hwy::RandomState& rng) {
|
||||||
|
using T = hn::TFromD<D>;
|
||||||
|
|
||||||
|
hwy::AlignedFreeUniquePtr<T[]> px =
|
||||||
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
||||||
|
hwy::AlignedFreeUniquePtr<T[]> pe =
|
||||||
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
||||||
|
hwy::AlignedFreeUniquePtr<T[]> po =
|
||||||
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
|
||||||
|
HWY_ASSERT(px && pe && po);
|
||||||
|
|
||||||
|
T* x = px.get() + misalign_a;
|
||||||
|
T* e = pe.get() + misalign_a;
|
||||||
|
T* o = po.get() + misalign_b;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
x[i] = Random<T>(rng);
|
||||||
|
e[i] = x[i];
|
||||||
|
o[i] = Random<T>(rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
SourceMulBy(o, e, count, count);
|
||||||
|
MulBy(o, x, count, count);
|
||||||
|
|
||||||
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||||
|
__LINE__);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TestMulByConstAndAdd {
|
||||||
|
template <class D>
|
||||||
|
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
||||||
|
hwy::RandomState& rng) {
|
||||||
|
using T = hn::TFromD<D>;
|
||||||
|
|
||||||
|
hwy::AlignedFreeUniquePtr<T[]> px =
|
||||||
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
||||||
|
hwy::AlignedFreeUniquePtr<T[]> pe =
|
||||||
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
||||||
|
hwy::AlignedFreeUniquePtr<T[]> po =
|
||||||
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
|
||||||
|
HWY_ASSERT(px && pe && po);
|
||||||
|
|
||||||
|
T* x = px.get() + misalign_a;
|
||||||
|
T* e = pe.get() + misalign_a;
|
||||||
|
T* o = po.get() + misalign_b;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
x[i] = Random<T>(rng);
|
||||||
|
e[i] = x[i];
|
||||||
|
o[i] = Random<T>(rng);
|
||||||
|
}
|
||||||
|
T constant = Random<T>(rng);
|
||||||
|
|
||||||
|
SourceMulByConstAndAdd(constant, o, e, count, count);
|
||||||
|
MulByConstAndAdd(constant, o, x, count, count);
|
||||||
|
|
||||||
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||||
|
__LINE__);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TestMulByConst {
|
||||||
|
template <class D>
|
||||||
|
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
||||||
|
hwy::RandomState& rng) {
|
||||||
|
if (misalign_b == 0) return;
|
||||||
|
using T = hn::TFromD<D>;
|
||||||
|
|
||||||
|
hwy::AlignedFreeUniquePtr<T[]> px =
|
||||||
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
||||||
|
hwy::AlignedFreeUniquePtr<T[]> pe =
|
||||||
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
||||||
|
HWY_ASSERT(px && pe);
|
||||||
|
|
||||||
|
T* x = px.get() + misalign_a;
|
||||||
|
T* e = pe.get() + misalign_a;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
x[i] = Random<T>(rng);
|
||||||
|
e[i] = x[i];
|
||||||
|
}
|
||||||
|
T constant = Random<T>(rng);
|
||||||
|
|
||||||
|
SourceMulByConst(constant, e, count, count);
|
||||||
|
MulByConst(constant, x, count, count);
|
||||||
|
|
||||||
|
hwy::AssertArraySimilar(e, x, count, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||||
|
__LINE__);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TestSoftmax {
|
||||||
|
template <class D>
|
||||||
|
void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
|
||||||
|
hwy::RandomState& rng) {
|
||||||
|
if (count == 0) return; // *Softmax would assert
|
||||||
|
if (misalign_b == 0) return;
|
||||||
|
using T = hn::TFromD<D>;
|
||||||
|
|
||||||
|
hwy::AlignedFreeUniquePtr<T[]> px =
|
||||||
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
||||||
|
hwy::AlignedFreeUniquePtr<T[]> pe =
|
||||||
|
hwy::AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
|
||||||
|
HWY_ASSERT(px && pe);
|
||||||
|
|
||||||
|
T* x = px.get() + misalign_a;
|
||||||
|
T* e = pe.get() + misalign_a;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
x[i] = Random<T>(rng);
|
||||||
|
e[i] = x[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
SourceSoftmax(e, count, count);
|
||||||
|
Softmax(x, count, count);
|
||||||
|
|
||||||
|
T sum = 0.0f;
|
||||||
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
sum += x[i];
|
||||||
|
double rel = std::abs(x[i] - e[i]) / e[i];
|
||||||
|
ASSERT_LT(rel, 1e-6) << "Mismatch on coordinate " << i << " out of "
|
||||||
|
<< count;
|
||||||
|
}
|
||||||
|
ASSERT_NEAR(sum, 1.0, 1e-6);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <size_t k>
|
||||||
|
struct TestCreateDistribution {
|
||||||
|
void operator()(hwy::RandomState& rng) {
|
||||||
|
std::array<float, k> x;
|
||||||
|
std::array<float, k> e;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < k; ++i) {
|
||||||
|
x[i] = Random<float>(rng);
|
||||||
|
e[i] = x[i];
|
||||||
|
}
|
||||||
|
const float constant = Random<float>(rng);
|
||||||
|
auto expected = SourceCreateDistribution(e, constant);
|
||||||
|
auto output = create_distribution(x, constant);
|
||||||
|
|
||||||
|
AssertEqual(expected, output, hwy::TargetName(HWY_TARGET), __FILE__,
|
||||||
|
__LINE__);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void TestAllAddFrom() {
|
||||||
|
hn::ForPartialVectors<ForeachCountAndMisalign<TestAddFrom>>()(float());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestAllMulBy() {
|
||||||
|
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulBy>>()(float());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestAllMulByConst() {
|
||||||
|
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulByConst>>()(float());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestAllMulByConstAndAdd() {
|
||||||
|
hn::ForPartialVectors<ForeachCountAndMisalign<TestMulByConstAndAdd>>()(
|
||||||
|
float());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestAllSoftmax() {
|
||||||
|
hn::ForPartialVectors<ForeachCountAndMisalign<TestSoftmax>>()(float());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestAllCreateDistribution() {
|
||||||
|
TestCreateDistribution<2048>();
|
||||||
|
TestCreateDistribution<5000>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestSigmoid() {
|
||||||
|
std::vector<float> values;
|
||||||
|
for (int i = -150; i <= 150; ++i) {
|
||||||
|
values.push_back(.1f * i);
|
||||||
|
}
|
||||||
|
std::vector<float> result = values;
|
||||||
|
Sigmoid(result.data(), result.size());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < values.size(); i++) {
|
||||||
|
const float max_error = 0.00007;
|
||||||
|
float value = values[i];
|
||||||
|
float approx = result[i];
|
||||||
|
float expected = (1 / (1 + std::exp(-values[i])));
|
||||||
|
EXPECT_NEAR(approx, expected, max_error) << "Input: " << value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
} // namespace HWY_NAMESPACE
|
||||||
|
} // namespace gcpp
|
||||||
|
HWY_AFTER_NAMESPACE();
|
||||||
|
|
||||||
|
#if HWY_ONCE
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
HWY_BEFORE_TEST(OpsTest);
|
||||||
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom);
|
||||||
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy);
|
||||||
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
|
||||||
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
||||||
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
||||||
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
||||||
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestSigmoid);
|
||||||
|
HWY_AFTER_TEST();
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
#endif
|
||||||
Loading…
Reference in New Issue