mirror of https://github.com/google/gemma.cpp.git
Major compression update, arbitrary-len unpack + new Dot
Compression:
* Implement {any packed} x {bf16, f32} 'Load2' and DecompressAndZeroPad
* New compression test for all packed formats, add to GEMMA_TEST_FILES, remove from sfp/nuq_test
* Decompress->DecompressAndZeroPad, use PackedSpan for args with bounds checking
* NUQ: support arbitrary-length enc/dec
* New compression/shared, remove sfp.h and nuq.h
* Move Store2 into Traits and provide Compress2 wrapper
* Remove unused Decompress()-with-pool overload
* Simplify CompressedArrayLen, rename to CompressedArrayElements
* Remove unused DistortionStats b_l1_
Misc:
* Add compensated and Kahan dot, support any length
* Use same Dot function everywhere
* Move exact arithmetic functions into fp_arith
* use FloatPtr and MatPtr typedefs in tests; less stack usage
* Rename args to packed/raw
* Remove Traits::Name, instead TypeName<T>()
* Move kMaxSFP and kClusters/kGroupSize into Sfp/NuqStream
PiperOrigin-RevId: 672868468
This commit is contained in:
parent
5c0da8c8c3
commit
8c0a8834c1
15
BUILD.bazel
15
BUILD.bazel
|
|
@ -48,15 +48,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
# Avoids circular dependency: fp_arith-inl -> compress-inl -> ops-inl
|
||||
cc_library(
|
||||
name = "fp_arith",
|
||||
textual_hdrs = ["ops/fp_arith-inl.h"],
|
||||
deps = [
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ops",
|
||||
hdrs = [
|
||||
|
|
@ -64,13 +55,13 @@ cc_library(
|
|||
],
|
||||
textual_hdrs = [
|
||||
"ops/dot-inl.h",
|
||||
"ops/fp_arith-inl.h",
|
||||
"ops/matmul-inl.h",
|
||||
"ops/matvec-inl.h",
|
||||
"ops/ops-inl.h",
|
||||
],
|
||||
deps = [
|
||||
":allocator",
|
||||
":fp_arith",
|
||||
":threading",
|
||||
"//compression:compress",
|
||||
"//compression:sfp",
|
||||
|
|
@ -97,14 +88,17 @@ cc_test(
|
|||
":common",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
":test_util",
|
||||
":threading",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
"//compression:test_util",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
"@hwy//:nanobenchmark", #buildcleaner: keep
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:stats",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -468,6 +462,7 @@ cc_library(
|
|||
":ops",
|
||||
":prompt",
|
||||
":weights",
|
||||
"@hwy//:dot",
|
||||
"@hwy//:hwy", # base.h
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -44,10 +44,10 @@ set(SOURCES
|
|||
compression/io_win.cc
|
||||
compression/io.cc
|
||||
compression/io.h
|
||||
compression/nuq.h
|
||||
compression/nuq-inl.h
|
||||
compression/sfp-inl.h
|
||||
compression/shared.h
|
||||
compression/test_util-inl.h
|
||||
compression/weights_raw.h
|
||||
backprop/activations.h
|
||||
backprop/backward.cc
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@
|
|||
// After highway.h
|
||||
#include "ops/matmul-inl.h"
|
||||
#include "ops/ops-inl.h"
|
||||
#include "hwy/contrib/dot/dot-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
|
|
|
|||
|
|
@ -52,11 +52,14 @@ namespace HWY_NAMESPACE {
|
|||
template <typename ArrayT>
|
||||
void InputEmbedding(const ArrayT& weights, const std::vector<int>& prompt,
|
||||
const float scaling, float* HWY_RESTRICT output,
|
||||
size_t model_dim) {
|
||||
size_t model_dim, size_t vocab_size) {
|
||||
const hn::ScalableTag<float> df;
|
||||
HWY_ASSERT(!prompt.empty());
|
||||
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
||||
int token = prompt[pos];
|
||||
Decompress(weights, token * model_dim, output + pos * model_dim, model_dim);
|
||||
DecompressAndZeroPad(df, MakeSpan(weights.data(), model_dim * vocab_size),
|
||||
token * model_dim, output + pos * model_dim,
|
||||
model_dim);
|
||||
MulByConst(scaling, output + pos * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
|
@ -245,7 +248,7 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
|
|||
const size_t num_tokens = prompt.size() - 1;
|
||||
|
||||
InputEmbedding(weights.embedder_input_embedding, prompt, kEmbScaling,
|
||||
forward.layers[0].input.data(), kModelDim);
|
||||
forward.layers[0].input.data(), kModelDim, kVocabSize);
|
||||
|
||||
for (size_t layer = 0; layer < kLayers; ++layer) {
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
|
|
|
|||
|
|
@ -50,7 +50,10 @@ cc_library(
|
|||
|
||||
cc_library(
|
||||
name = "distortion",
|
||||
hdrs = ["distortion.h"],
|
||||
hdrs = [
|
||||
"distortion.h",
|
||||
"shared.h",
|
||||
],
|
||||
deps = [
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:stats",
|
||||
|
|
@ -64,29 +67,43 @@ cc_test(
|
|||
srcs = ["distortion_test.cc"],
|
||||
deps = [
|
||||
":distortion",
|
||||
":shared",
|
||||
"@googletest//:gtest_main",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//:test_util",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
"@hwy//:nanobenchmark", # Unpredictable1
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "shared",
|
||||
name = "sfp",
|
||||
hdrs = ["shared.h"],
|
||||
textual_hdrs = ["sfp-inl.h"],
|
||||
deps = [
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sfp",
|
||||
textual_hdrs = ["sfp-inl.h"],
|
||||
name = "nuq",
|
||||
hdrs = ["shared.h"],
|
||||
textual_hdrs = ["nuq-inl.h"],
|
||||
deps = [
|
||||
":shared",
|
||||
":sfp",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//hwy/contrib/sort:vqsort",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_util",
|
||||
textual_hdrs = [
|
||||
"test_util-inl.h",
|
||||
],
|
||||
deps = [
|
||||
":compress",
|
||||
":distortion",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -102,9 +119,7 @@ cc_test(
|
|||
deps = [
|
||||
":distortion",
|
||||
":sfp",
|
||||
":shared",
|
||||
"@googletest//:gtest_main",
|
||||
"//:ops",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//:test_util",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
|
|
@ -112,18 +127,6 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nuq",
|
||||
hdrs = ["nuq.h"],
|
||||
textual_hdrs = ["nuq-inl.h"],
|
||||
deps = [
|
||||
":sfp",
|
||||
":shared",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//hwy/contrib/sort:vqsort",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "nuq_test",
|
||||
size = "small",
|
||||
|
|
@ -138,8 +141,7 @@ cc_test(
|
|||
":distortion",
|
||||
":nuq",
|
||||
":sfp",
|
||||
":shared",
|
||||
"@googletest//:gtest_main",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//:test_util",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
|
|
@ -149,19 +151,20 @@ cc_test(
|
|||
|
||||
cc_library(
|
||||
name = "compress",
|
||||
hdrs = ["compress.h"],
|
||||
textual_hdrs = [
|
||||
"compress-inl.h",
|
||||
hdrs = [
|
||||
"compress.h",
|
||||
"shared.h",
|
||||
],
|
||||
textual_hdrs = ["compress-inl.h"],
|
||||
deps = [
|
||||
":blob_store",
|
||||
":distortion",
|
||||
":io",
|
||||
":nuq",
|
||||
":sfp",
|
||||
":shared",
|
||||
"//:fp_arith",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:nanobenchmark",
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:stats",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
|
|
@ -170,6 +173,7 @@ cc_library(
|
|||
cc_test(
|
||||
name = "compress_test",
|
||||
size = "small",
|
||||
timeout = "long",
|
||||
srcs = ["compress_test.cc"],
|
||||
features = ["fully_static_link"],
|
||||
linkstatic = True,
|
||||
|
|
@ -179,11 +183,11 @@ cc_test(
|
|||
deps = [
|
||||
":compress",
|
||||
":distortion",
|
||||
"@googletest//:gtest_main",
|
||||
":test_util",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//:test_util",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
"@hwy//:nanobenchmark",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
|
@ -193,11 +197,9 @@ cc_library(
|
|||
name = "analyze",
|
||||
textual_hdrs = ["analyze.h"],
|
||||
deps = [
|
||||
":distortion",
|
||||
":nuq",
|
||||
":sfp",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:nanobenchmark", # timer
|
||||
"@hwy//:stats",
|
||||
"@hwy//:thread_pool",
|
||||
"@hwy//hwy/contrib/sort:vqsort",
|
||||
|
|
@ -210,7 +212,6 @@ cc_library(
|
|||
deps = [
|
||||
"//:allocator",
|
||||
"//:common",
|
||||
"//compression:compress",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
|
|
@ -221,15 +222,13 @@ cc_binary(
|
|||
srcs = ["compress_weights.cc"],
|
||||
deps = [
|
||||
":compress",
|
||||
":shared",
|
||||
":io",
|
||||
":weights_raw",
|
||||
# Placeholder for internal dep, do not remove.,
|
||||
"//:allocator",
|
||||
"//:args",
|
||||
"//:common",
|
||||
"//:gemma_lib",
|
||||
"//:weights",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:nanobenchmark",
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -26,12 +26,10 @@
|
|||
#include <cstdlib> // std::abs
|
||||
#include <vector>
|
||||
|
||||
#include "compression/distortion.h"
|
||||
#include "compression/nuq.h"
|
||||
#include "compression/shared.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/stats.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
|
||||
|
||||
|
|
@ -55,6 +53,7 @@ namespace HWY_NAMESPACE {
|
|||
class PerThread {
|
||||
public:
|
||||
void NotifyGroup(const float* group) {
|
||||
constexpr size_t kGroupSize = NuqStream::kGroupSize;
|
||||
hwy::Stats s_group;
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
// Skip zero so we can see the lowest actual magnitude
|
||||
|
|
@ -158,7 +157,7 @@ class PerThread {
|
|||
class PerLayer {
|
||||
public:
|
||||
void NotifyGroup(const float* group) {
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
for (size_t i = 0; i < NuqStream::kGroupSize; ++i) {
|
||||
s_layer_.Notify(group[i]);
|
||||
}
|
||||
}
|
||||
|
|
@ -197,8 +196,8 @@ static HWY_NOINLINE void Analyze(const char* caption, float* mat, size_t layers,
|
|||
const float* layer = &mat[idx_layer * weights_per_layer];
|
||||
// For each whole group in the layer
|
||||
for (size_t group_start = 0;
|
||||
group_start + kGroupSize <= weights_per_layer;
|
||||
group_start += kGroupSize) {
|
||||
group_start + NuqStream::kGroupSize <= weights_per_layer;
|
||||
group_start += NuqStream::kGroupSize) {
|
||||
const float* group = layer + group_start;
|
||||
per_layer[idx_layer].NotifyGroup(group);
|
||||
self.NotifyGroup(group);
|
||||
|
|
@ -210,7 +209,7 @@ static HWY_NOINLINE void Analyze(const char* caption, float* mat, size_t layers,
|
|||
const int skip = hwy::Stats::kNoGeomean;
|
||||
fprintf(stderr, "\n------------%s\n", caption);
|
||||
|
||||
for (size_t i = 1; i < pool.NumThreads(); ++i) {
|
||||
for (size_t i = 1; i < pool.NumWorkers(); ++i) {
|
||||
tls[0].Assimilate(tls[i]);
|
||||
}
|
||||
tls[0].PrintAll();
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@
|
|||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <array>
|
||||
#include <cmath> // lroundf, only if COMPRESS_STATS
|
||||
|
||||
#include "compression/blob_store.h"
|
||||
|
|
@ -42,133 +41,146 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/nuq-inl.h"
|
||||
#include "compression/sfp-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h" // also uses SIMD
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Adapters to store two f32 vectors to f32 or bf16; avoids duplicating
|
||||
// RMSNorm and RMSNormInplace for the two output types.
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
void Store2(DF df, hn::Vec<DF> v0, hn::Vec<DF> v1, float* HWY_RESTRICT out) {
|
||||
const size_t NF = hn::Lanes(df);
|
||||
hn::StoreU(v0, df, out);
|
||||
hn::StoreU(v1, df, out + NF);
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
void Store2(DF df, hn::Vec<DF> v0, hn::Vec<DF> v1, BF16* HWY_RESTRICT out) {
|
||||
const hn::Repartition<BF16, decltype(df)> dbf;
|
||||
hn::StoreU(hn::OrderedDemote2To(dbf, v0, v1), dbf, out);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Enables generic code independent of compression type.
|
||||
template <typename T> // primary, must specialize
|
||||
struct CompressTraits {};
|
||||
|
||||
// Useful for backprop/, where weights are currently f32.
|
||||
// Used by backprop/, where weights are currently f32; also MatMul for f32
|
||||
// weights or activations, if native `ReorderWidenMulAccumulate` is available.
|
||||
template <>
|
||||
struct CompressTraits<float> {
|
||||
using MatT = float;
|
||||
static const char* Name() { return "f32"; }
|
||||
static constexpr bool kSupportsEvenOdd = false; // unnecessary
|
||||
using Packed = float;
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
|
||||
size_t num, CompressPerThread& tls,
|
||||
size_t /*out_capacity*/,
|
||||
MatT* HWY_RESTRICT out, size_t out_ofs) {
|
||||
template <class DF, HWY_IF_F32_D(DF), class VF = hn::Vec<DF>>
|
||||
static HWY_INLINE void Compress(DF /*df*/, const float* HWY_RESTRICT raw,
|
||||
size_t num, CompressPerThread& /*tls*/,
|
||||
const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs) {
|
||||
hwy::CopyBytes(raw, packed.ptr + packed_ofs, num * sizeof(raw[0]));
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF), class VF = hn::Vec<DF>>
|
||||
static void Store2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs) {
|
||||
const size_t NF = hn::Lanes(df);
|
||||
hn::StoreU(raw0, df, packed.ptr + packed_ofs);
|
||||
hn::StoreU(raw1, df, packed.ptr + packed_ofs + NF);
|
||||
}
|
||||
|
||||
template <class DBF16, HWY_IF_BF16_D(DBF16), class VBF16 = hn::Vec<DBF16>>
|
||||
static HWY_INLINE void Load2(DBF16 dbf16,
|
||||
const PackedSpan<const Packed>& packed,
|
||||
const size_t packed_ofs, VBF16& raw0,
|
||||
VBF16& raw1) {
|
||||
const hn::Repartition<float, decltype(dbf16)> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t N = hn::Lanes(df);
|
||||
HWY_DASSERT(num >= 2 * N && num % (2 * N) == 0);
|
||||
const size_t NF = hn::Lanes(df);
|
||||
const VF f0 = hn::LoadU(df, packed.ptr + packed_ofs + 0 * NF);
|
||||
const VF f1 = hn::LoadU(df, packed.ptr + packed_ofs + 1 * NF);
|
||||
const VF f2 = hn::LoadU(df, packed.ptr + packed_ofs + 2 * NF);
|
||||
const VF f3 = hn::LoadU(df, packed.ptr + packed_ofs + 3 * NF);
|
||||
raw0 = hn::OrderedDemote2To(dbf16, f0, f1);
|
||||
raw1 = hn::OrderedDemote2To(dbf16, f2, f3);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num; i += 2 * N) {
|
||||
const VF in0 = hn::LoadU(df, in + i);
|
||||
const VF in1 = hn::LoadU(df, in + i + N);
|
||||
hn::StoreU(in0, df, out + out_ofs + i);
|
||||
hn::StoreU(in1, df, out + out_ofs + i + N);
|
||||
template <class DF, HWY_IF_F32_D(DF), class VF = hn::Vec<DF>>
|
||||
static HWY_INLINE void Load2(DF df, const PackedSpan<const Packed>& packed,
|
||||
const size_t packed_ofs, VF& raw0, VF& raw1) {
|
||||
const size_t N = hn::Lanes(df);
|
||||
raw0 = hn::LoadU(df, packed.ptr + packed_ofs);
|
||||
raw1 = hn::LoadU(df, packed.ptr + packed_ofs + N);
|
||||
}
|
||||
|
||||
template <class DBF, HWY_IF_BF16_D(DBF)>
|
||||
static HWY_INLINE void DecompressAndZeroPad(
|
||||
DBF dbf, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
|
||||
BF16* HWY_RESTRICT raw, size_t num) {
|
||||
const hn::Repartition<float, decltype(dbf)> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= 2 * NF) {
|
||||
for (; i <= num - 2 * NF; i += 2 * NF) {
|
||||
const VF f0 = hn::LoadU(df, packed.ptr + packed_ofs + i);
|
||||
const VF f1 = hn::LoadU(df, packed.ptr + packed_ofs + i + NF);
|
||||
hn::StoreU(hn::OrderedDemote2To(dbf, f0, f1), dbf, raw + i);
|
||||
}
|
||||
}
|
||||
const size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < 2 * NF);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
const size_t remaining2 = remaining - HWY_MIN(remaining, NF);
|
||||
const VF f0 = hn::LoadN(df, packed.ptr + packed_ofs + i, remaining);
|
||||
const VF f1 = hn::LoadN(df, packed.ptr + packed_ofs + i + NF, remaining2);
|
||||
hn::StoreU(hn::OrderedDemote2To(dbf, f0, f1), dbf, raw + i);
|
||||
}
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Decompress2(DF df, const MatT* HWY_RESTRICT in,
|
||||
size_t in_ofs, hn::Vec<DF>& f0,
|
||||
hn::Vec<DF>& f1) {
|
||||
const size_t N = hn::Lanes(df);
|
||||
f0 = hn::LoadU(df, in + in_ofs);
|
||||
f1 = hn::LoadU(df, in + in_ofs + N);
|
||||
}
|
||||
|
||||
// Called by MatMul for f32 weights or activations if native
|
||||
// `ReorderWidenMulAccumulate` is available.
|
||||
template <class DBF16, HWY_IF_BF16_D(DBF16), class VBF16 = hn::Vec<DBF16>>
|
||||
static HWY_INLINE void Decompress2(DBF16 dbf16, const MatT* HWY_RESTRICT in,
|
||||
size_t in_ofs, VBF16& v0, VBF16& v1) {
|
||||
const hn::Repartition<float, decltype(dbf16)> df;
|
||||
static HWY_INLINE void DecompressAndZeroPad(
|
||||
DF df, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
|
||||
float* HWY_RESTRICT raw, size_t num) {
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
const VF f0 = hn::LoadU(df, in + in_ofs + 0 * NF);
|
||||
const VF f1 = hn::LoadU(df, in + in_ofs + 1 * NF);
|
||||
const VF f2 = hn::LoadU(df, in + in_ofs + 2 * NF);
|
||||
const VF f3 = hn::LoadU(df, in + in_ofs + 3 * NF);
|
||||
v0 = hn::OrderedDemote2To(dbf16, f0, f1);
|
||||
v1 = hn::OrderedDemote2To(dbf16, f2, f3);
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/,
|
||||
const MatT* HWY_RESTRICT in, size_t in_ofs,
|
||||
float* HWY_RESTRICT out, size_t num) {
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t N = hn::Lanes(df);
|
||||
|
||||
for (size_t i = 0; i < num; i += N) {
|
||||
const VF v = hn::LoadU(df, in + in_ofs + i);
|
||||
hn::StoreU(v, df, out + i);
|
||||
size_t i = 0;
|
||||
if (num >= NF) {
|
||||
for (; i <= num - NF; i += NF) {
|
||||
const VF vf = hn::LoadU(df, packed.ptr + packed_ofs + i);
|
||||
hn::StoreU(vf, df, raw + i);
|
||||
}
|
||||
}
|
||||
const size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < NF);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
const VF vf = hn::LoadN(df, packed.ptr + packed_ofs + i, remaining);
|
||||
hn::StoreU(vf, df, raw + i); // adds zero padding
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CompressTraits<hwy::bfloat16_t> {
|
||||
using MatT = hwy::bfloat16_t;
|
||||
static const char* Name() { return "bf16"; }
|
||||
static constexpr bool kSupportsEvenOdd = true;
|
||||
struct CompressTraits<BF16> {
|
||||
using Packed = BF16;
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
|
||||
// Note: it is fine for the lower 16 mantissa bits of `raw` to be nonzero
|
||||
// because we round rather than truncate.
|
||||
template <class DF, HWY_IF_F32_D(DF), class VF = hn::Vec<DF>>
|
||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
|
||||
size_t num, CompressPerThread& tls,
|
||||
size_t /*out_capacity*/,
|
||||
MatT* HWY_RESTRICT out, size_t out_ofs) {
|
||||
const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs) {
|
||||
const hn::RebindToUnsigned<decltype(df)> du;
|
||||
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t N = hn::Lanes(df);
|
||||
|
||||
hn::Vec<decltype(du)> or_sum = hn::Zero(du);
|
||||
const hn::Repartition<BF16, decltype(df)> dbf;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= 2 * N) {
|
||||
for (; i <= num - 2 * N; i += 2 * N) {
|
||||
const VF in0 = hn::LoadU(df, in + i);
|
||||
const VF in1 = hn::LoadU(df, in + i + N);
|
||||
if (num >= 2 * NF) {
|
||||
for (; i <= num - 2 * NF; i += 2 * NF) {
|
||||
const VF raw0 = hn::LoadU(df, raw + i);
|
||||
const VF raw1 = hn::LoadU(df, raw + i + NF);
|
||||
|
||||
// Sticky bits so we can warn if any lower bits were set.
|
||||
or_sum = hn::Or3(or_sum, hn::BitCast(du, in0), hn::BitCast(du, in1));
|
||||
hn::StoreU(hn::OrderedDemote2To(dbf, in0, in1), dbf, out + out_ofs + i);
|
||||
hn::StoreU(hn::OrderedDemote2To(dbf, raw0, raw1), dbf,
|
||||
packed.ptr + packed_ofs + i);
|
||||
|
||||
if (COMPRESS_STATS) {
|
||||
DistortionStats stats;
|
||||
for (size_t j = 0; j < 2 * N; ++j) {
|
||||
stats.Notify(in[i + j], hwy::F32FromBF16(out[out_ofs + i + j]));
|
||||
for (size_t j = 0; j < 2 * NF; ++j) {
|
||||
stats.Notify(raw[i + j],
|
||||
hwy::F32FromBF16(packed.ptr[packed_ofs + i + j]));
|
||||
}
|
||||
tls.stats.Notify(stats);
|
||||
}
|
||||
|
|
@ -176,270 +188,248 @@ struct CompressTraits<hwy::bfloat16_t> {
|
|||
}
|
||||
|
||||
const size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < 2 * NF);
|
||||
if (remaining != 0) {
|
||||
const VF in0 = hn::LoadN(df, in + i, remaining);
|
||||
const size_t remaining1 = remaining - HWY_MIN(remaining, N / 2);
|
||||
const VF in1 = hn::LoadN(df, in + i + N, remaining1);
|
||||
const VF raw0 = hn::LoadN(df, raw + i, remaining);
|
||||
const size_t remaining1 = remaining - HWY_MIN(remaining, NF);
|
||||
const VF raw1 = hn::LoadN(df, raw + i + NF, remaining1);
|
||||
|
||||
// Sticky bits so we can warn if any lower bits were set.
|
||||
or_sum = hn::Or3(or_sum, hn::BitCast(du, in0), hn::BitCast(du, in1));
|
||||
hn::StoreU(hn::OrderedDemote2To(dbf, in0, in1), dbf, out + out_ofs + i);
|
||||
hn::StoreN(hn::OrderedDemote2To(dbf, raw0, raw1), dbf,
|
||||
packed.ptr + packed_ofs + i, remaining);
|
||||
|
||||
if (COMPRESS_STATS) {
|
||||
DistortionStats stats;
|
||||
for (size_t j = 0; j < remaining; ++j) {
|
||||
stats.Notify(in[i + j], hwy::F32FromBF16(out[out_ofs + i + j]));
|
||||
stats.Notify(raw[i + j],
|
||||
hwy::F32FromBF16(packed.ptr[packed_ofs + i + j]));
|
||||
}
|
||||
tls.stats.Notify(stats);
|
||||
}
|
||||
}
|
||||
|
||||
// If the lower 16 bits are not zero, we should implement rounding.
|
||||
or_sum = hn::And(or_sum, hn::Set(du, 0xFFFF));
|
||||
if (!hn::AllTrue(du, hn::Eq(or_sum, hn::Zero(du)))) {
|
||||
// fprintf(stderr, "Warning: Lossy truncation.");
|
||||
}
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Decompress2(DF df, const MatT* HWY_RESTRICT in,
|
||||
size_t in_ofs, hn::Vec<DF>& f0,
|
||||
hn::Vec<DF>& f1) {
|
||||
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
const VBF in16 = hn::LoadU(dbf, in + in_ofs);
|
||||
f0 = hn::PromoteLowerTo(df, in16);
|
||||
f1 = hn::PromoteUpperTo(df, in16);
|
||||
template <class DF, HWY_IF_F32_D(DF), class VF = hn::Vec<DF>>
|
||||
static void Store2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs) {
|
||||
const hn::Repartition<BF16, decltype(df)> dbf;
|
||||
hn::StoreU(hn::OrderedDemote2To(dbf, raw0, raw1), dbf,
|
||||
packed.ptr + packed_ofs);
|
||||
}
|
||||
|
||||
template <class DBF16, HWY_IF_BF16_D(DBF16)>
|
||||
static HWY_INLINE void Decompress2(DBF16 dbf16, const MatT* HWY_RESTRICT in,
|
||||
size_t in_ofs, hn::Vec<DBF16>& v0,
|
||||
hn::Vec<DBF16>& v1) {
|
||||
v0 = hn::LoadU(dbf16, in + in_ofs);
|
||||
v1 = hn::LoadU(dbf16, in + in_ofs + hn::Lanes(dbf16));
|
||||
static HWY_INLINE void Load2(DBF16 dbf16,
|
||||
const PackedSpan<const Packed>& packed,
|
||||
const size_t packed_ofs, hn::Vec<DBF16>& raw0,
|
||||
hn::Vec<DBF16>& raw1) {
|
||||
const size_t N16 = hn::Lanes(dbf16);
|
||||
raw0 = hn::LoadU(dbf16, packed.ptr + packed_ofs);
|
||||
raw1 = hn::LoadU(dbf16, packed.ptr + packed_ofs + N16);
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/,
|
||||
const MatT* HWY_RESTRICT in, size_t in_ofs,
|
||||
float* HWY_RESTRICT out, size_t num) {
|
||||
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf;
|
||||
static HWY_INLINE void Load2(DF df, const PackedSpan<const Packed>& packed,
|
||||
const size_t packed_ofs, hn::Vec<DF>& raw0,
|
||||
hn::Vec<DF>& raw1) {
|
||||
const hn::Repartition<BF16, decltype(df)> dbf;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
const VBF packed0 = hn::LoadU(dbf, packed.ptr + packed_ofs);
|
||||
raw0 = hn::PromoteLowerTo(df, packed0);
|
||||
raw1 = hn::PromoteUpperTo(df, packed0);
|
||||
}
|
||||
|
||||
template <class DBF, HWY_IF_BF16_D(DBF)>
|
||||
static HWY_INLINE void DecompressAndZeroPad(
|
||||
DBF dbf, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
|
||||
BF16* HWY_RESTRICT raw, size_t num) {
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t N16 = hn::Lanes(dbf);
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= N16) {
|
||||
for (i = 0; i <= num - N16; i += N16) {
|
||||
VF in0, in1;
|
||||
Decompress2(df, in, in_ofs + i, in0, in1);
|
||||
hn::StoreU(in0, df, out + i);
|
||||
hn::StoreU(in1, df, out + i + N16 / 2);
|
||||
const VBF packed0 = hn::LoadU(dbf, packed.ptr + packed_ofs + i);
|
||||
hn::StoreU(packed0, dbf, raw + i);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t remaining = num - i;
|
||||
if (remaining != 0) {
|
||||
const VBF in16 = hn::LoadN(dbf, in + in_ofs + i, remaining);
|
||||
const VF in0 = hn::PromoteLowerTo(df, in16);
|
||||
const VF in1 = hn::PromoteUpperTo(df, in16);
|
||||
hn::StoreN(in0, df, out + i, remaining);
|
||||
// Avoid wraparound, potentially store nothing.
|
||||
const size_t remaining1 = remaining - HWY_MIN(remaining, N16 / 2);
|
||||
hn::StoreN(in1, df, out + i + N16 / 2, remaining1);
|
||||
HWY_DASSERT(remaining < N16);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
const VBF packed0 =
|
||||
hn::LoadN(dbf, packed.ptr + packed_ofs + i, remaining);
|
||||
hn::StoreU(packed0, dbf, raw + i);
|
||||
}
|
||||
}
|
||||
|
||||
// Computes the dot product of an even-odd deinterleaved, f32 `vec_aligned`
|
||||
// and a column- major matrix `in`. `vec_aligned` should be aligned and
|
||||
// alternate even-indexed `hn::Lanes(df32)` elements followed by odd-indexed
|
||||
// `hn::Lanes(df32)` elements.
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE float DotEO(
|
||||
const DF df32, const hwy::bfloat16_t* HWY_RESTRICT in, size_t in_ofs,
|
||||
const float* HWY_RESTRICT vec_aligned, size_t num) {
|
||||
HWY_DASSERT(num >= (hn::Lanes(df32) * 2) &&
|
||||
(num % (hn::Lanes(df32) * 2)) == 0);
|
||||
HWY_DASSERT((in_ofs % (hn::Lanes(df32) * 2)) == 0);
|
||||
HWY_DASSERT(hn::IsAligned(df32, vec_aligned));
|
||||
static HWY_INLINE void DecompressAndZeroPad(
|
||||
DF df, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
|
||||
float* HWY_RESTRICT raw, size_t num) {
|
||||
const hn::Repartition<BF16, decltype(df)> dbf;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf16;
|
||||
using VF32 = decltype(Zero(df32));
|
||||
const size_t N = Lanes(dbf16);
|
||||
|
||||
VF32 sum0 = Zero(df32);
|
||||
VF32 sum1 = Zero(df32);
|
||||
VF32 sum2 = Zero(df32);
|
||||
VF32 sum3 = Zero(df32);
|
||||
|
||||
for (size_t i = 0; i < num; /* i += 2 * N */) {
|
||||
const auto interleaved0 = hn::LoadU(dbf16, in + in_ofs + i);
|
||||
const VF32 ae0 = Load(df32, vec_aligned + i);
|
||||
const VF32 ao0 = Load(df32, vec_aligned + i + (N / 2));
|
||||
sum0 = hn::MulAdd(ae0, hn::PromoteEvenTo(df32, interleaved0), sum0);
|
||||
sum1 = hn::MulAdd(ao0, hn::PromoteOddTo(df32, interleaved0), sum1);
|
||||
i += N;
|
||||
|
||||
const auto interleaved1 = hn::LoadU(dbf16, in + in_ofs + i);
|
||||
const VF32 ae1 = Load(df32, vec_aligned + i);
|
||||
const VF32 ao1 = Load(df32, vec_aligned + i + (N / 2));
|
||||
sum2 = hn::MulAdd(ae1, hn::PromoteEvenTo(df32, interleaved1), sum2);
|
||||
sum3 = hn::MulAdd(ao1, hn::PromoteOddTo(df32, interleaved1), sum3);
|
||||
i += N;
|
||||
size_t i = 0;
|
||||
if (num >= 2 * NF) {
|
||||
for (i = 0; i <= num - 2 * NF; i += 2 * NF) {
|
||||
VF raw0, raw1;
|
||||
Load2(df, packed, packed_ofs + i, raw0, raw1);
|
||||
hn::StoreU(raw0, df, raw + i);
|
||||
hn::StoreU(raw1, df, raw + i + NF);
|
||||
}
|
||||
}
|
||||
|
||||
sum0 = hn::Add(sum0, sum1);
|
||||
sum2 = hn::Add(sum2, sum3);
|
||||
sum0 = hn::Add(sum0, sum2);
|
||||
return hn::ReduceSum(df32, sum0);
|
||||
const size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < 2 * NF);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
const VBF packed0 =
|
||||
hn::LoadN(dbf, packed.ptr + packed_ofs + i, remaining);
|
||||
const VF raw0 = hn::PromoteLowerTo(df, packed0);
|
||||
const VF raw1 = hn::PromoteUpperTo(df, packed0);
|
||||
// If at most one vector, the first store adds zero padding. Check before
|
||||
// storing the second, because callers only pad to one vector.
|
||||
hn::StoreU(raw0, df, raw + i);
|
||||
if (remaining >= NF) hn::StoreU(raw1, df, raw + i + NF);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Switching floating point: 8-bit, 2..3 mantissa bits.
|
||||
template <>
|
||||
struct CompressTraits<SfpStream> {
|
||||
using MatT = SfpStream;
|
||||
static const char* Name() { return "sfp"; }
|
||||
static constexpr bool kSupportsEvenOdd = true;
|
||||
using Packed = SfpStream;
|
||||
|
||||
// Callers are responsible for scaling `in` such that its magnitudes do not
|
||||
// exceed 1.875. See CompressedArray::scale().
|
||||
// Callers are responsible for scaling `raw` such that its magnitudes do not
|
||||
// exceed `SfpStream::kMax`. See CompressedArray::scale().
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
|
||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
|
||||
size_t num, CompressPerThread& tls,
|
||||
size_t /*out_capacity*/,
|
||||
MatT* HWY_RESTRICT out, size_t out_ofs) {
|
||||
SfpCodec::Enc(df, in, num, out + out_ofs);
|
||||
const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs) {
|
||||
SfpCodec::Enc(df, raw, num, packed.ptr + packed_ofs);
|
||||
|
||||
if (COMPRESS_STATS) {
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
|
||||
auto distorted = hwy::AllocateAligned<hwy::bfloat16_t>(num);
|
||||
SfpCodec::Dec(dbf, out + out_ofs, num, distorted.get());
|
||||
const hn::Repartition<BF16, DF> dbf;
|
||||
auto distorted =
|
||||
hwy::AllocateAligned<BF16>(hwy::RoundUpTo(num, hn::Lanes(dbf)));
|
||||
SfpCodec::DecompressAndZeroPad(dbf, MakeConst(packed), packed_ofs,
|
||||
distorted.get(), num);
|
||||
DistortionStats stats;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
stats.Notify(in[i], hwy::F32FromBF16(distorted[i]));
|
||||
stats.Notify(raw[i], hwy::F32FromBF16(distorted[i]));
|
||||
}
|
||||
tls.stats.Notify(stats);
|
||||
}
|
||||
}
|
||||
|
||||
template <class D> // f32 or bf16
|
||||
static HWY_INLINE void Decompress2(D d, const MatT* HWY_RESTRICT in,
|
||||
size_t in_ofs, hn::Vec<D>& v0,
|
||||
hn::Vec<D>& v1) {
|
||||
template <class D> // Caller checks this is f32 or bf16
|
||||
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
|
||||
const size_t packed_ofs, hn::Vec<D>& raw0,
|
||||
hn::Vec<D>& raw1) {
|
||||
const hn::Twice<hn::Rebind<uint8_t, D>> d8;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
const V8 packed = hn::LoadU(d8, &in->byte + in_ofs);
|
||||
SfpCodec::Dec2(d, packed, v0, v1);
|
||||
const V8 v8 = hn::LoadU(d8, &packed.ptr->byte + packed_ofs);
|
||||
SfpCodec::Dec2(d, v8, raw0, raw1);
|
||||
}
|
||||
|
||||
template <class D, typename OutT>
|
||||
static HWY_INLINE void Decompress(D d, size_t /*in_capacity*/,
|
||||
const MatT* HWY_RESTRICT in, size_t in_ofs,
|
||||
OutT* HWY_RESTRICT out, size_t num) {
|
||||
SfpCodec::Dec(d, in + in_ofs, num, out);
|
||||
}
|
||||
// Store2 is not yet implemented.
|
||||
|
||||
// Computes the dot product of an even-odd deinterleaved, f32 or bf16
|
||||
// `vec_aligned` and a column-major matrix `in`. `vec_aligned` should be
|
||||
// aligned and alternate even-indexed `hn::Lanes(df)` elements followed by
|
||||
// odd-indexed `hn::Lanes(df)` elements.
|
||||
template <class DF, typename VecT, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE float DotEO(const DF df, const MatT* HWY_RESTRICT in,
|
||||
size_t in_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
size_t num) {
|
||||
HWY_DASSERT(num >= (hn::Lanes(df) * 2) && (num % (hn::Lanes(df) * 2)) == 0);
|
||||
HWY_DASSERT((in_ofs % (hn::Lanes(df) * 2)) == 0);
|
||||
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
||||
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
VF sum0 = hn::Zero(df);
|
||||
VF sum1 = hn::Zero(df);
|
||||
VF sum2 = hn::Zero(df);
|
||||
VF sum3 = hn::Zero(df);
|
||||
|
||||
SfpCodec::DotEO(df, in + in_ofs, num, vec_aligned, sum0, sum1, sum2, sum3);
|
||||
|
||||
// Reduction tree: sum of all accumulators, then their lanes
|
||||
sum0 = hn::Add(sum0, sum1);
|
||||
sum2 = hn::Add(sum2, sum3);
|
||||
sum0 = hn::Add(sum0, sum2);
|
||||
return hn::ReduceSum(df, sum0);
|
||||
template <class D, typename Raw>
|
||||
static HWY_INLINE void DecompressAndZeroPad(
|
||||
D d, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
|
||||
Raw* HWY_RESTRICT raw, const size_t num) {
|
||||
SfpCodec::DecompressAndZeroPad(d, packed, packed_ofs, raw, num);
|
||||
}
|
||||
};
|
||||
|
||||
// Nonuniform quantization, 4.5 bits per element, two separate streams.
|
||||
template <>
|
||||
struct CompressTraits<NuqStream> {
|
||||
using MatT = NuqStream;
|
||||
static const char* Name() { return "nuq"; }
|
||||
static constexpr bool kSupportsEvenOdd = false;
|
||||
using Packed = NuqStream;
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
|
||||
CompressPerThread& tls, size_t out_capacity,
|
||||
MatT* out, size_t out_ofs) {
|
||||
NuqCodec::Enc(df, in, num, tls.buf, out_capacity, out, out_ofs);
|
||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
|
||||
size_t num, CompressPerThread& tls,
|
||||
const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs) {
|
||||
NuqCodec::Enc(df, raw, num, tls.buf, packed, packed_ofs);
|
||||
|
||||
if (COMPRESS_STATS) {
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
tls.stats.NotifyIn(static_cast<int>(lroundf(in[i] * 100.0f + 500.0f)));
|
||||
tls.stats.NotifyIn(static_cast<int>(lroundf(raw[i] * 100.0f + 500.0f)));
|
||||
}
|
||||
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
|
||||
auto distorted = hwy::AllocateAligned<hwy::bfloat16_t>(num);
|
||||
NuqCodec::Dec(dbf, out_capacity, out, out_ofs, distorted.get(), num);
|
||||
const hn::Repartition<BF16, DF> dbf;
|
||||
const size_t N16 = hn::Lanes(dbf);
|
||||
auto distorted = hwy::AllocateAligned<BF16>(hwy::RoundUpTo(num, N16));
|
||||
NuqCodec::DecompressAndZeroPad(dbf, MakeConst(packed), packed_ofs,
|
||||
distorted.get(), num);
|
||||
DistortionStats stats;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
stats.Notify(in[i], hwy::F32FromBF16(distorted[i]));
|
||||
stats.Notify(raw[i], hwy::F32FromBF16(distorted[i]));
|
||||
}
|
||||
tls.stats.Notify(stats);
|
||||
}
|
||||
}
|
||||
|
||||
template <class D, typename OutT>
|
||||
static HWY_INLINE void Decompress(D d, size_t in_capacity, const MatT* in,
|
||||
size_t in_ofs, OutT* out, size_t num) {
|
||||
NuqCodec::Dec(d, in_capacity, in, in_ofs, out, num);
|
||||
template <class D> // Caller checks this is f32 or bf16
|
||||
static HWY_INLINE void Load2(D d, const PackedSpan<const Packed>& packed,
|
||||
const size_t packed_ofs, hn::Vec<D>& raw0,
|
||||
hn::Vec<D>& raw1) {
|
||||
const hn::Twice<hn::Rebind<uint8_t, D>> d8;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1);
|
||||
}
|
||||
|
||||
// Store2 is not yet implemented.
|
||||
|
||||
template <class D, typename Raw>
|
||||
static HWY_INLINE void DecompressAndZeroPad(
|
||||
D d, const PackedSpan<const Packed>& packed, const size_t packed_ofs,
|
||||
Raw* raw, const size_t num) {
|
||||
NuqCodec::DecompressAndZeroPad(d, packed, packed_ofs, raw, num);
|
||||
}
|
||||
};
|
||||
|
||||
// Compresses `num` inputs to `out` starting at `out_ofs`. This can be used for
|
||||
// compressing sub-regions of an array.
|
||||
template <typename MatT>
|
||||
HWY_NOINLINE void Compress(const float* in, size_t num,
|
||||
CompressWorkingSet& work, size_t out_capacity,
|
||||
MatT* out, size_t out_ofs, hwy::ThreadPool& pool) {
|
||||
HWY_DASSERT(out_ofs + num <= out_capacity);
|
||||
work.tls.resize(pool.NumThreads());
|
||||
// Compresses `num` elements of `raw` to `packed` starting at `packed_ofs`,
|
||||
// which is useful for compressing sub-regions of an array.
|
||||
template <typename Packed>
|
||||
HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num,
|
||||
CompressWorkingSet& work,
|
||||
const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs, hwy::ThreadPool& pool) {
|
||||
packed.BoundsCheck(packed_ofs, num);
|
||||
work.tls.resize(pool.NumWorkers());
|
||||
if (COMPRESS_STATS) {
|
||||
for (auto& tls : work.tls) {
|
||||
tls.stats.Reset();
|
||||
}
|
||||
}
|
||||
|
||||
const double t0 = hwy::platform::Now();
|
||||
const bool want_bench = num > 1024 * 1024 || COMPRESS_STATS;
|
||||
const double t0 = want_bench ? hwy::platform::Now() : 0.0;
|
||||
|
||||
using Traits = CompressTraits<MatT>;
|
||||
using Traits = CompressTraits<Packed>;
|
||||
constexpr size_t kBatch = 8192;
|
||||
const size_t num_batches = hwy::DivCeil(num, kBatch);
|
||||
pool.Run(0, num_batches,
|
||||
[&](const uint32_t idx_batch, size_t thread) HWY_ATTR {
|
||||
const hn::ScalableTag<float> df;
|
||||
|
||||
const size_t in_ofs = idx_batch * kBatch;
|
||||
const size_t my_pos = idx_batch * kBatch;
|
||||
const size_t my_num =
|
||||
idx_batch == num_batches - 1 ? (num - in_ofs) : kBatch;
|
||||
Traits::Compress(df, in + in_ofs, my_num, work.tls[thread],
|
||||
out_capacity, out, out_ofs + in_ofs);
|
||||
idx_batch == num_batches - 1 ? (num - my_pos) : kBatch;
|
||||
Traits::Compress(df, raw + my_pos, my_num, work.tls[thread],
|
||||
packed, packed_ofs + my_pos);
|
||||
});
|
||||
|
||||
const double t1 = hwy::platform::Now();
|
||||
const double mb = static_cast<double>(num) * sizeof(in[0]) * 1E-6;
|
||||
const double mbps = mb / (t1 - t0);
|
||||
fprintf(stderr, "Compress %.1f MB/s\n", mbps);
|
||||
if (want_bench) { // Avoids log spam in tests
|
||||
const double t1 = hwy::platform::Now();
|
||||
const double mb = static_cast<double>(num) * sizeof(raw[0]) * 1E-6;
|
||||
const double mbps = mb / (t1 - t0);
|
||||
fprintf(stderr, "Compress %.1f MB/s\n", mbps);
|
||||
}
|
||||
|
||||
if (COMPRESS_STATS) {
|
||||
for (size_t i = 1; i < work.tls.size(); ++i) {
|
||||
|
|
@ -449,53 +439,182 @@ HWY_NOINLINE void Compress(const float* in, size_t num,
|
|||
}
|
||||
}
|
||||
|
||||
// Compresses an entire std::array into `out`, which is assumed to have exactly
|
||||
// that much capacity.
|
||||
template <size_t kCapacity, typename MatT>
|
||||
HWY_INLINE void Compress(const std::array<float, kCapacity>& in,
|
||||
CompressWorkingSet& work,
|
||||
CompressedArray<MatT, kCapacity>& compressed,
|
||||
hwy::ThreadPool& pool) {
|
||||
Compress(in.data(), kCapacity, work, kCapacity, compressed.data(), 0, pool);
|
||||
// Adapter that compresses into `CompressedArray`. `raw` must already be scaled
|
||||
// to fit the value range, if `Packed` is `SfpStream`.
|
||||
template <typename Packed, size_t kCapacity>
|
||||
HWY_INLINE void CompressScaled(const float* HWY_RESTRICT raw, size_t num,
|
||||
CompressWorkingSet& work,
|
||||
CompressedArray<Packed, kCapacity>& compressed,
|
||||
hwy::ThreadPool& pool) {
|
||||
Compress(raw, num, work, MakeSpan(compressed.data(), kCapacity),
|
||||
/*packed_ofs=*/0, pool);
|
||||
}
|
||||
|
||||
// Decompresses `num` values from `compressed` starting at `compressed_ofs`.
|
||||
template <typename ArrayT, typename OutT>
|
||||
HWY_NOINLINE void Decompress(const ArrayT& compressed, size_t compressed_ofs,
|
||||
OutT* out, size_t num) {
|
||||
HWY_DASSERT(compressed_ofs + num <= compressed.size());
|
||||
const hn::ScalableTag<OutT> d;
|
||||
using Traits = CompressTraits<typename ArrayT::value_type>;
|
||||
Traits::Decompress(d, compressed.size(), compressed.data(), compressed_ofs,
|
||||
out, num);
|
||||
// Stores two f32 vectors to f32 or bf16; avoids duplicating RMSNorm and
|
||||
// RMSNormInplace for the two output types.
|
||||
template <class DF, typename Packed, HWY_IF_F32_D(DF), class VF = hn::Vec<DF>>
|
||||
void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs) {
|
||||
static_assert(hwy::IsSameEither<Packed, float, BF16>());
|
||||
packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(df));
|
||||
using Traits = CompressTraits<Packed>;
|
||||
Traits::Store2(df, raw0, raw1, packed, packed_ofs);
|
||||
}
|
||||
|
||||
// As above, but with threading and benchmarking.
|
||||
template <typename MatT, size_t kCapacity, typename OutT>
|
||||
HWY_INLINE void Decompress(const CompressedArray<MatT, kCapacity>& compressed,
|
||||
size_t compressed_ofs, OutT* out, size_t num,
|
||||
hwy::ThreadPool& pool) {
|
||||
HWY_DASSERT(compressed_ofs + num <= compressed.size());
|
||||
const double t0 = hwy::platform::Now();
|
||||
// Decompresses from any type of `packed`, to two float or BF16 vectors.
|
||||
template <class DRaw, typename Packed, class VRaw = hn::Vec<DRaw>>
|
||||
HWY_INLINE void Decompress2(DRaw d, const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs, VRaw& raw0, VRaw& raw1) {
|
||||
using TRaw = hn::TFromD<DRaw>;
|
||||
static_assert(hwy::IsSameEither<TRaw, float, BF16>());
|
||||
packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(d));
|
||||
using Traits = CompressTraits<hwy::RemoveCvRef<Packed>>;
|
||||
Traits::Load2(d, MakeConst(packed), packed_ofs, raw0, raw1);
|
||||
}
|
||||
|
||||
using Traits = CompressTraits<MatT>;
|
||||
constexpr size_t kBatch = 8192;
|
||||
const size_t num_batches = hwy::DivCeil(num, kBatch);
|
||||
pool.Run(
|
||||
0, num_batches, [&](const uint32_t idx_batch, size_t thread) HWY_ATTR {
|
||||
const hn::ScalableTag<OutT> d;
|
||||
// Decompresses from any type of `packed`, starting at (any) `packed_ofs`, to
|
||||
// (any) `num` elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as
|
||||
// required to round `num` up to one vector, if it is not already. The caller is
|
||||
// responsible for scaling `raw` to the original range because `EmbedToken`
|
||||
// also wants to scale the decompressed elements.
|
||||
template <class DRaw, typename Packed, typename TRaw = hn::TFromD<DRaw>>
|
||||
HWY_NOINLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed,
|
||||
const size_t packed_ofs, TRaw* raw,
|
||||
size_t num) {
|
||||
static_assert(hwy::IsSameEither<TRaw, float, BF16>());
|
||||
using Traits = CompressTraits<hwy::RemoveCvRef<Packed>>;
|
||||
packed.BoundsCheck(packed_ofs, num);
|
||||
Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num);
|
||||
}
|
||||
|
||||
const size_t ofs = idx_batch * kBatch;
|
||||
const size_t batch =
|
||||
idx_batch == num_batches - 1 ? (num - ofs) : kBatch;
|
||||
Traits::Decompress(d, compressed.size(), compressed.data(),
|
||||
compressed_ofs + ofs, out + ofs, batch);
|
||||
});
|
||||
// Decompresses to the type specified by `D` from each of two arrays in groups
|
||||
// of four vectors, passes them to `kernel.Update4`, zero-pads to a vector
|
||||
// multiple, then calls `kernel.Update1` for the remaining vectors. Returns
|
||||
// `kernel.Reduce`.
|
||||
//
|
||||
// This is useful for implementing dot products, and similar to
|
||||
// `hwy/contrib/unroller`, but also supports compressed types with simpler
|
||||
// remainder handling thanks to `DecompressAndZeroPad`.
|
||||
//
|
||||
// `w` can be any packed type, including NUQ, which requires a separate `w_ofs`
|
||||
// rather than pointer arithmetic. `vec_aligned` can also be any type, but
|
||||
// typically float or BF16. We omit a `v_ofs` because it is 0 in our use cases.
|
||||
// `num`, the number of elements to process, need not be a vector multiple.
|
||||
//
|
||||
// `kernel` is const& so we can pass an rvalue argument, but can contain
|
||||
// mutable state, though not vectors (see highway.h). We pass in the four
|
||||
// loaded vectors plus eight *f32* state vectors, independent of `D`.
|
||||
template <class D, typename WeightT, typename VecT, class Kernel>
|
||||
HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const WeightT>& w,
|
||||
const size_t w_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
const size_t num, const Kernel& kernel) {
|
||||
PROFILER_FUNC;
|
||||
|
||||
const double t1 = hwy::platform::Now();
|
||||
const double mb = num * sizeof(MatT) * 1E-6;
|
||||
const double mbps = mb / (t1 - t0);
|
||||
fprintf(stderr, "Decompress %.1f MB/s\n", mbps);
|
||||
HWY_DASSERT(hn::IsAligned(hn::Repartition<VecT, D>(), vec_aligned));
|
||||
const auto v_span = MakeSpan(vec_aligned, num);
|
||||
|
||||
// Decompressed inputs
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
V w0, w1, w2, w3, v0, v1, v2, v3;
|
||||
|
||||
// State for Kernel
|
||||
const hn::Repartition<float, D> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
VF sum0 = hn::Zero(df);
|
||||
VF sum1 = hn::Zero(df);
|
||||
VF sum2 = hn::Zero(df);
|
||||
VF sum3 = hn::Zero(df);
|
||||
VF comp0 = hn::Zero(df);
|
||||
VF comp1 = hn::Zero(df);
|
||||
VF comp2 = hn::Zero(df);
|
||||
VF comp3 = hn::Zero(df);
|
||||
|
||||
const size_t N = hn::Lanes(d);
|
||||
size_t i = 0;
|
||||
if (num >= 4 * N) {
|
||||
for (; i <= num - 4 * N; i += 4 * N) {
|
||||
Decompress2(d, w, w_ofs + i + 0 * N, w0, w1);
|
||||
Decompress2(d, w, w_ofs + i + 2 * N, w2, w3);
|
||||
Decompress2(d, v_span, i + 0 * N, v0, v1);
|
||||
Decompress2(d, v_span, i + 2 * N, v2, v3);
|
||||
|
||||
kernel.Update4(d, w0, w1, w2, w3, v0, v1, v2, v3, sum0, sum1, sum2, sum3,
|
||||
comp0, comp1, comp2, comp3);
|
||||
}
|
||||
}
|
||||
|
||||
size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < 4 * N);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
using T = hn::TFromD<D>;
|
||||
HWY_ALIGN T padded_w[4 * hn::MaxLanes(d)];
|
||||
HWY_ALIGN T padded_v[4 * hn::MaxLanes(d)];
|
||||
DecompressAndZeroPad(d, w, w_ofs + i, padded_w, remaining);
|
||||
DecompressAndZeroPad(d, v_span, i, padded_v, remaining);
|
||||
|
||||
// 1..4 whole vectors, possibly zero-padded.
|
||||
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
|
||||
const V w0 = hn::Load(d, padded_w + padded_pos);
|
||||
const V v0 = hn::Load(d, padded_v + padded_pos);
|
||||
kernel.Update1(d, w0, v0, sum0, comp0);
|
||||
}
|
||||
}
|
||||
|
||||
return kernel.Reduce(df, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3);
|
||||
}
|
||||
|
||||
// Same as above, but single input array. Used by RMSNorm.
|
||||
template <class D, typename VecT, class Kernel>
|
||||
HWY_INLINE float DecompressAndCall(D d, const VecT* HWY_RESTRICT vec_aligned,
|
||||
const size_t num, const Kernel& kernel) {
|
||||
PROFILER_FUNC;
|
||||
|
||||
HWY_DASSERT(hn::IsAligned(hn::Repartition<VecT, D>(), vec_aligned));
|
||||
const auto v_span = MakeSpan(vec_aligned, num);
|
||||
|
||||
// Decompressed inputs
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
V v0, v1, v2, v3;
|
||||
|
||||
// State for Kernel
|
||||
const hn::Repartition<float, D> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
VF sum0 = hn::Zero(d);
|
||||
VF sum1 = hn::Zero(d);
|
||||
VF sum2 = hn::Zero(d);
|
||||
VF sum3 = hn::Zero(d);
|
||||
VF comp0 = hn::Zero(d);
|
||||
VF comp1 = hn::Zero(d);
|
||||
VF comp2 = hn::Zero(d);
|
||||
VF comp3 = hn::Zero(d);
|
||||
|
||||
const size_t N = hn::Lanes(d);
|
||||
size_t i = 0;
|
||||
if (num >= 4 * N) {
|
||||
for (; i <= num - 4 * N; i += 4 * N) {
|
||||
Decompress2(d, v_span, i + 0 * N, v0, v1);
|
||||
Decompress2(d, v_span, i + 2 * N, v2, v3);
|
||||
|
||||
kernel.Update4(d, v0, v1, v2, v3, v0, v1, v2, v3, sum0, sum1, sum2, sum3,
|
||||
comp0, comp1, comp2, comp3);
|
||||
}
|
||||
}
|
||||
|
||||
size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < 4 * N);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
HWY_ALIGN float padded_v[4 * hn::MaxLanes(d)];
|
||||
DecompressAndZeroPad(d, v_span, i, padded_v, remaining);
|
||||
|
||||
// 1..4 whole vectors, possibly zero-padded.
|
||||
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
|
||||
const VF v0 = hn::Load(d, padded_v + padded_pos);
|
||||
kernel.Update1(d, v0, v0, sum0, comp0);
|
||||
}
|
||||
}
|
||||
|
||||
return kernel.Reduce(d, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3);
|
||||
}
|
||||
|
||||
// Functor called for each tensor, which compresses and stores them along with
|
||||
|
|
@ -504,21 +623,22 @@ class Compressor {
|
|||
public:
|
||||
explicit Compressor(hwy::ThreadPool& pool) : pool_(pool) {}
|
||||
|
||||
template <typename MatT, size_t kCapacity>
|
||||
template <typename Packed, size_t kCapacity>
|
||||
void operator()(const char* name, const float* weights,
|
||||
CompressedArray<MatT, kCapacity>& compressed) {
|
||||
CompressedArray<Packed, kCapacity>& compressed) {
|
||||
Insert(name, weights, kCapacity, work_, compressed.CompressedSize(),
|
||||
compressed.data(), 0, pool_);
|
||||
}
|
||||
|
||||
template <typename MatT>
|
||||
template <typename Packed>
|
||||
void Insert(const char* name, const float* weights, size_t weights_count,
|
||||
CompressWorkingSet& work, size_t out_capacity, MatT* out,
|
||||
size_t out_ofs, hwy::ThreadPool& pool) {
|
||||
CompressWorkingSet& work, size_t out_capacity, Packed* packed,
|
||||
size_t packed_ofs, hwy::ThreadPool& pool) {
|
||||
fprintf(stderr, "Regenerating %s (%zuM), please wait\n", name,
|
||||
weights_count / (1000 * 1000));
|
||||
Compress(weights, weights_count, work_, weights_count, out, 0, pool_);
|
||||
writer_.Add(CacheKey<MatT>(name), out, out_capacity);
|
||||
Compress(weights, weights_count, work_,
|
||||
PackedSpan<Packed>{packed, weights_count}, 0, pool_);
|
||||
writer_.Add(CacheKey<Packed>(name), packed, out_capacity);
|
||||
}
|
||||
|
||||
void AddScales(const float* scales, size_t len) {
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@
|
|||
// IWYU pragma: begin_exports
|
||||
#include "compression/blob_store.h"
|
||||
#include "compression/io.h"
|
||||
#include "compression/nuq.h"
|
||||
#include "compression/shared.h"
|
||||
// IWYU pragma: end_exports
|
||||
#include "compression/distortion.h"
|
||||
|
|
@ -41,22 +40,6 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
static inline const char* TypeName(float) { return "f32"; }
|
||||
static inline const char* TypeName(BF16) { return "b16"; }
|
||||
static inline const char* TypeName(SfpStream) { return "sfp"; }
|
||||
static inline const char* TypeName(NuqStream) { return "nuq"; }
|
||||
|
||||
// Returns the number of `MatT` elements required to store `capacity` values,
|
||||
// which must not be zero.
|
||||
template <typename MatT>
|
||||
constexpr size_t CompressedArrayElements(size_t capacity) {
|
||||
if constexpr (hwy::IsSame<hwy::RemoveCvRef<MatT>, NuqStream>()) {
|
||||
return NuqStream::PackedEnd(capacity);
|
||||
} else {
|
||||
return capacity;
|
||||
}
|
||||
}
|
||||
|
||||
// Compressed representation of floating-point elements. The array length may
|
||||
// differ from the number of elements. Associated operations such as Dot are
|
||||
// implemented in SIMD code and are thus non-member functions.
|
||||
|
|
@ -152,8 +135,8 @@ struct CompressStats {
|
|||
#endif // COMPRESS_STATS
|
||||
|
||||
struct CompressPerThread {
|
||||
NuqStream::ClusterBuf buf;
|
||||
CompressStats stats;
|
||||
ClusterBuf buf;
|
||||
};
|
||||
|
||||
struct CompressWorkingSet {
|
||||
|
|
|
|||
|
|
@ -12,3 +12,198 @@
|
|||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests.
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE)
|
||||
#endif
|
||||
|
||||
#include "compression/compress.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "compression/distortion.h"
|
||||
#include "util/test_util.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "compression/compress_test.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/test_util-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
// Calls Compress and Decompress2 and verifies the distortion/error.
|
||||
template <typename Packed>
|
||||
struct TestDecompress2T {
|
||||
template <typename T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const size_t N = hn::Lanes(d);
|
||||
CompressWorkingSet work;
|
||||
hwy::ThreadPool pool(0);
|
||||
hwy::RandomState rng;
|
||||
|
||||
const size_t num = 2 * N;
|
||||
const size_t packed_num = CompressedArrayElements<Packed>(num);
|
||||
auto raw = hwy::AllocateAligned<float>(num); // Compress requires f32
|
||||
auto packed = hwy::AllocateAligned<Packed>(packed_num);
|
||||
auto dec = hwy::AllocateAligned<T>(num);
|
||||
HWY_ASSERT(raw && packed && dec);
|
||||
const auto packed_span = MakeSpan(packed.get(), packed_num);
|
||||
|
||||
hwy::Stats in_stats;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
raw[i] = static_cast<float>(RandomGaussian(rng));
|
||||
in_stats.Notify(raw[i]);
|
||||
}
|
||||
// Short inputs fail VerifyGaussian.
|
||||
|
||||
const size_t packed_ofs = 0;
|
||||
Compress(raw.get(), num, work, packed_span, packed_ofs, pool);
|
||||
hn::Vec<D> raw0, raw1;
|
||||
Decompress2(d, MakeConst(packed_span), packed_ofs, raw0, raw1);
|
||||
hn::Store(raw0, d, dec.get());
|
||||
hn::Store(raw1, d, dec.get() + N);
|
||||
|
||||
DistortionStats stats;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
stats.Notify(raw[i], hwy::ConvertScalarTo<float>(dec[i]));
|
||||
}
|
||||
|
||||
if constexpr (false) {
|
||||
fprintf(stderr, "%s %s: %zu: %f %f %f %f\n", TypeName<Packed>(),
|
||||
TypeName<T>(), num, stats.SumL1(), stats.GeomeanValueDivL1(),
|
||||
stats.WeightedAverageL1(), stats.L1().Max());
|
||||
}
|
||||
|
||||
constexpr bool kFromFloat = hwy::IsSame<Packed, float>();
|
||||
constexpr bool kToFloat = hwy::IsSame<T, float>();
|
||||
if constexpr (kFromFloat && kToFloat) { // Lossless
|
||||
HWY_ASSERT(stats.NumExact() == num);
|
||||
HWY_ASSERT(stats.SumL1() == 0.0f);
|
||||
HWY_ASSERT(stats.L1().Max() == 0.0f);
|
||||
} else if constexpr (hwy::IsSame<Packed, BF16>() ||
|
||||
(kFromFloat && hwy::IsSame<T, BF16>())) {
|
||||
// Small roundoff error. BF16 to float is not lossless because the
|
||||
// comparison is with float `raw`, prior to the Compress to BF16.
|
||||
HWY_ASSERT(stats.L1().Max() <= 2E-3f);
|
||||
HWY_ASSERT(IsInside(3E-4, 2E-3, stats.WeightedAverageL1()));
|
||||
HWY_ASSERT(IsInside(600.0, 900.0, stats.GeomeanValueDivL1()));
|
||||
} else if constexpr (hwy::IsSame<Packed, SfpStream>()) {
|
||||
HWY_ASSERT(stats.SumL1() <= 0.4f);
|
||||
HWY_ASSERT(stats.L1().Max() <= 0.04f);
|
||||
HWY_ASSERT(IsInside(0.01, 0.03, stats.WeightedAverageL1()));
|
||||
HWY_ASSERT(IsInside(48.0, 72.0, stats.GeomeanValueDivL1()));
|
||||
} else if constexpr (hwy::IsSame<Packed, NuqStream>()) {
|
||||
static_assert(NuqStream::kGroupSize == 256, "Update expected");
|
||||
HWY_ASSERT(stats.SumL1() <= 1.2f);
|
||||
HWY_ASSERT(stats.L1().Max() <= 0.08f);
|
||||
HWY_ASSERT(IsInside(0.02, 0.05, stats.WeightedAverageL1()));
|
||||
HWY_ASSERT(IsInside(18.0, 62.0, stats.GeomeanValueDivL1()));
|
||||
} else {
|
||||
HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllDecompress2() { ForeachPackedAndRawType<TestDecompress2T>(); }
|
||||
|
||||
// Calls Compress and DecompressAndZeroPad for all short lengths and verifies
|
||||
// the distortion/error.
|
||||
template <typename Packed>
|
||||
struct TestShortLengthsT {
|
||||
template <typename T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const size_t N = hn::Lanes(d);
|
||||
CompressWorkingSet work;
|
||||
hwy::ThreadPool pool(0);
|
||||
hwy::RandomState rng;
|
||||
|
||||
for (size_t num = 1; num < 5 * hn::Lanes(d); ++num) {
|
||||
const size_t packed_num = CompressedArrayElements<Packed>(num);
|
||||
|
||||
auto raw = hwy::AllocateAligned<float>(num); // Compress requires f32
|
||||
auto packed = hwy::AllocateAligned<Packed>(packed_num);
|
||||
auto dec = hwy::AllocateAligned<T>(hwy::RoundUpTo(num, N));
|
||||
HWY_ASSERT(raw && packed && dec);
|
||||
const auto packed_span = MakeSpan(packed.get(), packed_num);
|
||||
|
||||
hwy::Stats in_stats;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
raw[i] = static_cast<float>(RandomGaussian(rng));
|
||||
in_stats.Notify(raw[i]);
|
||||
}
|
||||
// Short inputs fail VerifyGaussian.
|
||||
|
||||
const size_t packed_ofs = 0;
|
||||
Compress(raw.get(), num, work, packed_span, packed_ofs, pool);
|
||||
DecompressAndZeroPad(d, MakeConst(packed_span), packed_ofs, dec.get(),
|
||||
num);
|
||||
|
||||
DistortionStats stats;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
stats.Notify(raw[i], hwy::ConvertScalarTo<float>(dec[i]));
|
||||
}
|
||||
|
||||
if constexpr (false) {
|
||||
fprintf(stderr, "%s %s: %zu: %f %f %f %f\n", TypeName<Packed>(),
|
||||
TypeName<T>(), num, stats.SumL1(), stats.GeomeanValueDivL1(),
|
||||
stats.WeightedAverageL1(), stats.L1().Max());
|
||||
}
|
||||
|
||||
constexpr bool kFromFloat = hwy::IsSame<Packed, float>();
|
||||
constexpr bool kToFloat = hwy::IsSame<T, float>();
|
||||
if constexpr (kFromFloat && kToFloat) { // Lossless
|
||||
HWY_ASSERT(stats.NumExact() == num);
|
||||
HWY_ASSERT(stats.SumL1() == 0.0f);
|
||||
HWY_ASSERT(stats.L1().Max() == 0.0f);
|
||||
} else if (hwy::IsSame<Packed, BF16>() ||
|
||||
(kFromFloat && hwy::IsSame<T, BF16>())) {
|
||||
// Small roundoff error. BF16 to float is not lossless because the
|
||||
// comparison is with float `raw`, prior to the Compress to BF16.
|
||||
HWY_ASSERT(stats.L1().Max() <= 4E-3f);
|
||||
HWY_ASSERT(IsInside(1E-5, 3E-3, stats.WeightedAverageL1()));
|
||||
HWY_ASSERT(IsInside(300.0, 2200.0, stats.GeomeanValueDivL1()));
|
||||
} else if (hwy::IsSame<Packed, SfpStream>()) {
|
||||
HWY_ASSERT(stats.SumL1() <= 1.3f);
|
||||
HWY_ASSERT(stats.L1().Max() <= 0.08f);
|
||||
HWY_ASSERT(IsInside(7E-5, 0.05, stats.WeightedAverageL1()));
|
||||
HWY_ASSERT(IsInside(28.0, 200.0, stats.GeomeanValueDivL1()));
|
||||
} else if (hwy::IsSame<Packed, NuqStream>()) {
|
||||
static_assert(NuqStream::kGroupSize == 256, "Update expected");
|
||||
HWY_ASSERT(stats.SumL1() <= 4.6f);
|
||||
HWY_ASSERT(stats.L1().Max() <= 0.14f);
|
||||
HWY_ASSERT(IsInside(7E-5, 0.06, stats.WeightedAverageL1()));
|
||||
HWY_ASSERT(IsInside(11.0, 180.0, stats.GeomeanValueDivL1()));
|
||||
} else {
|
||||
HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType");
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllShortLengths() { ForeachPackedAndRawType<TestShortLengthsT>(); }
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(CompressTest);
|
||||
HWY_EXPORT_AND_TEST_P(CompressTest, TestAllDecompress2);
|
||||
HWY_EXPORT_AND_TEST_P(CompressTest, TestAllShortLengths);
|
||||
HWY_AFTER_TEST();
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -54,21 +54,6 @@ namespace gcpp {
|
|||
constexpr bool kDryRunFread = false;
|
||||
|
||||
namespace {
|
||||
float ScaleWeights(float* data, size_t len) {
|
||||
float maxabs = 0.0;
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
maxabs = std::max(maxabs, std::abs(data[i]));
|
||||
}
|
||||
if (maxabs <= kMaxSFP) {
|
||||
return 1.0f;
|
||||
}
|
||||
const float scale = maxabs / kMaxSFP;
|
||||
const float inv_scale = 1.0f / scale;
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
data[i] *= inv_scale;
|
||||
}
|
||||
return scale;
|
||||
}
|
||||
|
||||
#define READ_WEIGHTS(name) \
|
||||
do { \
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
#include <stdio.h>
|
||||
|
||||
#include "compression/shared.h"
|
||||
#include "compression/shared.h" // SfpStream::kMax
|
||||
#include "util/test_util.h"
|
||||
#include "hwy/nanobenchmark.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
|
@ -75,13 +75,15 @@ TEST(DistortionTest, TestDilution) {
|
|||
HWY_ASSERT(IsNear(0.001, stats.WeightedAverageL1()));
|
||||
|
||||
// Now add a large difference:
|
||||
stats.Notify(kMaxSFP - 0.0625f, kMaxSFP); // max magnitude, 3-bit mantissa
|
||||
stats.Notify(SfpStream::kMax - 0.0625f,
|
||||
SfpStream::kMax); // max magnitude, 3-bit mantissa
|
||||
// .. WeightedAverageL1 is closer to it.
|
||||
HWY_ASSERT(IsInside(0.020, 0.025, stats.WeightedAverageL1()));
|
||||
|
||||
// Add a small and large difference:
|
||||
stats.Notify((1.75f - 0.125f) / 1024, 1.75f / 1024); // small, 2-bit mantissa
|
||||
stats.Notify(-kMaxSFP + 0.0625f, -kMaxSFP); // larger negative
|
||||
stats.Notify(-SfpStream::kMax + 0.0625f,
|
||||
-SfpStream::kMax); // larger negative
|
||||
// .. SNR is still barely affected.
|
||||
HWY_ASSERT(IsInside(890.0, 900.0, stats.GeomeanValueDivL1()));
|
||||
// .. WeightedAverageL1 is higher after another large error.
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,112 +0,0 @@
|
|||
// Copyright 2023 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_H_
|
||||
|
||||
// Non-uniform quantization: a compressed representation of f32 inputs that
|
||||
// supports seeking at a granularity of kGroupSize, decoding to bf16/f32, and a
|
||||
// fused decode/dot product with bf16/f32 vectors.
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // HWY_INLINE
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// 4-bit indices are a sweet spot in terms of quality per size.
|
||||
static constexpr size_t kClusters = 16;
|
||||
|
||||
// Number of weights that share a table. Larger = slower encode, higher error,
|
||||
// smaller size (table amortized over more weights). This is the minimum
|
||||
// granularity for seeking/decoding in the stream, and must be at least four
|
||||
// times the number of bf16 elements per vector.
|
||||
static constexpr size_t kGroupSize = 256;
|
||||
|
||||
// Points to the *start* of a NUQ stream. Aligning the allocation (see
|
||||
// aligned_allocator.h) may be speed up decoding but is not required.
|
||||
//
|
||||
// Layout: first one table of kClusters entries per group, in ascending order
|
||||
// of group index, then two packed indices per byte.
|
||||
//
|
||||
// Indices are stored in-order to enable vector-length agnostic decode, because
|
||||
// streams may be persisted to disk and used by other CPUs.
|
||||
//
|
||||
// To enable parallel encoding and decoding, Enc/Dec have `offset` parameters
|
||||
// which refer to the stream, NOT the raw from/to pointers, which point directly
|
||||
// to the source/destination. Offsets are in units of values, NOT compressed
|
||||
// bytes within the stream.
|
||||
#pragma pack(push, 1)
|
||||
struct NuqStream {
|
||||
// Returns offset of packed indices from the start of the stream. This matches
|
||||
// the (padded) total table size because table entries are bytes.
|
||||
static constexpr size_t PackedStart(size_t capacity) {
|
||||
// Round up to avoid cache-line splits when loading indices. No effect on
|
||||
// size as long as capacity / kGroupSize is a multiple of 4.
|
||||
return hwy::RoundUpTo(hwy::DivCeil(capacity, kGroupSize) * kClusters, 64);
|
||||
}
|
||||
|
||||
// Returns number of NuqStream to allocate for the stream, which matches its
|
||||
// size in bytes. `capacity` is already a multiple of `kGroupSize`.
|
||||
static constexpr size_t PackedEnd(size_t capacity) {
|
||||
return PackedStart(capacity) + hwy::DivCeil(capacity, 2); // 2x 4-bit/byte
|
||||
}
|
||||
|
||||
uint8_t byte;
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
// Storage for dynamic programming. There are two matrices; we use separate
|
||||
// allocations to avoid type punning.
|
||||
template <class T>
|
||||
class AlignedMatrix {
|
||||
public:
|
||||
AlignedMatrix() : mem_(hwy::AllocateAligned<T>(kClusters * kGroupSize)) {}
|
||||
|
||||
HWY_INLINE const T& operator()(size_t row, size_t col) const {
|
||||
return mem_[row * kGroupSize + col];
|
||||
}
|
||||
|
||||
HWY_INLINE T& operator()(size_t row, size_t col) {
|
||||
return mem_[row * kGroupSize + col];
|
||||
}
|
||||
|
||||
private:
|
||||
hwy::AlignedFreeUniquePtr<T[]> mem_;
|
||||
};
|
||||
|
||||
// Reuse memory across calls to Enc to avoid per-call allocations.
|
||||
struct ClusterBuf {
|
||||
void Resize(size_t new_num) {
|
||||
if (new_num < num) return;
|
||||
|
||||
num = new_num;
|
||||
const size_t num_groups = hwy::DivCeil(num, kGroupSize);
|
||||
centers = hwy::AllocateAligned<float>(num_groups * kClusters);
|
||||
idx = hwy::AllocateAligned<uint16_t>(hwy::RoundUpTo(num, kGroupSize));
|
||||
}
|
||||
|
||||
AlignedMatrix<float> d;
|
||||
AlignedMatrix<int32_t> t;
|
||||
|
||||
size_t num = 0;
|
||||
hwy::AlignedFreeUniquePtr<float[]> centers;
|
||||
hwy::AlignedFreeUniquePtr<uint16_t[]> idx;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_H_
|
||||
|
|
@ -18,8 +18,6 @@
|
|||
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE)
|
||||
#endif
|
||||
|
||||
#include "compression/nuq.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
|
@ -28,9 +26,11 @@
|
|||
#include <random>
|
||||
|
||||
#include "compression/distortion.h"
|
||||
#include "compression/shared.h"
|
||||
#include "util/test_util.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
#include "hwy/tests/test_util.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
|
|
@ -39,10 +39,9 @@
|
|||
#define HWY_TARGET_INCLUDE "compression/nuq_test.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
// Other headers that include Highway must come after foreach_target.h
|
||||
#include "compression/nuq-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
// After highway.h
|
||||
#include "compression/nuq-inl.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
|
|
@ -50,6 +49,8 @@ namespace gcpp {
|
|||
namespace HWY_NAMESPACE {
|
||||
|
||||
static constexpr size_t kTimingReps = hn::AdjustedReps(3);
|
||||
static constexpr size_t kClusters = NuqStream::kClusters;
|
||||
static constexpr size_t kGroupSize = NuqStream::kGroupSize;
|
||||
|
||||
// All-equal inputs: only one cluster
|
||||
struct TestFlat {
|
||||
|
|
@ -65,7 +66,7 @@ struct TestFlat {
|
|||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
in[i] = 0.5f;
|
||||
}
|
||||
ClusterBuf buf;
|
||||
NuqStream::ClusterBuf buf;
|
||||
float centers[kClusters];
|
||||
uint16_t indices[kGroupSize];
|
||||
const size_t unused_clusters = NuqClustering::ClusterExactL2(
|
||||
|
|
@ -107,7 +108,7 @@ struct TestPlateaus {
|
|||
std::mt19937 rng(rd());
|
||||
std::shuffle(in.get(), in.get() + kGroupSize, rng);
|
||||
|
||||
ClusterBuf buf;
|
||||
NuqStream::ClusterBuf buf;
|
||||
float centers[kClusters];
|
||||
uint16_t indices[kGroupSize];
|
||||
const size_t unused_clusters = NuqClustering::ClusterExactL2(
|
||||
|
|
@ -154,7 +155,7 @@ struct TestRamp {
|
|||
std::mt19937 rng(rd());
|
||||
std::shuffle(in.get(), in.get() + kGroupSize, rng);
|
||||
|
||||
ClusterBuf buf;
|
||||
NuqStream::ClusterBuf buf;
|
||||
float centers[kClusters];
|
||||
uint16_t indices[kGroupSize];
|
||||
const size_t unused_clusters = NuqClustering::ClusterExactL2(
|
||||
|
|
@ -199,7 +200,7 @@ struct TestNormal {
|
|||
}
|
||||
VerifyGaussian(in_stats);
|
||||
|
||||
ClusterBuf buf;
|
||||
NuqStream::ClusterBuf buf;
|
||||
float centers[kClusters];
|
||||
uint16_t indices[kGroupSize];
|
||||
double elapsed = hwy::HighestValue<double>();
|
||||
|
|
@ -239,7 +240,7 @@ struct TestOffset {
|
|||
template <typename T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const hn::Repartition<float, D> df;
|
||||
const size_t total = 10 * kGroupSize;
|
||||
const size_t total = 10 * kGroupSize; // already padded
|
||||
const size_t kMidLen = 2 * kGroupSize; // length of middle piece
|
||||
|
||||
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
|
||||
|
|
@ -247,6 +248,7 @@ struct TestOffset {
|
|||
auto dec2 = hwy::AllocateAligned<T>(kMidLen);
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
|
||||
HWY_ASSERT(in && dec1 && dec2 && nuq);
|
||||
const auto nuq_span = MakeSpan(nuq.get(), total);
|
||||
|
||||
hwy::RandomState rng;
|
||||
for (size_t i = 0; i < total; ++i) {
|
||||
|
|
@ -254,53 +256,72 @@ struct TestOffset {
|
|||
}
|
||||
|
||||
// Encode + decode everything
|
||||
ClusterBuf buf;
|
||||
(void)NuqCodec::Enc(df, in.get(), total, buf, total, nuq.get(), 0);
|
||||
NuqCodec::Dec(d, total, nuq.get(), 0, dec1.get(), total);
|
||||
NuqStream::ClusterBuf buf;
|
||||
(void)NuqCodec::Enc(df, in.get(), total, buf, nuq_span, 0);
|
||||
NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), 0, dec1.get(),
|
||||
total);
|
||||
|
||||
// Overwrite middle with first inputs
|
||||
const size_t offset = 5 * kGroupSize;
|
||||
(void)NuqCodec::Enc(df, in.get(), kMidLen, buf, total, nuq.get(), offset);
|
||||
(void)NuqCodec::Enc(df, in.get(), kMidLen, buf, nuq_span, offset);
|
||||
|
||||
// Decoded middle now matches previously decoded first
|
||||
NuqCodec::Dec(d, total, nuq.get(), offset, dec2.get(), kMidLen);
|
||||
NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), offset, dec2.get(),
|
||||
kMidLen);
|
||||
for (size_t i = 0; i < kMidLen; ++i) {
|
||||
HWY_ASSERT(dec1[i] == dec2[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllOffsetF32() {
|
||||
const hn::ForGEVectors<128, TestOffset> test;
|
||||
test(float());
|
||||
}
|
||||
|
||||
void TestAllOffsetBF16() {
|
||||
const hn::ForGEVectors<128, TestOffset> test;
|
||||
test(hwy::bfloat16_t());
|
||||
}
|
||||
void TestOffsetBF16() { hn::ForGEVectors<128, TestOffset>()(BF16()); }
|
||||
void TestOffsetF32() { hn::ForGEVectors<128, TestOffset>()(float()); }
|
||||
|
||||
struct TestNibble {
|
||||
template <typename T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const hn::Repartition<uint8_t, D> d8;
|
||||
const hn::Half<decltype(d8)> d8h;
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
const size_t N = hn::Lanes(d);
|
||||
const size_t num = 4 * N;
|
||||
auto bytes = hwy::AllocateAligned<uint8_t>(num / 2);
|
||||
HWY_ASSERT(bytes);
|
||||
const V v0 = hn::And(hn::Iota(d, 0), hn::Set(d, 15));
|
||||
const V v1 = hn::Set(d, 1);
|
||||
const V v2 = hn::OddEven(v1, hn::Zero(d));
|
||||
const V v3 = hn::Reverse(d, v0);
|
||||
NibbleCodec::OrderedPackU16(d, v0, v1, v2, v3, bytes.get());
|
||||
const V out0 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 0 * N / 2);
|
||||
const V out1 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 1 * N / 2);
|
||||
const V out2 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 2 * N / 2);
|
||||
const V out3 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 3 * N / 2);
|
||||
HWY_ASSERT_VEC_EQ(d, v0, out0);
|
||||
HWY_ASSERT_VEC_EQ(d, v1, out1);
|
||||
HWY_ASSERT_VEC_EQ(d, v2, out2);
|
||||
HWY_ASSERT_VEC_EQ(d, v3, out3);
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
using V8H = hn::Vec<decltype(d8h)>;
|
||||
const V mask = hn::Set(d, 15);
|
||||
|
||||
{
|
||||
const V v0 = hn::And(hn::Iota(d, 0), mask);
|
||||
const V v1 = hn::Set(d, 1);
|
||||
const V v2 = hn::OddEven(v1, hn::Zero(d));
|
||||
const V v3 = hn::Reverse(d, v0);
|
||||
const V8 nibbles = NibbleCodec::OrderedPackU16(d, v0, v1, v2, v3);
|
||||
const V8H nibbles0 = hn::LowerHalf(d8h, nibbles);
|
||||
const V8H nibbles1 = hn::UpperHalf(d8h, nibbles);
|
||||
const V out0 = NibbleCodec::OrderedUnpackU16<0>(d, nibbles0);
|
||||
const V out1 = NibbleCodec::OrderedUnpackU16<1>(d, nibbles0);
|
||||
const V out2 = NibbleCodec::OrderedUnpackU16<0>(d, nibbles1);
|
||||
const V out3 = NibbleCodec::OrderedUnpackU16<1>(d, nibbles1);
|
||||
HWY_ASSERT_VEC_EQ(d, v0, out0);
|
||||
HWY_ASSERT_VEC_EQ(d, v1, out1);
|
||||
HWY_ASSERT_VEC_EQ(d, v2, out2);
|
||||
HWY_ASSERT_VEC_EQ(d, v3, out3);
|
||||
}
|
||||
// Same, but with different values in each lane.
|
||||
{
|
||||
const V v0 = hn::And(hn::Iota(d, 0), mask);
|
||||
const V v1 = hn::And(hn::Iota(d, 1), mask);
|
||||
const V v2 = hn::And(hn::Iota(d, 2), mask);
|
||||
const V v3 = hn::And(hn::Iota(d, 3), mask);
|
||||
const V8 nibbles = NibbleCodec::OrderedPackU16(d, v0, v1, v2, v3);
|
||||
const V8H nibbles0 = hn::LowerHalf(d8h, nibbles);
|
||||
const V8H nibbles1 = hn::UpperHalf(d8h, nibbles);
|
||||
const V out0 = NibbleCodec::OrderedUnpackU16<0>(d, nibbles0);
|
||||
const V out1 = NibbleCodec::OrderedUnpackU16<1>(d, nibbles0);
|
||||
const V out2 = NibbleCodec::OrderedUnpackU16<0>(d, nibbles1);
|
||||
const V out3 = NibbleCodec::OrderedUnpackU16<1>(d, nibbles1);
|
||||
HWY_ASSERT_VEC_EQ(d, v0, out0);
|
||||
HWY_ASSERT_VEC_EQ(d, v1, out1);
|
||||
HWY_ASSERT_VEC_EQ(d, v2, out2);
|
||||
HWY_ASSERT_VEC_EQ(d, v3, out3);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -309,15 +330,19 @@ void TestAllNibble() {
|
|||
test(uint16_t());
|
||||
}
|
||||
|
||||
struct TestStream {
|
||||
// Checks the distortion from an encode and decode round trip. Unlike
|
||||
// `TestShortLengthsT` in compress_test, this covers large `num` and
|
||||
// prints the enc/dec throughput.
|
||||
struct TestEncDec {
|
||||
template <typename T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const hn::Repartition<float, D> df;
|
||||
const size_t num = 4 * kGroupSize;
|
||||
auto in = hwy::AllocateAligned<float>(num); // Enc() requires f32
|
||||
auto out = hwy::AllocateAligned<T>(num);
|
||||
auto out = hwy::AllocateAligned<T>(num); // already padded
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num));
|
||||
HWY_ASSERT(in && out && nuq);
|
||||
const auto nuq_span = MakeSpan(nuq.get(), num);
|
||||
|
||||
hwy::RandomState rng;
|
||||
hwy::Stats in_stats;
|
||||
|
|
@ -327,12 +352,12 @@ struct TestStream {
|
|||
}
|
||||
VerifyGaussian(in_stats);
|
||||
|
||||
ClusterBuf buf;
|
||||
NuqStream::ClusterBuf buf;
|
||||
double elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < kTimingReps; ++rep) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
const size_t unused_clusters =
|
||||
NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0);
|
||||
NuqCodec::Enc(df, in.get(), num, buf, nuq_span, 0);
|
||||
HWY_ASSERT(unused_clusters == 0);
|
||||
const double t1 = hwy::platform::Now();
|
||||
elapsed = HWY_MIN(elapsed, t1 - t0);
|
||||
|
|
@ -343,7 +368,7 @@ struct TestStream {
|
|||
elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < kTimingReps; ++rep) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
NuqCodec::Dec(d, num, nuq.get(), 0, out.get(), num);
|
||||
NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), 0, out.get(), num);
|
||||
const double t1 = hwy::platform::Now();
|
||||
elapsed = HWY_MIN(elapsed, t1 - t0);
|
||||
}
|
||||
|
|
@ -367,129 +392,8 @@ struct TestStream {
|
|||
}
|
||||
};
|
||||
|
||||
void TestAllStreamF32() {
|
||||
const hn::ForGEVectors<128, TestStream> test;
|
||||
test(float());
|
||||
}
|
||||
|
||||
void TestAllStreamBF16() {
|
||||
const hn::ForGEVectors<128, TestStream> test;
|
||||
test(hwy::bfloat16_t());
|
||||
}
|
||||
|
||||
struct TestDot {
|
||||
template <typename T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const hn::Repartition<float, D> df;
|
||||
const size_t num = 4 * kGroupSize;
|
||||
auto in = hwy::AllocateAligned<float>(num);
|
||||
auto dec = hwy::AllocateAligned<float>(num);
|
||||
auto vec = hwy::AllocateAligned<T>(num);
|
||||
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num));
|
||||
HWY_ASSERT(in && dec && vec && nuq);
|
||||
|
||||
// Generate inputs and verify their distribution.
|
||||
hwy::RandomState rng;
|
||||
hwy::Stats in_stats;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
in[i] = static_cast<float>(RandomGaussian(rng));
|
||||
in_stats.Notify(in[i]);
|
||||
}
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
const float r = static_cast<float>(RandomGaussian(rng));
|
||||
in_stats.Notify(r);
|
||||
vec[i] = hwy::ConvertScalarTo<T>(r);
|
||||
}
|
||||
VerifyGaussian(in_stats);
|
||||
|
||||
ClusterBuf buf;
|
||||
const size_t unused_clusters =
|
||||
NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0);
|
||||
HWY_ASSERT(unused_clusters == 0);
|
||||
|
||||
// Compute dot product without decompression.
|
||||
float actual = 0.0f;
|
||||
double elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < kTimingReps; ++rep) {
|
||||
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
|
||||
hn::Vec<decltype(df)> sum1 = hn::Zero(df);
|
||||
hn::Vec<decltype(df)> sum2 = hn::Zero(df);
|
||||
hn::Vec<decltype(df)> sum3 = hn::Zero(df);
|
||||
const double t0 = hwy::platform::Now();
|
||||
NuqCodec::Dot(df, num, nuq.get(), 0, vec.get(), num, sum0, sum1, sum2,
|
||||
sum3);
|
||||
const double t1 = hwy::platform::Now();
|
||||
elapsed = HWY_MIN(elapsed, t1 - t0);
|
||||
sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3));
|
||||
actual = hn::ReduceSum(df, sum0);
|
||||
}
|
||||
|
||||
NuqCodec::Dec(df, num, nuq.get(), 0, dec.get(), num);
|
||||
fprintf(stderr, "Vec %zu Dec %.2f MB/s\n", Lanes(d) * sizeof(T),
|
||||
num * sizeof(in[0]) * 1E-6 / elapsed);
|
||||
|
||||
// Exact and decompressed dot products for comparison.
|
||||
float exact = 0.0f; // using original input
|
||||
float expected = 0.0f; // using decoded NUQ
|
||||
DistortionStats dec_stats;
|
||||
hwy::Stats ratios;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
dec_stats.Notify(in[i], dec[i]);
|
||||
const float v1 = hwy::ConvertScalarTo<float>(vec[i]);
|
||||
exact += in[i] * v1;
|
||||
expected += dec[i] * v1;
|
||||
if (expected != 0.0f) {
|
||||
ratios.Notify(exact / expected);
|
||||
}
|
||||
}
|
||||
const bool isBF = sizeof(T) == 2;
|
||||
const double dec_snr = dec_stats.GeomeanValueDivL1();
|
||||
const double dec_wl1 = dec_stats.WeightedAverageL1();
|
||||
const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean());
|
||||
// exact and actual fluctuate due to the combination of NUQ imprecision,
|
||||
// and whether vec[i] is negative or positive, so this is quite loose.
|
||||
const float final_ratio = HWY_MIN(exact / actual, actual / exact);
|
||||
if (HWY_ONCE) {
|
||||
fprintf(stderr, "ratios %s\n", ratios.ToString().c_str());
|
||||
fprintf(stderr,
|
||||
"exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f "
|
||||
"dot_snr %.2f dec_wl1 %.4f\n",
|
||||
exact, expected, actual, final_ratio, dec_snr, dot_snr, dec_wl1);
|
||||
}
|
||||
// Final values are not too far apart.
|
||||
HWY_ASSERT(gcpp::IsInside(0.88f, 1.0f, final_ratio));
|
||||
// Decompressed and uncompressed dot should match exactly.
|
||||
HWY_ASSERT(gcpp::IsNear(expected, actual, 1E-4f));
|
||||
// Geomean of ratios for each i should be very close to one.
|
||||
HWY_ASSERT(dot_snr >= (isBF ? 17.7 : 14.3));
|
||||
|
||||
// dec[] is close to in[], but we already check that in TestStream with the
|
||||
// same input distribution.
|
||||
HWY_ASSERT(gcpp::IsNear(13.1, dec_snr, 0.1));
|
||||
HWY_ASSERT(gcpp::IsNear(0.034, dec_wl1, 0.001));
|
||||
HWY_ASSERT(gcpp::IsNear(23.5, dec_stats.SumL1(), 0.1));
|
||||
HWY_ASSERT(dec_stats.NumSignFlip() < num / kClusters);
|
||||
HWY_ASSERT_EQ(0, dec_stats.NumExact());
|
||||
HWY_ASSERT_EQ(0, dec_stats.NumRoundedToZero());
|
||||
HWY_ASSERT_EQ(0.0, dec_stats.SumL1Rounded());
|
||||
// Absolute decode errors are in [0, 0.11], and somewhat right-tailed.
|
||||
HWY_ASSERT(gcpp::IsInside(0.0f, 2E-5f, dec_stats.L1().Min()));
|
||||
HWY_ASSERT(gcpp::IsInside(0.09f, 0.11f, dec_stats.L1().Max()));
|
||||
HWY_ASSERT(gcpp::IsInside(0.02, 0.03, dec_stats.L1().Mean()));
|
||||
HWY_ASSERT(gcpp::IsInside(1.0, 1.1, dec_stats.L1().Skewness()));
|
||||
HWY_ASSERT(gcpp::IsInside(4.0, 5.0, dec_stats.L1().Kurtosis()));
|
||||
static_assert(kGroupSize == 256, "Update expected*");
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllDotF32() {
|
||||
const hn::ForGEVectors<128, TestDot> test;
|
||||
test(float());
|
||||
}
|
||||
void TestAllDotBF16() {
|
||||
const hn::ForGEVectors<128, TestDot> test;
|
||||
test(hwy::bfloat16_t());
|
||||
}
|
||||
void TestEncDecBF16() { hn::ForGEVectors<128, TestEncDec>()(BF16()); }
|
||||
void TestEncDecF32() { hn::ForGEVectors<128, TestEncDec>()(float()); }
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
|
|
@ -497,23 +401,17 @@ void TestAllDotBF16() {
|
|||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
|
||||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(NuqTest);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllFlat);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllPlateaus);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNormal);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetF32);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetBF16);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetBF16);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetF32);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamF32);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamBF16);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotF32);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotBF16);
|
||||
#ifdef HWY_AFTER_TEST
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecBF16);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecF32);
|
||||
HWY_AFTER_TEST();
|
||||
#endif
|
||||
} // namespace gcpp
|
||||
|
||||
#endif
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -52,9 +52,6 @@ HWY_INLINE hn::Mask<DU> SignedLt(DU du, hn::Vec<DU> a, hn::Vec<DU> b) {
|
|||
return SignedGt(du, b, a);
|
||||
}
|
||||
|
||||
// Saturated subtraction; returns 0 if the result would be negative.
|
||||
static inline size_t SubOr0(size_t a, size_t b) { return a > b ? a - b : 0; }
|
||||
|
||||
// Encode/decode functions.
|
||||
class SfpCodec {
|
||||
public:
|
||||
|
|
@ -260,9 +257,9 @@ class SfpCodec {
|
|||
}
|
||||
|
||||
// Encodes `num` bf16 values from `in_bf` to `out_packed`. Their magnitude
|
||||
// must be at most 1.875.
|
||||
// must be at most SfpStream::kMax.
|
||||
template <class DBF, HWY_IF_BF16_D(DBF)>
|
||||
static HWY_INLINE void Enc(DBF dbf, const hwy::bfloat16_t* HWY_RESTRICT in_bf,
|
||||
static HWY_INLINE void Enc(DBF dbf, const BF16* HWY_RESTRICT in_bf,
|
||||
size_t num, SfpStream* HWY_RESTRICT out_packed) {
|
||||
const hn::Repartition<uint8_t, DBF> d8;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
|
|
@ -280,7 +277,7 @@ class SfpCodec {
|
|||
const size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < 2 * N16);
|
||||
if (remaining != 0) {
|
||||
HWY_ALIGN hwy::bfloat16_t padded[2 * hn::MaxLanes(dbf)];
|
||||
HWY_ALIGN BF16 padded[2 * hn::MaxLanes(dbf)];
|
||||
hwy::ZeroBytes(padded, sizeof(padded));
|
||||
hwy::CopyBytes(in_bf + i, padded, remaining * sizeof(padded[0]));
|
||||
const V8 packed = Enc2B(dbf, padded);
|
||||
|
|
@ -289,7 +286,7 @@ class SfpCodec {
|
|||
}
|
||||
|
||||
// Encodes `num` f32 values from `in_f` to `packed`. Their magnitude
|
||||
// must be at most 1.875.
|
||||
// must be at most SfpStream::kMax.
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Enc(DF df, const float* HWY_RESTRICT in_f, size_t num,
|
||||
SfpStream* HWY_RESTRICT out_packed) {
|
||||
|
|
@ -317,148 +314,112 @@ class SfpCodec {
|
|||
}
|
||||
}
|
||||
|
||||
// Decodes `num` values from `in_packed` to `out_bf`.
|
||||
template <class DBF16, HWY_IF_BF16_D(DBF16),
|
||||
class V8 = hn::Vec<hn::Repartition<uint8_t, DBF16>>>
|
||||
static HWY_INLINE void Dec2(DBF16 dbf16, V8 packed, hn::Vec<DBF16>& raw0,
|
||||
hn::Vec<DBF16>& raw1) {
|
||||
Dec2B(dbf16, packed, raw0, raw1);
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF),
|
||||
class V8 = hn::Vec<hn::Twice<hn::Rebind<uint8_t, DF>>>>
|
||||
static HWY_INLINE void Dec2(DF df, V8 packed, hn::Vec<DF>& raw0,
|
||||
hn::Vec<DF>& raw1) {
|
||||
const hn::Rebind<BF16, DF> dbf; // half-vector
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
VBF bf0, bf1;
|
||||
Dec2B(dbf, packed, bf0, bf1);
|
||||
raw0 = hn::PromoteTo(df, bf0);
|
||||
raw1 = hn::PromoteTo(df, bf1);
|
||||
}
|
||||
|
||||
// Decompresses to (arbitrary) `num` BF16 elements in `raw_bf`, then appends
|
||||
// `[0, hn::Lanes(dbf))` zeroes as required to round `num` up to one vector,
|
||||
// if it is not already. DBF argument is provided by nuq-inl.h.
|
||||
template <class DBF, HWY_IF_BF16_D(DBF)>
|
||||
static HWY_INLINE void Dec(DBF dbf, const SfpStream* HWY_RESTRICT in_packed,
|
||||
size_t num, hwy::bfloat16_t* HWY_RESTRICT out_bf) {
|
||||
static HWY_INLINE void DecompressAndZeroPad(
|
||||
DBF dbf, const PackedSpan<const SfpStream>& packed, size_t packed_ofs,
|
||||
BF16* HWY_RESTRICT raw_bf, size_t num) {
|
||||
const hn::Repartition<uint8_t, DBF> d8;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
const size_t N16 = hn::Lanes(dbf);
|
||||
|
||||
const uint8_t* HWY_RESTRICT base = &packed.ptr->byte + packed_ofs;
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= 2 * N16) {
|
||||
HWY_UNROLL(1)
|
||||
for (; i <= num - 2 * N16; i += 2 * N16) {
|
||||
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
|
||||
const V8 packed = hn::LoadU(d8, base + i);
|
||||
VBF bf0, bf1;
|
||||
Dec2B(dbf, packed, bf0, bf1);
|
||||
hn::StoreU(bf0, dbf, out_bf + i);
|
||||
hn::StoreU(bf1, dbf, out_bf + i + N16);
|
||||
hn::StoreU(bf0, dbf, raw_bf + i);
|
||||
hn::StoreU(bf1, dbf, raw_bf + i + N16);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < 2 * N16);
|
||||
if (remaining != 0) {
|
||||
const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining);
|
||||
const V8 packed = hn::LoadN(d8, base + i, remaining);
|
||||
VBF bf0, bf1;
|
||||
Dec2B(dbf, packed, bf0, bf1);
|
||||
hn::StoreN(bf0, dbf, out_bf + i, remaining);
|
||||
hn::StoreN(bf1, dbf, out_bf + i + N16, SubOr0(remaining, N16));
|
||||
// If at most one vector, the first store adds zero padding. Check before
|
||||
// storing the second, because callers only pad to one vector.
|
||||
hn::StoreU(bf0, dbf, raw_bf + i);
|
||||
if (remaining > N16) hn::StoreU(bf1, dbf, raw_bf + i + N16);
|
||||
}
|
||||
}
|
||||
|
||||
// Decodes `num` values from `in_packed` to `out_f`.
|
||||
// Decompresses to (arbitrary) `num` float elements in `raw_f`, then appends
|
||||
// `[0, hn::Lanes(df))` zeroes as required to round `num` up to one vector,
|
||||
// if it is not already.
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Dec(DF df, const SfpStream* HWY_RESTRICT in_packed,
|
||||
size_t num, float* HWY_RESTRICT out_f) {
|
||||
static HWY_INLINE void DecompressAndZeroPad(
|
||||
DF df, const PackedSpan<const SfpStream>& packed, size_t packed_ofs,
|
||||
float* HWY_RESTRICT raw_f, size_t num) {
|
||||
const hn::Repartition<uint8_t, DF> d8;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
||||
const uint8_t* HWY_RESTRICT base = &packed.ptr->byte + packed_ofs;
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= 4 * NF) {
|
||||
HWY_UNROLL(1)
|
||||
for (; i <= num - 4 * NF; i += 4 * NF) {
|
||||
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
|
||||
const V8 packed = hn::LoadU(d8, base + i);
|
||||
VF f0, f1, f2, f3;
|
||||
Dec4F(df, packed, f0, f1, f2, f3);
|
||||
hn::StoreU(f0, df, out_f + i + NF * 0);
|
||||
hn::StoreU(f1, df, out_f + i + NF * 1);
|
||||
hn::StoreU(f2, df, out_f + i + NF * 2);
|
||||
hn::StoreU(f3, df, out_f + i + NF * 3);
|
||||
hn::StoreU(f0, df, raw_f + i + NF * 0);
|
||||
hn::StoreU(f1, df, raw_f + i + NF * 1);
|
||||
hn::StoreU(f2, df, raw_f + i + NF * 2);
|
||||
hn::StoreU(f3, df, raw_f + i + NF * 3);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < 4 * NF);
|
||||
if (remaining != 0) {
|
||||
const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
const V8 packed = hn::LoadN(d8, base + i, remaining);
|
||||
VF f0, f1, f2, f3;
|
||||
Dec4F(df, packed, f0, f1, f2, f3);
|
||||
hn::StoreN(f0, df, out_f + i + 0 * NF, remaining);
|
||||
hn::StoreN(f1, df, out_f + i + 1 * NF, SubOr0(remaining, 1 * NF));
|
||||
hn::StoreN(f2, df, out_f + i + 2 * NF, SubOr0(remaining, 2 * NF));
|
||||
hn::StoreN(f3, df, out_f + i + 3 * NF, SubOr0(remaining, 3 * NF));
|
||||
// We are only guaranteed one vector of padding, so cannot unconditionally
|
||||
// store four vectors. `StoreN` would work, at the cost of saturated
|
||||
// subtraction and creating masks. Because we know that `raw_f` is padded
|
||||
// to at least one vector, we can instead store entire vectors and only
|
||||
// make the address conditional, which potentially avoids branches.
|
||||
// Separate per-vector storage may avoid conflicts.
|
||||
HWY_ALIGN float buf[4 * hn::MaxLanes(df)];
|
||||
hn::StoreU(f0, df, raw_f + i);
|
||||
hn::StoreU(f1, df, (remaining > 1 * NF ? (raw_f + i) : buf) + 1 * NF);
|
||||
hn::StoreU(f2, df, (remaining > 2 * NF ? (raw_f + i) : buf) + 2 * NF);
|
||||
hn::StoreU(f3, df, (remaining > 3 * NF ? (raw_f + i) : buf) + 3 * NF);
|
||||
}
|
||||
}
|
||||
|
||||
// Fused decode and dot product with even-odd bf16 into four f32 accumulators.
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void DotEO(DF df, const SfpStream* HWY_RESTRICT in_packed,
|
||||
size_t num,
|
||||
const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
|
||||
hn::Vec<DF>& sum0, hn::Vec<DF>& sum1,
|
||||
hn::Vec<DF>& sum2, hn::Vec<DF>& sum3) {
|
||||
const hn::Repartition<uint8_t, DF> d8;
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
const size_t N16 = hn::Lanes(dbf);
|
||||
HWY_DASSERT(num % (2 * N16) == 0); // whole SFP vector -> 2x bf16
|
||||
|
||||
HWY_UNROLL(1)
|
||||
for (size_t i = 0; i < num; i += 2 * N16) {
|
||||
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
|
||||
const VBF ve = hn::LoadU(dbf, vec_aligned + i);
|
||||
const VBF vo = hn::LoadU(dbf, vec_aligned + i + N16);
|
||||
VBF be, bo;
|
||||
DecEvenOdd(dbf, packed, be, bo);
|
||||
sum0 = hn::ReorderWidenMulAccumulate(df, be, ve, sum0, sum1);
|
||||
sum2 = hn::ReorderWidenMulAccumulate(df, bo, vo, sum2, sum3);
|
||||
}
|
||||
}
|
||||
|
||||
// Fused decode and dot product with even-odd f32 into four f32 accumulators.
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void DotEO(DF df, const SfpStream* HWY_RESTRICT in_packed,
|
||||
size_t num,
|
||||
const float* HWY_RESTRICT vec_aligned,
|
||||
hn::Vec<DF>& sum0, hn::Vec<DF>& sum1,
|
||||
hn::Vec<DF>& sum2, hn::Vec<DF>& sum3) {
|
||||
const hn::Repartition<uint8_t, DF> d8;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
HWY_DASSERT(num % (4 * NF) == 0); // whole SFP vector -> 4x f32
|
||||
|
||||
HWY_UNROLL(1)
|
||||
for (size_t i = 0; i < num; i += 4 * NF) {
|
||||
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
|
||||
const VF ve0 = hn::LoadU(df, vec_aligned + i + NF * 0);
|
||||
const VF vo0 = hn::LoadU(df, vec_aligned + i + NF * 1);
|
||||
const VF ve1 = hn::LoadU(df, vec_aligned + i + NF * 2);
|
||||
const VF vo1 = hn::LoadU(df, vec_aligned + i + NF * 3);
|
||||
VF fe0, fo0, fe1, fo1;
|
||||
DecEvenOddF(df, packed, fe0, fo0, fe1, fo1);
|
||||
sum0 = hn::MulAdd(fe0, ve0, sum0);
|
||||
sum1 = hn::MulAdd(fo0, vo0, sum1);
|
||||
sum2 = hn::MulAdd(fe1, ve1, sum2);
|
||||
sum3 = hn::MulAdd(fo1, vo1, sum3);
|
||||
}
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF),
|
||||
class V8 = hn::Vec<hn::Twice<hn::Rebind<uint8_t, DF>>>>
|
||||
static HWY_INLINE void Dec2(DF df, V8 packed, hn::Vec<DF>& f0,
|
||||
hn::Vec<DF>& f1) {
|
||||
const hn::Rebind<hwy::bfloat16_t, DF> dbf;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
VBF bf0, bf1;
|
||||
Dec2B(dbf, packed, bf0, bf1);
|
||||
f0 = hn::PromoteTo(df, bf0);
|
||||
f1 = hn::PromoteTo(df, bf1);
|
||||
}
|
||||
|
||||
template <class DBF16, HWY_IF_BF16_D(DBF16),
|
||||
class V8 = hn::Vec<hn::Repartition<uint8_t, DBF16>>>
|
||||
static HWY_INLINE void Dec2(DBF16 dbf16, V8 packed, hn::Vec<DBF16>& bf0,
|
||||
hn::Vec<DBF16>& bf1) {
|
||||
Dec2B(dbf16, packed, bf0, bf1);
|
||||
}
|
||||
|
||||
private:
|
||||
// Wrappers to avoid code duplication across float/bf16 input types and
|
||||
// the main loop/remainder.
|
||||
|
|
@ -479,7 +440,7 @@ class SfpCodec {
|
|||
|
||||
template <class DBF, HWY_IF_BF16_D(DBF),
|
||||
class V8 = hn::Vec<hn::Repartition<uint8_t, DBF>>>
|
||||
static HWY_INLINE V8 Enc2B(DBF dbf, const hwy::bfloat16_t* HWY_RESTRICT in) {
|
||||
static HWY_INLINE V8 Enc2B(DBF dbf, const BF16* HWY_RESTRICT in) {
|
||||
const hn::Repartition<uint16_t, DBF> d16;
|
||||
const size_t N16 = hn::Lanes(d16);
|
||||
using V16 = hn::Vec<decltype(d16)>;
|
||||
|
|
@ -505,7 +466,7 @@ class SfpCodec {
|
|||
class V8 = hn::Vec<hn::Repartition<uint8_t, DF>>>
|
||||
static HWY_INLINE V8 Enc4F(DF df, const float* HWY_RESTRICT in) {
|
||||
const hn::Repartition<uint16_t, DF> d16;
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
|
||||
const hn::Repartition<BF16, DF> dbf;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
using V16 = hn::Vec<decltype(d16)>;
|
||||
const size_t NF = hn::Lanes(df);
|
||||
|
|
@ -549,7 +510,7 @@ class SfpCodec {
|
|||
static HWY_INLINE void Dec4F(DF df, V8 packed, hn::Vec<DF>& f0,
|
||||
hn::Vec<DF>& f1, hn::Vec<DF>& f2,
|
||||
hn::Vec<DF>& f3) {
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
|
||||
const hn::Repartition<BF16, DF> dbf;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
VBF bf0, bf1;
|
||||
Dec2B(dbf, packed, bf0, bf1);
|
||||
|
|
@ -559,6 +520,7 @@ class SfpCodec {
|
|||
f3 = hn::PromoteUpperTo(df, bf1);
|
||||
}
|
||||
|
||||
// TODO: currently unused, but keep for potential later MatMul packing.
|
||||
template <class DBF, HWY_IF_BF16_D(DBF),
|
||||
class V8 = hn::Vec<hn::Repartition<uint8_t, DBF>>>
|
||||
static HWY_INLINE void DecEvenOdd(DBF dbf, V8 packed, hn::Vec<DBF>& even,
|
||||
|
|
@ -576,7 +538,7 @@ class SfpCodec {
|
|||
static HWY_INLINE void DecEvenOddF(DF df, V8 packed, hn::Vec<DF>& even0,
|
||||
hn::Vec<DF>& odd0, hn::Vec<DF>& even1,
|
||||
hn::Vec<DF>& odd1) {
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
|
||||
const hn::Repartition<BF16, DF> dbf;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
VBF even_bf, odd_bf;
|
||||
DecEvenOdd(dbf, packed, even_bf, odd_bf);
|
||||
|
|
|
|||
|
|
@ -39,7 +39,6 @@
|
|||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/sfp-inl.h"
|
||||
#include "ops/dot-inl.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
|
|
@ -128,7 +127,7 @@ void TestAllFastDecode() {
|
|||
|
||||
// Encode
|
||||
HWY_INLINE uint32_t SFP8FromF32(float f) {
|
||||
HWY_ASSERT(-1.875f <= f && f <= 1.875f);
|
||||
HWY_ASSERT(-SfpStream::kMax <= f && f <= SfpStream::kMax);
|
||||
|
||||
constexpr uint32_t kMaskM = hwy::MantissaMask<float>();
|
||||
uint32_t binary32;
|
||||
|
|
@ -182,7 +181,7 @@ struct TestDecEnc {
|
|||
template <class T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const hn::RepartitionToWide<D> d16;
|
||||
const hn::Rebind<hwy::bfloat16_t, decltype(d16)> dbf;
|
||||
const hn::Rebind<BF16, decltype(d16)> dbf;
|
||||
const hn::Repartition<float, D> df;
|
||||
for (uint32_t encoded = 0; encoded < 256; ++encoded) {
|
||||
if (encoded == 0x80) continue; // -0 is reserved
|
||||
|
|
@ -215,7 +214,7 @@ struct TestGolden {
|
|||
template <class T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const hn::Repartition<float, D> df;
|
||||
const hn::Repartition<hwy::bfloat16_t, D> dbf;
|
||||
const hn::Repartition<BF16, D> dbf;
|
||||
const hn::RebindToUnsigned<decltype(dbf)> d16;
|
||||
|
||||
struct Golden {
|
||||
|
|
@ -294,9 +293,53 @@ void TestAllGolden() {
|
|||
TestGolden()(uint8_t(), hn::ScalableTag<uint8_t>());
|
||||
}
|
||||
|
||||
// ------------------------------ Order
|
||||
|
||||
// Store 8-bit iota, decode, encode, check iota == packed. This ensures
|
||||
// Enc/Dec are preserving the order independent of vector length.
|
||||
struct TestOrder {
|
||||
template <class T, class DBF>
|
||||
HWY_INLINE void operator()(T /*unused*/, DBF dbf) {
|
||||
const size_t N16 = hn::Lanes(dbf);
|
||||
|
||||
for (size_t num = 1; num < 6 * N16; ++num) {
|
||||
const size_t padded = hwy::RoundUpTo(num, N16);
|
||||
|
||||
auto iota = hwy::AllocateAligned<SfpStream>(num);
|
||||
auto packed = hwy::AllocateAligned<SfpStream>(num);
|
||||
auto bf = hwy::AllocateAligned<BF16>(padded);
|
||||
HWY_ASSERT(iota && packed && bf);
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
// Clear sign bit so we can also check that bf is in ascending order.
|
||||
iota[i].byte = i & 127;
|
||||
}
|
||||
|
||||
SfpCodec::DecompressAndZeroPad(dbf, MakeConstSpan(iota.get(), num), 0,
|
||||
bf.get(), num);
|
||||
for (size_t i = num; i < padded; ++i) {
|
||||
if (hwy::ConvertScalarTo<float>(bf[i]) != 0.0f) {
|
||||
HWY_ABORT("num %zu padded %zu i %zu: not padded", num, padded, i);
|
||||
}
|
||||
}
|
||||
|
||||
SfpCodec::Enc(dbf, bf.get(), num, packed.get());
|
||||
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
if (iota[i].byte != packed[i].byte) {
|
||||
HWY_ABORT("@%zu: %d %d\n", i, iota[i].byte, packed[i].byte);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllOrder() { hn::ForGEVectors<32, TestOrder>()(BF16()); }
|
||||
|
||||
// ------------------------------ Foreach bf16 input
|
||||
|
||||
// Generate all values, encode, decode back.
|
||||
// Checks the distortion from an encode and decode round trip. Unlike
|
||||
// `TestShortLengthsT` in compress_test, this covers large `num` and
|
||||
// prints the enc/dec throughput.
|
||||
struct TestEncDec {
|
||||
template <class T, class DBF>
|
||||
HWY_INLINE void operator()(T /*unused*/, DBF dbf) {
|
||||
|
|
@ -309,14 +352,14 @@ struct TestEncDec {
|
|||
|
||||
auto in = hwy::AllocateAligned<T>(max);
|
||||
auto packed = hwy::AllocateAligned<SfpStream>(max);
|
||||
auto dec = hwy::AllocateAligned<T>(max);
|
||||
auto dec = hwy::AllocateAligned<T>(max); // already padded
|
||||
HWY_ASSERT(in && packed && dec);
|
||||
size_t num = 0;
|
||||
for (size_t i = 0; i < max; ++i) {
|
||||
const uint16_t bits = i * kStep;
|
||||
const float f = hwy::F32FromBF16(hwy::BitCastScalar<T>(bits));
|
||||
// Keep if within range
|
||||
if (hwy::ScalarIsFinite(f) && f <= 1.875f) {
|
||||
if (hwy::ScalarIsFinite(f) && f <= SfpStream::kMax) {
|
||||
in[num] = hwy::BF16FromF32(f);
|
||||
in[num + 1] = hwy::BF16FromF32(-f);
|
||||
num += 2;
|
||||
|
|
@ -329,7 +372,8 @@ struct TestEncDec {
|
|||
const double t0 = hwy::platform::Now();
|
||||
SfpCodec::Enc(dbf, in.get(), num, packed.get());
|
||||
const double t1 = hwy::platform::Now();
|
||||
SfpCodec::Dec(dbf, packed.get(), num, dec.get());
|
||||
SfpCodec::DecompressAndZeroPad(dbf, MakeConstSpan(packed.get(), num), 0,
|
||||
dec.get(), num);
|
||||
const double t2 = hwy::platform::Now();
|
||||
enc_elapsed = HWY_MIN(enc_elapsed, t1 - t0);
|
||||
dec_elapsed = HWY_MIN(dec_elapsed, t2 - t1);
|
||||
|
|
@ -358,9 +402,10 @@ struct TestEncDec {
|
|||
stats.SumL1Rounded(), snr, wl1);
|
||||
}
|
||||
HWY_ASSERT(stats.Original().Count() == stats.L1().Count());
|
||||
// Inputs are in [-1.875, 1.875], symmetric, and heavy-tailed.
|
||||
HWY_ASSERT(stats.Original().Min() == -1.875f);
|
||||
HWY_ASSERT(stats.Original().Max() == 1.875f);
|
||||
// Inputs are in [-SfpStream::kMax, SfpStream::kMax], symmetric, and
|
||||
// heavy-tailed.
|
||||
HWY_ASSERT(stats.Original().Min() == -SfpStream::kMax);
|
||||
HWY_ASSERT(stats.Original().Max() == SfpStream::kMax);
|
||||
HWY_ASSERT(gcpp::IsInside(-1E-6, 1E-6, stats.Original().Mean()));
|
||||
HWY_ASSERT(gcpp::IsInside(-1E-6, 1E-6, stats.Original().Skewness()));
|
||||
HWY_ASSERT(gcpp::IsInside(80.0, 100.0, stats.Original().Kurtosis()));
|
||||
|
|
@ -382,179 +427,7 @@ struct TestEncDec {
|
|||
}
|
||||
};
|
||||
|
||||
void TestAllEncDec() { hn::ForGEVectors<32, TestEncDec>()(hwy::bfloat16_t()); }
|
||||
|
||||
// ------------------------------ Order
|
||||
|
||||
// Store 8-bit iota, decode, encode, check iota == packed. This ensures
|
||||
// Enc/Dec are preserving the order independent of vector length.
|
||||
struct TestOrder {
|
||||
template <class T, class DBF>
|
||||
HWY_INLINE void operator()(T /*unused*/, DBF dbf) {
|
||||
const hn::Repartition<uint8_t, DBF> du8;
|
||||
|
||||
const size_t num = 10 * hn::Lanes(du8) / 3;
|
||||
|
||||
auto iota = hwy::AllocateAligned<SfpStream>(num);
|
||||
auto packed = hwy::AllocateAligned<SfpStream>(num);
|
||||
auto bf = hwy::AllocateAligned<hwy::bfloat16_t>(num);
|
||||
HWY_ASSERT(iota && packed && bf);
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
// Clear sign bit so we can also check that bf is in ascending order.
|
||||
iota[i].byte = i & 127;
|
||||
}
|
||||
|
||||
SfpCodec::Dec(dbf, iota.get(), num, bf.get());
|
||||
SfpCodec::Enc(dbf, bf.get(), num, packed.get());
|
||||
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
if (iota[i].byte != packed[i].byte) {
|
||||
HWY_ABORT("@%zu: %d %d\n", i, iota[i].byte, packed[i].byte);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllOrder() { hn::ForGEVectors<32, TestOrder>()(hwy::bfloat16_t()); }
|
||||
|
||||
// ------------------------------ Dot
|
||||
|
||||
struct TestDot {
|
||||
template <typename T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const hn::Repartition<float, D> df;
|
||||
const size_t num = 1024; // not too many for GeometricMean overflow.
|
||||
const size_t N = hn::Lanes(d);
|
||||
auto in = hwy::AllocateAligned<T>(num);
|
||||
auto dec = hwy::AllocateAligned<T>(num);
|
||||
auto vec = hwy::AllocateAligned<T>(num);
|
||||
auto vec_eo = hwy::AllocateAligned<T>(num);
|
||||
auto sfp = hwy::AllocateAligned<SfpStream>(num);
|
||||
HWY_ASSERT(in && dec && vec && vec_eo && sfp);
|
||||
|
||||
// Generate inputs and verify their distribution.
|
||||
hwy::RandomState rng;
|
||||
hwy::Stats in_stats;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
const float r = static_cast<float>(RandomGaussian(rng));
|
||||
in_stats.Notify(r);
|
||||
in[i] = hwy::ConvertScalarTo<T>(r);
|
||||
}
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
const float r = static_cast<float>(RandomGaussian(rng));
|
||||
in_stats.Notify(r);
|
||||
vec[i] = hwy::ConvertScalarTo<T>(r);
|
||||
}
|
||||
VerifyGaussian(in_stats);
|
||||
|
||||
// Convert vec to even/odd for DotEO
|
||||
for (size_t i = 0; i < num; i += 2 * N) {
|
||||
hn::Vec<D> ve, vo;
|
||||
hn::LoadInterleaved2(d, vec.get() + i, ve, vo);
|
||||
hn::Store(ve, d, vec_eo.get() + i + 0);
|
||||
hn::Store(vo, d, vec_eo.get() + i + N);
|
||||
}
|
||||
|
||||
SfpCodec::Enc(d, in.get(), num, sfp.get());
|
||||
|
||||
// Compute dot product without decompression.
|
||||
float actual = 0.0f;
|
||||
float actual_eo = 0.0f;
|
||||
double elapsed = hwy::HighestValue<double>();
|
||||
double elapsed_eo = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < 200; ++rep) {
|
||||
{
|
||||
const double t0 = hwy::platform::Now();
|
||||
actual = SimpleDot(df, sfp.get(), 0, vec.get(), num);
|
||||
const double t1 = hwy::platform::Now();
|
||||
elapsed = HWY_MIN(elapsed, t1 - t0);
|
||||
}
|
||||
{
|
||||
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
|
||||
hn::Vec<decltype(df)> sum1 = hn::Zero(df);
|
||||
hn::Vec<decltype(df)> sum2 = hn::Zero(df);
|
||||
hn::Vec<decltype(df)> sum3 = hn::Zero(df);
|
||||
const double t0 = hwy::platform::Now();
|
||||
SfpCodec::DotEO(df, sfp.get(), num, vec_eo.get(), sum0, sum1, sum2,
|
||||
sum3);
|
||||
const double t1 = hwy::platform::Now();
|
||||
elapsed_eo = HWY_MIN(elapsed_eo, t1 - t0);
|
||||
sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3));
|
||||
actual_eo = hn::ReduceSum(df, sum0);
|
||||
}
|
||||
}
|
||||
|
||||
SfpCodec::Dec(d, sfp.get(), num, dec.get());
|
||||
fprintf(stderr, "Vec %zu Dot %zu-bit %.2f ; %.2f MB/s\n",
|
||||
Lanes(d) * sizeof(T), sizeof(T) * 8,
|
||||
num * sizeof(T) * 1E-6 / elapsed,
|
||||
num * sizeof(T) * 1E-6 / elapsed_eo);
|
||||
|
||||
// Exact and decompressed dot products for comparison.
|
||||
float exact = 0.0f; // using original input
|
||||
float expected = 0.0f; // using decoded SFP
|
||||
DistortionStats dec_stats;
|
||||
hwy::Stats ratios;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
const float in1 = hwy::ConvertScalarTo<float>(in[i]);
|
||||
const float dec1 = hwy::ConvertScalarTo<float>(dec[i]);
|
||||
const float vec1 = hwy::ConvertScalarTo<float>(vec[i]);
|
||||
dec_stats.Notify(in1, dec1);
|
||||
|
||||
exact += in1 * vec1;
|
||||
expected += dec1 * vec1;
|
||||
if (expected != 0.0f) {
|
||||
ratios.Notify(exact / expected);
|
||||
}
|
||||
}
|
||||
const bool isBF = sizeof(T) == 2;
|
||||
const double dec_snr = dec_stats.GeomeanValueDivL1();
|
||||
const double dec_wl1 = dec_stats.WeightedAverageL1();
|
||||
const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean());
|
||||
// exact and actual fluctuate due to the combination of SFP imprecision,
|
||||
// and whether vec[i] is negative or positive, so this is quite loose.
|
||||
const float final_ratio = HWY_MIN(exact / actual, actual / exact);
|
||||
if (HWY_ONCE) {
|
||||
fprintf(stderr, "ratios %s\n", ratios.ToString().c_str());
|
||||
fprintf(stderr,
|
||||
"exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f "
|
||||
"dot_snr %.2f dec_wl1 %.5f\n",
|
||||
exact, expected, actual, final_ratio, dec_snr, dot_snr, dec_wl1);
|
||||
}
|
||||
// Final values are not too far apart.
|
||||
HWY_ASSERT(gcpp::IsInside(0.87f, 1.0f, final_ratio));
|
||||
// Decompressed and uncompressed dot should match exactly.
|
||||
HWY_ASSERT(gcpp::IsNear(expected, actual, 1E-4f));
|
||||
// Even/odd dot should also match
|
||||
HWY_ASSERT(gcpp::IsNear(actual, actual_eo, 1E-4f));
|
||||
// Geomean of ratios for each i should be very close to one.
|
||||
HWY_ASSERT(dot_snr >= (isBF ? 70.0 : 1000.0));
|
||||
|
||||
// dec[] is close to in[]. We also check that in TestEncDec, but for much
|
||||
// smaller input magnitudes.
|
||||
HWY_ASSERT(gcpp::IsNear(isBF ? 51.0 : 64.0, dec_snr, 1.0));
|
||||
HWY_ASSERT(gcpp::IsNear(isBF ? 0.013 : 0.012, dec_wl1, 0.001));
|
||||
HWY_ASSERT(gcpp::IsNear(isBF ? 6.2 : 6.3, dec_stats.SumL1(), 0.1));
|
||||
HWY_ASSERT_EQ(0, dec_stats.NumSignFlip());
|
||||
HWY_ASSERT_EQ(0, dec_stats.NumRoundedToZero());
|
||||
HWY_ASSERT_EQ(0.0, dec_stats.SumL1Rounded());
|
||||
// Absolute decode errors are in [0, 5E-2], and somewhat right-tailed.
|
||||
HWY_ASSERT(gcpp::IsInside(0.0f, 2E-6f, dec_stats.L1().Min()));
|
||||
HWY_ASSERT(gcpp::IsInside(3E-2f, 5E-2f, dec_stats.L1().Max()));
|
||||
HWY_ASSERT(gcpp::IsInside(4E-3, 7E-3, dec_stats.L1().Mean()));
|
||||
HWY_ASSERT(gcpp::IsInside(1.8, 1.9, dec_stats.L1().Skewness()));
|
||||
HWY_ASSERT(gcpp::IsInside(6.0, 7.0, dec_stats.L1().Kurtosis()));
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllDotF32() {
|
||||
const hn::ForGEVectors<128, TestDot> test;
|
||||
test(float());
|
||||
}
|
||||
void TestAllDotBF16() {
|
||||
const hn::ForGEVectors<128, TestDot> test;
|
||||
test(hwy::bfloat16_t());
|
||||
}
|
||||
void TestAllEncDec() { hn::ForGEVectors<32, TestEncDec>()(BF16()); }
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
|
|
@ -562,7 +435,6 @@ void TestAllDotBF16() {
|
|||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
|
||||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(SfpTest);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, PrintTables);
|
||||
|
|
@ -570,13 +442,8 @@ HWY_EXPORT_AND_TEST_P(SfpTest, TestAllUnique);
|
|||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllFastDecode);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDecEnc);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllGolden);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllEncDec);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllOrder);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDotF32);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDotBF16);
|
||||
#ifdef HWY_AFTER_TEST
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllEncDec);
|
||||
HWY_AFTER_TEST();
|
||||
#endif
|
||||
} // namespace gcpp
|
||||
|
||||
#endif
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -20,8 +20,12 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
#include <cstdio>
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // HWY_INLINE
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -35,25 +39,172 @@ using BF16 = hwy::bfloat16_t;
|
|||
// - 24-bit dynamic range, with max exponent 2^0.
|
||||
// - 3 bit mantissa for values >= 2^-7, otherwise 2.
|
||||
//
|
||||
// A pointer to this is the *start* of an SFP stream. Values are stored
|
||||
// in-order to enable vector-length agnostic seeking, because streams may be
|
||||
// written to disk for loading on other CPUs.
|
||||
// A pointer to this is the *start* of an SFP stream. Aligning the allocation
|
||||
// (see aligned_allocator.h) may speed up decoding but is not required.
|
||||
//
|
||||
// Layout: Values are stored in-order to enable vector-length agnostic seeking,
|
||||
// because streams may be written to disk for loading on other CPUs.
|
||||
//
|
||||
// This is faster to decode than a straightforward implementation of eXmY, in
|
||||
// part because SFP does not require subnormals. Unlike OCP MX, it also does not
|
||||
// require side information (shared exponents).
|
||||
//
|
||||
// Although the representation could probably be shrunk to 6-7 bits, more
|
||||
// savings can be had by non-uniform clustering - see nuq.h.
|
||||
// savings can be had by non-uniform clustering - see NuqStream.
|
||||
#pragma pack(push, 1)
|
||||
struct SfpStream {
|
||||
// Largest possible input magnitude: 1.111 * 2^0. This could be increased by
|
||||
// shifting the value range (exponent bias).
|
||||
static constexpr float kMax = 1.875f;
|
||||
|
||||
uint8_t byte;
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
// Largest possible input magnitude: 1.111 * 2^0. This could be increased by
|
||||
// shifting the value range (exponent bias).
|
||||
constexpr float kMaxSFP = 1.875f;
|
||||
// Returns 1.0f if all magnitudes are <= SfpStream::kMax, otherwise scales them
|
||||
// such that the largest magnitude is SfpStream::kMax, and returns the
|
||||
// multiplier with which to restore the original values. This is only necessary
|
||||
// before compressing to SfpStream.
|
||||
// TODO: vectorize
|
||||
static inline float ScaleWeights(float* HWY_RESTRICT raw, size_t num) {
|
||||
float maxabs = 0.0;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
maxabs = HWY_MAX(maxabs, hwy::ScalarAbs(raw[i]));
|
||||
}
|
||||
if (maxabs <= SfpStream::kMax) {
|
||||
return 1.0f;
|
||||
}
|
||||
const float scale = maxabs / SfpStream::kMax;
|
||||
const float inv_scale = static_cast<float>(1.0 / static_cast<double>(scale));
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
// Clamp because kMax may still be exceeded.
|
||||
const float magn =
|
||||
HWY_MIN(SfpStream::kMax, hwy::ScalarAbs(raw[i] * inv_scale));
|
||||
raw[i] = hwy::ScalarCopySign(magn, raw[i]);
|
||||
}
|
||||
return scale;
|
||||
}
|
||||
|
||||
// Non-uniform quantization: a compressed representation of f32 inputs that
|
||||
// supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or
|
||||
// two vectors (for `Decompress2`), and decoding to bf16/f32.
|
||||
//
|
||||
// A pointer to this is the *start* of a NUQ stream. Aligning the allocation
|
||||
// (see aligned_allocator.h) may be speed up decoding but is not required.
|
||||
//
|
||||
// Layout: first one table of kClusters entries per group, in ascending order
|
||||
// of group index, then two packed indices per byte. Indices are stored
|
||||
// in-order to enable vector-length agnostic decode, because streams may be
|
||||
// persisted to disk and used by other CPUs.
|
||||
//
|
||||
// To enable parallel encoding and decoding, Enc/Dec have `offset` parameters
|
||||
// which refer to the stream, NOT the raw from/to pointers, which point directly
|
||||
// to the source/destination. Offsets are in units of values, NOT compressed
|
||||
// bytes within the stream.
|
||||
#pragma pack(push, 1)
|
||||
struct NuqStream {
|
||||
// 4-bit indices are a sweet spot in terms of quality per size.
|
||||
static constexpr size_t kClusters = 16;
|
||||
|
||||
// Number of weights that share a table. Larger = slower encode, higher error,
|
||||
// smaller size (table amortized over more weights).
|
||||
static constexpr size_t kGroupSize = 256;
|
||||
|
||||
// Storage for dynamic programming. There are two matrices; we use separate
|
||||
// allocations to avoid type punning.
|
||||
template <class T>
|
||||
class AlignedMatrix {
|
||||
public:
|
||||
AlignedMatrix() : mem_(hwy::AllocateAligned<T>(kClusters * kGroupSize)) {}
|
||||
|
||||
HWY_INLINE const T& operator()(size_t row, size_t col) const {
|
||||
return mem_[row * kGroupSize + col];
|
||||
}
|
||||
|
||||
HWY_INLINE T& operator()(size_t row, size_t col) {
|
||||
return mem_[row * kGroupSize + col];
|
||||
}
|
||||
|
||||
private:
|
||||
hwy::AlignedFreeUniquePtr<T[]> mem_;
|
||||
};
|
||||
|
||||
// Reuse memory across calls to Enc to avoid per-call allocations.
|
||||
struct ClusterBuf {
|
||||
// Move-only (stored inside vector in CompressWorkingSet).
|
||||
ClusterBuf() = default;
|
||||
ClusterBuf(const ClusterBuf&) = delete;
|
||||
ClusterBuf& operator=(const ClusterBuf&) = delete;
|
||||
ClusterBuf(ClusterBuf&&) = default;
|
||||
ClusterBuf& operator=(ClusterBuf&&) = default;
|
||||
|
||||
void Resize(size_t new_num_groups) {
|
||||
if (new_num_groups < num_groups) return;
|
||||
|
||||
num_groups = new_num_groups;
|
||||
centers = hwy::AllocateAligned<float>(num_groups * kClusters);
|
||||
idx = hwy::AllocateAligned<uint16_t>(num_groups * kGroupSize);
|
||||
}
|
||||
|
||||
// Independent of num_groups.
|
||||
AlignedMatrix<float> costs;
|
||||
AlignedMatrix<int32_t> argmin;
|
||||
|
||||
size_t num_groups = 0;
|
||||
hwy::AlignedFreeUniquePtr<float[]> centers;
|
||||
hwy::AlignedFreeUniquePtr<uint16_t[]> idx;
|
||||
};
|
||||
|
||||
// Returns offset of packed indices from the start of the stream. This matches
|
||||
// the (padded) total table size because table entries are bytes.
|
||||
static constexpr size_t PackedStart(size_t capacity) {
|
||||
// Round up to avoid cache-line splits when loading indices. No effect on
|
||||
// size as long as capacity / kGroupSize is a multiple of 4.
|
||||
return hwy::RoundUpTo(hwy::DivCeil(capacity, kGroupSize) * kClusters, 64);
|
||||
}
|
||||
|
||||
// Returns number of NuqStream to allocate for the stream, which matches its
|
||||
// size in bytes.
|
||||
static constexpr size_t PackedEnd(size_t capacity) {
|
||||
return PackedStart(capacity) + hwy::DivCeil(capacity, 2); // 2x 4-bit/byte
|
||||
}
|
||||
|
||||
uint8_t byte;
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
template <typename PackedT>
|
||||
const char* TypeName() {
|
||||
using Packed = hwy::RemoveCvRef<PackedT>;
|
||||
if constexpr (hwy::IsSame<Packed, float>()) {
|
||||
return "f32";
|
||||
} else if constexpr (hwy::IsSame<Packed, BF16>()) {
|
||||
return "b16";
|
||||
} else if constexpr (hwy::IsSame<Packed, SfpStream>()) {
|
||||
return "sfp";
|
||||
} else if constexpr (hwy::IsSame<Packed, NuqStream>()) {
|
||||
return "nuq";
|
||||
} else {
|
||||
HWY_DASSERT(false);
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Packed>
|
||||
constexpr bool IsCompressed() {
|
||||
return hwy::IsSameEither<hwy::RemoveCvRef<Packed>, SfpStream, NuqStream>();
|
||||
}
|
||||
|
||||
// Returns the number of `MatT` elements required to store `capacity` values,
|
||||
// which must not be zero.
|
||||
template <typename Packed>
|
||||
constexpr size_t CompressedArrayElements(size_t capacity) {
|
||||
if constexpr (hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>()) {
|
||||
return NuqStream::PackedEnd(capacity);
|
||||
} else {
|
||||
return capacity;
|
||||
}
|
||||
}
|
||||
|
||||
// Non-owning view of packed elements. Shortens argument lists.
|
||||
//
|
||||
|
|
@ -63,13 +214,19 @@ constexpr float kMaxSFP = 1.875f;
|
|||
// reusing `hwy::Span`.
|
||||
template <typename Packed>
|
||||
struct PackedSpan {
|
||||
void BoundsCheck(size_t packed_ofs, size_t num) const {
|
||||
HWY_DASSERT(packed_ofs + num <= size);
|
||||
(void)size;
|
||||
// Ensures callers can read or write `num_accessible` elements starting at
|
||||
// `packed_ofs`.
|
||||
void BoundsCheck(size_t packed_ofs, size_t num_accessible) const {
|
||||
// For NUQ, there can be fewer Packed than the number of elements, hence
|
||||
// check the compressed count and ensure we have that many.
|
||||
const size_t required =
|
||||
CompressedArrayElements<Packed>(packed_ofs + num_accessible);
|
||||
HWY_DASSERT(num >= required);
|
||||
(void)required;
|
||||
}
|
||||
|
||||
Packed* HWY_RESTRICT ptr;
|
||||
size_t size; // for BoundsCheck and nuq-inl.h HWY_ASSERT.
|
||||
size_t num; // for BoundsCheck and nuq-inl.h HWY_ASSERT.
|
||||
};
|
||||
|
||||
// Avoids spelling out the template parameter in every call.
|
||||
|
|
@ -87,7 +244,7 @@ HWY_INLINE PackedSpan<const Packed> MakeConstSpan(Packed* ptr, size_t size) {
|
|||
// `RMSNormInplace` and compression tests.
|
||||
template <typename Packed>
|
||||
HWY_INLINE PackedSpan<const Packed> MakeConst(PackedSpan<Packed> packed) {
|
||||
return {packed.ptr, packed.size};
|
||||
return {packed.ptr, packed.num};
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -0,0 +1,68 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Include guard for headers.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "compression/compress.h"
|
||||
#include "compression/distortion.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_COMPRESS_TEST_UTIL_TOGGLE) == \
|
||||
defined(HWY_TARGET_TOGGLE) // NOLINT
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_COMPRESS_TEST_UTIL_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_COMPRESS_TEST_UTIL_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESS_TEST_UTIL_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "hwy/tests/test_util-inl.h" // IWYU pragma: export
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
// `Packed` is the type passed to `TestT`.
|
||||
template <typename Packed, template <class> class TestT>
|
||||
void ForeachRawType() {
|
||||
const hn::ForGEVectors<128, TestT<Packed>> test;
|
||||
// The argument selects the type to decode to: BF16 or float.
|
||||
test(BF16());
|
||||
test(float());
|
||||
}
|
||||
|
||||
template <template <class> class TestT>
|
||||
void ForeachPackedAndRawType() {
|
||||
ForeachRawType<BF16, TestT>();
|
||||
ForeachRawType<float, TestT>();
|
||||
ForeachRawType<SfpStream, TestT>();
|
||||
ForeachRawType<NuqStream, TestT>();
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // NOLINT
|
||||
|
|
@ -244,7 +244,7 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
|
|||
pools.Inner(0).NumWorkers(),
|
||||
hwy::TargetName(hwy::DispatchedTarget()), hwy::VectorBytes() * 8,
|
||||
CompiledConfig(), StringFromType(loader.Info().weight),
|
||||
TypeName(EmbedderInputT()));
|
||||
TypeName<EmbedderInputT>());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -560,18 +560,23 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
|||
const CompressedWeights<TConfig>& weights,
|
||||
RowVectorBatch<float>& x) {
|
||||
constexpr size_t kModelDim = TConfig::kModelDim;
|
||||
constexpr size_t kVocabSize = TConfig::kVocabSize;
|
||||
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
|
||||
EmbeddingScaling<TConfig>();
|
||||
|
||||
HWY_DASSERT(token >= 0);
|
||||
HWY_DASSERT(token < TConfig::kVocabSize);
|
||||
HWY_DASSERT(token < kVocabSize);
|
||||
|
||||
Decompress(weights.embedder_input_embedding, token * kModelDim,
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(
|
||||
df,
|
||||
MakeSpan(weights.embedder_input_embedding.data(), kVocabSize * kModelDim),
|
||||
token * kModelDim, x.Batch(batch_idx), kModelDim);
|
||||
MulByConst(kEmbScaling * weights.embedder_input_embedding.scale(),
|
||||
x.Batch(batch_idx), kModelDim);
|
||||
MulByConst(kEmbScaling, x.Batch(batch_idx), kModelDim);
|
||||
if constexpr (TConfig::kAbsolutePE) {
|
||||
AddAbsolutePositionalEmbeddings(x.Batch(batch_idx), kModelDim, pos);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
template <class TConfig, typename T>
|
||||
|
|
|
|||
485
ops/dot-inl.h
485
ops/dot-inl.h
|
|
@ -15,12 +15,9 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <algorithm> // std::sort
|
||||
#include <array>
|
||||
#include <cstdlib> // std::abs
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/distortion.h" // TwoSum
|
||||
#include "hwy/base.h"
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
|
|
@ -34,339 +31,247 @@
|
|||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "ops/fp_arith-inl.h"
|
||||
#include "hwy/contrib/math/math-inl.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/profiler.h" // also uses SIMD
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
// Returns dot product of `x` and `w`, both length `num`. Uses Decompress2 to
|
||||
// convert WeightT and VecT to float, then FMA.
|
||||
// TODO: improve precision?
|
||||
// TODO: use bf16 products?
|
||||
template <class DF, typename WeightT, typename VecT>
|
||||
HWY_INLINE float SimpleDot(DF df, const WeightT* HWY_RESTRICT w, size_t w_ofs,
|
||||
const VecT* HWY_RESTRICT x, size_t num) {
|
||||
PROFILER_FUNC;
|
||||
const size_t N = hn::Lanes(df);
|
||||
HWY_DASSERT(hn::IsAligned(df, x));
|
||||
using VF = hn::Vec<DF>;
|
||||
using TraitsW = CompressTraits<WeightT>;
|
||||
using TraitsV = CompressTraits<VecT>;
|
||||
|
||||
VF sum0 = hn::Zero(df);
|
||||
VF sum1 = hn::Zero(df);
|
||||
VF sum2 = hn::Zero(df);
|
||||
VF sum3 = hn::Zero(df);
|
||||
|
||||
VF w0, w1, w2, w3, v0, v1, v2, v3; // decompressed inputs
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= 4 * N) {
|
||||
for (; i <= num - 4 * N; i += 4 * N) {
|
||||
TraitsW::Decompress2(df, w, w_ofs + i, w0, w1);
|
||||
TraitsW::Decompress2(df, w, w_ofs + i + 2 * N, w2, w3);
|
||||
TraitsV::Decompress2(df, x, i, v0, v1);
|
||||
TraitsV::Decompress2(df, x, i + 2 * N, v2, v3);
|
||||
|
||||
sum0 = hn::MulAdd(w0, v0, sum0);
|
||||
sum1 = hn::MulAdd(w1, v1, sum1);
|
||||
sum2 = hn::MulAdd(w2, v2, sum2);
|
||||
sum3 = hn::MulAdd(w3, v3, sum3);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t remaining = num - i;
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
HWY_ALIGN float padded_w[4 * hn::MaxLanes(df)] = {};
|
||||
HWY_ALIGN float padded_x[4 * hn::MaxLanes(df)] = {};
|
||||
// The actual capacity of w[] is unknown, so pass a lower bound.
|
||||
const size_t w_capacity = w_ofs + num;
|
||||
TraitsW::Decompress(df, w_capacity, w, w_ofs + i, padded_w, remaining);
|
||||
TraitsV::Decompress(df, num, x, i, padded_x, remaining);
|
||||
const size_t padding = 4 * N - remaining;
|
||||
hwy::ZeroBytes(padded_w + remaining, padding * sizeof(padded_w[0]));
|
||||
hwy::ZeroBytes(padded_x + remaining, padding * sizeof(padded_x[0]));
|
||||
for (; i < num; i += N) {
|
||||
const VF w0 = hn::Load(df, padded_w + i);
|
||||
const VF v0 = hn::Load(df, padded_x + i);
|
||||
sum0 = hn::MulAdd(w0, v0, sum0);
|
||||
}
|
||||
}
|
||||
|
||||
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
||||
sum0 = hn::Add(sum0, sum1);
|
||||
sum2 = hn::Add(sum2, sum3);
|
||||
sum0 = hn::Add(sum0, sum2);
|
||||
return hn::ReduceSum(df, sum0);
|
||||
}
|
||||
|
||||
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
|
||||
template <bool kVecEO, class DF, size_t kCapacity, typename VecT>
|
||||
HWY_INLINE float Dot(DF df, const std::array<float, kCapacity>& w, size_t ofs,
|
||||
const VecT* vec_aligned, size_t num) {
|
||||
PROFILER_ZONE("Dot array");
|
||||
HWY_DASSERT(ofs + num <= kCapacity);
|
||||
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
||||
return SimpleDot(df, w.data(), ofs, vec_aligned, num);
|
||||
}
|
||||
|
||||
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
|
||||
template <bool kVecEO, class DF, typename MatT, size_t kCapacity, typename VecT>
|
||||
HWY_INLINE float Dot(DF df, const CompressedArray<MatT, kCapacity>& compressed,
|
||||
size_t compressed_ofs, const VecT* vec_aligned,
|
||||
size_t num) {
|
||||
PROFILER_ZONE("Dot CompressedArray");
|
||||
HWY_DASSERT(compressed_ofs + num <= compressed.size());
|
||||
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
||||
using Traits = CompressTraits<MatT>;
|
||||
float dot_result;
|
||||
if constexpr (kVecEO) {
|
||||
dot_result =
|
||||
Traits::DotEO(df, compressed.data(), compressed_ofs, vec_aligned, num);
|
||||
} else {
|
||||
dot_result =
|
||||
SimpleDot(df, compressed.data(), compressed_ofs, vec_aligned, num);
|
||||
}
|
||||
return compressed.scale() * dot_result;
|
||||
}
|
||||
|
||||
// Returns result accurate to 1.5 ulp, assuming `num` < 2^(52-23), no overflow,
|
||||
// and round to nearest. See "Accurate and efficient floating point summation".
|
||||
HWY_INLINE float ExactDot(const float* HWY_RESTRICT a,
|
||||
const float* HWY_RESTRICT b, size_t num,
|
||||
double* HWY_RESTRICT buf) {
|
||||
PROFILER_FUNC;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
buf[i] = static_cast<double>(a[i]) * static_cast<double>(b[i]);
|
||||
}
|
||||
// Sort by decreasing magnitude (not supported by VQSort).
|
||||
std::sort(buf, buf + num,
|
||||
[](double a, double b) { return std::abs(a) > std::abs(b); });
|
||||
double sum = 0.0;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
sum += buf[i];
|
||||
}
|
||||
return static_cast<float>(sum);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Cascaded summation (twice working precision)
|
||||
|
||||
// Returns `sum` and `err` such that `sum + err` is exactly equal to `a + b`,
|
||||
// despite floating-point rounding. `sum` is already the best estimate for the
|
||||
// addition, so do not actually add `err` to it. `UpdateCascadedSums` instead
|
||||
// accumulates multiple `err`, which are then later added to `sum`.
|
||||
//
|
||||
// Knuth98/Moller65. Unlike Fast2Sum [Dekker71], this does not require any
|
||||
// relative ordering of the exponents of a and b.
|
||||
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||
static HWY_INLINE VF TwoSums(DF /*df*/, VF a, VF b, VF& err) {
|
||||
const VF sum = hn::Add(a, b);
|
||||
const VF a2 = hn::Sub(sum, b);
|
||||
const VF b2 = hn::Sub(sum, a2);
|
||||
const VF err_a = hn::Sub(a, a2);
|
||||
const VF err_b = hn::Sub(b, b2);
|
||||
err = hn::Add(err_a, err_b);
|
||||
return sum;
|
||||
}
|
||||
|
||||
// Adds vectors with about twice the precision of VF using 7 FLOPS.
|
||||
// Rump/Ogita/Oishi08, Algorithm 6.11 in Handbook of Floating-Point Arithmetic.
|
||||
// `sum` and `sum_err` must be initially zero.
|
||||
//
|
||||
// Each lane is an independent cascaded sum. To obtain a single result, use
|
||||
// `ReduceCascadedSum`. Vectors generally cannot be wrapped in a class, hence we
|
||||
// use free functions.
|
||||
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||
void UpdateCascadedSums(DF df, VF v, VF& sum, VF& sum_err) {
|
||||
VF err;
|
||||
sum = TwoSums(df, sum, v, err);
|
||||
sum_err += err;
|
||||
}
|
||||
|
||||
// Combines two cascaded sum vectors, typically from unrolling/parallelization.
|
||||
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||
void AssimilateCascadedSums(DF df, const VF& other_sum, const VF& other_sum_err,
|
||||
VF& sum, VF& sum_err) {
|
||||
UpdateCascadedSums(df, other_sum, sum, sum_err);
|
||||
sum_err += other_sum_err;
|
||||
}
|
||||
|
||||
// Reduces cascaded sums, to a single value. Slow, call outside of loops.
|
||||
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||
hn::TFromD<DF> ReduceCascadedSums(DF df, const VF sum, VF sum_err) {
|
||||
const size_t N = hn::Lanes(df);
|
||||
using TF = hn::TFromD<DF>;
|
||||
TF total = TF{0.0};
|
||||
TF total_err = TF{0.0};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
TF err;
|
||||
total = TwoSum(total, hn::ExtractLane(sum, i), err);
|
||||
total_err += hn::ExtractLane(sum_err, i);
|
||||
total_err += err;
|
||||
}
|
||||
return total + total_err;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// Returns 2 * sum(|f|) / |sum(f)|. This is large when there are many
|
||||
// similar-magnitude and opposite-sign elements in `f`. See
|
||||
// Returns 2 * sum(|w.*v|) / |sum(w.*v)|. This is large when there are many
|
||||
// similar-magnitude and opposite-sign elements. See
|
||||
// https://en.wikipedia.org/wiki/Condition_number.
|
||||
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||
static inline double ConditionNumber(DF df, const float* HWY_RESTRICT f,
|
||||
size_t num) {
|
||||
template <typename WeightT, typename VecT>
|
||||
HWY_MAYBE_UNUSED double ConditionNumber(const WeightT* HWY_RESTRICT w,
|
||||
const VecT* HWY_RESTRICT v,
|
||||
size_t num) {
|
||||
PROFILER_FUNC;
|
||||
const hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t N = hn::Lanes(df);
|
||||
|
||||
VF sum = hn::Zero(df);
|
||||
VF sum_err = hn::Zero(df);
|
||||
VF sum_abs = hn::Zero(df);
|
||||
VF sum_err_abs = hn::Zero(df);
|
||||
VF sum_abs_err = hn::Zero(df);
|
||||
|
||||
const auto packed_w = MakeSpan(w, num);
|
||||
const auto packed_v = MakeSpan(v, num);
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= N) {
|
||||
for (; i <= num - N; i += N) {
|
||||
const VF v = hn::Load(df, f + i);
|
||||
UpdateCascadedSums(v, sum, sum_err);
|
||||
UpdateCascadedSums(hn::Abs(v), sum_abs, sum_err_abs);
|
||||
if (num >= 2 * N) {
|
||||
for (; i <= num - 2 * N; i += 2 * N) {
|
||||
VF w0, w1, v0, v1;
|
||||
Decompress2(df, packed_w, i, w0, w1);
|
||||
Decompress2(df, packed_v, i, v0, v1);
|
||||
const VF mul0 = hn::Mul(w0, v0);
|
||||
const VF mul1 = hn::Mul(w1, v1);
|
||||
UpdateCascadedSums(df, mul0, sum, sum_err);
|
||||
UpdateCascadedSums(df, mul1, sum, sum_err);
|
||||
UpdateCascadedSums(df, hn::Abs(mul0), sum_abs, sum_abs_err);
|
||||
UpdateCascadedSums(df, hn::Abs(mul1), sum_abs, sum_abs_err);
|
||||
}
|
||||
}
|
||||
const size_t remaining = num - i;
|
||||
if (remaining != 0) {
|
||||
const VF v = hn::LoadN(df, f + i, remaining);
|
||||
UpdateCascadedSums(v, sum, sum_err);
|
||||
UpdateCascadedSums(hn::Abs(v), sum_abs, sum_err_abs);
|
||||
}
|
||||
|
||||
const float div = std::abs(ReduceCascadedSums(df, sum, sum_err));
|
||||
if (div == 0.0f) return hwy::HighestValue<float>();
|
||||
const double cond = 2.0 * ReduceCascadedSums(df, sum_abs, sum_err_abs) /
|
||||
static_cast<double>(div);
|
||||
HWY_ASSERT(cond >= 0.0);
|
||||
return cond;
|
||||
}
|
||||
size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < 2 * N);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
HWY_ALIGN float padded_w[2 * hn::MaxLanes(df)];
|
||||
HWY_ALIGN float padded_v[2 * hn::MaxLanes(df)];
|
||||
DecompressAndZeroPad(df, packed_w, i, padded_w, remaining);
|
||||
DecompressAndZeroPad(df, packed_v, i, padded_v, remaining);
|
||||
|
||||
// Same, but for dot product of two arrays.
|
||||
// TODO: move into dot_test.
|
||||
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||
static inline double ConditionNumber(DF df, const float* HWY_RESTRICT a,
|
||||
const float* HWY_RESTRICT b, size_t num) {
|
||||
PROFILER_FUNC;
|
||||
const size_t N = hn::Lanes(df);
|
||||
|
||||
VF sum = hn::Zero(df);
|
||||
VF sum_err = hn::Zero(df);
|
||||
VF sum_abs = hn::Zero(df);
|
||||
VF sum_err_abs = hn::Zero(df);
|
||||
|
||||
size_t i = 0;
|
||||
if (num >= N) {
|
||||
for (; i <= num - N; i += N) {
|
||||
const VF va = hn::Load(df, a + i);
|
||||
const VF vb = hn::Load(df, b + i);
|
||||
const VF mul = hn::Mul(va, vb);
|
||||
// 1..2 whole vectors, possibly zero-padded.
|
||||
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
|
||||
const VF w0 = hn::Load(df, padded_w + padded_pos);
|
||||
const VF v0 = hn::Load(df, padded_v + padded_pos);
|
||||
const VF mul = hn::Mul(w0, v0);
|
||||
UpdateCascadedSums(df, mul, sum, sum_err);
|
||||
UpdateCascadedSums(df, hn::Abs(mul), sum_abs, sum_err_abs);
|
||||
UpdateCascadedSums(df, hn::Abs(mul), sum_abs, sum_abs_err);
|
||||
}
|
||||
}
|
||||
const size_t remaining = num - i;
|
||||
if (remaining != 0) {
|
||||
const VF va = hn::LoadN(df, a + i, remaining);
|
||||
const VF vb = hn::LoadN(df, b + i, remaining);
|
||||
const VF mul = hn::Mul(va, vb);
|
||||
UpdateCascadedSums(df, mul, sum, sum_err);
|
||||
UpdateCascadedSums(df, hn::Abs(mul), sum_abs, sum_err_abs);
|
||||
}
|
||||
|
||||
const float div = std::abs(ReduceCascadedSums(df, sum, sum_err));
|
||||
const float div = hwy::ScalarAbs(ReduceCascadedSums(df, sum, sum_err));
|
||||
if (div == 0.0f) return hn::GetLane(hn::Inf(df));
|
||||
const double cond = 2.0 * ReduceCascadedSums(df, sum_abs, sum_err_abs) /
|
||||
const double cond = 2.0 * ReduceCascadedSums(df, sum_abs, sum_abs_err) /
|
||||
static_cast<double>(div);
|
||||
HWY_ASSERT(cond >= 0.0);
|
||||
return cond;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Compensated dot product
|
||||
|
||||
#if !HWY_NATIVE_FMA
|
||||
|
||||
// Returns non-overlapping `x` and `y` such that `x + y` = `f` and |x| >= |y|.
|
||||
// Notation from Algorithm 3.1 in Handbook of Floating-Point Arithmetic. 4 ops.
|
||||
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||
static HWY_INLINE void VeltkampSplit(DF df, VF a, VF& x, VF& y) {
|
||||
using TF = hn::TFromD<DF>;
|
||||
constexpr int t = hwy::MantissaBits<TF>() + 1; // = -log2(epsilon)
|
||||
constexpr int s = hwy::DivCeil(t, 2);
|
||||
const VF factor = hn::Set(df, hwy::ConvertScalarTo<TF>((1ULL << s) + 1));
|
||||
const VF c = hn::Mul(factor, a);
|
||||
x = hn::Sub(c, hn::Sub(c, a));
|
||||
y = hn::Sub(a, x);
|
||||
}
|
||||
|
||||
#endif // !HWY_NATIVE_FMA
|
||||
|
||||
// Returns `prod` and `err` such that `prod + err` is exactly equal to `a * b`,
|
||||
// despite floating-point rounding, assuming that `err` is not subnormal, i.e.,
|
||||
// the sum of exponents >= min exponent + mantissa bits. 2..17 ops.
|
||||
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||
static HWY_INLINE VF TwoProducts(DF df, VF a, VF b, VF& err) {
|
||||
const VF prod = hn::Mul(a, b);
|
||||
#if HWY_NATIVE_FMA
|
||||
err = hn::MulSub(a, b, prod);
|
||||
#else
|
||||
VF a1, a2, b1, b2;
|
||||
VeltkampSplit(df, a, a1, a2);
|
||||
VeltkampSplit(df, b, b1, b2);
|
||||
const VF m = hn::Sub(prod, hn::Mul(a1, b1));
|
||||
const VF n = hn::Sub(m, hn::Mul(a2, b1));
|
||||
const VF o = hn::Sub(n, hn::Mul(a1, b2));
|
||||
err = hn::Sub(hn::Mul(a2, b2), o);
|
||||
#endif
|
||||
return prod;
|
||||
}
|
||||
|
||||
// Algorithm 6.15 from Handbook of Floating-Point Arithmetic.
|
||||
template <class DF, typename WeightT, typename VecT>
|
||||
HWY_INLINE float CompensatedDot(DF df, const WeightT* HWY_RESTRICT w,
|
||||
size_t w_ofs, const VecT* HWY_RESTRICT x,
|
||||
size_t num) {
|
||||
// Same, but for a single vector - just skips the product.
|
||||
template <typename VecT>
|
||||
HWY_MAYBE_UNUSED double ConditionNumber(const VecT* HWY_RESTRICT v,
|
||||
size_t num) {
|
||||
PROFILER_FUNC;
|
||||
const hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const size_t N = hn::Lanes(df);
|
||||
HWY_ASSERT((num % (2 * N)) == 0);
|
||||
HWY_DASSERT(hn::IsAligned(df, x));
|
||||
using VF = hn::Vec<DF>;
|
||||
using TraitsW = CompressTraits<WeightT>;
|
||||
using TraitsV = CompressTraits<VecT>;
|
||||
|
||||
VF sum0 = hn::Zero(df);
|
||||
VF sum1 = hn::Zero(df);
|
||||
VF sum_err0 = hn::Zero(df);
|
||||
VF sum_err1 = hn::Zero(df);
|
||||
VF sum = hn::Zero(df);
|
||||
VF sum_err = hn::Zero(df);
|
||||
VF sum_abs = hn::Zero(df);
|
||||
VF sum_abs_err = hn::Zero(df);
|
||||
|
||||
VF w0, w1, v0, v1; // decompressed inputs
|
||||
VF perr0, perr1, serr0, serr1; // output arg of TwoProducts/TwoSums
|
||||
const auto packed_v = MakeSpan(v, num);
|
||||
|
||||
for (size_t i = 0; i < num; i += 2 * N) {
|
||||
TraitsW::Decompress2(df, w, w_ofs + i, w0, w1);
|
||||
TraitsV::Decompress2(df, x, i, v0, v1);
|
||||
size_t i = 0;
|
||||
if (num >= 2 * N) {
|
||||
for (; i <= num - 2 * N; i += 2 * N) {
|
||||
VF v0, v1;
|
||||
Decompress2(df, packed_v, i, v0, v1);
|
||||
UpdateCascadedSums(df, v0, sum, sum_err);
|
||||
UpdateCascadedSums(df, v1, sum, sum_err);
|
||||
UpdateCascadedSums(df, hn::Abs(v0), sum_abs, sum_abs_err);
|
||||
UpdateCascadedSums(df, hn::Abs(v1), sum_abs, sum_abs_err);
|
||||
}
|
||||
}
|
||||
|
||||
size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < 2 * N);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
HWY_ALIGN float padded_v[2 * hn::MaxLanes(df)];
|
||||
DecompressAndZeroPad(df, packed_v, i, padded_v, remaining);
|
||||
|
||||
// 1..2 whole vectors, possibly zero-padded.
|
||||
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
|
||||
const VF v0 = hn::Load(df, padded_v + padded_pos);
|
||||
UpdateCascadedSums(df, v0, sum, sum_err);
|
||||
UpdateCascadedSums(df, hn::Abs(v0), sum_abs, sum_abs_err);
|
||||
}
|
||||
}
|
||||
|
||||
const float div = hwy::ScalarAbs(ReduceCascadedSums(df, sum, sum_err));
|
||||
if (div == 0.0f) return hn::GetLane(hn::Inf(df));
|
||||
const double cond = 2.0 * ReduceCascadedSums(df, sum_abs, sum_abs_err) /
|
||||
static_cast<double>(div);
|
||||
HWY_ASSERT(cond >= 0.0);
|
||||
return cond;
|
||||
}
|
||||
|
||||
// Algorithm 6.15 from Handbook of Floating-Point Arithmetic. 10 ops is too slow
|
||||
// for compute-limited Matmul but might be OK for attention.
|
||||
// Also supports bf16 inputs, used by matvec-inl.h.
|
||||
struct DotKernelCompensated {
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
|
||||
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
||||
const VF w3, const VF v0, const VF v1, const VF v2,
|
||||
const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
|
||||
VF perr0, perr1, perr2, perr3;
|
||||
const VF prod0 = TwoProducts(df, w0, v0, perr0);
|
||||
const VF prod1 = TwoProducts(df, w1, v1, perr1);
|
||||
const VF prod2 = TwoProducts(df, w2, v2, perr2);
|
||||
const VF prod3 = TwoProducts(df, w3, v3, perr3);
|
||||
|
||||
VF serr0, serr1, serr2, serr3;
|
||||
sum0 = TwoSums(df, prod0, sum0, serr0);
|
||||
sum1 = TwoSums(df, prod1, sum1, serr1);
|
||||
sum2 = TwoSums(df, prod2, sum2, serr2);
|
||||
sum3 = TwoSums(df, prod3, sum3, serr3);
|
||||
|
||||
sum_err0 += perr0 + serr0;
|
||||
sum_err1 += perr1 + serr1;
|
||||
comp0 = hn::Add(comp0, hn::Add(perr0, serr0));
|
||||
comp1 = hn::Add(comp1, hn::Add(perr1, serr1));
|
||||
comp2 = hn::Add(comp2, hn::Add(perr2, serr2));
|
||||
comp3 = hn::Add(comp3, hn::Add(perr3, serr3));
|
||||
}
|
||||
|
||||
AssimilateCascadedSums(df, sum1, sum_err1, sum0, sum_err0);
|
||||
return ReduceCascadedSums(df, sum0, sum_err0);
|
||||
template <class DBF, class VBF = hn::Vec<DBF>, HWY_IF_BF16_D(DBF),
|
||||
class DF = hn::Repartition<float, DBF>, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void Update4(DBF /*dbf*/, const VBF w0, const VBF w1, const VBF w2,
|
||||
const VBF w3, const VBF v0, const VBF v1,
|
||||
const VBF v2, const VBF v3, VF& sum0, VF& sum1,
|
||||
VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2,
|
||||
VF& comp3) const {
|
||||
const DF df;
|
||||
const VF prod0 = WidenMulPairwiseAdd(df, w0, v0);
|
||||
const VF prod1 = WidenMulPairwiseAdd(df, w1, v1);
|
||||
const VF prod2 = WidenMulPairwiseAdd(df, w2, v2);
|
||||
const VF prod3 = WidenMulPairwiseAdd(df, w3, v3);
|
||||
|
||||
VF serr0, serr1, serr2, serr3;
|
||||
sum0 = TwoSums(df, prod0, sum0, serr0);
|
||||
sum1 = TwoSums(df, prod1, sum1, serr1);
|
||||
sum2 = TwoSums(df, prod2, sum2, serr2);
|
||||
sum3 = TwoSums(df, prod3, sum3, serr3);
|
||||
|
||||
comp0 = hn::Add(comp0, serr0);
|
||||
comp1 = hn::Add(comp1, serr1);
|
||||
comp2 = hn::Add(comp2, serr2);
|
||||
comp3 = hn::Add(comp3, serr3);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
|
||||
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
|
||||
VF& comp0) const {
|
||||
VF perr0;
|
||||
const VF prod0 = TwoProducts(df, w0, v0, perr0);
|
||||
|
||||
VF serr0;
|
||||
sum0 = TwoSums(df, prod0, sum0, serr0);
|
||||
|
||||
comp0 = hn::Add(comp0, hn::Add(perr0, serr0));
|
||||
}
|
||||
|
||||
template <class DBF, class VBF = hn::Vec<DBF>, HWY_IF_BF16_D(DBF),
|
||||
class DF = hn::Repartition<float, DBF>, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void Update1(DBF /*dbf*/, const VBF w0, const VBF v0, VF& sum0,
|
||||
VF& comp0) const {
|
||||
const DF df;
|
||||
const VF prod0 = WidenMulPairwiseAdd(df, w0, v0);
|
||||
|
||||
VF serr0;
|
||||
sum0 = TwoSums(df, prod0, sum0, serr0);
|
||||
|
||||
comp0 = hn::Add(comp0, serr0);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
|
||||
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
||||
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
|
||||
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
|
||||
AssimilateCascadedSums(df, sum2, comp2, sum0, comp0);
|
||||
return ReduceCascadedSums(df, sum0, comp0);
|
||||
}
|
||||
};
|
||||
|
||||
// Default kernel
|
||||
template <class D, typename WeightT, typename VecT>
|
||||
HWY_INLINE float Dot(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned, size_t num) {
|
||||
return DecompressAndCall(d, w, w_ofs, vec_aligned, num,
|
||||
DotKernelCompensated());
|
||||
}
|
||||
|
||||
// Adapter for a single pointer, no bounds checking.
|
||||
template <typename WeightT, typename VecT>
|
||||
HWY_INLINE float Dot(const WeightT* HWY_RESTRICT w, const VecT* vec_aligned,
|
||||
size_t num) {
|
||||
const hn::ScalableTag<VecT> d;
|
||||
return Dot(d, MakeConstSpan(w, num), /*w_ofs=*/0, vec_aligned, num);
|
||||
}
|
||||
|
||||
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
|
||||
template <size_t kCapacity, typename VecT>
|
||||
HWY_INLINE float Dot(const std::array<float, kCapacity>& w, size_t w_ofs,
|
||||
const VecT* vec_aligned, size_t num) {
|
||||
const hn::ScalableTag<VecT> d;
|
||||
return Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec_aligned, num);
|
||||
}
|
||||
|
||||
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
|
||||
template <typename MatT, size_t kCapacity, typename VecT>
|
||||
HWY_INLINE float Dot(const CompressedArray<MatT, kCapacity>& w, size_t w_ofs,
|
||||
const VecT* vec_aligned, size_t num) {
|
||||
const hn::ScalableTag<VecT> d;
|
||||
return w.scale() *
|
||||
Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec_aligned, num);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
797
ops/dot_test.cc
797
ops/dot_test.cc
|
|
@ -21,15 +21,18 @@
|
|||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::swap
|
||||
#include <algorithm> // std::swap, std::sort
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/shared.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/test_util.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/stats.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
|
|
@ -40,17 +43,526 @@
|
|||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "compression/test_util-inl.h"
|
||||
#include "ops/dot-inl.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
#include "hwy/profiler.h" // also uses SIMD
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
using Array = hwy::AlignedFreeUniquePtr<float[]>;
|
||||
//------------------------------------------------------------------------------
|
||||
// Dot product variants
|
||||
|
||||
// All combinations of {*, TwoProducts} x {+, FastTwoSums, TwoSums}.
|
||||
|
||||
struct DotKernelNaive {
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
||||
const VF w3, const VF v0, const VF v1, const VF v2,
|
||||
const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||
VF& /*comp0*/, VF& /*comp1*/, VF& /*comp2*/,
|
||||
VF& /*comp3*/) const {
|
||||
sum0 = hn::MulAdd(w0, v0, sum0);
|
||||
sum1 = hn::MulAdd(w1, v1, sum1);
|
||||
sum2 = hn::MulAdd(w2, v2, sum2);
|
||||
sum3 = hn::MulAdd(w3, v3, sum3);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
|
||||
VF& /*comp0*/) const {
|
||||
sum0 = hn::MulAdd(w0, v0, sum0);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||
VF& /*comp0*/, VF& /*comp1*/, VF& /*comp2*/,
|
||||
VF& /*comp3*/) const {
|
||||
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
||||
sum0 = hn::Add(sum0, sum1);
|
||||
sum2 = hn::Add(sum2, sum3);
|
||||
sum0 = hn::Add(sum0, sum2);
|
||||
return hn::ReduceSum(df, sum0);
|
||||
}
|
||||
};
|
||||
template <class D, typename WeightT, typename VecT>
|
||||
HWY_INLINE float DotNaive(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned, size_t num) {
|
||||
return DecompressAndCall(d, w, w_ofs, vec_aligned, num, DotKernelNaive());
|
||||
}
|
||||
|
||||
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm: FastTwoSum.
|
||||
struct DotKernelKahan {
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
||||
const VF w3, const VF v0, const VF v1, const VF v2,
|
||||
const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
|
||||
// Add compensation from last iteration, which is an approximation of the
|
||||
// running error.
|
||||
const VF prod0 = hn::MulAdd(w0, v0, comp0);
|
||||
const VF prod1 = hn::MulAdd(w1, v1, comp1);
|
||||
const VF prod2 = hn::MulAdd(w2, v2, comp2);
|
||||
const VF prod3 = hn::MulAdd(w3, v3, comp3);
|
||||
|
||||
sum0 = FastTwoSums(df, sum0, prod0, comp0);
|
||||
sum1 = FastTwoSums(df, sum1, prod1, comp1);
|
||||
sum2 = FastTwoSums(df, sum2, prod2, comp2);
|
||||
sum3 = FastTwoSums(df, sum3, prod3, comp3);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
|
||||
VF& comp0) const {
|
||||
const VF prod0 = hn::MulAdd(w0, v0, comp0);
|
||||
sum0 = FastTwoSums(df, sum0, prod0, comp0);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
|
||||
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
||||
comp0 = hn::Add(comp0, comp1);
|
||||
comp2 = hn::Add(comp2, comp3);
|
||||
VF sum_err = hn::Add(comp0, comp2);
|
||||
UpdateCascadedSums(df, sum1, sum0, sum_err);
|
||||
UpdateCascadedSums(df, sum3, sum2, sum_err);
|
||||
UpdateCascadedSums(df, sum2, sum0, sum_err);
|
||||
return ReduceCascadedSums(df, sum0, sum_err);
|
||||
}
|
||||
};
|
||||
template <class D, typename WeightT, typename VecT>
|
||||
HWY_INLINE float DotKahan(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned, size_t num) {
|
||||
return DecompressAndCall(d, w, w_ofs, vec_aligned, num, DotKernelKahan());
|
||||
}
|
||||
|
||||
template <class D, typename WeightT, typename VecT>
|
||||
HWY_INLINE float DotCompensated(D d, const PackedSpan<const WeightT>& w,
|
||||
size_t w_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
size_t num) {
|
||||
return DecompressAndCall(d, w, w_ofs, vec_aligned, num,
|
||||
DotKernelCompensated());
|
||||
}
|
||||
|
||||
// Like Compensated, but FastTwoSum instead of TwoSum.
|
||||
struct DotKernelTwoProdFast {
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
||||
const VF w3, const VF v0, const VF v1, const VF v2,
|
||||
const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
|
||||
VF perr0, perr1, perr2, perr3;
|
||||
const VF prod0 = TwoProducts(df, w0, v0, perr0);
|
||||
const VF prod1 = TwoProducts(df, w1, v1, perr1);
|
||||
const VF prod2 = TwoProducts(df, w2, v2, perr2);
|
||||
const VF prod3 = TwoProducts(df, w3, v3, perr3);
|
||||
|
||||
VF serr0, serr1, serr2, serr3;
|
||||
sum0 = FastTwoSums(df, sum0, prod0, serr0);
|
||||
sum1 = FastTwoSums(df, sum1, prod1, serr1);
|
||||
sum2 = FastTwoSums(df, sum2, prod2, serr2);
|
||||
sum3 = FastTwoSums(df, sum3, prod3, serr3);
|
||||
|
||||
comp0 = hn::Add(comp0, hn::Add(perr0, serr0));
|
||||
comp1 = hn::Add(comp1, hn::Add(perr1, serr1));
|
||||
comp2 = hn::Add(comp2, hn::Add(perr2, serr2));
|
||||
comp3 = hn::Add(comp3, hn::Add(perr3, serr3));
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
|
||||
VF& comp0) const {
|
||||
VF perr0;
|
||||
const VF prod0 = TwoProducts(df, w0, v0, perr0);
|
||||
|
||||
VF serr0;
|
||||
sum0 = FastTwoSums(df, sum0, prod0, serr0);
|
||||
|
||||
comp0 = hn::Add(comp0, hn::Add(perr0, serr0));
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
|
||||
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
||||
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
|
||||
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
|
||||
AssimilateCascadedSums(df, sum2, comp2, sum0, comp0);
|
||||
return ReduceCascadedSums(df, sum0, comp0);
|
||||
}
|
||||
};
|
||||
template <class D, typename WeightT, typename VecT>
|
||||
HWY_INLINE float DotTwoProdFast(D d, const PackedSpan<const WeightT>& w,
|
||||
size_t w_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
size_t num) {
|
||||
return DecompressAndCall(d, w, w_ofs, vec_aligned, num,
|
||||
DotKernelTwoProdFast());
|
||||
}
|
||||
|
||||
// Like Compensated, but without TwoProducts. Vs Kahan, upgrades FastTwoSums
|
||||
// to TwoSums.
|
||||
struct DotKernelMulTwoSum {
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
||||
const VF w3, const VF v0, const VF v1, const VF v2,
|
||||
const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
|
||||
const VF prod0 = hn::Mul(w0, v0);
|
||||
const VF prod1 = hn::Mul(w1, v1);
|
||||
const VF prod2 = hn::Mul(w2, v2);
|
||||
const VF prod3 = hn::Mul(w3, v3);
|
||||
|
||||
VF serr0, serr1, serr2, serr3;
|
||||
sum0 = TwoSums(df, prod0, sum0, serr0);
|
||||
sum1 = TwoSums(df, prod1, sum1, serr1);
|
||||
sum2 = TwoSums(df, prod2, sum2, serr2);
|
||||
sum3 = TwoSums(df, prod3, sum3, serr3);
|
||||
|
||||
comp0 = hn::Add(comp0, serr0);
|
||||
comp1 = hn::Add(comp1, serr1);
|
||||
comp2 = hn::Add(comp2, serr2);
|
||||
comp3 = hn::Add(comp3, serr3);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
|
||||
VF& comp0) const {
|
||||
const VF prod0 = hn::Mul(w0, v0);
|
||||
|
||||
VF serr0;
|
||||
sum0 = TwoSums(df, prod0, sum0, serr0);
|
||||
|
||||
comp0 = hn::Add(comp0, serr0);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
|
||||
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
||||
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
|
||||
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
|
||||
AssimilateCascadedSums(df, sum2, comp2, sum0, comp0);
|
||||
return ReduceCascadedSums(df, sum0, comp0);
|
||||
}
|
||||
};
|
||||
template <class D, typename WeightT, typename VecT>
|
||||
HWY_INLINE float DotMulTwoSum(D d, const PackedSpan<const WeightT>& w,
|
||||
size_t w_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
size_t num) {
|
||||
return DecompressAndCall(d, w, w_ofs, vec_aligned, num, DotKernelMulTwoSum());
|
||||
}
|
||||
|
||||
// -Like Compensated, but only TwoProducts, no [Fast]TwoSums. This is only 10%
|
||||
// better (mul) than naive.
|
||||
struct DotKernelTwoProdAdd {
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
||||
const VF w3, const VF v0, const VF v1, const VF v2,
|
||||
const VF v3, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
|
||||
VF perr0, perr1, perr2, perr3;
|
||||
const VF prod0 = TwoProducts(df, w0, v0, perr0);
|
||||
const VF prod1 = TwoProducts(df, w1, v1, perr1);
|
||||
const VF prod2 = TwoProducts(df, w2, v2, perr2);
|
||||
const VF prod3 = TwoProducts(df, w3, v3, perr3);
|
||||
|
||||
sum0 = hn::Add(sum0, prod0);
|
||||
sum1 = hn::Add(sum1, prod1);
|
||||
sum2 = hn::Add(sum2, prod2);
|
||||
sum3 = hn::Add(sum3, prod3);
|
||||
|
||||
comp0 = hn::Add(comp0, perr0);
|
||||
comp1 = hn::Add(comp1, perr1);
|
||||
comp2 = hn::Add(comp2, perr2);
|
||||
comp3 = hn::Add(comp3, perr3);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
|
||||
VF& comp0) const {
|
||||
VF perr0;
|
||||
const VF prod0 = TwoProducts(df, w0, v0, perr0);
|
||||
|
||||
sum0 = hn::Add(sum0, prod0);
|
||||
|
||||
comp0 = hn::Add(comp0, perr0);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
|
||||
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
||||
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
|
||||
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
|
||||
AssimilateCascadedSums(df, sum2, comp2, sum0, comp0);
|
||||
return ReduceCascadedSums(df, sum0, comp0);
|
||||
}
|
||||
};
|
||||
template <class D, typename WeightT, typename VecT>
|
||||
HWY_INLINE float DotTwoProdAdd(D d, const PackedSpan<const WeightT>& w,
|
||||
size_t w_ofs,
|
||||
const VecT* HWY_RESTRICT vec_aligned,
|
||||
size_t num) {
|
||||
return DecompressAndCall(d, w, w_ofs, vec_aligned, num,
|
||||
DotKernelTwoProdAdd());
|
||||
}
|
||||
|
||||
enum { // alphabetical order
|
||||
kAddTwoProd,
|
||||
kAddTwoSum,
|
||||
kCompensated,
|
||||
kKahan,
|
||||
kNaive,
|
||||
kOnlyTwoProd,
|
||||
|
||||
kVariants
|
||||
};
|
||||
|
||||
const char* VariantName(size_t variant) {
|
||||
switch (variant) {
|
||||
case kAddTwoProd:
|
||||
return "add2prod";
|
||||
case kAddTwoSum:
|
||||
return "add2sum";
|
||||
case kCompensated:
|
||||
return "comp";
|
||||
case kKahan:
|
||||
return "kahan";
|
||||
case kNaive:
|
||||
return "naive";
|
||||
case kOnlyTwoProd:
|
||||
return "only2prod";
|
||||
default:
|
||||
HWY_ABORT("Unknown variant %zu", variant);
|
||||
return "?";
|
||||
}
|
||||
}
|
||||
|
||||
template <class D, typename WeightT, typename VecT>
|
||||
float CallDot(D d, size_t variant, const PackedSpan<const WeightT>& w,
|
||||
size_t w_ofs, const VecT* HWY_RESTRICT v, size_t num) {
|
||||
switch (variant) {
|
||||
case kAddTwoProd:
|
||||
return DotTwoProdFast(d, w, 0, v, num);
|
||||
case kAddTwoSum:
|
||||
return DotMulTwoSum(d, w, 0, v, num);
|
||||
case kCompensated:
|
||||
return DotCompensated(d, w, 0, v, num);
|
||||
case kKahan:
|
||||
return DotKahan(d, w, 0, v, num);
|
||||
case kNaive:
|
||||
return DotNaive(d, w, 0, v, num);
|
||||
case kOnlyTwoProd:
|
||||
return DotTwoProdAdd(d, w, 0, v, num);
|
||||
default:
|
||||
HWY_ABORT("Unknown variant %zu", variant);
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns result accurate to 1.5 ulp, assuming `num` < 2^(52-23), no overflow,
|
||||
// and round to nearest. See "Accurate and efficient floating point summation".
|
||||
// Much too slow to be useful. Kept separate from the above kernels because it
|
||||
// is used to compute their error.
|
||||
template <typename WeightT, typename VecT>
|
||||
float ExactDot(const WeightT* HWY_RESTRICT w, const VecT* HWY_RESTRICT v,
|
||||
size_t num, double* HWY_RESTRICT buf) {
|
||||
PROFILER_FUNC;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
buf[i] =
|
||||
hwy::ConvertScalarTo<double>(w[i]) * hwy::ConvertScalarTo<double>(v[i]);
|
||||
}
|
||||
// Sort by decreasing magnitude (not supported by VQSort).
|
||||
std::sort(buf, buf + num,
|
||||
[](double a, double b) { return std::abs(a) > std::abs(b); });
|
||||
double sum = 0.0;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
sum += buf[i];
|
||||
}
|
||||
return static_cast<float>(sum);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
class DotStats {
|
||||
static float Ratio(float a, float b) {
|
||||
// If 0, we would return infinity, which messes up the statistics.
|
||||
if (a == 0.0f || b == 0.0f) return 1.0f;
|
||||
// Absolute value because a sign change and 4x difference would
|
||||
// otherwise return the smaller ratio 0.25.
|
||||
return HWY_MAX(std::abs(a / b), std::abs(b / a));
|
||||
}
|
||||
|
||||
public:
|
||||
DotStats() {
|
||||
for (size_t i = 0; i < kVariants; ++i) {
|
||||
max_muls[i] = 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
static void PrintStats(const char* caption, size_t variant,
|
||||
const hwy::Stats& stats) {
|
||||
fprintf(stderr, "%s %9s %s\n", caption, VariantName(variant),
|
||||
stats.ToString(/*exclude=*/0).c_str());
|
||||
}
|
||||
|
||||
// Call once per rep.
|
||||
void NotifyRep(size_t num, double cond, float dot_exact,
|
||||
float dots[kVariants]) {
|
||||
s_cond.Notify(cond);
|
||||
const float mul_tol = cond > 1E8 ? 1.5f : cond > 1E7 ? 1.1f : 1.01f;
|
||||
|
||||
float muls[kVariants];
|
||||
float l1s[kVariants];
|
||||
uint32_t ulps[kVariants];
|
||||
for (size_t i = 0; i < kVariants; ++i) {
|
||||
muls[i] = Ratio(dots[i], dot_exact);
|
||||
max_muls[i] = HWY_MAX(max_muls[i], muls[i]);
|
||||
|
||||
l1s[i] = std::abs(dots[i] - dot_exact);
|
||||
s_l1s[i].Notify(l1s[i]);
|
||||
|
||||
ulps[i] = hwy::detail::ComputeUlpDelta(dots[i], dot_exact);
|
||||
s_ulps[i].Notify(ulps[i]);
|
||||
}
|
||||
|
||||
if (muls[kKahan] > mul_tol || l1s[kKahan] > 0.1f ||
|
||||
muls[kNaive] + 1E-3f < muls[kKahan] || ulps[kCompensated] > 10) {
|
||||
fprintf(stderr, "num %2zu cond %.1E exact %.8f\n", num, cond, dot_exact);
|
||||
for (size_t i = 0; i < kVariants; ++i) {
|
||||
fprintf(stderr, " %9s dot %11.8f mul %.8f\n", VariantName(i), dots[i],
|
||||
muls[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Call after all reps.
|
||||
void NotifyRatios() {
|
||||
for (size_t i = 0; i < kVariants; ++i) {
|
||||
s_muls[i].Notify(max_muls[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void NotifyTimes(double times[kVariants]) {
|
||||
for (size_t i = 0; i < kVariants; ++i) {
|
||||
s_times[i].Notify(times[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void Assimilate(const DotStats& other) {
|
||||
s_cond.Assimilate(other.s_cond);
|
||||
for (size_t i = 0; i < kVariants; ++i) {
|
||||
s_muls[i].Assimilate(other.s_muls[i]);
|
||||
s_l1s[i].Assimilate(other.s_l1s[i]);
|
||||
s_ulps[i].Assimilate(other.s_ulps[i]);
|
||||
s_times[i].Assimilate(other.s_times[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const {
|
||||
PrintStats("cond", 0, s_cond);
|
||||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
PrintStats("mul", variant, s_muls[variant]);
|
||||
}
|
||||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
PrintStats(" l1", variant, s_l1s[variant]);
|
||||
}
|
||||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
PrintStats("ulp", variant, s_ulps[variant]);
|
||||
}
|
||||
if (s_times[0].Count()) {
|
||||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
PrintStats("time", variant, s_times[variant]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Check() const {
|
||||
CheckMuls();
|
||||
CheckL1();
|
||||
CheckUlps();
|
||||
|
||||
// We do not check times because they can be noisy/nonportable, but
|
||||
// `kAddTwoProd` is only about 10% slower than `kKahan`, and about 1.5 times
|
||||
// as fast as `kCompensated`.
|
||||
}
|
||||
|
||||
private:
|
||||
// Factor by which the approximate result is off; larger is worse.
|
||||
void CheckMuls() const {
|
||||
// Compensated is very accurate.
|
||||
HWY_ASSERT(s_muls[kCompensated].Min() <= 1.0f + 2E-6f);
|
||||
HWY_ASSERT(s_muls[kCompensated].Max() <= 1.0f + 2E-5f);
|
||||
|
||||
// Naive and OnlyTwoProd are considerably worse. >10x is for narrower
|
||||
// vectors, compared to AVX-512. GeometricMean overflows, must use Mean.
|
||||
HWY_ASSERT(gcpp::IsInside(1.01, 16.0, s_muls[kNaive].Mean()));
|
||||
HWY_ASSERT(gcpp::IsInside(1.01, 13.0, s_muls[kOnlyTwoProd].Mean()));
|
||||
|
||||
// Kahan (FastTwoSum) is decent:
|
||||
HWY_ASSERT(gcpp::IsInside(1.001, 4.1, s_muls[kKahan].Mean()));
|
||||
HWY_ASSERT(gcpp::IsInside(1.001f, 14.1f, s_muls[kKahan].Max()));
|
||||
HWY_ASSERT(gcpp::IsInside(1.0, 1.6, s_muls[kKahan].GeometricMean()));
|
||||
|
||||
// But can be considerably improved via TwoProducts:
|
||||
HWY_ASSERT(gcpp::IsInside(1.0005, 1.5, s_muls[kAddTwoProd].Mean()));
|
||||
HWY_ASSERT(gcpp::IsInside(1.001f, 2.3f, s_muls[kAddTwoProd].Max()));
|
||||
HWY_ASSERT(gcpp::IsInside(1.0, 1.2, s_muls[kAddTwoProd].GeometricMean()));
|
||||
// Updating Kahan's FastTwoSums to TwoSums is not quite as helpful.
|
||||
HWY_ASSERT(gcpp::IsInside(1.0005, 2.2, s_muls[kAddTwoSum].Mean()));
|
||||
HWY_ASSERT(gcpp::IsInside(1.0, 1.3, s_muls[kAddTwoProd].GeometricMean()));
|
||||
}
|
||||
|
||||
// Absolute error; larger is worse.
|
||||
void CheckL1() const {
|
||||
// Compensated is very accurate.
|
||||
HWY_ASSERT(s_l1s[kCompensated].Min() == 0.0f);
|
||||
HWY_ASSERT(s_l1s[kCompensated].Max() <= 3E-7f);
|
||||
|
||||
// Naive and OnlyTwoProd are considerably higher, but not huge.
|
||||
HWY_ASSERT(gcpp::IsInside(1E-3, 2E-2, s_l1s[kNaive].Mean()));
|
||||
HWY_ASSERT(gcpp::IsInside(1E-3, 2E-2, s_l1s[kOnlyTwoProd].Mean()));
|
||||
|
||||
// Kahan (FastTwoSum) is decent:
|
||||
HWY_ASSERT(gcpp::IsInside(4.5E-4, 1E-3, s_l1s[kKahan].Mean()));
|
||||
HWY_ASSERT(gcpp::IsInside(1.1E-3f, 3.2E-3f, s_l1s[kKahan].Max()));
|
||||
|
||||
// But can be nearly halved via TwoProducts:
|
||||
HWY_ASSERT(gcpp::IsInside(2.5E-4, 8E-4, s_l1s[kAddTwoProd].Mean()));
|
||||
HWY_ASSERT(gcpp::IsInside(4E-4f, 2.0E-3f, s_l1s[kAddTwoProd].Max()));
|
||||
// Updating Kahan's FastTwoSums to TwoSums does help a bit.
|
||||
HWY_ASSERT(gcpp::IsInside(1.5E-4, 5.2E-4, s_l1s[kAddTwoSum].Mean()));
|
||||
}
|
||||
|
||||
// Units in the last place; larger is worse.
|
||||
void CheckUlps() const {
|
||||
HWY_ASSERT(s_ulps[kCompensated].Max() <= 250.0f);
|
||||
|
||||
HWY_ASSERT(s_ulps[kNaive].Max() <= 4E9f);
|
||||
HWY_ASSERT(s_ulps[kOnlyTwoProd].Max() <= 3E9f);
|
||||
|
||||
HWY_ASSERT(s_ulps[kKahan].Max() <= 4E7f);
|
||||
HWY_ASSERT(s_ulps[kAddTwoProd].Max() <= 1E7f);
|
||||
HWY_ASSERT(s_ulps[kAddTwoSum].Max() <= 2.5E7f);
|
||||
}
|
||||
|
||||
hwy::Stats s_cond;
|
||||
|
||||
// Relative error
|
||||
float max_muls[kVariants];
|
||||
hwy::Stats s_muls[kVariants];
|
||||
|
||||
hwy::Stats s_l1s[kVariants]; // Absolute error
|
||||
|
||||
hwy::Stats s_ulps[kVariants]; // Only relevant for small cond
|
||||
hwy::Stats s_times[kVariants];
|
||||
};
|
||||
|
||||
// Returns normalized value in [-1, 1).
|
||||
float RandomFloat(std::mt19937& rng) {
|
||||
|
|
@ -64,51 +576,77 @@ float RandomFloat(std::mt19937& rng) {
|
|||
return f;
|
||||
}
|
||||
|
||||
// Based on Algorithm 6.1 from "Accurate Sum and Dot Product".
|
||||
// `num` is the size of a, b[, and buf] and must be larger than 2 and even.
|
||||
void GenerateIllConditionedInputs(double target_cond, size_t num,
|
||||
float* HWY_RESTRICT a, float* HWY_RESTRICT b,
|
||||
double* HWY_RESTRICT buf, std::mt19937& rng) {
|
||||
// `raw` holds the decompressed values, so that the test measures only the
|
||||
// error from the Dot algorithms, not the compression.
|
||||
template <typename Packed>
|
||||
void GenerateWellConditionedInputs(const size_t num, float* HWY_RESTRICT raw,
|
||||
std::mt19937& rng,
|
||||
const PackedSpan<Packed>& packed,
|
||||
CompressWorkingSet& work) {
|
||||
std::uniform_int_distribution<int> e_dist(0, 6);
|
||||
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
raw[i] = RandomFloat(rng) * (1 << e_dist(rng));
|
||||
}
|
||||
|
||||
if (IsCompressed<Packed>()) {
|
||||
// Don't care about the original range.
|
||||
(void)ScaleWeights(raw, num);
|
||||
}
|
||||
|
||||
hwy::ThreadPool pool(0); // num is too small for parallelization
|
||||
const size_t packed_ofs = 0;
|
||||
Compress(raw, num, work, packed, packed_ofs, pool);
|
||||
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, MakeConst(packed), packed_ofs, raw, num);
|
||||
}
|
||||
|
||||
// Returns the actual condition number. Based on Algorithm 6.1 from "Accurate
|
||||
// Sum and Dot Product". `num` is the (arbitrary) size of w, v, and buf.
|
||||
template <typename WeightT, typename VecT>
|
||||
double GenerateIllConditionedInputs(const size_t num, WeightT* w,
|
||||
VecT* HWY_RESTRICT v, std::mt19937& rng) {
|
||||
PROFILER_FUNC;
|
||||
HWY_ASSERT(target_cond >= 1.0);
|
||||
HWY_ASSERT(num % 2 == 0);
|
||||
const size_t half = num / 2;
|
||||
const size_t half = HWY_MAX(1, num / 2); // generate at least one random
|
||||
HWY_DASSERT(half != 0);
|
||||
|
||||
const hn::ScalableTag<float> df;
|
||||
|
||||
const int max_exp = static_cast<int>(std::log2(target_cond) / 2.0);
|
||||
const PackedSpan<WeightT> w_span(w, num);
|
||||
|
||||
// Regardless of WeightT and VecT, we will accumulate into float. Multiplying
|
||||
// two maximal inputs and accumulating `num` times is enough for some loss of
|
||||
// precision and condition numbers between 1E6-1E9, which is what we see for
|
||||
// Attention Dot and `RMSNormMul`.
|
||||
const int max_exp = 5;
|
||||
std::uniform_int_distribution<int> e_dist(0, max_exp);
|
||||
|
||||
// First half: random exponents and mantissas
|
||||
for (size_t i = 0; i < half; ++i) {
|
||||
// Ensure the min and max exponents are used.
|
||||
const int e = i == 0 ? 0 : i == 1 ? max_exp : e_dist(rng);
|
||||
a[i] = RandomFloat(rng) * (1 << e);
|
||||
b[i] = RandomFloat(rng) * (1 << e);
|
||||
w[i] = hwy::ConvertScalarTo<WeightT>(RandomFloat(rng) * (1 << e));
|
||||
v[i] = hwy::ConvertScalarTo<VecT>(RandomFloat(rng) * (1 << e));
|
||||
}
|
||||
|
||||
// Zero-init second half for DotExact
|
||||
for (size_t i = half; i < num; ++i) {
|
||||
a[i] = 0.0f;
|
||||
b[i] = 0.0f;
|
||||
}
|
||||
|
||||
const float a_exp_step = max_exp / (half - 1);
|
||||
const float a_exp_step =
|
||||
num == half ? 0.0f : static_cast<float>(max_exp) / (num - half);
|
||||
float a_exp = max_exp; // max_exp downto 0
|
||||
for (size_t i = half; i < num; ++i, a_exp -= a_exp_step) {
|
||||
const int e = static_cast<int>(a_exp);
|
||||
HWY_DASSERT(e >= 0);
|
||||
a[i] = RandomFloat(rng) * (1 << e);
|
||||
w[i] = hwy::ConvertScalarTo<WeightT>(RandomFloat(rng) * (1 << e));
|
||||
const float r = RandomFloat(rng) * (1 << e);
|
||||
if (a[i] == 0.0f) {
|
||||
b[i] = 0.0f;
|
||||
if (hwy::ConvertScalarTo<float>(w[i]) == 0.0f) {
|
||||
v[i] = hwy::ConvertScalarTo<VecT>(0.0f);
|
||||
} else {
|
||||
// This is called >100K times. CompensatedDot is much faster than ExactDot
|
||||
// and just about as accurate, but requires multiples of two vectors.
|
||||
// const float exact = ExactDot(a, b, i, buf);
|
||||
(void)buf;
|
||||
const size_t padded = hwy::RoundUpTo(i, 2 * hn::Lanes(df));
|
||||
const float exact = CompensatedDot(df, a, /*w_ofs=*/0, b, padded);
|
||||
b[i] = r - exact / a[i];
|
||||
// This is called >100K times. DotCompensated is much faster than ExactDot
|
||||
// and just about as accurate.
|
||||
const float exact =
|
||||
DotCompensated(df, MakeConst(w_span), /*w_ofs=*/0, v, i);
|
||||
v[i] = hwy::ConvertScalarTo<VecT>(
|
||||
r - exact / hwy::ConvertScalarTo<float>(w[i]));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -118,20 +656,106 @@ void GenerateIllConditionedInputs(double target_cond, size_t num,
|
|||
std::uniform_int_distribution<size_t> dist(0, i);
|
||||
const size_t j = dist(rng);
|
||||
|
||||
std::swap(a[i], a[j]);
|
||||
std::swap(b[i], b[j]);
|
||||
std::swap(w[i], w[j]);
|
||||
std::swap(v[i], v[j]);
|
||||
}
|
||||
|
||||
return ConditionNumber(w, v, num);
|
||||
}
|
||||
|
||||
template <typename T, size_t kNum>
|
||||
void PrintStats(const char* caption, const std::array<T, kNum>& values) {
|
||||
hwy::Stats stats;
|
||||
for (T t : values) {
|
||||
stats.Notify(static_cast<float>(t));
|
||||
// Runs all Dot algorithms for all short lengths and all Packed/raw types
|
||||
// on well-conditioned inputs, and ensures the results are close to exact.
|
||||
template <typename Packed>
|
||||
struct TestShortDotsT {
|
||||
template <typename T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
const size_t N = hn::Lanes(d);
|
||||
const hn::ScalableTag<float> df; // for CallDot
|
||||
|
||||
CompressWorkingSet work;
|
||||
std::mt19937 rng;
|
||||
rng.seed(12345);
|
||||
|
||||
hwy::Stats s_l1[kVariants];
|
||||
|
||||
for (size_t num = 1; num <= 5 * N; ++num) {
|
||||
// GenerateWellConditionedInputs calls DecompressAndZeroPad to `raw*`,
|
||||
// hence they require padding to one vector.
|
||||
const size_t padded_num = hwy::RoundUpTo(num, N);
|
||||
const size_t packed_num = CompressedArrayElements<Packed>(num);
|
||||
RowVectorBatch<float> raw_w(1, padded_num);
|
||||
RowVectorBatch<float> raw_v(1, padded_num);
|
||||
RowVectorBatch<Packed> weights(1, packed_num);
|
||||
const PackedSpan<Packed> w(weights.Batch(0), packed_num);
|
||||
RowVectorBatch<T> vectors(1, num);
|
||||
const PackedSpan<T> v(vectors.Batch(0), num);
|
||||
|
||||
RowVectorBatch<double> bufs(1, num);
|
||||
double* HWY_RESTRICT buf = bufs.Batch(0);
|
||||
|
||||
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
|
||||
GenerateWellConditionedInputs(num, raw_w.All(), rng, w, work);
|
||||
GenerateWellConditionedInputs(num, raw_v.All(), rng, v, work);
|
||||
|
||||
const float dot_exact = ExactDot(raw_w.All(), raw_v.All(), num, buf);
|
||||
float dots[kVariants];
|
||||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
dots[variant] = CallDot(df, variant, MakeConst(w), 0, v.ptr, num);
|
||||
|
||||
const float l1 = hwy::ScalarAbs(dots[variant] - dot_exact);
|
||||
s_l1[variant].Notify(l1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Avoid extra output for partial vectors.
|
||||
if (hn::detail::IsFull(d)) {
|
||||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
DotStats::PrintStats("l1", variant, s_l1[variant]);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the dot products are plausible. This is only to verify
|
||||
// correctness, not to differentiate between the variants.
|
||||
double expected_l1[kVariants];
|
||||
// Tolerances are much lower for compressed inputs: the more limited set of
|
||||
// values seems to reduce roundoff.
|
||||
constexpr bool kCompressed = IsCompressed<Packed>();
|
||||
expected_l1[kAddTwoProd] = kCompressed ? 1.5E-6 : 5E-5;
|
||||
expected_l1[kAddTwoSum] = kCompressed ? 1.5E-6 : 6E-5;
|
||||
expected_l1[kCompensated] = kCompressed ? 1.5E-6 : 4E-5;
|
||||
expected_l1[kKahan] = kCompressed ? 1.5E-6 : 7E-5;
|
||||
expected_l1[kNaive] = kCompressed ? 4E-6 : 1.5E-4;
|
||||
expected_l1[kOnlyTwoProd] = kCompressed ? 1.5E-6 : 6E-5;
|
||||
|
||||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
HWY_ASSERT(s_l1[variant].Min() >= 0.0f);
|
||||
HWY_ASSERT(s_l1[variant].Max() <= 1.5E-3f);
|
||||
if (s_l1[variant].Mean() > expected_l1[variant]) {
|
||||
HWY_ABORT("%s -> %s: %s mean l1 %.5E > %.5E\n", TypeName<Packed>(),
|
||||
TypeName<T>(), VariantName(variant), s_l1[variant].Mean(),
|
||||
expected_l1[variant]);
|
||||
}
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "%s %s\n", caption, stats.ToString().c_str());
|
||||
};
|
||||
|
||||
void TestAllShortDots() { ForeachPackedAndRawType<TestShortDotsT>(); }
|
||||
|
||||
// Excludes outliers; we might not have enough samples for a reliable mode.
|
||||
double TrimmedMean(double* seconds, size_t num) {
|
||||
std::sort(seconds, seconds + num);
|
||||
double sum = 0;
|
||||
int count = 0;
|
||||
for (size_t i = num / 4; i < num / 2; ++i) {
|
||||
sum += seconds[i];
|
||||
count += 1;
|
||||
}
|
||||
return sum / count;
|
||||
}
|
||||
|
||||
// Tests W=float, V=float for one large size and many reps on ill-conditioned
|
||||
// inputs. Also includes benchmarking.
|
||||
void TestAllDot() {
|
||||
// Skip EMU128 and old x86, include SSE4 because it tests the non-FMA path.
|
||||
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSSE3 ||
|
||||
|
|
@ -139,72 +763,64 @@ void TestAllDot() {
|
|||
return;
|
||||
}
|
||||
|
||||
hn::ScalableTag<float> df;
|
||||
const hn::ScalableTag<float> df;
|
||||
|
||||
constexpr size_t kMaxThreads = 8;
|
||||
std::mt19937 rngs[kMaxThreads];
|
||||
for (size_t i = 0; i < kMaxThreads; ++i) {
|
||||
constexpr size_t kMaxWorkers = 15;
|
||||
std::mt19937 rngs[kMaxWorkers];
|
||||
for (size_t i = 0; i < kMaxWorkers; ++i) {
|
||||
rngs[i].seed(12345 + 65537 * i);
|
||||
}
|
||||
|
||||
constexpr size_t kReps = hn::AdjustedReps(200);
|
||||
constexpr size_t kReps = hn::AdjustedReps(40);
|
||||
const size_t num = 24 * 1024;
|
||||
PerClusterPools pools(/*max_clusters=*/1, kMaxThreads, /*pin=*/1);
|
||||
RowVectorBatch<float> a(kMaxThreads, num);
|
||||
RowVectorBatch<float> b(kMaxThreads, num);
|
||||
RowVectorBatch<double> bufs(kMaxThreads, num);
|
||||
|
||||
const double target_cond = 1e12;
|
||||
std::array<double, kReps> conds;
|
||||
std::array<uint32_t, kReps> ulps_fast;
|
||||
std::array<uint32_t, kReps> ulps_comp;
|
||||
std::array<double, kReps> t_fast;
|
||||
std::array<double, kReps> t_comp;
|
||||
|
||||
constexpr size_t kTimeReps = 3;
|
||||
PerClusterPools pools(/*max_clusters=*/1, kMaxWorkers - 1, /*pin=*/1);
|
||||
RowVectorBatch<float> a(kMaxWorkers, num);
|
||||
RowVectorBatch<float> b(kMaxWorkers, num);
|
||||
RowVectorBatch<double> bufs(kMaxWorkers, num);
|
||||
std::array<DotStats, kMaxWorkers> all_stats;
|
||||
|
||||
pools.Inner(0).Run(0, kReps, [&](const uint32_t rep, size_t thread) {
|
||||
float* HWY_RESTRICT pa = a.Batch(thread);
|
||||
float* HWY_RESTRICT pb = b.Batch(thread);
|
||||
double* HWY_RESTRICT buf = bufs.Batch(thread);
|
||||
GenerateIllConditionedInputs(target_cond, num, pa, pb, buf, rngs[thread]);
|
||||
conds[rep] = ConditionNumber(df, pa, pb, num);
|
||||
const PackedSpan<const float> a_span(pa, num);
|
||||
DotStats& stats = all_stats[thread];
|
||||
const double cond = GenerateIllConditionedInputs(num, pa, pb, rngs[thread]);
|
||||
|
||||
const float dot_exact = ExactDot(pa, pb, num, buf);
|
||||
|
||||
float dot_fast = 0.0f;
|
||||
float dot_comp = 0.0f;
|
||||
|
||||
double elapsed = hwy::HighestValue<double>();
|
||||
for (int rep = 0; rep < kTimeReps; ++rep) {
|
||||
const double start = hwy::platform::Now();
|
||||
dot_fast += SimpleDot(df, pa, 0, pb, num);
|
||||
elapsed = HWY_MIN(elapsed, hwy::platform::Now() - start);
|
||||
float dots[kVariants] = {};
|
||||
double times[kVariants] = {};
|
||||
for (size_t variant = 0; variant < kVariants; ++variant) {
|
||||
constexpr size_t kTimeReps = hn::AdjustedReps(10);
|
||||
std::array<double, kTimeReps> elapsed;
|
||||
for (int time_rep = 0; time_rep < kTimeReps; ++time_rep) {
|
||||
const double start = hwy::platform::Now();
|
||||
dots[variant] += CallDot(df, variant, a_span, /*w_ofs=*/0, pb, num);
|
||||
hwy::PreventElision(*pa);
|
||||
elapsed[time_rep] = hwy::platform::Now() - start;
|
||||
}
|
||||
dots[variant] /= kTimeReps;
|
||||
times[variant] = TrimmedMean(elapsed.data(), kTimeReps);
|
||||
}
|
||||
dot_fast /= kTimeReps;
|
||||
t_fast[rep] = elapsed;
|
||||
|
||||
elapsed = hwy::HighestValue<double>();
|
||||
for (size_t r = 0; r < kTimeReps; ++r) {
|
||||
const double start = hwy::platform::Now();
|
||||
dot_comp += CompensatedDot(df, pa, /*w_ofs=*/0, pb, num);
|
||||
elapsed = HWY_MIN(elapsed, hwy::platform::Now() - start);
|
||||
}
|
||||
dot_comp /= kTimeReps;
|
||||
t_comp[rep] = elapsed;
|
||||
|
||||
ulps_fast[rep] = hwy::detail::ComputeUlpDelta(dot_fast, dot_exact);
|
||||
ulps_comp[rep] = hwy::detail::ComputeUlpDelta(dot_comp, dot_exact);
|
||||
fprintf(stderr, "cond %.1E: %15.7E %15.7E %15.7E ulp %5u %1u\n", conds[rep],
|
||||
dot_exact, dot_fast, dot_comp, ulps_fast[rep], ulps_comp[rep]);
|
||||
stats.NotifyTimes(times);
|
||||
stats.NotifyRep(num, cond, dot_exact, dots);
|
||||
stats.NotifyRatios();
|
||||
});
|
||||
|
||||
DotStats& stats = all_stats[0];
|
||||
for (size_t i = 1; i < kMaxWorkers; ++i) {
|
||||
stats.Assimilate(all_stats[i]);
|
||||
}
|
||||
static bool once = true;
|
||||
if (once) {
|
||||
once = false;
|
||||
stats.Print();
|
||||
}
|
||||
stats.Check();
|
||||
|
||||
PROFILER_PRINT_RESULTS();
|
||||
PrintStats("cond", conds);
|
||||
PrintStats("ulp fast", ulps_fast);
|
||||
PrintStats("ulp comp", ulps_comp);
|
||||
PrintStats("t fast", t_fast);
|
||||
PrintStats("t comp", t_comp);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
@ -216,6 +832,7 @@ HWY_AFTER_NAMESPACE();
|
|||
|
||||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(DotTest);
|
||||
HWY_EXPORT_AND_TEST_P(DotTest, TestAllShortDots);
|
||||
HWY_EXPORT_AND_TEST_P(DotTest, TestAllDot);
|
||||
HWY_AFTER_TEST();
|
||||
|
||||
|
|
|
|||
|
|
@ -21,9 +21,8 @@
|
|||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <algorithm> // std::max
|
||||
#include <cmath> // std::abs
|
||||
#include <memory>
|
||||
|
||||
#include "compression/compress.h"
|
||||
|
|
@ -37,58 +36,59 @@
|
|||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
// After highway.h
|
||||
#include "ops/matvec-inl.h"
|
||||
#include "ops/ops-inl.h" // MulByConst
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
template <size_t kOuter, size_t kInner>
|
||||
hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
|
||||
const CompressedArray<float, kOuter * kInner>& mat,
|
||||
const hwy::AlignedFreeUniquePtr<float[]>& vec,
|
||||
const hwy::AlignedFreeUniquePtr<float[]>& add) {
|
||||
hwy::AlignedFreeUniquePtr<float[]> uncompressed_mat =
|
||||
hwy::AllocateAligned<float>(kOuter * kInner);
|
||||
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter);
|
||||
HWY_ASSERT(uncompressed_mat && out);
|
||||
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
|
||||
MulByConst(mat.scale(), uncompressed_mat.get(), kOuter * kInner);
|
||||
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
|
||||
|
||||
template <size_t kOuter, size_t kInner, size_t kNum = kOuter * kInner>
|
||||
FloatPtr SimpleMatVecAdd(const CompressedArray<float, kNum>& mat,
|
||||
const FloatPtr& vec, const FloatPtr& add) {
|
||||
FloatPtr raw_mat = hwy::AllocateAligned<float>(kNum);
|
||||
FloatPtr out = hwy::AllocateAligned<float>(kOuter);
|
||||
HWY_ASSERT(raw_mat && out);
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, MakeSpan(mat.data(), kNum), 0, raw_mat.get(), kNum);
|
||||
for (size_t idx_row = 0; idx_row < kOuter; idx_row++) {
|
||||
out[idx_row] = add[idx_row];
|
||||
out[idx_row] = 0.0f;
|
||||
for (size_t idx_col = 0; idx_col < kInner; idx_col++) {
|
||||
out[idx_row] +=
|
||||
uncompressed_mat[kInner * idx_row + idx_col] * vec[idx_col];
|
||||
out[idx_row] += raw_mat[kInner * idx_row + idx_col] * vec[idx_col];
|
||||
}
|
||||
out[idx_row] *= mat.scale();
|
||||
out[idx_row] += add[idx_row];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename MatT, size_t kOuter, size_t kInner>
|
||||
CompressedArray<MatT, kOuter * kInner> GenerateMat(size_t offset,
|
||||
hwy::ThreadPool& pool) {
|
||||
template <typename MatT, size_t kOuter, size_t kInner,
|
||||
size_t kNum = kOuter * kInner,
|
||||
class MatPtr = std::unique_ptr<CompressedArray<MatT, kNum>>>
|
||||
MatPtr GenerateMat(size_t offset, hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
CompressedArray<MatT, kOuter * kInner> mat;
|
||||
std::array<float, kOuter * kInner> content;
|
||||
MatPtr mat = std::make_unique<CompressedArray<MatT, kNum>>();
|
||||
FloatPtr raw_mat = hwy::AllocateAligned<float>(kNum);
|
||||
HWY_ASSERT(raw_mat);
|
||||
const float scale = 1.0f / kInner;
|
||||
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
|
||||
for (size_t j = 0; j < kInner; j++) {
|
||||
content[i * kInner + j] =
|
||||
raw_mat[i * kInner + j] =
|
||||
static_cast<float>((i * kInner + j + offset) * scale);
|
||||
}
|
||||
});
|
||||
|
||||
Compress(content, ws, mat, pool);
|
||||
mat.set_scale(1.9f); // Arbitrary value, different from 1.
|
||||
CompressScaled(raw_mat.get(), kNum, ws, *mat, pool);
|
||||
mat->set_scale(1.9f); // Arbitrary value, different from 1.
|
||||
return mat;
|
||||
}
|
||||
|
||||
template <size_t length>
|
||||
hwy::AlignedFreeUniquePtr<float[]> GenerateVec(size_t offset) {
|
||||
hwy::AlignedFreeUniquePtr<float[]> vec = hwy::AllocateAligned<float>(length);
|
||||
FloatPtr GenerateVec(size_t offset) {
|
||||
FloatPtr vec = hwy::AllocateAligned<float>(length);
|
||||
HWY_ASSERT(vec);
|
||||
for (size_t idx = 0; idx < length; idx++) {
|
||||
vec[idx] = static_cast<float>(idx + offset);
|
||||
|
|
@ -97,8 +97,7 @@ hwy::AlignedFreeUniquePtr<float[]> GenerateVec(size_t offset) {
|
|||
}
|
||||
|
||||
template <size_t length>
|
||||
void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
|
||||
const hwy::AlignedFreeUniquePtr<float[]>& b) {
|
||||
void AssertClose(const FloatPtr& a, const FloatPtr& b) {
|
||||
for (size_t idx = 0; idx < length; idx++) {
|
||||
const float rel_abs_delta = std::abs(a[idx] - b[idx]) /
|
||||
std::max(std::abs(a[idx]), std::abs(b[idx]));
|
||||
|
|
@ -111,16 +110,13 @@ void TestMatVecAdd() {
|
|||
hwy::ThreadPool pool(hwy::ThreadPool::MaxThreads());
|
||||
constexpr size_t kOuter = 128 * 3;
|
||||
constexpr size_t kInner = 128 * 5;
|
||||
CompressedArray<float, kOuter * kInner> mat =
|
||||
GenerateMat<float, kOuter, kInner>(0, pool);
|
||||
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> add = GenerateVec<kOuter>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> expected_out =
|
||||
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add);
|
||||
hwy::AlignedFreeUniquePtr<float[]> actual_out =
|
||||
hwy::AllocateAligned<float>(kOuter);
|
||||
auto mat = GenerateMat<float, kOuter, kInner>(0, pool);
|
||||
FloatPtr vec = GenerateVec<kInner>(0);
|
||||
FloatPtr add = GenerateVec<kOuter>(0);
|
||||
FloatPtr expected_out = SimpleMatVecAdd<kOuter, kInner>(*mat, vec, add);
|
||||
FloatPtr actual_out = hwy::AllocateAligned<float>(kOuter);
|
||||
HWY_ASSERT(vec && add && expected_out && actual_out);
|
||||
MatVecAdd<kOuter, kInner>(mat, 0, vec.get(), add.get(), actual_out.get(),
|
||||
MatVecAdd<kOuter, kInner>(*mat, 0, vec.get(), add.get(), actual_out.get(),
|
||||
pool);
|
||||
AssertClose<kOuter>(actual_out, expected_out);
|
||||
}
|
||||
|
|
@ -129,25 +125,20 @@ void TestTwoMatVecAdd() {
|
|||
hwy::ThreadPool pool(hwy::ThreadPool::MaxThreads());
|
||||
constexpr size_t kOuter = 128 * 3;
|
||||
constexpr size_t kInner = 128 * 5;
|
||||
CompressedArray<float, kOuter * kInner> mat0 =
|
||||
GenerateMat<float, kOuter, kInner>(0, pool);
|
||||
CompressedArray<float, kOuter * kInner> mat1 =
|
||||
GenerateMat<float, kOuter, kInner>(1, pool);
|
||||
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> add0 = GenerateVec<kOuter>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> add1 = GenerateVec<kOuter>(1);
|
||||
hwy::AlignedFreeUniquePtr<float[]> expected_out0 =
|
||||
SimpleMatVecAdd<kOuter, kInner>(mat0, vec, add0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> expected_out1 =
|
||||
SimpleMatVecAdd<kOuter, kInner>(mat1, vec, add1);
|
||||
hwy::AlignedFreeUniquePtr<float[]> actual_out0 =
|
||||
hwy::AllocateAligned<float>(kOuter);
|
||||
hwy::AlignedFreeUniquePtr<float[]> actual_out1 =
|
||||
hwy::AllocateAligned<float>(kOuter);
|
||||
auto mat0 = GenerateMat<float, kOuter, kInner>(0, pool);
|
||||
auto mat1 = GenerateMat<float, kOuter, kInner>(1, pool);
|
||||
FloatPtr vec = GenerateVec<kInner>(0);
|
||||
FloatPtr add0 = GenerateVec<kOuter>(0);
|
||||
FloatPtr add1 = GenerateVec<kOuter>(1);
|
||||
FloatPtr expected_out0 = SimpleMatVecAdd<kOuter, kInner>(*mat0, vec, add0);
|
||||
FloatPtr expected_out1 = SimpleMatVecAdd<kOuter, kInner>(*mat1, vec, add1);
|
||||
FloatPtr actual_out0 = hwy::AllocateAligned<float>(kOuter);
|
||||
FloatPtr actual_out1 = hwy::AllocateAligned<float>(kOuter);
|
||||
HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 &&
|
||||
expected_out1 && actual_out1);
|
||||
TwoMatVecAdd<kOuter, kInner>(mat0, mat1, 0, vec.get(), add0.get(), add1.get(),
|
||||
actual_out0.get(), actual_out1.get(), pool);
|
||||
TwoMatVecAdd<kOuter, kInner>(*mat0, *mat1, 0, vec.get(), add0.get(),
|
||||
add1.get(), actual_out0.get(), actual_out1.get(),
|
||||
pool);
|
||||
AssertClose<kOuter>(actual_out0, expected_out0);
|
||||
AssertClose<kOuter>(actual_out1, expected_out1);
|
||||
}
|
||||
|
|
@ -156,22 +147,17 @@ void TestTwoOfsMatVecAddLoop() {
|
|||
hwy::ThreadPool pool(hwy::ThreadPool::MaxThreads());
|
||||
constexpr size_t kOuter = 128 * 3;
|
||||
constexpr size_t kInner = 128 * 5;
|
||||
CompressedArray<float, kOuter * kInner> mat =
|
||||
GenerateMat<float, kOuter, kInner>(0, pool);
|
||||
hwy::AlignedFreeUniquePtr<float[]> vec = GenerateVec<kInner>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> add0 = GenerateVec<kOuter>(0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> add1 = GenerateVec<kOuter>(1);
|
||||
hwy::AlignedFreeUniquePtr<float[]> expected_out0 =
|
||||
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add0);
|
||||
hwy::AlignedFreeUniquePtr<float[]> expected_out1 =
|
||||
SimpleMatVecAdd<kOuter, kInner>(mat, vec, add1);
|
||||
hwy::AlignedFreeUniquePtr<float[]> actual_out0 =
|
||||
hwy::AllocateAligned<float>(kOuter);
|
||||
hwy::AlignedFreeUniquePtr<float[]> actual_out1 =
|
||||
hwy::AllocateAligned<float>(kOuter);
|
||||
auto mat = GenerateMat<float, kOuter, kInner>(0, pool);
|
||||
FloatPtr vec = GenerateVec<kInner>(0);
|
||||
FloatPtr add0 = GenerateVec<kOuter>(0);
|
||||
FloatPtr add1 = GenerateVec<kOuter>(1);
|
||||
FloatPtr expected_out0 = SimpleMatVecAdd<kOuter, kInner>(*mat, vec, add0);
|
||||
FloatPtr expected_out1 = SimpleMatVecAdd<kOuter, kInner>(*mat, vec, add1);
|
||||
FloatPtr actual_out0 = hwy::AllocateAligned<float>(kOuter);
|
||||
FloatPtr actual_out1 = hwy::AllocateAligned<float>(kOuter);
|
||||
HWY_ASSERT(vec && add0 && add1 && expected_out0 && actual_out0 &&
|
||||
expected_out1 && actual_out1);
|
||||
TwoOfsMatVecAddLoop<kOuter, kInner>(mat, 0, 0, vec.get(), add0.get(),
|
||||
TwoOfsMatVecAddLoop<kOuter, kInner>(*mat, 0, 0, vec.get(), add0.get(),
|
||||
add1.get(), actual_out0.get(),
|
||||
actual_out1.get());
|
||||
AssertClose<kOuter>(actual_out0, expected_out0);
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "compression/compress.h" // IWYU pragma: keep, b/conditionally used
|
||||
#include "ops/matmul.h" // IWYU pragma: export
|
||||
|
|
@ -57,9 +58,9 @@ constexpr size_t kRegRows = kRegCols;
|
|||
// at a time. Any combination of A and B can be bf16: activations may already be
|
||||
// bf16, and weights can be decompressed to bf16.
|
||||
//
|
||||
// The corresponding op is `ReordenWidenMulAccumulate`, and it is always
|
||||
// The corresponding op is `ReorderWidenMulAccumulate`, and it is always
|
||||
// supported, but only useful if it returns a single vector of pairwise sums
|
||||
// `a[0] * b[0] + a[1] * b[1]`. On other targets, `ReordenWidenMulAccumulate`
|
||||
// `a[0] * b[0] + a[1] * b[1]`. On other targets, `ReorderWidenMulAccumulate`
|
||||
// insteads return `a[1] * b[1]` in its `sum1` output. We cannot afford to keep
|
||||
// a `sum1` for each of the `kRegRows * kRegCols` C vectors, and it would be
|
||||
// expensive to add each `sum0` and `sum1`, hence we only 'decompress' A and B
|
||||
|
|
@ -73,20 +74,22 @@ using MulT = hwy::If<HWY_NATIVE_DOT_BF16, BF16, float>;
|
|||
template <size_t kRow, typename MatTB>
|
||||
class BRow {
|
||||
static_assert(kRow < kRegRows); // which unrolled instance we are
|
||||
using TraitsB = CompressTraits<MatTB>;
|
||||
|
||||
public:
|
||||
BRow(const Mat<const MatTB>& B, size_t row_b)
|
||||
: B_(B.ptr), B_ofs_(B.Row(row_b + kRow)) {}
|
||||
BRow(const Mat<const MatTB>& B, size_t row_b, size_t cols_c)
|
||||
// B.cols * C.cols is the total number of elements, required for
|
||||
// PackedSpan::BoundsCheck.
|
||||
: B_(MakeSpan(B.ptr, B.ofs + B.cols * cols_c)),
|
||||
B_ofs_(B.Row(row_b + kRow)) {}
|
||||
|
||||
template <class DM, class VM = hn::Vec<DM>>
|
||||
HWY_INLINE void Load2(DM d, size_t col_ab, VM& b0, VM& b1) const {
|
||||
static_assert(hwy::IsSame<hn::TFromD<DM>, MulT>());
|
||||
TraitsB::Decompress2(d, B_, B_ofs_ + col_ab, b0, b1);
|
||||
Decompress2(d, B_, B_ofs_ + col_ab, b0, b1);
|
||||
}
|
||||
|
||||
private:
|
||||
const MatTB* HWY_RESTRICT B_;
|
||||
PackedSpan<const MatTB> B_;
|
||||
const size_t B_ofs_;
|
||||
};
|
||||
|
||||
|
|
@ -101,7 +104,7 @@ class BRow {
|
|||
// `AddHorizontalSums`. Most MatMul instead broadcast one element from A and
|
||||
// multiply with one element from N columns in B to obtain N columns of C.
|
||||
// This is a poor fit for our setting:
|
||||
// - `CompressTraits` decompresses two vectors at a time;
|
||||
// - `Decompress2` decompresses two vectors at a time;
|
||||
// - B is column-major, so unit-stride SIMD loads return a column, not values
|
||||
// from different columns, i.e. a row.
|
||||
// Both could be fixed in a packing stage, which is not implemented yet, and
|
||||
|
|
@ -113,11 +116,13 @@ class BRow {
|
|||
template <size_t kRow, typename MatTA>
|
||||
class ALoadAccumulate {
|
||||
static_assert(kRow < kRegRows); // which unrolled instance we are
|
||||
using TraitsA = CompressTraits<MatTA>;
|
||||
|
||||
public:
|
||||
ALoadAccumulate(const Mat<const MatTA>& A, size_t row_ac)
|
||||
: A_(A.ptr), A_ofs_(A.Row(row_ac + kRow)) {}
|
||||
ALoadAccumulate(const Mat<const MatTA>& A, size_t row_ac, size_t batch_size)
|
||||
// A.cols * batch_size is the total number of elements, required for
|
||||
// PackedSpan::BoundsCheck.
|
||||
: A_(MakeSpan(A.ptr, A.ofs + A.cols * batch_size)),
|
||||
A_ofs_(A.Row(row_ac + kRow)) {}
|
||||
|
||||
// First iteration, col_ab = 0: initialize C0..3 instead of updating them.
|
||||
template <size_t kNumRows, class DM, class VM = hn::Vec<DM>, HWY_IF_F32_D(DM)>
|
||||
|
|
@ -128,7 +133,7 @@ class ALoadAccumulate {
|
|||
static_assert(kNumRows <= kRegRows); // How many rows actually present
|
||||
if constexpr (kRow < kNumRows) {
|
||||
VM a0, a1;
|
||||
TraitsA::Decompress2(dm, A_, A_ofs_, a0, a1);
|
||||
Decompress2(dm, A_, A_ofs_, a0, a1);
|
||||
|
||||
static_assert(kRegCols == 4);
|
||||
C0 = hn::Mul(a0, b00);
|
||||
|
|
@ -153,7 +158,7 @@ class ALoadAccumulate {
|
|||
static_assert(kNumRows <= kRegRows); // How many rows actually present
|
||||
if constexpr (kRow < kNumRows) {
|
||||
VM a0, a1;
|
||||
TraitsA::Decompress2(dm, A_, A_ofs_, a0, a1);
|
||||
Decompress2(dm, A_, A_ofs_, a0, a1);
|
||||
|
||||
const DF df;
|
||||
VF unused_sum1 = hn::Zero(df);
|
||||
|
|
@ -183,7 +188,7 @@ class ALoadAccumulate {
|
|||
HWY_DASSERT(col_ab >= 2 * hn::Lanes(dm)); // Should not be first iteration.
|
||||
if constexpr (kRow < kNumRows) {
|
||||
VM a0, a1;
|
||||
TraitsA::Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1);
|
||||
Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1);
|
||||
|
||||
static_assert(kRegCols == 4);
|
||||
C0 = hn::MulAdd(a0, b00, C0);
|
||||
|
|
@ -209,7 +214,7 @@ class ALoadAccumulate {
|
|||
HWY_DASSERT(col_ab >= 2 * hn::Lanes(dm)); // Should not be first iteration.
|
||||
if constexpr (kRow < kNumRows) {
|
||||
VM a0, a1;
|
||||
TraitsA::Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1);
|
||||
Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1);
|
||||
|
||||
const DF df;
|
||||
hn::Vec<DF> unused_sum1 = hn::Zero(df);
|
||||
|
|
@ -230,7 +235,7 @@ class ALoadAccumulate {
|
|||
}
|
||||
|
||||
private:
|
||||
const MatTA* HWY_RESTRICT A_;
|
||||
PackedSpan<const MatTA> A_;
|
||||
const size_t A_ofs_;
|
||||
}; // ALoadAccumulate
|
||||
|
||||
|
|
@ -352,9 +357,10 @@ class AddHorizontalSums {
|
|||
// *finished* tile of f32 `C` whose top left is (row_ac, row_b_col_c).
|
||||
// TODO: loop over sections instead of full rows and accumulate into `tile_c`.
|
||||
template <size_t kNumRows, bool kAdd, typename MatTA, typename MatTB>
|
||||
HWY_INLINE void MatMulTile(const Mat<const MatTA>& A, const Mat<const MatTB>& B,
|
||||
const size_t row_ac, const size_t row_b_col_c,
|
||||
const float scale, const float* HWY_RESTRICT add,
|
||||
HWY_INLINE void MatMulTile(const size_t batch_size, const Mat<const MatTA>& A,
|
||||
const Mat<const MatTB>& B, const size_t row_ac,
|
||||
const size_t row_b_col_c, const float scale,
|
||||
const float* HWY_RESTRICT add,
|
||||
float* HWY_RESTRICT buf, const Mat<float>& C) {
|
||||
// For 'decompressing' A and B into BF16 or float.
|
||||
const hn::ScalableTag<MulT> dm;
|
||||
|
|
@ -362,15 +368,15 @@ HWY_INLINE void MatMulTile(const Mat<const MatTA>& A, const Mat<const MatTB>& B,
|
|||
const size_t NM = hn::Lanes(dm);
|
||||
|
||||
static_assert(kRegRows == 4);
|
||||
const BRow<0, MatTB> b_row0(B, row_b_col_c);
|
||||
const BRow<1, MatTB> b_row1(B, row_b_col_c);
|
||||
const BRow<2, MatTB> b_row2(B, row_b_col_c);
|
||||
const BRow<3, MatTB> b_row3(B, row_b_col_c);
|
||||
const BRow<0, MatTB> b_row0(B, row_b_col_c, C.cols);
|
||||
const BRow<1, MatTB> b_row1(B, row_b_col_c, C.cols);
|
||||
const BRow<2, MatTB> b_row2(B, row_b_col_c, C.cols);
|
||||
const BRow<3, MatTB> b_row3(B, row_b_col_c, C.cols);
|
||||
|
||||
const ALoadAccumulate<0, MatTA> a_row0(A, row_ac);
|
||||
const ALoadAccumulate<1, MatTA> a_row1(A, row_ac);
|
||||
const ALoadAccumulate<2, MatTA> a_row2(A, row_ac);
|
||||
const ALoadAccumulate<3, MatTA> a_row3(A, row_ac);
|
||||
const ALoadAccumulate<0, MatTA> a_row0(A, row_ac, batch_size);
|
||||
const ALoadAccumulate<1, MatTA> a_row1(A, row_ac, batch_size);
|
||||
const ALoadAccumulate<2, MatTA> a_row2(A, row_ac, batch_size);
|
||||
const ALoadAccumulate<3, MatTA> a_row3(A, row_ac, batch_size);
|
||||
|
||||
const hn::Repartition<float, decltype(dm)> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
|
|
@ -475,16 +481,20 @@ HWY_NOINLINE void MatMul(const size_t batch_size, const Mat<const MatTA>& A,
|
|||
HWY_DASSERT(num_rows != 0);
|
||||
switch (num_rows) {
|
||||
case 1:
|
||||
MatMulTile<1, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C);
|
||||
MatMulTile<1, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale,
|
||||
add, buf, C);
|
||||
break;
|
||||
case 2:
|
||||
MatMulTile<2, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C);
|
||||
MatMulTile<2, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale,
|
||||
add, buf, C);
|
||||
break;
|
||||
case 3:
|
||||
MatMulTile<3, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C);
|
||||
MatMulTile<3, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale,
|
||||
add, buf, C);
|
||||
break;
|
||||
default:
|
||||
MatMulTile<4, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C);
|
||||
MatMulTile<4, kAdd>(batch_size, A, B, row_ac, row_b_col_c, scale,
|
||||
add, buf, C);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,16 +45,17 @@ HWY_BEFORE_NAMESPACE();
|
|||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
|
||||
|
||||
// Generates inputs: deterministic, within max SfpStream range.
|
||||
template <typename MatT, size_t kRows, size_t kCols>
|
||||
std::unique_ptr<CompressedArray<MatT, kRows * kCols>> GenerateMatHeap(
|
||||
size_t offset, hwy::ThreadPool& pool) {
|
||||
template <typename MatT, size_t kRows, size_t kCols,
|
||||
size_t kNum = kRows * kCols,
|
||||
class MatPtr = std::unique_ptr<CompressedArray<MatT, kNum>>>
|
||||
MatPtr GenerateMatHeap(size_t offset, hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
hwy::AlignedFreeUniquePtr<float[]> content =
|
||||
hwy::AllocateAligned<float>(kRows * kCols);
|
||||
const float scale = 1.875f / (kCols * kRows + offset);
|
||||
FloatPtr content = hwy::AllocateAligned<float>(kNum);
|
||||
HWY_ASSERT(content);
|
||||
const float scale = SfpStream::kMax / (kNum + offset);
|
||||
pool.Run(0, kRows, [&](const size_t i, size_t /*thread*/) {
|
||||
for (size_t j = 0; j < kCols; j++) {
|
||||
content[i * kCols + j] =
|
||||
|
|
@ -62,21 +63,19 @@ std::unique_ptr<CompressedArray<MatT, kRows * kCols>> GenerateMatHeap(
|
|||
}
|
||||
});
|
||||
|
||||
std::unique_ptr<CompressedArray<MatT, kRows * kCols>> mat =
|
||||
std::make_unique<CompressedArray<MatT, kRows * kCols>>();
|
||||
Compress(content.get(), kRows * kCols, ws, kRows * kCols, mat->data(), 0,
|
||||
pool);
|
||||
MatPtr mat = std::make_unique<CompressedArray<MatT, kNum>>();
|
||||
CompressScaled(content.get(), kNum, ws, *mat, pool);
|
||||
mat->set_scale(0.6f); // Arbitrary value, different from 1.
|
||||
return mat;
|
||||
}
|
||||
|
||||
template <typename MatT, size_t kRows, size_t kCols>
|
||||
std::unique_ptr<CompressedArray<MatT, kRows * kCols>> GenerateTransposeMatHeap(
|
||||
size_t offset, hwy::ThreadPool& pool) {
|
||||
template <typename MatT, size_t kRows, size_t kCols,
|
||||
size_t kNum = kRows * kCols,
|
||||
class MatPtr = std::unique_ptr<CompressedArray<MatT, kNum>>>
|
||||
MatPtr GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
hwy::AlignedFreeUniquePtr<float[]> content =
|
||||
hwy::AllocateAligned<float>(kRows * kCols);
|
||||
const float scale = 1.875f / (kCols * kRows + offset);
|
||||
FloatPtr content = hwy::AllocateAligned<float>(kNum);
|
||||
const float scale = SfpStream::kMax / (kNum + offset);
|
||||
pool.Run(0, kRows, [&](const size_t i, size_t /*thread*/) {
|
||||
for (size_t j = 0; j < kCols; j++) {
|
||||
content[j * kRows + i] =
|
||||
|
|
@ -84,42 +83,31 @@ std::unique_ptr<CompressedArray<MatT, kRows * kCols>> GenerateTransposeMatHeap(
|
|||
}
|
||||
});
|
||||
|
||||
std::unique_ptr<CompressedArray<MatT, kRows * kCols>> mat =
|
||||
std::make_unique<CompressedArray<MatT, kRows * kCols>>();
|
||||
Compress(content.get(), kRows * kCols, ws, kRows * kCols, mat->data(), 0,
|
||||
pool);
|
||||
MatPtr mat = std::make_unique<CompressedArray<MatT, kNum>>();
|
||||
CompressScaled(content.get(), kNum, ws, *mat, pool);
|
||||
// Arbitrary value, different from 1, must match GenerateMatHeap.
|
||||
mat->set_scale(0.6f);
|
||||
return mat;
|
||||
}
|
||||
|
||||
template <typename MatT, size_t kRows, size_t kCols>
|
||||
std::unique_ptr<CompressedArray<MatT, kRows * kCols>> GenerateZeroMatHeap(
|
||||
hwy::ThreadPool& pool) {
|
||||
template <typename MatT, size_t kRows, size_t kCols,
|
||||
size_t kNum = kRows * kCols,
|
||||
class MatPtr = std::unique_ptr<CompressedArray<MatT, kNum>>>
|
||||
MatPtr GenerateZeroMatHeap(hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
hwy::AlignedFreeUniquePtr<float[]> content =
|
||||
hwy::AllocateAligned<float>(kRows * kCols);
|
||||
FloatPtr content = hwy::AllocateAligned<float>(kNum);
|
||||
HWY_ASSERT(content);
|
||||
|
||||
pool.Run(0, kRows, [&](const size_t i, size_t thread) {
|
||||
hwy::ZeroBytes(&content[i * kCols], kCols * sizeof(content[0]));
|
||||
});
|
||||
|
||||
std::unique_ptr<CompressedArray<MatT, kRows * kCols>> mat =
|
||||
std::make_unique<CompressedArray<MatT, kRows * kCols>>();
|
||||
Compress(content.get(), kRows * kCols, ws, kRows * kCols, mat->data(), 0,
|
||||
pool);
|
||||
MatPtr mat = std::make_unique<CompressedArray<MatT, kNum>>();
|
||||
CompressScaled(content.get(), kNum, ws, *mat, pool);
|
||||
mat->set_scale(1.2f); // Arbitrary value, different from 1.
|
||||
return mat;
|
||||
}
|
||||
|
||||
template <typename MatT>
|
||||
void Decompress(const MatT* compressed, size_t num, float* out) {
|
||||
const hn::ScalableTag<float> d;
|
||||
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(num);
|
||||
CompressTraits<MatT>::Decompress(d, /*in_capacity=*/0, compressed, 0, out,
|
||||
num);
|
||||
}
|
||||
|
||||
// Returns 1-norm, used for estimating tolerable numerical differences.
|
||||
double MaxColAbsSum(const float* HWY_RESTRICT a, size_t rows, size_t cols) {
|
||||
double max_col_abs_sum = 0.0;
|
||||
|
|
@ -135,18 +123,21 @@ double MaxColAbsSum(const float* HWY_RESTRICT a, size_t rows, size_t cols) {
|
|||
|
||||
template <typename MatTA, typename MatTB>
|
||||
void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b,
|
||||
const MatTA* HWY_RESTRICT a_compr,
|
||||
const MatTB* HWY_RESTRICT b_trans_compr,
|
||||
const MatTA* HWY_RESTRICT pa,
|
||||
const MatTB* HWY_RESTRICT pb_trans,
|
||||
const float* HWY_RESTRICT expected_c,
|
||||
const float* HWY_RESTRICT actual_c) {
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t num_a = rows_ac * cols_ab;
|
||||
const size_t num_b = cols_c_rows_b * cols_ab;
|
||||
HWY_ASSERT(num_a % hn::Lanes(df) == 0); // for DecompressAndZeroPad
|
||||
HWY_ASSERT(num_b % hn::Lanes(df) == 0); // for DecompressAndZeroPad
|
||||
const size_t num_c = rows_ac * cols_c_rows_b;
|
||||
hwy::AlignedFreeUniquePtr<float[]> a = hwy::AllocateAligned<float>(num_a);
|
||||
hwy::AlignedFreeUniquePtr<float[]> b_trans =
|
||||
hwy::AllocateAligned<float>(num_b);
|
||||
Decompress(a_compr, num_a, a.get());
|
||||
Decompress(b_trans_compr, num_b, b_trans.get());
|
||||
FloatPtr a = hwy::AllocateAligned<float>(num_a);
|
||||
FloatPtr b_trans = hwy::AllocateAligned<float>(num_b);
|
||||
HWY_ASSERT(a && b_trans);
|
||||
DecompressAndZeroPad(df, MakeSpan(pa, num_a), 0, a.get(), num_a);
|
||||
DecompressAndZeroPad(df, MakeSpan(pb_trans, num_b), 0, b_trans.get(), num_b);
|
||||
|
||||
const double norm = MaxColAbsSum(a.get(), rows_ac, cols_ab) *
|
||||
MaxColAbsSum(b_trans.get(), cols_c_rows_b, cols_ab);
|
||||
|
|
@ -196,38 +187,37 @@ HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
|
|||
const MatTA* HWY_RESTRICT a,
|
||||
const MatTB* HWY_RESTRICT b_compr, const float scale,
|
||||
const float* add, float* HWY_RESTRICT out) {
|
||||
const hn::ScalableTag<float> d;
|
||||
hwy::AlignedFreeUniquePtr<float[]> b =
|
||||
hwy::AllocateAligned<float>(cols_a_rows_b * cols_bc);
|
||||
CompressTraits<MatTB>::Decompress(d, /*in_capacity=*/0, b_compr, 0, b.get(),
|
||||
cols_a_rows_b * cols_bc);
|
||||
const size_t num_b = cols_a_rows_b * cols_bc;
|
||||
FloatPtr b = hwy::AllocateAligned<float>(num_b);
|
||||
HWY_ASSERT(b);
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, MakeSpan(b_compr, num_b), 0, b.get(), num_b);
|
||||
MatMulSlow(rows_ac, cols_a_rows_b, cols_bc, a, b.get(), scale, add, out);
|
||||
}
|
||||
|
||||
void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b,
|
||||
size_t cols_bc, double elapsed) {
|
||||
const size_t num_b = cols_a_rows_b * cols_bc;
|
||||
// 2x because of FMA.
|
||||
fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo,
|
||||
elapsed, 2 * 1E-9 * rows_ac * cols_a_rows_b * cols_bc / elapsed);
|
||||
elapsed, 2 * 1E-9 * rows_ac * num_b / elapsed);
|
||||
}
|
||||
|
||||
template <size_t kRowsAC, size_t kColsARowsB, size_t kColsBC, bool kAdd,
|
||||
typename MatTA, typename MatTB = MatTA>
|
||||
void TestMatMul(MatMulEnv& env) {
|
||||
hwy::ThreadPool& pool = env.Pool();
|
||||
using TraitsA = CompressTraits<MatTA>;
|
||||
using TraitsB = CompressTraits<MatTB>;
|
||||
const bool want_bench = kColsBC > 2000; // avoid spam for small matrices
|
||||
fprintf(stderr, "TestMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
|
||||
kRowsAC, kColsARowsB, kColsBC, kAdd, TraitsA::Name(),
|
||||
TraitsB::Name());
|
||||
kRowsAC, kColsARowsB, kColsBC, kAdd, TypeName<MatTA>(),
|
||||
TypeName<MatTB>());
|
||||
|
||||
std::unique_ptr<CompressedArray<MatTA, kRowsAC * kColsARowsB>> a =
|
||||
GenerateMatHeap<MatTA, kRowsAC, kColsARowsB>(0, pool);
|
||||
std::unique_ptr<CompressedArray<MatTB, kColsARowsB * kColsBC>> b_trans =
|
||||
GenerateTransposeMatHeap<MatTB, kColsARowsB, kColsBC>(0, pool);
|
||||
hwy::AlignedFreeUniquePtr<float[]> c =
|
||||
hwy::AllocateAligned<float>(kRowsAC * kColsBC);
|
||||
FloatPtr c = hwy::AllocateAligned<float>(kRowsAC * kColsBC);
|
||||
HWY_ASSERT(c);
|
||||
|
||||
const float scale = a->scale() * b_trans->scale();
|
||||
std::unique_ptr<CompressedArray<float, kColsBC>> add;
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@
|
|||
|
||||
#include "compression/compress-inl.h"
|
||||
#include "ops/dot-inl.h"
|
||||
#include "hwy/contrib/dot/dot-inl.h"
|
||||
#include "hwy/contrib/math/math-inl.h"
|
||||
#include "hwy/contrib/matvec/matvec-inl.h"
|
||||
|
||||
|
|
@ -58,15 +57,14 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
|
|||
float* HWY_RESTRICT out0,
|
||||
float* HWY_RESTRICT out1) {
|
||||
PROFILER_ZONE("TwoOfsMatVecAddLoop");
|
||||
const hn::ScalableTag<float> df;
|
||||
|
||||
for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) {
|
||||
const size_t row_ofs0 = mat_ofs0 + (idx_row)*kInner;
|
||||
const size_t row_ofs1 = mat_ofs1 + (idx_row)*kInner;
|
||||
out0[idx_row] = hwy::ConvertScalarTo<float>(add0[idx_row]) +
|
||||
Dot<false>(df, mat, row_ofs0, vec_aligned, kInner);
|
||||
Dot(mat, row_ofs0, vec_aligned, kInner);
|
||||
out1[idx_row] = hwy::ConvertScalarTo<float>(add1[idx_row]) +
|
||||
Dot<false>(df, mat, row_ofs1, vec_aligned, kInner);
|
||||
Dot(mat, row_ofs1, vec_aligned, kInner);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -98,8 +96,7 @@ HWY_INLINE void AccumulatePartialDotProducts(
|
|||
const VecT* HWY_RESTRICT vec_aligned, float* HWY_RESTRICT out) {
|
||||
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
|
||||
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
|
||||
out[idx_row] +=
|
||||
Dot<false>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||
out[idx_row] += Dot(mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -117,12 +114,10 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
|
|||
for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) {
|
||||
const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride;
|
||||
if constexpr (kInit) {
|
||||
out[idx_row] =
|
||||
hwy::ConvertScalarTo<float>(init[idx_row + r0]) +
|
||||
Dot<false>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||
out[idx_row] = hwy::ConvertScalarTo<float>(init[idx_row + r0]) +
|
||||
Dot(mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||
} else {
|
||||
out[idx_row] =
|
||||
Dot<false>(df, mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||
out[idx_row] = Dot(mat, row_ofs + c0, vec_aligned + c0, num_cols);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@
|
|||
#include <random>
|
||||
#include <type_traits> // std::enable_if_t
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/detect_targets.h"
|
||||
|
|
@ -41,8 +42,8 @@
|
|||
#endif
|
||||
|
||||
#include "compression/compress-inl.h"
|
||||
#include "ops/dot-inl.h"
|
||||
#include "hwy/contrib/algo/transform-inl.h"
|
||||
#include "hwy/contrib/dot/dot-inl.h"
|
||||
#include "hwy/contrib/math/math-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
|
|
@ -54,11 +55,12 @@ template <typename To, typename From>
|
|||
HWY_INLINE constexpr std::enable_if_t<
|
||||
std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To>
|
||||
StaticCast(From from) noexcept {
|
||||
if constexpr (std::is_unsigned_v<From> && std::is_floating_point_v<To>)
|
||||
if constexpr (std::is_unsigned_v<From> && std::is_floating_point_v<To>) {
|
||||
return static_cast<To>(
|
||||
static_cast<hwy::SignedFromSize<sizeof(From)>>(from));
|
||||
else
|
||||
} else {
|
||||
return static_cast<To>(from);
|
||||
}
|
||||
}
|
||||
|
||||
template <class D, HWY_IF_F32_D(D)>
|
||||
|
|
@ -136,48 +138,13 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Sigmoid(float* HWY_RESTRICT x,
|
|||
[](D d, hn::Vec<D> v) HWY_ATTR { return Sigmoid(d, v); });
|
||||
}
|
||||
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a,
|
||||
const float* HWY_RESTRICT b,
|
||||
size_t size) {
|
||||
PROFILER_ZONE("ops.Dot");
|
||||
const hn::ScalableTag<float> d;
|
||||
HWY_DASSERT(size >= hn::Lanes(d));
|
||||
HWY_DASSERT(size % hn::Lanes(d) == 0);
|
||||
constexpr int kAssumptions =
|
||||
hn::Dot::kAtLeastOneVector | hn::Dot::kMultipleOfVector;
|
||||
return hn::Dot::Compute<kAssumptions>(d, a, b, size);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
// = Dot(a, a, size), but that is not allowed due to HWY_RESTRICT.
|
||||
template <typename VecT>
|
||||
float SquaredL2(const VecT* HWY_RESTRICT a, size_t size) {
|
||||
using TraitsV = CompressTraits<VecT>;
|
||||
|
||||
const hn::ScalableTag<float> d;
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
const size_t N = hn::Lanes(d);
|
||||
HWY_DASSERT(size >= 2 * N);
|
||||
HWY_DASSERT(size % (2 * N) == 0);
|
||||
|
||||
// TODO: use more accurate Dot
|
||||
V sum0 = hn::Zero(d);
|
||||
V sum1 = hn::Zero(d);
|
||||
for (size_t i = 0; i <= size - 2 * N; i += 2 * N) {
|
||||
V a0, a1;
|
||||
TraitsV::Decompress2(d, a, i, a0, a1);
|
||||
sum0 = hn::MulAdd(a0, a0, sum0);
|
||||
sum1 = hn::MulAdd(a1, a1, sum1);
|
||||
}
|
||||
|
||||
return hn::ReduceSum(d, hn::Add(sum0, sum1));
|
||||
}
|
||||
|
||||
// Shared by RMSNorm and RMSNormInplace.
|
||||
template <typename VecT>
|
||||
float RMSNormMul(const VecT* HWY_RESTRICT x, size_t size) {
|
||||
const float l2 = SquaredL2(x, size);
|
||||
const hn::ScalableTag<float> df;
|
||||
const float l2 = DecompressAndCall(df, x, size, DotKernelCompensated());
|
||||
constexpr float kEps = 1e-6f; // avoid divide by zero
|
||||
return 1.0f / sqrtf(l2 / StaticCast<float>(size) + kEps);
|
||||
}
|
||||
|
|
@ -191,9 +158,6 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const VecT* HWY_RESTRICT x,
|
|||
const size_t size) {
|
||||
PROFILER_FUNC;
|
||||
|
||||
using TraitsV = CompressTraits<VecT>;
|
||||
using TraitsW = CompressTraits<WeightT>;
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
|
|
@ -201,17 +165,21 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const VecT* HWY_RESTRICT x,
|
|||
|
||||
const VF mul = hn::Set(df, detail::RMSNormMul(x, size));
|
||||
|
||||
const auto packed_w = MakeSpan(weight, size);
|
||||
const auto packed_v = MakeSpan(x, size);
|
||||
const auto packed_out = MakeSpan(out, size);
|
||||
|
||||
HWY_DASSERT(size % (2 * MaxLanes(df)) == 0);
|
||||
for (size_t i = 0; i < size; i += 2 * NF) {
|
||||
VF v0, v1, w0, w1;
|
||||
TraitsV::Decompress2(df, x, i, v0, v1);
|
||||
TraitsW::Decompress2(df, weight, i, w0, w1);
|
||||
Decompress2(df, packed_v, i, v0, v1);
|
||||
Decompress2(df, packed_w, i, w0, w1);
|
||||
const VF m0 = hn::Mul(mul, v0);
|
||||
const VF m1 = hn::Mul(mul, v1);
|
||||
// (1+weight) * m = m + weight*m = one FMA.
|
||||
const VF out0 = hn::MulAdd(m0, w0, m0);
|
||||
const VF out1 = hn::MulAdd(m1, w1, m1);
|
||||
detail::Store2(df, out0, out1, out + i);
|
||||
Compress2(df, out0, out1, packed_out, i);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -222,9 +190,6 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
|||
const size_t size) {
|
||||
PROFILER_FUNC;
|
||||
|
||||
using TraitsV = CompressTraits<VecT>;
|
||||
using TraitsW = CompressTraits<WeightT>;
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
const hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
|
|
@ -232,17 +197,20 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
|||
|
||||
const VF mul = hn::Set(df, detail::RMSNormMul(inout, size));
|
||||
|
||||
const auto packed_w = MakeSpan(weight, size);
|
||||
const auto packed_v = MakeSpan(inout, size);
|
||||
|
||||
HWY_DASSERT(size % (2 * MaxLanes(df)) == 0);
|
||||
for (size_t i = 0; i < size; i += 2 * NF) {
|
||||
VF v0, v1, w0, w1;
|
||||
TraitsV::Decompress2(df, inout, i, v0, v1);
|
||||
TraitsW::Decompress2(df, weight, i, w0, w1);
|
||||
const VF m0 = hn::Mul(mul, hn::LoadU(df, inout + i));
|
||||
const VF m1 = hn::Mul(mul, hn::LoadU(df, inout + i + NF));
|
||||
Decompress2(df, MakeConst(packed_v), i, v0, v1);
|
||||
Decompress2(df, packed_w, i, w0, w1);
|
||||
const VF m0 = hn::Mul(mul, v0);
|
||||
const VF m1 = hn::Mul(mul, v1);
|
||||
// (1+weight) * m = m + weight*m = one FMA.
|
||||
const VF out0 = hn::MulAdd(m0, w0, m0);
|
||||
const VF out1 = hn::MulAdd(m1, w1, m1);
|
||||
detail::Store2(df, out0, out1, inout + i);
|
||||
Compress2(df, out0, out1, packed_v, i);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -486,9 +454,9 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
|||
const V vmin = hn::Set(d, hwy::LowestValue<float>());
|
||||
V vmax = vmin;
|
||||
V* pmax = &vmax; // workaround for SVE: cannot capture &vector directly
|
||||
Foreach(d, x, mask_pos, vmin, [pmax](const auto d, const V value) HWY_ATTR {
|
||||
*pmax = hn::Max(*pmax, value);
|
||||
});
|
||||
hn::Foreach(d, x, mask_pos, vmin,
|
||||
[pmax](const auto d, const V value)
|
||||
HWY_ATTR { *pmax = hn::Max(*pmax, value); });
|
||||
vmax = hn::MaxOfLanes(d, vmax);
|
||||
|
||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||
|
|
@ -504,9 +472,9 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
|||
|
||||
V sum = hn::Zero(d);
|
||||
V* psum = ∑
|
||||
Foreach(d, x, mask_pos, sum, [psum](const auto d, const V value) HWY_ATTR {
|
||||
*psum = hn::Add(*psum, value);
|
||||
});
|
||||
hn::Foreach(d, x, mask_pos, sum,
|
||||
[psum](const auto d, const V value)
|
||||
HWY_ATTR { *psum = hn::Add(*psum, value); });
|
||||
|
||||
// Normalize to probability distribution
|
||||
const float mul = 1.0f / hn::ReduceSum(d, sum);
|
||||
|
|
|
|||
|
|
@ -480,8 +480,8 @@ void TestRMSNorm(hwy::RandomState& rng) {
|
|||
const float e = hwy::ConvertScalarTo<float>(expected[i]);
|
||||
const float a = hwy::ConvertScalarTo<float>(actual[i]);
|
||||
if (!IsNear(e, a, 1e-5f)) {
|
||||
HWY_ABORT("RMSNorm %s %s %s mismatch at %zu: %E %E\n", TypeName(VecT()),
|
||||
TypeName(WeightT()), TypeName(OutT()), i, e, a);
|
||||
HWY_ABORT("RMSNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<VecT>(),
|
||||
TypeName<WeightT>(), TypeName<OutT>(), i, e, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -139,6 +139,13 @@ class PerClusterPools {
|
|||
}
|
||||
|
||||
public:
|
||||
// Move-only.
|
||||
PerClusterPools() = delete;
|
||||
PerClusterPools(const PerClusterPools&) = delete;
|
||||
PerClusterPools& operator=(const PerClusterPools&) = delete;
|
||||
PerClusterPools(PerClusterPools&&) = default;
|
||||
PerClusterPools& operator=(PerClusterPools&&) = default;
|
||||
|
||||
// PerClusterPools supports spin waits (see StartSpinning below). To prevent
|
||||
// drastic slowdowns caused by excessive user-specified thread counts, which
|
||||
// result in threads not running on their own core, we only allow for
|
||||
|
|
|
|||
Loading…
Reference in New Issue