gemma.cpp/ops/dot_test.cc

842 lines
30 KiB
C++

// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HWY_DISABLED_TARGETS
// Exclude HWY_SCALAR due to 2x bf16 -> f32.
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#include <stddef.h>
#include <stdio.h>
#include <algorithm> // std::swap, std::sort
#include <array>
#include <cmath>
#include <random>
#include "compression/compress.h"
#include "compression/shared.h"
#include "util/allocator.h"
#include "util/test_util.h"
#include "util/threading.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/stats.h"
#include "hwy/timer.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "ops/dot_test.cc"
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/test_util-inl.h"
#include "ops/dot-inl.h"
#include "hwy/profiler.h" // also uses SIMD
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
//------------------------------------------------------------------------------
// Dot product variants
// All combinations of {*, TwoProducts} x {+, FastTwoSums, TwoSums}.
struct DotKernelNaive {
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2,
const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& /*comp0*/, VF& /*comp1*/, VF& /*comp2*/,
VF& /*comp3*/) const {
sum0 = hn::MulAdd(w0, v0, sum0);
sum1 = hn::MulAdd(w1, v1, sum1);
sum2 = hn::MulAdd(w2, v2, sum2);
sum3 = hn::MulAdd(w3, v3, sum3);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
VF& /*comp0*/) const {
sum0 = hn::MulAdd(w0, v0, sum0);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& /*comp0*/, VF& /*comp1*/, VF& /*comp2*/,
VF& /*comp3*/) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
sum0 = hn::Add(sum0, sum1);
sum2 = hn::Add(sum2, sum3);
sum0 = hn::Add(sum0, sum2);
return hn::ReduceSum(df, sum0);
}
};
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotNaive(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
const VecT* HWY_RESTRICT vec_aligned, size_t num) {
return DecompressAndCall(d, w, w_ofs, vec_aligned, num, DotKernelNaive());
}
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm: FastTwoSum.
struct DotKernelKahan {
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2,
const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
// Add compensation from last iteration, which is an approximation of the
// running error.
const VF prod0 = hn::MulAdd(w0, v0, comp0);
const VF prod1 = hn::MulAdd(w1, v1, comp1);
const VF prod2 = hn::MulAdd(w2, v2, comp2);
const VF prod3 = hn::MulAdd(w3, v3, comp3);
sum0 = FastTwoSums(df, sum0, prod0, comp0);
sum1 = FastTwoSums(df, sum1, prod1, comp1);
sum2 = FastTwoSums(df, sum2, prod2, comp2);
sum3 = FastTwoSums(df, sum3, prod3, comp3);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
VF& comp0) const {
const VF prod0 = hn::MulAdd(w0, v0, comp0);
sum0 = FastTwoSums(df, sum0, prod0, comp0);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
comp0 = hn::Add(comp0, comp1);
comp2 = hn::Add(comp2, comp3);
VF sum_err = hn::Add(comp0, comp2);
UpdateCascadedSums(df, sum1, sum0, sum_err);
UpdateCascadedSums(df, sum3, sum2, sum_err);
UpdateCascadedSums(df, sum2, sum0, sum_err);
return ReduceCascadedSums(df, sum0, sum_err);
}
};
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotKahan(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
const VecT* HWY_RESTRICT vec_aligned, size_t num) {
return DecompressAndCall(d, w, w_ofs, vec_aligned, num, DotKernelKahan());
}
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotCompensated(D d, const PackedSpan<const WeightT>& w,
size_t w_ofs,
const VecT* HWY_RESTRICT vec_aligned,
size_t num) {
return DecompressAndCall(d, w, w_ofs, vec_aligned, num,
DotKernelCompensated());
}
// Like Compensated, but FastTwoSum instead of TwoSum.
struct DotKernelTwoProdFast {
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2,
const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
VF perr0, perr1, perr2, perr3;
const VF prod0 = TwoProducts(df, w0, v0, perr0);
const VF prod1 = TwoProducts(df, w1, v1, perr1);
const VF prod2 = TwoProducts(df, w2, v2, perr2);
const VF prod3 = TwoProducts(df, w3, v3, perr3);
VF serr0, serr1, serr2, serr3;
sum0 = FastTwoSums(df, sum0, prod0, serr0);
sum1 = FastTwoSums(df, sum1, prod1, serr1);
sum2 = FastTwoSums(df, sum2, prod2, serr2);
sum3 = FastTwoSums(df, sum3, prod3, serr3);
comp0 = hn::Add(comp0, hn::Add(perr0, serr0));
comp1 = hn::Add(comp1, hn::Add(perr1, serr1));
comp2 = hn::Add(comp2, hn::Add(perr2, serr2));
comp3 = hn::Add(comp3, hn::Add(perr3, serr3));
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
VF& comp0) const {
VF perr0;
const VF prod0 = TwoProducts(df, w0, v0, perr0);
VF serr0;
sum0 = FastTwoSums(df, sum0, prod0, serr0);
comp0 = hn::Add(comp0, hn::Add(perr0, serr0));
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
AssimilateCascadedSums(df, sum2, comp2, sum0, comp0);
return ReduceCascadedSums(df, sum0, comp0);
}
};
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotTwoProdFast(D d, const PackedSpan<const WeightT>& w,
size_t w_ofs,
const VecT* HWY_RESTRICT vec_aligned,
size_t num) {
return DecompressAndCall(d, w, w_ofs, vec_aligned, num,
DotKernelTwoProdFast());
}
// Like Compensated, but without TwoProducts. Vs Kahan, upgrades FastTwoSums
// to TwoSums.
struct DotKernelMulTwoSum {
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2,
const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
const VF prod0 = hn::Mul(w0, v0);
const VF prod1 = hn::Mul(w1, v1);
const VF prod2 = hn::Mul(w2, v2);
const VF prod3 = hn::Mul(w3, v3);
VF serr0, serr1, serr2, serr3;
sum0 = TwoSums(df, prod0, sum0, serr0);
sum1 = TwoSums(df, prod1, sum1, serr1);
sum2 = TwoSums(df, prod2, sum2, serr2);
sum3 = TwoSums(df, prod3, sum3, serr3);
comp0 = hn::Add(comp0, serr0);
comp1 = hn::Add(comp1, serr1);
comp2 = hn::Add(comp2, serr2);
comp3 = hn::Add(comp3, serr3);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
VF& comp0) const {
const VF prod0 = hn::Mul(w0, v0);
VF serr0;
sum0 = TwoSums(df, prod0, sum0, serr0);
comp0 = hn::Add(comp0, serr0);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
AssimilateCascadedSums(df, sum2, comp2, sum0, comp0);
return ReduceCascadedSums(df, sum0, comp0);
}
};
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotMulTwoSum(D d, const PackedSpan<const WeightT>& w,
size_t w_ofs,
const VecT* HWY_RESTRICT vec_aligned,
size_t num) {
return DecompressAndCall(d, w, w_ofs, vec_aligned, num, DotKernelMulTwoSum());
}
// -Like Compensated, but only TwoProducts, no [Fast]TwoSums. This is only 10%
// better (mul) than naive.
struct DotKernelTwoProdAdd {
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, const VF v0, const VF v1, const VF v2,
const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
VF perr0, perr1, perr2, perr3;
const VF prod0 = TwoProducts(df, w0, v0, perr0);
const VF prod1 = TwoProducts(df, w1, v1, perr1);
const VF prod2 = TwoProducts(df, w2, v2, perr2);
const VF prod3 = TwoProducts(df, w3, v3, perr3);
sum0 = hn::Add(sum0, prod0);
sum1 = hn::Add(sum1, prod1);
sum2 = hn::Add(sum2, prod2);
sum3 = hn::Add(sum3, prod3);
comp0 = hn::Add(comp0, perr0);
comp1 = hn::Add(comp1, perr1);
comp2 = hn::Add(comp2, perr2);
comp3 = hn::Add(comp3, perr3);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
VF& comp0) const {
VF perr0;
const VF prod0 = TwoProducts(df, w0, v0, perr0);
sum0 = hn::Add(sum0, prod0);
comp0 = hn::Add(comp0, perr0);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
AssimilateCascadedSums(df, sum2, comp2, sum0, comp0);
return ReduceCascadedSums(df, sum0, comp0);
}
};
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotTwoProdAdd(D d, const PackedSpan<const WeightT>& w,
size_t w_ofs,
const VecT* HWY_RESTRICT vec_aligned,
size_t num) {
return DecompressAndCall(d, w, w_ofs, vec_aligned, num,
DotKernelTwoProdAdd());
}
enum { // alphabetical order
kAddTwoProd,
kAddTwoSum,
kCompensated,
kKahan,
kNaive,
kOnlyTwoProd,
kVariants
};
const char* VariantName(size_t variant) {
switch (variant) {
case kAddTwoProd:
return "add2prod";
case kAddTwoSum:
return "add2sum";
case kCompensated:
return "comp";
case kKahan:
return "kahan";
case kNaive:
return "naive";
case kOnlyTwoProd:
return "only2prod";
default:
HWY_ABORT("Unknown variant %zu", variant);
return "?";
}
}
template <class D, typename WeightT, typename VecT>
float CallDot(D d, size_t variant, const PackedSpan<const WeightT>& w,
size_t w_ofs, const VecT* HWY_RESTRICT v, size_t num) {
switch (variant) {
case kAddTwoProd:
return DotTwoProdFast(d, w, 0, v, num);
case kAddTwoSum:
return DotMulTwoSum(d, w, 0, v, num);
case kCompensated:
return DotCompensated(d, w, 0, v, num);
case kKahan:
return DotKahan(d, w, 0, v, num);
case kNaive:
return DotNaive(d, w, 0, v, num);
case kOnlyTwoProd:
return DotTwoProdAdd(d, w, 0, v, num);
default:
HWY_ABORT("Unknown variant %zu", variant);
return 0.0f;
}
}
// Returns result accurate to 1.5 ulp, assuming `num` < 2^(52-23), no overflow,
// and round to nearest. See "Accurate and efficient floating point summation".
// Much too slow to be useful. Kept separate from the above kernels because it
// is used to compute their error.
template <typename WeightT, typename VecT>
float ExactDot(const WeightT* HWY_RESTRICT w, const VecT* HWY_RESTRICT v,
size_t num, double* HWY_RESTRICT buf) {
PROFILER_FUNC;
for (size_t i = 0; i < num; ++i) {
buf[i] =
hwy::ConvertScalarTo<double>(w[i]) * hwy::ConvertScalarTo<double>(v[i]);
}
// Sort by decreasing magnitude (not supported by VQSort).
std::sort(buf, buf + num,
[](double a, double b) { return std::abs(a) > std::abs(b); });
double sum = 0.0;
for (size_t i = 0; i < num; ++i) {
sum += buf[i];
}
return static_cast<float>(sum);
}
//------------------------------------------------------------------------------
class DotStats {
static float Ratio(float a, float b) {
// If 0, we would return infinity, which messes up the statistics.
if (a == 0.0f || b == 0.0f) return 1.0f;
// Absolute value because a sign change and 4x difference would
// otherwise return the smaller ratio 0.25.
return HWY_MAX(std::abs(a / b), std::abs(b / a));
}
public:
DotStats() {
for (size_t i = 0; i < kVariants; ++i) {
max_muls[i] = 1.0f;
}
}
static void PrintStats(const char* caption, size_t variant,
const hwy::Stats& stats) {
fprintf(stderr, "%s %9s %s\n", caption, VariantName(variant),
stats.ToString(/*exclude=*/0).c_str());
}
// Call once per rep.
void NotifyRep(size_t num, double cond, float dot_exact,
float dots[kVariants]) {
s_cond.Notify(cond);
const float mul_tol = cond > 1E8 ? 1.5f : cond > 1E7 ? 1.1f : 1.01f;
float muls[kVariants];
float l1s[kVariants];
uint32_t ulps[kVariants];
for (size_t i = 0; i < kVariants; ++i) {
muls[i] = Ratio(dots[i], dot_exact);
max_muls[i] = HWY_MAX(max_muls[i], muls[i]);
l1s[i] = std::abs(dots[i] - dot_exact);
s_l1s[i].Notify(l1s[i]);
ulps[i] = hwy::detail::ComputeUlpDelta(dots[i], dot_exact);
s_ulps[i].Notify(ulps[i]);
}
if (muls[kKahan] > mul_tol || l1s[kKahan] > 0.1f ||
muls[kNaive] + 1E-3f < muls[kKahan] || ulps[kCompensated] > 10) {
fprintf(stderr, "num %2zu cond %.1E exact %.8f\n", num, cond, dot_exact);
for (size_t i = 0; i < kVariants; ++i) {
fprintf(stderr, " %9s dot %11.8f mul %.8f\n", VariantName(i), dots[i],
muls[i]);
}
}
}
// Call after all reps.
void NotifyRatios() {
for (size_t i = 0; i < kVariants; ++i) {
s_muls[i].Notify(max_muls[i]);
}
}
void NotifyTimes(double times[kVariants]) {
for (size_t i = 0; i < kVariants; ++i) {
s_times[i].Notify(times[i]);
}
}
void Assimilate(const DotStats& other) {
s_cond.Assimilate(other.s_cond);
for (size_t i = 0; i < kVariants; ++i) {
s_muls[i].Assimilate(other.s_muls[i]);
s_l1s[i].Assimilate(other.s_l1s[i]);
s_ulps[i].Assimilate(other.s_ulps[i]);
s_times[i].Assimilate(other.s_times[i]);
}
}
void Print() const {
PrintStats("cond", 0, s_cond);
for (size_t variant = 0; variant < kVariants; ++variant) {
PrintStats("mul", variant, s_muls[variant]);
}
for (size_t variant = 0; variant < kVariants; ++variant) {
PrintStats(" l1", variant, s_l1s[variant]);
}
for (size_t variant = 0; variant < kVariants; ++variant) {
PrintStats("ulp", variant, s_ulps[variant]);
}
if (s_times[0].Count()) {
for (size_t variant = 0; variant < kVariants; ++variant) {
PrintStats("time", variant, s_times[variant]);
}
}
}
void Check() const {
CheckMuls();
CheckL1();
CheckUlps();
// We do not check times because they can be noisy/nonportable, but
// `kAddTwoProd` is only about 10% slower than `kKahan`, and about 1.5 times
// as fast as `kCompensated`.
}
private:
// Factor by which the approximate result is off; larger is worse.
void CheckMuls() const {
// Compensated is very accurate.
HWY_ASSERT(s_muls[kCompensated].Min() <= 1.0f + 2E-6f);
HWY_ASSERT(s_muls[kCompensated].Max() <= 1.0f + 2E-5f);
// Naive and OnlyTwoProd are considerably worse. >10x is for narrower
// vectors, compared to AVX-512. GeometricMean overflows, must use Mean.
HWY_ASSERT(gcpp::IsInside(1.01, 16.0, s_muls[kNaive].Mean()));
HWY_ASSERT(gcpp::IsInside(1.01, 13.0, s_muls[kOnlyTwoProd].Mean()));
// Kahan (FastTwoSum) is decent:
HWY_ASSERT(gcpp::IsInside(1.001, 4.1, s_muls[kKahan].Mean()));
HWY_ASSERT(gcpp::IsInside(1.001f, 14.1f, s_muls[kKahan].Max()));
HWY_ASSERT(gcpp::IsInside(1.0, 1.6, s_muls[kKahan].GeometricMean()));
// But can be considerably improved via TwoProducts:
HWY_ASSERT(gcpp::IsInside(1.0005, 1.5, s_muls[kAddTwoProd].Mean()));
HWY_ASSERT(gcpp::IsInside(1.001f, 2.3f, s_muls[kAddTwoProd].Max()));
HWY_ASSERT(gcpp::IsInside(1.0, 1.2, s_muls[kAddTwoProd].GeometricMean()));
// Updating Kahan's FastTwoSums to TwoSums is not quite as helpful.
HWY_ASSERT(gcpp::IsInside(1.0005, 2.2, s_muls[kAddTwoSum].Mean()));
HWY_ASSERT(gcpp::IsInside(1.0, 1.3, s_muls[kAddTwoProd].GeometricMean()));
}
// Absolute error; larger is worse.
void CheckL1() const {
// Compensated is very accurate.
HWY_ASSERT(s_l1s[kCompensated].Min() == 0.0f);
HWY_ASSERT(s_l1s[kCompensated].Max() <= 3E-7f);
// Naive and OnlyTwoProd are considerably higher, but not huge.
HWY_ASSERT(gcpp::IsInside(1E-3, 2E-2, s_l1s[kNaive].Mean()));
HWY_ASSERT(gcpp::IsInside(1E-3, 2E-2, s_l1s[kOnlyTwoProd].Mean()));
// Kahan (FastTwoSum) is decent:
HWY_ASSERT(gcpp::IsInside(4.5E-4, 1E-3, s_l1s[kKahan].Mean()));
HWY_ASSERT(gcpp::IsInside(1.1E-3f, 3.2E-3f, s_l1s[kKahan].Max()));
// But can be nearly halved via TwoProducts:
HWY_ASSERT(gcpp::IsInside(2.5E-4, 8E-4, s_l1s[kAddTwoProd].Mean()));
HWY_ASSERT(gcpp::IsInside(4E-4f, 2.0E-3f, s_l1s[kAddTwoProd].Max()));
// Updating Kahan's FastTwoSums to TwoSums does help a bit.
HWY_ASSERT(gcpp::IsInside(1.5E-4, 5.2E-4, s_l1s[kAddTwoSum].Mean()));
}
// Units in the last place; larger is worse.
void CheckUlps() const {
HWY_ASSERT(s_ulps[kCompensated].Max() <= 250.0f);
HWY_ASSERT(s_ulps[kNaive].Max() <= 4E9f);
HWY_ASSERT(s_ulps[kOnlyTwoProd].Max() <= 3E9f);
HWY_ASSERT(s_ulps[kKahan].Max() <= 4E7f);
HWY_ASSERT(s_ulps[kAddTwoProd].Max() <= 1E7f);
HWY_ASSERT(s_ulps[kAddTwoSum].Max() <= 2.5E7f);
}
hwy::Stats s_cond;
// Relative error
float max_muls[kVariants];
hwy::Stats s_muls[kVariants];
hwy::Stats s_l1s[kVariants]; // Absolute error
hwy::Stats s_ulps[kVariants]; // Only relevant for small cond
hwy::Stats s_times[kVariants];
};
// Returns normalized value in [-1, 1).
float RandomFloat(std::mt19937& rng) {
const uint32_t exp = hwy::BitCastScalar<uint32_t>(1.0f);
const uint32_t mantissa_mask = hwy::MantissaMask<float>();
const uint32_t representation = exp | (rng() & mantissa_mask);
const float f12 = hwy::BitCastScalar<float>(representation);
HWY_DASSERT(1.0f <= f12 && f12 < 2.0f); // exponent is 2^0, only mantissa
const float f = (2.0f * (f12 - 1.0f)) - 1.0f;
HWY_DASSERT(-1.0f <= f && f < 1.0f);
return f;
}
// `raw` holds the decompressed values, so that the test measures only the
// error from the Dot algorithms, not the compression.
template <typename Packed>
void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw,
std::mt19937& rng,
const PackedSpan<Packed>& packed,
CompressWorkingSet& work) {
std::uniform_int_distribution<int> e_dist(0, 6);
for (size_t i = 0; i < num; ++i) {
raw[i] = RandomFloat(rng) * (1 << e_dist(rng));
}
if (IsCompressed<Packed>()) {
// Don't care about the original range.
(void)ScaleWeights(raw, num);
}
hwy::ThreadPool pool(0); // num is too small for parallelization
const size_t packed_ofs = 0;
Compress(raw, num, work, packed, packed_ofs, pool);
const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, MakeConst(packed), packed_ofs, raw, num);
}
// Returns the actual condition number. Based on Algorithm 6.1 from "Accurate
// Sum and Dot Product". `num` is the (arbitrary) size of w, v, and buf.
template <typename WeightT, typename VecT>
double GenerateIllConditionedInputs(const size_t num, WeightT* w,
VecT* HWY_RESTRICT v, std::mt19937& rng) {
PROFILER_FUNC;
const size_t half = HWY_MAX(1, num / 2); // generate at least one random
HWY_DASSERT(half != 0);
const hn::ScalableTag<float> df;
const PackedSpan<WeightT> w_span(w, num);
// Regardless of WeightT and VecT, we will accumulate into float. Multiplying
// two maximal inputs and accumulating `num` times is enough for some loss of
// precision and condition numbers between 1E6-1E9, which is what we see for
// Attention Dot and `RMSNormMul`.
const int max_exp = 5;
std::uniform_int_distribution<int> e_dist(0, max_exp);
// First half: random exponents and mantissas
for (size_t i = 0; i < half; ++i) {
// Ensure the min and max exponents are used.
const int e = i == 0 ? 0 : i == 1 ? max_exp : e_dist(rng);
w[i] = hwy::ConvertScalarTo<WeightT>(RandomFloat(rng) * (1 << e));
v[i] = hwy::ConvertScalarTo<VecT>(RandomFloat(rng) * (1 << e));
}
const float a_exp_step =
num == half ? 0.0f : static_cast<float>(max_exp) / (num - half);
float a_exp = max_exp; // max_exp downto 0
for (size_t i = half; i < num; ++i, a_exp -= a_exp_step) {
const int e = static_cast<int>(a_exp);
HWY_DASSERT(e >= 0);
w[i] = hwy::ConvertScalarTo<WeightT>(RandomFloat(rng) * (1 << e));
const float r = RandomFloat(rng) * (1 << e);
if (hwy::ConvertScalarTo<float>(w[i]) == 0.0f) {
v[i] = hwy::ConvertScalarTo<VecT>(0.0f);
} else {
// This is called >100K times. DotCompensated is much faster than ExactDot
// and just about as accurate.
const float exact =
DotCompensated(df, MakeConst(w_span), /*w_ofs=*/0, v, i);
v[i] = hwy::ConvertScalarTo<VecT>(
r - exact / hwy::ConvertScalarTo<float>(w[i]));
}
}
// Fisher-Yates shuffle of both a and b simultaneously - std::shuffle only
// shuffles one array, and we want the same permutation for both.
for (size_t i = num - 1; i != 0; --i) {
std::uniform_int_distribution<size_t> dist(0, i);
const size_t j = dist(rng);
std::swap(w[i], w[j]);
std::swap(v[i], v[j]);
}
return ConditionNumber(w, v, num);
}
// Runs all Dot algorithms for all short lengths and all Packed/raw types
// on well-conditioned inputs, and ensures the results are close to exact.
template <typename Packed>
struct TestShortDotsT {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const size_t N = hn::Lanes(d);
const hn::ScalableTag<float> df; // for CallDot
CompressWorkingSet work;
std::mt19937 rng;
rng.seed(12345);
hwy::Stats s_l1[kVariants];
for (size_t num = 1; num <= 5 * N; ++num) {
// GenerateWellConditionedInputs calls DecompressAndZeroPad to `raw*`,
// hence they require padding to one vector.
const size_t padded_num = hwy::RoundUpTo(num, N);
const size_t packed_num = CompressedArrayElements<Packed>(num);
RowVectorBatch<float> raw_w(1, padded_num);
RowVectorBatch<float> raw_v(1, padded_num);
RowVectorBatch<Packed> weights(1, packed_num);
const PackedSpan<Packed> w(weights.Batch(0), packed_num);
RowVectorBatch<T> vectors(1, num);
const PackedSpan<T> v(vectors.Batch(0), num);
RowVectorBatch<double> bufs(1, num);
double* HWY_RESTRICT buf = bufs.Batch(0);
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
GenerateWellConditionedInputs(num, raw_w.All(), rng, w, work);
GenerateWellConditionedInputs(num, raw_v.All(), rng, v, work);
const float dot_exact = ExactDot(raw_w.All(), raw_v.All(), num, buf);
float dots[kVariants];
for (size_t variant = 0; variant < kVariants; ++variant) {
dots[variant] = CallDot(df, variant, MakeConst(w), 0, v.ptr, num);
const float l1 = hwy::ScalarAbs(dots[variant] - dot_exact);
s_l1[variant].Notify(l1);
}
}
}
// Avoid extra output for partial vectors.
if (hn::detail::IsFull(d)) {
for (size_t variant = 0; variant < kVariants; ++variant) {
DotStats::PrintStats("l1", variant, s_l1[variant]);
}
}
// Verify the dot products are plausible. This is only to verify
// correctness, not to differentiate between the variants.
double expected_l1[kVariants];
// Tolerances are much lower for compressed inputs: the more limited set of
// values seems to reduce roundoff.
constexpr bool kCompressed = IsCompressed<Packed>();
expected_l1[kAddTwoProd] = kCompressed ? 1.5E-6 : 5E-5;
expected_l1[kAddTwoSum] = kCompressed ? 1.5E-6 : 6E-5;
expected_l1[kCompensated] = kCompressed ? 1.5E-6 : 4E-5;
expected_l1[kKahan] = kCompressed ? 1.5E-6 : 7E-5;
expected_l1[kNaive] = kCompressed ? 4E-6 : 1.5E-4;
expected_l1[kOnlyTwoProd] = kCompressed ? 1.5E-6 : 6E-5;
for (size_t variant = 0; variant < kVariants; ++variant) {
HWY_ASSERT(s_l1[variant].Min() >= 0.0f);
HWY_ASSERT(s_l1[variant].Max() <= 1.5E-3f);
if (s_l1[variant].Mean() > expected_l1[variant]) {
HWY_ABORT("%s -> %s: %s mean l1 %.5E > %.5E\n", TypeName<Packed>(),
TypeName<T>(), VariantName(variant), s_l1[variant].Mean(),
expected_l1[variant]);
}
}
}
};
void TestAllShortDots() { ForeachPackedAndRawType<TestShortDotsT>(); }
// Excludes outliers; we might not have enough samples for a reliable mode.
double TrimmedMean(double* seconds, size_t num) {
std::sort(seconds, seconds + num);
double sum = 0;
int count = 0;
for (size_t i = num / 4; i < num / 2; ++i) {
sum += seconds[i];
count += 1;
}
return sum / count;
}
// Tests W=float, V=float for one large size and many reps on ill-conditioned
// inputs. Also includes benchmarking.
void TestAllDot() {
// Skip EMU128 and old x86, include SSE4 because it tests the non-FMA path.
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSSE3 ||
HWY_TARGET == HWY_SSE2) {
return;
}
const hn::ScalableTag<float> df;
constexpr size_t kMaxWorkers = 15;
std::mt19937 rngs[kMaxWorkers];
for (size_t i = 0; i < kMaxWorkers; ++i) {
rngs[i].seed(12345 + 65537 * i);
}
constexpr size_t kReps = hn::AdjustedReps(40);
const size_t num = 24 * 1024;
PerClusterPools pools(/*max_clusters=*/1, kMaxWorkers - 1, /*pin=*/1);
RowVectorBatch<float> a(kMaxWorkers, num);
RowVectorBatch<float> b(kMaxWorkers, num);
RowVectorBatch<double> bufs(kMaxWorkers, num);
std::array<DotStats, kMaxWorkers> all_stats;
pools.Inner(0).Run(0, kReps, [&](const uint32_t rep, size_t thread) {
float* HWY_RESTRICT pa = a.Batch(thread);
float* HWY_RESTRICT pb = b.Batch(thread);
double* HWY_RESTRICT buf = bufs.Batch(thread);
const PackedSpan<const float> a_span(pa, num);
DotStats& stats = all_stats[thread];
const double cond = GenerateIllConditionedInputs(num, pa, pb, rngs[thread]);
const float dot_exact = ExactDot(pa, pb, num, buf);
float dots[kVariants] = {};
double times[kVariants] = {};
for (size_t variant = 0; variant < kVariants; ++variant) {
constexpr size_t kTimeReps = hn::AdjustedReps(10);
std::array<double, kTimeReps> elapsed;
for (int time_rep = 0; time_rep < kTimeReps; ++time_rep) {
const double start = hwy::platform::Now();
dots[variant] += CallDot(df, variant, a_span, /*w_ofs=*/0, pb, num);
hwy::PreventElision(*pa);
elapsed[time_rep] = hwy::platform::Now() - start;
}
dots[variant] /= kTimeReps;
times[variant] = TrimmedMean(elapsed.data(), kTimeReps);
}
stats.NotifyTimes(times);
stats.NotifyRep(num, cond, dot_exact, dots);
stats.NotifyRatios();
});
DotStats& stats = all_stats[0];
for (size_t i = 1; i < kMaxWorkers; ++i) {
stats.Assimilate(all_stats[i]);
}
static bool once = true;
if (once) {
once = false;
stats.Print();
}
stats.Check();
PROFILER_PRINT_RESULTS();
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_BEFORE_TEST(DotTest);
HWY_EXPORT_AND_TEST_P(DotTest, TestAllShortDots);
HWY_EXPORT_AND_TEST_P(DotTest, TestAllDot);
HWY_AFTER_TEST();
} // namespace gcpp
#endif