mirror of https://github.com/google/gemma.cpp.git
Fix SFP/NUQ for bf16 rounding in Highway
SFP: Avoid rounding twice, and more robust TestDot. NUQ: also more robust SNR, minor touchups to header. PiperOrigin-RevId: 618030096
This commit is contained in:
parent
a135bc1e47
commit
24add61dd9
|
|
@ -40,6 +40,7 @@ set(SOURCES
|
|||
compression/nuq-inl.h
|
||||
compression/sfp.h
|
||||
compression/sfp-inl.h
|
||||
compression/test_util.h
|
||||
util/app.h
|
||||
util/args.h
|
||||
)
|
||||
|
|
|
|||
|
|
@ -24,20 +24,36 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
# Deprecated because it is also implemented in Highway; will be removed once
|
||||
# that Highway version is sufficiently widespread.
|
||||
cc_library(
|
||||
name = "stats",
|
||||
srcs = [
|
||||
"stats.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"distortion.h",
|
||||
"stats.h",
|
||||
],
|
||||
srcs = ["stats.cc"],
|
||||
hdrs = ["stats.h"],
|
||||
deps = [
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "distortion",
|
||||
hdrs = ["distortion.h"],
|
||||
deps = [
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_util",
|
||||
hdrs = ["test_util.h"],
|
||||
deps = [
|
||||
":distortion",
|
||||
":stats",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sfp",
|
||||
hdrs = [
|
||||
|
|
@ -62,12 +78,11 @@ cc_test(
|
|||
tags = ["hwy_ops_test"],
|
||||
deps = [
|
||||
":sfp",
|
||||
":stats",
|
||||
":test_util",
|
||||
"@googletest//:gtest_main",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
"@hwy//:nanobenchmark",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -98,7 +113,7 @@ cc_test(
|
|||
deps = [
|
||||
":nuq",
|
||||
":sfp",
|
||||
":stats",
|
||||
":test_util",
|
||||
"@googletest//:gtest_main",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
|
|
@ -118,6 +133,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":blob_store",
|
||||
":distortion",
|
||||
":nuq",
|
||||
":sfp",
|
||||
":stats",
|
||||
|
|
@ -134,6 +150,7 @@ cc_library(
|
|||
"analyze.h",
|
||||
],
|
||||
deps = [
|
||||
":distortion",
|
||||
":nuq",
|
||||
":sfp",
|
||||
":stats",
|
||||
|
|
|
|||
|
|
@ -42,6 +42,10 @@
|
|||
#include "hwy/contrib/sort/vqsort-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
#ifndef HWY_IF_CONSTEXPR
|
||||
#define HWY_IF_CONSTEXPR if
|
||||
#endif
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
|
@ -124,7 +128,7 @@ class NuqClustering {
|
|||
}
|
||||
|
||||
private:
|
||||
// Float has enough precision for our relatively small kGroupSize (128).
|
||||
// Float has enough precision for our relatively small kGroupSize (256).
|
||||
float cumsum_[kGroupSize + 1];
|
||||
float cumsum2_[kGroupSize + 1];
|
||||
float inv_len_[kGroupSize + 1];
|
||||
|
|
@ -168,8 +172,8 @@ class NuqClustering {
|
|||
// `centers`; prior centers are zero-initialized.
|
||||
//
|
||||
// O(kClusters * kGroupSize * kGroupSize), but the constant factors are so low
|
||||
// that this is about 10 times as fast as the O(kClusters * kGroupSize) SMAWK
|
||||
// as implemented in FAISS, for our kGroupSize <= 128.
|
||||
// that this is about 5 times as fast as the O(kClusters * kGroupSize) SMAWK
|
||||
// as implemented in FAISS, for our kGroupSize of 256.
|
||||
template <class DF>
|
||||
static HWY_NOINLINE size_t ClusterExactL2(DF df, const float* x,
|
||||
ClusterBuf& buf,
|
||||
|
|
@ -228,7 +232,7 @@ class NuqClustering {
|
|||
// Center = mean, O(1) thanks to cumulative sums.
|
||||
const float sum = cc.SumOfSorted(start, last);
|
||||
const int size = static_cast<int>(last) - static_cast<int>(start) + 1;
|
||||
HWY_DASSERT(0 < size && size <= kGroupSize);
|
||||
HWY_DASSERT(0 < size && size <= static_cast<int>(kGroupSize));
|
||||
centers[k] = sum / static_cast<float>(size);
|
||||
|
||||
// We know the range inside sorted_and_i[]; translate to original indices,
|
||||
|
|
@ -427,7 +431,7 @@ class NuqCodec {
|
|||
// instead of TableLookupBytes, which requires extra interleaving of lo/hi.
|
||||
HWY_DASSERT(hn::Lanes(du) >= 8);
|
||||
|
||||
if (NumTables(du) == 2) {
|
||||
HWY_IF_CONSTEXPR(NumTables(du) == 2) {
|
||||
// Reduce cap for second half to avoid loading past the end of the table.
|
||||
const hn::CappedTag<hwy::bfloat16_t, kClusters / 2> d_table2;
|
||||
*tbl1 = hn::ResizeBitCast(du, hn::LoadU(d_table2, table + kClusters / 2));
|
||||
|
|
@ -449,11 +453,12 @@ class NuqCodec {
|
|||
const auto indices0 = hn::IndicesFromVec(du, idx0);
|
||||
const auto indices1 = hn::IndicesFromVec(du, idx1);
|
||||
|
||||
if (NumTables(du) == 1) {
|
||||
HWY_IF_CONSTEXPR(NumTables(du) == 1) {
|
||||
(void)tbl1;
|
||||
c0 = hn::TableLookupLanes(tbl0, indices0);
|
||||
c1 = hn::TableLookupLanes(tbl0, indices1);
|
||||
} else {
|
||||
}
|
||||
HWY_IF_CONSTEXPR(NumTables(du) == 2) { // `else` is poorly formatted.
|
||||
c0 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices0);
|
||||
c1 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices1);
|
||||
}
|
||||
|
|
@ -521,8 +526,8 @@ class NuqCodec {
|
|||
// Decodes `num` values from the stream `in`, starting at the offset `in_ofs`
|
||||
// (in units of values), to bf16 in `out`. `in_capacity`, `in_ofs` and `num`
|
||||
// must all be multiples of `kGroupSize`.
|
||||
template <class DF, HWY_IF_BF16_D(DF)>
|
||||
static HWY_INLINE void Dec(DF dbf, const size_t in_capacity,
|
||||
template <class DBF, HWY_IF_BF16_D(DBF)>
|
||||
static HWY_INLINE void Dec(DBF dbf, const size_t in_capacity,
|
||||
const NuqStream* const in, const size_t in_ofs,
|
||||
hwy::bfloat16_t* const out, const size_t num) {
|
||||
const hn::RebindToUnsigned<decltype(dbf)> d16;
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@
|
|||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
|
|
@ -35,15 +36,14 @@
|
|||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
// Other headers that include Highway must come after foreach_target.h
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/distortion.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/nuq-inl.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/nuq.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/test_util.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
|
|
@ -181,12 +181,14 @@ struct TestNormal {
|
|||
auto in = hwy::AllocateAligned<float>(kGroupSize);
|
||||
HWY_ASSERT(in);
|
||||
|
||||
std::mt19937 rng(123);
|
||||
std::normal_distribution<float> dist{0.001f, 0.3f};
|
||||
hwy::RandomState rng;
|
||||
Stats in_stats;
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
in[i] = dist(rng);
|
||||
const double r = RandomGaussian(rng);
|
||||
in_stats.Notify(r);
|
||||
in[i] = hwy::ConvertScalarTo<T>(r);
|
||||
}
|
||||
std::shuffle(in.get(), in.get() + kGroupSize, rng);
|
||||
VerifyGaussian(in_stats);
|
||||
|
||||
ClusterBuf buf;
|
||||
float centers[kClusters];
|
||||
|
|
@ -212,9 +214,9 @@ struct TestNormal {
|
|||
const float snr = stats.GeomeanValueDivL1();
|
||||
fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr,
|
||||
stats.MaxIndex(), stats.MaxL1());
|
||||
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected");
|
||||
const float expected_pnorm = kGroupSize == 128 ? 3E-2f : 3.4E-2f;
|
||||
const float expected_snr = kGroupSize == 128 ? 17.4f : 13.1f;
|
||||
static_assert(kGroupSize == 256, "Update expected");
|
||||
const float expected_pnorm = 3.68E-2f;
|
||||
const float expected_snr = 12.7f;
|
||||
HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm);
|
||||
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
|
||||
}
|
||||
|
|
@ -345,21 +347,27 @@ struct TestDot {
|
|||
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num));
|
||||
HWY_ASSERT(in && dec && vec && nuq);
|
||||
|
||||
std::mt19937 rng(123);
|
||||
std::normal_distribution<float> dist{0.001f, 0.3f};
|
||||
// Generate inputs and verify their distribution.
|
||||
hwy::RandomState rng;
|
||||
Stats in_stats;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
in[i] = dist(rng);
|
||||
vec[i] = hwy::ConvertScalarTo<T>(dist(rng));
|
||||
const float r = static_cast<float>(RandomGaussian(rng));
|
||||
in_stats.Notify(r);
|
||||
in[i] = r;
|
||||
}
|
||||
// This changes the correlation between in and vec, which considerably
|
||||
// affects the error of the result.
|
||||
std::shuffle(in.get(), in.get() + num, rng);
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
const float r = static_cast<float>(RandomGaussian(rng));
|
||||
in_stats.Notify(r);
|
||||
vec[i] = hwy::ConvertScalarTo<T>(r);
|
||||
}
|
||||
VerifyGaussian(in_stats);
|
||||
|
||||
ClusterBuf buf;
|
||||
const size_t unused_clusters =
|
||||
NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0);
|
||||
HWY_ASSERT(unused_clusters == 0);
|
||||
|
||||
// Compute dot product without decompression.
|
||||
double actual = 0.0;
|
||||
double elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < 20; ++rep) {
|
||||
|
|
@ -380,24 +388,39 @@ struct TestDot {
|
|||
fprintf(stderr, "Vec %zu Dec %.2f MB/s\n", Lanes(d) * sizeof(T),
|
||||
num * sizeof(in[0]) * 1E-6 / elapsed);
|
||||
|
||||
double expected = 0.0; // using original input
|
||||
double expected2 = 0.0; // using decoded NUQ
|
||||
// Exact and decompressed dot products for comparison.
|
||||
double exact = 0.0; // using original input
|
||||
double expected = 0.0; // using decoded NUQ
|
||||
DistortionStats dec_stats;
|
||||
Stats ratios;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
expected += in[i] * hwy::ConvertScalarTo<double>(vec[i]);
|
||||
expected2 += dec[i] * hwy::ConvertScalarTo<double>(vec[i]);
|
||||
dec_stats.Notify(in[i], dec[i]);
|
||||
const float v1 = hwy::ConvertScalarTo<float>(vec[i]);
|
||||
exact += in[i] * v1;
|
||||
expected += dec[i] * v1;
|
||||
if (expected != 0.0f) {
|
||||
ratios.Notify(exact / expected);
|
||||
}
|
||||
const double l1 = hwy::ScalarAbs(expected - actual);
|
||||
const double snr = 1.0 + hwy::ScalarAbs(expected) / l1;
|
||||
fprintf(stderr, "expected %.3f e2 %.4f actual %.4f l1 %E snr %.2f\n",
|
||||
expected, expected2, actual, l1, snr);
|
||||
HWY_ASSERT(hwy::ScalarAbs(expected2 - actual) < 1E-4);
|
||||
static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected");
|
||||
const double expected_l1 = kGroupSize == 128 ? 7.3E-2 : 4.34E-2;
|
||||
const double expected_snr = kGroupSize == 128 ? 9.7f
|
||||
: sizeof(T) == 2 ? 14.5f
|
||||
: 14.9f;
|
||||
HWY_ASSERT(expected_l1 <= l1 && l1 < 1.02f * expected_l1);
|
||||
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
|
||||
}
|
||||
const double dec_snr = dec_stats.GeomeanValueDivL1();
|
||||
const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean());
|
||||
// exact and actual fluctuate due to the combination of NUQ imprecision,
|
||||
// and whether vec[i] is negative or positive, so this is quite loose.
|
||||
const float final_ratio = HWY_MIN(exact / actual, actual / exact);
|
||||
fprintf(stderr, "ratios %s\n", ratios.ToString().c_str());
|
||||
fprintf(stderr,
|
||||
"exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f "
|
||||
"dot_snr %.2f\n",
|
||||
exact, expected, actual, final_ratio, dec_snr, dot_snr);
|
||||
// Final values are not too far apart.
|
||||
HWY_ASSERT(0.88f <= final_ratio && final_ratio <= 1.0f);
|
||||
// Decompressed and uncompressed dot should match exactly.
|
||||
HWY_ASSERT(hwy::ScalarAbs(expected - actual) < 1E-4f);
|
||||
// dec[] is close to in[], but we already check that in TestStream.
|
||||
HWY_ASSERT(dec_snr >= 13.0);
|
||||
// Geomean of ratios for each i is an approximation of the actual SNR.
|
||||
HWY_ASSERT(dot_snr >= (sizeof(T) == 2 ? 17.0 : 14.0));
|
||||
static_assert(kGroupSize == 256, "Update expected*");
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -449,6 +449,18 @@ class SfpCodec {
|
|||
return Enc2U(d16, w0, w1);
|
||||
}
|
||||
|
||||
// Truncates two f32 to bf16, in lane order, without rounding (see Enc4F).
|
||||
template <class DBF, class DF = hn::RepartitionToWide<DBF>>
|
||||
static HWY_INLINE hn::Vec<DBF> Truncate2To(DBF dbf, hn::Vec<DF> f0,
|
||||
hn::Vec<DF> f1) {
|
||||
const hn::RebindToUnsigned<DBF> d16;
|
||||
using V16 = hn::Vec<decltype(d16)>;
|
||||
const V16 u0 = BitCast(d16, f0);
|
||||
const V16 u1 = BitCast(d16, f1);
|
||||
return BitCast(DBF(), HWY_IS_LITTLE_ENDIAN ? ConcatOdd(d16, u1, u0)
|
||||
: ConcatEven(d16, u1, u0));
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF),
|
||||
class V8 = hn::Vec<hn::Repartition<uint8_t, DF>>>
|
||||
static HWY_INLINE V8 Enc4F(DF df, const float* HWY_RESTRICT in) {
|
||||
|
|
@ -462,9 +474,10 @@ class SfpCodec {
|
|||
const VF f1 = hn::LoadU(df, in + NF * 1);
|
||||
const VF f2 = hn::LoadU(df, in + NF * 2);
|
||||
const VF f3 = hn::LoadU(df, in + NF * 3);
|
||||
// Chop off the lower 16 bits; EncBytes still rounds properly.
|
||||
const V16 w0 = hn::BitCast(d16, hn::OrderedDemote2To(dbf, f0, f1));
|
||||
const V16 w1 = hn::BitCast(d16, hn::OrderedDemote2To(dbf, f2, f3));
|
||||
// Chop off the lower 16 bits instead of OrderedDemote2To, which rounds to
|
||||
// the nearest bf16, because EncBytes will round again.
|
||||
const V16 w0 = hn::BitCast(d16, Truncate2To(dbf, f0, f1));
|
||||
const V16 w1 = hn::BitCast(d16, Truncate2To(dbf, f2, f3));
|
||||
return Enc2U(d16, w0, w1);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,12 +25,11 @@
|
|||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
#include <set>
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
|
|
@ -39,13 +38,12 @@
|
|||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
// Any highway.h must come after foreach_target.h
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/distortion.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/sfp-inl.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/test_util.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
|
|
@ -358,25 +356,31 @@ struct TestDot {
|
|||
template <typename T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const hn::Repartition<float, D> df;
|
||||
const size_t num = 384;
|
||||
const size_t num = 1024; // not too many for GeometricMean overflow.
|
||||
auto in = hwy::AllocateAligned<T>(num);
|
||||
auto dec = hwy::AllocateAligned<T>(num);
|
||||
auto vec = hwy::AllocateAligned<T>(num);
|
||||
auto sfp = hwy::AllocateAligned<SfpStream>(num);
|
||||
HWY_ASSERT(in && dec && vec && sfp);
|
||||
|
||||
std::mt19937 rng(123);
|
||||
std::normal_distribution<float> dist{0.001f, 0.3f};
|
||||
// Generate inputs and verify their distribution.
|
||||
hwy::RandomState rng;
|
||||
Stats in_stats;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
in[i] = hwy::ConvertScalarTo<T>(dist(rng));
|
||||
vec[i] = hwy::ConvertScalarTo<T>(dist(rng));
|
||||
const float r = static_cast<float>(RandomGaussian(rng));
|
||||
in_stats.Notify(r);
|
||||
in[i] = hwy::ConvertScalarTo<T>(r);
|
||||
}
|
||||
// This changes the correlation between in and vec, which considerably
|
||||
// affects the error of the result.
|
||||
std::shuffle(in.get(), in.get() + num, rng);
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
const float r = static_cast<float>(RandomGaussian(rng));
|
||||
in_stats.Notify(r);
|
||||
vec[i] = hwy::ConvertScalarTo<T>(r);
|
||||
}
|
||||
VerifyGaussian(in_stats);
|
||||
|
||||
SfpCodec::Enc(d, in.get(), num, sfp.get());
|
||||
|
||||
// Compute dot product without decompression.
|
||||
double actual = 0.0;
|
||||
double elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < 200; ++rep) {
|
||||
|
|
@ -393,26 +397,44 @@ struct TestDot {
|
|||
}
|
||||
|
||||
SfpCodec::Dec(d, sfp.get(), num, dec.get());
|
||||
fprintf(stderr, "Vec %zu Dot %.2f MB/s\n", Lanes(d) * sizeof(T),
|
||||
num * sizeof(T) * 1E-6 / elapsed);
|
||||
fprintf(stderr, "Vec %zu Dot %zu-bit %.2f MB/s\n", Lanes(d) * sizeof(T),
|
||||
sizeof(T) * 8, num * sizeof(T) * 1E-6 / elapsed);
|
||||
|
||||
double expected = 0.0; // using original input
|
||||
double expected2 = 0.0; // using decoded SFP
|
||||
// Exact and decompressed dot products for comparison.
|
||||
float exact = 0.0f; // using original input
|
||||
float expected = 0.0f; // using decoded SFP
|
||||
DistortionStats dec_stats;
|
||||
Stats ratios;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
expected += hwy::ConvertScalarTo<double>(in[i]) *
|
||||
hwy::ConvertScalarTo<double>(vec[i]);
|
||||
expected2 += hwy::ConvertScalarTo<double>(dec[i]) *
|
||||
hwy::ConvertScalarTo<double>(vec[i]);
|
||||
const float in1 = hwy::ConvertScalarTo<float>(in[i]);
|
||||
const float dec1 = hwy::ConvertScalarTo<float>(dec[i]);
|
||||
const float vec1 = hwy::ConvertScalarTo<float>(vec[i]);
|
||||
dec_stats.Notify(in1, dec1);
|
||||
|
||||
exact += in1 * vec1;
|
||||
expected += dec1 * vec1;
|
||||
if (expected != 0.0f) {
|
||||
ratios.Notify(exact / expected);
|
||||
}
|
||||
const double l1 = hwy::ScalarAbs(expected - actual);
|
||||
const double snr = 1.0 + hwy::ScalarAbs(expected) / l1;
|
||||
fprintf(stderr, "expected %.3f e2 %.4f actual %.4f l1 %E snr %.2f\n",
|
||||
expected, expected2, actual, l1, snr);
|
||||
HWY_ASSERT(hwy::ScalarAbs(expected2 - actual) < 1E-4);
|
||||
const double expected_l1 = sizeof(T) == 2 ? 1.52E-2 : 1.15E-2;
|
||||
const double expected_snr = sizeof(T) == 2 ? 80.1f : 104.9f;
|
||||
HWY_ASSERT(expected_l1 <= l1 && l1 < 1.02f * expected_l1);
|
||||
HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr);
|
||||
}
|
||||
const double dec_snr = dec_stats.GeomeanValueDivL1();
|
||||
const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean());
|
||||
// exact and actual fluctuate due to the combination of SFP imprecision,
|
||||
// and whether vec[i] is negative or positive, so this is quite loose.
|
||||
const float final_ratio = HWY_MIN(exact / actual, actual / exact);
|
||||
fprintf(stderr, "ratios %s\n", ratios.ToString().c_str());
|
||||
fprintf(stderr,
|
||||
"exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f "
|
||||
"dot_snr %.2f\n",
|
||||
exact, expected, actual, final_ratio, dec_snr, dot_snr);
|
||||
// Final values are not too far apart.
|
||||
HWY_ASSERT(0.87f <= final_ratio && final_ratio <= 1.0f);
|
||||
// Decompressed and uncompressed dot should match exactly.
|
||||
HWY_ASSERT(hwy::ScalarAbs(expected - actual) < 1E-4f);
|
||||
// dec[] is close to in[], but we already check that in TestEncDec.
|
||||
HWY_ASSERT(dec_snr >= 50.0);
|
||||
// Geomean of ratios for each i should be very close to one.
|
||||
HWY_ASSERT(dot_snr >= (sizeof(T) == 2 ? 70.0 : 1000.0));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,64 @@
|
|||
// Copyright 2023 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "hwy/base.h"
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/distortion.h"
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/stats.h"
|
||||
#include "hwy/tests/test_util.h" // RandomState
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Returns random Gaussian (mean=0, stddev=1/3 similar to expected weights)
|
||||
// using the central limit theorem. Avoid std::normal_distribution for
|
||||
// consistent cross-platform output.
|
||||
HWY_INLINE double RandomGaussian(hwy::RandomState& rng) {
|
||||
uint64_t sum = 0;
|
||||
constexpr int kReps = 40;
|
||||
for (int rep = 0; rep < kReps; ++rep) {
|
||||
sum += hwy::Random32(&rng) & 0xFFFFF;
|
||||
}
|
||||
const double sum_f =
|
||||
static_cast<double>(sum) / static_cast<double>(0xFFFFF * kReps);
|
||||
HWY_ASSERT(0.0 <= sum_f && sum_f <= 1.0);
|
||||
const double plus_minus_1 = 2.0 * sum_f - 1.0;
|
||||
HWY_ASSERT(-1.0 <= plus_minus_1 && plus_minus_1 <= 1.0);
|
||||
// Normalize by stddev of sum of uniform random scaled to [-1, 1].
|
||||
return plus_minus_1 * std::sqrt(kReps / 3.0);
|
||||
};
|
||||
|
||||
HWY_INLINE void VerifyGaussian(Stats& stats) {
|
||||
const double stddev = stats.StandardDeviation();
|
||||
HWY_ASSERT(-0.01 <= stats.Mean() && stats.Mean() <= 0.01);
|
||||
HWY_ASSERT(0.30 <= stddev && stddev <= 0.35);
|
||||
HWY_ASSERT(-1.1 <= stats.Min() && stats.Min() <= -0.9);
|
||||
HWY_ASSERT(0.9 <= stats.Max() && stats.Max() <= 1.1);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_
|
||||
Loading…
Reference in New Issue