diff --git a/BUILD.bazel b/BUILD.bazel index 2fc9e60..862555c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -48,15 +48,6 @@ cc_library( ], ) -# Avoids circular dependency: fp_arith-inl -> compress-inl -> ops-inl -cc_library( - name = "fp_arith", - textual_hdrs = ["ops/fp_arith-inl.h"], - deps = [ - "@hwy//:hwy", - ], -) - cc_library( name = "ops", hdrs = [ @@ -64,13 +55,13 @@ cc_library( ], textual_hdrs = [ "ops/dot-inl.h", + "ops/fp_arith-inl.h", "ops/matmul-inl.h", "ops/matvec-inl.h", "ops/ops-inl.h", ], deps = [ ":allocator", - ":fp_arith", ":threading", "//compression:compress", "//compression:sfp", @@ -97,14 +88,17 @@ cc_test( ":common", ":gemma_lib", ":ops", + ":test_util", ":threading", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", + "//compression:test_util", "@hwy//:hwy", "@hwy//:hwy_test_util", "@hwy//:nanobenchmark", #buildcleaner: keep "@hwy//:profiler", "@hwy//:stats", + "@hwy//:thread_pool", ], ) @@ -468,6 +462,7 @@ cc_library( ":ops", ":prompt", ":weights", + "@hwy//:dot", "@hwy//:hwy", # base.h "@hwy//:thread_pool", ], diff --git a/CMakeLists.txt b/CMakeLists.txt index 84832ff..c49bdf9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,10 +44,10 @@ set(SOURCES compression/io_win.cc compression/io.cc compression/io.h - compression/nuq.h compression/nuq-inl.h compression/sfp-inl.h compression/shared.h + compression/test_util-inl.h compression/weights_raw.h backprop/activations.h backprop/backward.cc diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index 76bd87e..df7c6fb 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -46,6 +46,7 @@ // After highway.h #include "ops/matmul-inl.h" #include "ops/ops-inl.h" +#include "hwy/contrib/dot/dot-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h index 838b042..c799cf4 100644 --- a/backprop/forward-inl.h +++ b/backprop/forward-inl.h @@ -52,11 +52,14 @@ namespace HWY_NAMESPACE { template void InputEmbedding(const ArrayT& weights, const std::vector& prompt, const float scaling, float* HWY_RESTRICT output, - size_t model_dim) { + size_t model_dim, size_t vocab_size) { + const hn::ScalableTag df; HWY_ASSERT(!prompt.empty()); for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { int token = prompt[pos]; - Decompress(weights, token * model_dim, output + pos * model_dim, model_dim); + DecompressAndZeroPad(df, MakeSpan(weights.data(), model_dim * vocab_size), + token * model_dim, output + pos * model_dim, + model_dim); MulByConst(scaling, output + pos * model_dim, model_dim); } } @@ -245,7 +248,7 @@ float CrossEntropyLossForwardPass(const std::vector& prompt, const size_t num_tokens = prompt.size() - 1; InputEmbedding(weights.embedder_input_embedding, prompt, kEmbScaling, - forward.layers[0].input.data(), kModelDim); + forward.layers[0].input.data(), kModelDim, kVocabSize); for (size_t layer = 0; layer < kLayers; ++layer) { auto type = TConfig::kLayerConfig[layer]; diff --git a/compression/BUILD b/compression/BUILD index a437a17..f3767bf 100644 --- a/compression/BUILD +++ b/compression/BUILD @@ -50,7 +50,10 @@ cc_library( cc_library( name = "distortion", - hdrs = ["distortion.h"], + hdrs = [ + "distortion.h", + "shared.h", + ], deps = [ "@hwy//:hwy", "@hwy//:stats", @@ -64,29 +67,43 @@ cc_test( srcs = ["distortion_test.cc"], deps = [ ":distortion", - ":shared", - "@googletest//:gtest_main", + "@googletest//:gtest_main", # buildcleaner: keep "//:test_util", - "@hwy//:hwy", "@hwy//:hwy_test_util", "@hwy//:nanobenchmark", # Unpredictable1 ], ) cc_library( - name = "shared", + name = "sfp", hdrs = ["shared.h"], + textual_hdrs = ["sfp-inl.h"], deps = [ "@hwy//:hwy", ], ) cc_library( - name = "sfp", - textual_hdrs = ["sfp-inl.h"], + name = "nuq", + hdrs = ["shared.h"], + textual_hdrs = ["nuq-inl.h"], deps = [ - ":shared", + ":sfp", "@hwy//:hwy", + "@hwy//hwy/contrib/sort:vqsort", + ], +) + +cc_library( + name = "test_util", + textual_hdrs = [ + "test_util-inl.h", + ], + deps = [ + ":compress", + ":distortion", + "@hwy//:hwy", + "@hwy//:hwy_test_util", ], ) @@ -102,9 +119,7 @@ cc_test( deps = [ ":distortion", ":sfp", - ":shared", - "@googletest//:gtest_main", - "//:ops", + "@googletest//:gtest_main", # buildcleaner: keep "//:test_util", "@hwy//:hwy", "@hwy//:hwy_test_util", @@ -112,18 +127,6 @@ cc_test( ], ) -cc_library( - name = "nuq", - hdrs = ["nuq.h"], - textual_hdrs = ["nuq-inl.h"], - deps = [ - ":sfp", - ":shared", - "@hwy//:hwy", - "@hwy//hwy/contrib/sort:vqsort", - ], -) - cc_test( name = "nuq_test", size = "small", @@ -138,8 +141,7 @@ cc_test( ":distortion", ":nuq", ":sfp", - ":shared", - "@googletest//:gtest_main", + "@googletest//:gtest_main", # buildcleaner: keep "//:test_util", "@hwy//:hwy", "@hwy//:hwy_test_util", @@ -149,19 +151,20 @@ cc_test( cc_library( name = "compress", - hdrs = ["compress.h"], - textual_hdrs = [ - "compress-inl.h", + hdrs = [ + "compress.h", + "shared.h", ], + textual_hdrs = ["compress-inl.h"], deps = [ ":blob_store", ":distortion", ":io", ":nuq", ":sfp", - ":shared", - "//:fp_arith", "@hwy//:hwy", + "@hwy//:nanobenchmark", + "@hwy//:profiler", "@hwy//:stats", "@hwy//:thread_pool", ], @@ -170,6 +173,7 @@ cc_library( cc_test( name = "compress_test", size = "small", + timeout = "long", srcs = ["compress_test.cc"], features = ["fully_static_link"], linkstatic = True, @@ -179,11 +183,11 @@ cc_test( deps = [ ":compress", ":distortion", - "@googletest//:gtest_main", + ":test_util", + "@googletest//:gtest_main", # buildcleaner: keep "//:test_util", "@hwy//:hwy", "@hwy//:hwy_test_util", - "@hwy//:nanobenchmark", "@hwy//:thread_pool", ], ) @@ -193,11 +197,9 @@ cc_library( name = "analyze", textual_hdrs = ["analyze.h"], deps = [ - ":distortion", ":nuq", ":sfp", "@hwy//:hwy", - "@hwy//:nanobenchmark", # timer "@hwy//:stats", "@hwy//:thread_pool", "@hwy//hwy/contrib/sort:vqsort", @@ -210,7 +212,6 @@ cc_library( deps = [ "//:allocator", "//:common", - "//compression:compress", "@hwy//:hwy", "@hwy//:thread_pool", ], @@ -221,15 +222,13 @@ cc_binary( srcs = ["compress_weights.cc"], deps = [ ":compress", - ":shared", + ":io", ":weights_raw", - # Placeholder for internal dep, do not remove., + "//:allocator", "//:args", "//:common", - "//:gemma_lib", "//:weights", "@hwy//:hwy", - "@hwy//:nanobenchmark", "@hwy//:profiler", "@hwy//:thread_pool", ], diff --git a/compression/analyze.h b/compression/analyze.h index 342f2f2..38537db 100644 --- a/compression/analyze.h +++ b/compression/analyze.h @@ -26,12 +26,10 @@ #include // std::abs #include -#include "compression/distortion.h" -#include "compression/nuq.h" +#include "compression/shared.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/stats.h" -#include "hwy/timer.h" #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_ @@ -55,6 +53,7 @@ namespace HWY_NAMESPACE { class PerThread { public: void NotifyGroup(const float* group) { + constexpr size_t kGroupSize = NuqStream::kGroupSize; hwy::Stats s_group; for (size_t i = 0; i < kGroupSize; ++i) { // Skip zero so we can see the lowest actual magnitude @@ -158,7 +157,7 @@ class PerThread { class PerLayer { public: void NotifyGroup(const float* group) { - for (size_t i = 0; i < kGroupSize; ++i) { + for (size_t i = 0; i < NuqStream::kGroupSize; ++i) { s_layer_.Notify(group[i]); } } @@ -197,8 +196,8 @@ static HWY_NOINLINE void Analyze(const char* caption, float* mat, size_t layers, const float* layer = &mat[idx_layer * weights_per_layer]; // For each whole group in the layer for (size_t group_start = 0; - group_start + kGroupSize <= weights_per_layer; - group_start += kGroupSize) { + group_start + NuqStream::kGroupSize <= weights_per_layer; + group_start += NuqStream::kGroupSize) { const float* group = layer + group_start; per_layer[idx_layer].NotifyGroup(group); self.NotifyGroup(group); @@ -210,7 +209,7 @@ static HWY_NOINLINE void Analyze(const char* caption, float* mat, size_t layers, const int skip = hwy::Stats::kNoGeomean; fprintf(stderr, "\n------------%s\n", caption); - for (size_t i = 1; i < pool.NumThreads(); ++i) { + for (size_t i = 1; i < pool.NumWorkers(); ++i) { tls[0].Assimilate(tls[i]); } tls[0].PrintAll(); diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 9579e18..e4ea1a1 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -21,7 +21,6 @@ #include #include -#include #include // lroundf, only if COMPRESS_STATS #include "compression/blob_store.h" @@ -42,133 +41,146 @@ #define THIRD_PARTY_GEMMA_CPP_COMPRESS_TOGGLE #endif +#include "hwy/highway.h" +// After highway.h #include "compression/nuq-inl.h" #include "compression/sfp-inl.h" -#include "hwy/highway.h" +#include "hwy/profiler.h" // also uses SIMD HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -namespace detail { - -// Adapters to store two f32 vectors to f32 or bf16; avoids duplicating -// RMSNorm and RMSNormInplace for the two output types. -template -void Store2(DF df, hn::Vec v0, hn::Vec v1, float* HWY_RESTRICT out) { - const size_t NF = hn::Lanes(df); - hn::StoreU(v0, df, out); - hn::StoreU(v1, df, out + NF); -} - -template -void Store2(DF df, hn::Vec v0, hn::Vec v1, BF16* HWY_RESTRICT out) { - const hn::Repartition dbf; - hn::StoreU(hn::OrderedDemote2To(dbf, v0, v1), dbf, out); -} - -} // namespace detail - // Enables generic code independent of compression type. template // primary, must specialize struct CompressTraits {}; -// Useful for backprop/, where weights are currently f32. +// Used by backprop/, where weights are currently f32; also MatMul for f32 +// weights or activations, if native `ReorderWidenMulAccumulate` is available. template <> struct CompressTraits { - using MatT = float; - static const char* Name() { return "f32"; } - static constexpr bool kSupportsEvenOdd = false; // unnecessary + using Packed = float; - template - static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in, - size_t num, CompressPerThread& tls, - size_t /*out_capacity*/, - MatT* HWY_RESTRICT out, size_t out_ofs) { + template > + static HWY_INLINE void Compress(DF /*df*/, const float* HWY_RESTRICT raw, + size_t num, CompressPerThread& /*tls*/, + const PackedSpan& packed, + const size_t packed_ofs) { + hwy::CopyBytes(raw, packed.ptr + packed_ofs, num * sizeof(raw[0])); + } + + template > + static void Store2(DF df, VF raw0, VF raw1, const PackedSpan& packed, + const size_t packed_ofs) { + const size_t NF = hn::Lanes(df); + hn::StoreU(raw0, df, packed.ptr + packed_ofs); + hn::StoreU(raw1, df, packed.ptr + packed_ofs + NF); + } + + template > + static HWY_INLINE void Load2(DBF16 dbf16, + const PackedSpan& packed, + const size_t packed_ofs, VBF16& raw0, + VBF16& raw1) { + const hn::Repartition df; using VF = hn::Vec; - const size_t N = hn::Lanes(df); - HWY_DASSERT(num >= 2 * N && num % (2 * N) == 0); + const size_t NF = hn::Lanes(df); + const VF f0 = hn::LoadU(df, packed.ptr + packed_ofs + 0 * NF); + const VF f1 = hn::LoadU(df, packed.ptr + packed_ofs + 1 * NF); + const VF f2 = hn::LoadU(df, packed.ptr + packed_ofs + 2 * NF); + const VF f3 = hn::LoadU(df, packed.ptr + packed_ofs + 3 * NF); + raw0 = hn::OrderedDemote2To(dbf16, f0, f1); + raw1 = hn::OrderedDemote2To(dbf16, f2, f3); + } - for (size_t i = 0; i < num; i += 2 * N) { - const VF in0 = hn::LoadU(df, in + i); - const VF in1 = hn::LoadU(df, in + i + N); - hn::StoreU(in0, df, out + out_ofs + i); - hn::StoreU(in1, df, out + out_ofs + i + N); + template > + static HWY_INLINE void Load2(DF df, const PackedSpan& packed, + const size_t packed_ofs, VF& raw0, VF& raw1) { + const size_t N = hn::Lanes(df); + raw0 = hn::LoadU(df, packed.ptr + packed_ofs); + raw1 = hn::LoadU(df, packed.ptr + packed_ofs + N); + } + + template + static HWY_INLINE void DecompressAndZeroPad( + DBF dbf, const PackedSpan& packed, const size_t packed_ofs, + BF16* HWY_RESTRICT raw, size_t num) { + const hn::Repartition df; + using VF = hn::Vec; + using VBF = hn::Vec; + const size_t NF = hn::Lanes(df); + + size_t i = 0; + if (num >= 2 * NF) { + for (; i <= num - 2 * NF; i += 2 * NF) { + const VF f0 = hn::LoadU(df, packed.ptr + packed_ofs + i); + const VF f1 = hn::LoadU(df, packed.ptr + packed_ofs + i + NF); + hn::StoreU(hn::OrderedDemote2To(dbf, f0, f1), dbf, raw + i); + } + } + const size_t remaining = num - i; + HWY_DASSERT(remaining < 2 * NF); + if (HWY_UNLIKELY(remaining != 0)) { + const size_t remaining2 = remaining - HWY_MIN(remaining, NF); + const VF f0 = hn::LoadN(df, packed.ptr + packed_ofs + i, remaining); + const VF f1 = hn::LoadN(df, packed.ptr + packed_ofs + i + NF, remaining2); + hn::StoreU(hn::OrderedDemote2To(dbf, f0, f1), dbf, raw + i); } } template - static HWY_INLINE void Decompress2(DF df, const MatT* HWY_RESTRICT in, - size_t in_ofs, hn::Vec& f0, - hn::Vec& f1) { - const size_t N = hn::Lanes(df); - f0 = hn::LoadU(df, in + in_ofs); - f1 = hn::LoadU(df, in + in_ofs + N); - } - - // Called by MatMul for f32 weights or activations if native - // `ReorderWidenMulAccumulate` is available. - template > - static HWY_INLINE void Decompress2(DBF16 dbf16, const MatT* HWY_RESTRICT in, - size_t in_ofs, VBF16& v0, VBF16& v1) { - const hn::Repartition df; + static HWY_INLINE void DecompressAndZeroPad( + DF df, const PackedSpan& packed, const size_t packed_ofs, + float* HWY_RESTRICT raw, size_t num) { using VF = hn::Vec; const size_t NF = hn::Lanes(df); - const VF f0 = hn::LoadU(df, in + in_ofs + 0 * NF); - const VF f1 = hn::LoadU(df, in + in_ofs + 1 * NF); - const VF f2 = hn::LoadU(df, in + in_ofs + 2 * NF); - const VF f3 = hn::LoadU(df, in + in_ofs + 3 * NF); - v0 = hn::OrderedDemote2To(dbf16, f0, f1); - v1 = hn::OrderedDemote2To(dbf16, f2, f3); - } - template - static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/, - const MatT* HWY_RESTRICT in, size_t in_ofs, - float* HWY_RESTRICT out, size_t num) { - using VF = hn::Vec; - const size_t N = hn::Lanes(df); - - for (size_t i = 0; i < num; i += N) { - const VF v = hn::LoadU(df, in + in_ofs + i); - hn::StoreU(v, df, out + i); + size_t i = 0; + if (num >= NF) { + for (; i <= num - NF; i += NF) { + const VF vf = hn::LoadU(df, packed.ptr + packed_ofs + i); + hn::StoreU(vf, df, raw + i); + } + } + const size_t remaining = num - i; + HWY_DASSERT(remaining < NF); + if (HWY_UNLIKELY(remaining != 0)) { + const VF vf = hn::LoadN(df, packed.ptr + packed_ofs + i, remaining); + hn::StoreU(vf, df, raw + i); // adds zero padding } } }; template <> -struct CompressTraits { - using MatT = hwy::bfloat16_t; - static const char* Name() { return "bf16"; } - static constexpr bool kSupportsEvenOdd = true; +struct CompressTraits { + using Packed = BF16; - template - static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in, + // Note: it is fine for the lower 16 mantissa bits of `raw` to be nonzero + // because we round rather than truncate. + template > + static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw, size_t num, CompressPerThread& tls, - size_t /*out_capacity*/, - MatT* HWY_RESTRICT out, size_t out_ofs) { + const PackedSpan& packed, + const size_t packed_ofs) { const hn::RebindToUnsigned du; - const hn::Repartition dbf; - using VF = hn::Vec; - const size_t N = hn::Lanes(df); - - hn::Vec or_sum = hn::Zero(du); + const hn::Repartition dbf; + const size_t NF = hn::Lanes(df); size_t i = 0; - if (num >= 2 * N) { - for (; i <= num - 2 * N; i += 2 * N) { - const VF in0 = hn::LoadU(df, in + i); - const VF in1 = hn::LoadU(df, in + i + N); + if (num >= 2 * NF) { + for (; i <= num - 2 * NF; i += 2 * NF) { + const VF raw0 = hn::LoadU(df, raw + i); + const VF raw1 = hn::LoadU(df, raw + i + NF); - // Sticky bits so we can warn if any lower bits were set. - or_sum = hn::Or3(or_sum, hn::BitCast(du, in0), hn::BitCast(du, in1)); - hn::StoreU(hn::OrderedDemote2To(dbf, in0, in1), dbf, out + out_ofs + i); + hn::StoreU(hn::OrderedDemote2To(dbf, raw0, raw1), dbf, + packed.ptr + packed_ofs + i); if (COMPRESS_STATS) { DistortionStats stats; - for (size_t j = 0; j < 2 * N; ++j) { - stats.Notify(in[i + j], hwy::F32FromBF16(out[out_ofs + i + j])); + for (size_t j = 0; j < 2 * NF; ++j) { + stats.Notify(raw[i + j], + hwy::F32FromBF16(packed.ptr[packed_ofs + i + j])); } tls.stats.Notify(stats); } @@ -176,270 +188,248 @@ struct CompressTraits { } const size_t remaining = num - i; + HWY_DASSERT(remaining < 2 * NF); if (remaining != 0) { - const VF in0 = hn::LoadN(df, in + i, remaining); - const size_t remaining1 = remaining - HWY_MIN(remaining, N / 2); - const VF in1 = hn::LoadN(df, in + i + N, remaining1); + const VF raw0 = hn::LoadN(df, raw + i, remaining); + const size_t remaining1 = remaining - HWY_MIN(remaining, NF); + const VF raw1 = hn::LoadN(df, raw + i + NF, remaining1); - // Sticky bits so we can warn if any lower bits were set. - or_sum = hn::Or3(or_sum, hn::BitCast(du, in0), hn::BitCast(du, in1)); - hn::StoreU(hn::OrderedDemote2To(dbf, in0, in1), dbf, out + out_ofs + i); + hn::StoreN(hn::OrderedDemote2To(dbf, raw0, raw1), dbf, + packed.ptr + packed_ofs + i, remaining); if (COMPRESS_STATS) { DistortionStats stats; for (size_t j = 0; j < remaining; ++j) { - stats.Notify(in[i + j], hwy::F32FromBF16(out[out_ofs + i + j])); + stats.Notify(raw[i + j], + hwy::F32FromBF16(packed.ptr[packed_ofs + i + j])); } tls.stats.Notify(stats); } } - - // If the lower 16 bits are not zero, we should implement rounding. - or_sum = hn::And(or_sum, hn::Set(du, 0xFFFF)); - if (!hn::AllTrue(du, hn::Eq(or_sum, hn::Zero(du)))) { - // fprintf(stderr, "Warning: Lossy truncation."); - } } - template - static HWY_INLINE void Decompress2(DF df, const MatT* HWY_RESTRICT in, - size_t in_ofs, hn::Vec& f0, - hn::Vec& f1) { - const hn::Repartition dbf; - using VBF = hn::Vec; - const VBF in16 = hn::LoadU(dbf, in + in_ofs); - f0 = hn::PromoteLowerTo(df, in16); - f1 = hn::PromoteUpperTo(df, in16); + template > + static void Store2(DF df, VF raw0, VF raw1, const PackedSpan& packed, + const size_t packed_ofs) { + const hn::Repartition dbf; + hn::StoreU(hn::OrderedDemote2To(dbf, raw0, raw1), dbf, + packed.ptr + packed_ofs); } template - static HWY_INLINE void Decompress2(DBF16 dbf16, const MatT* HWY_RESTRICT in, - size_t in_ofs, hn::Vec& v0, - hn::Vec& v1) { - v0 = hn::LoadU(dbf16, in + in_ofs); - v1 = hn::LoadU(dbf16, in + in_ofs + hn::Lanes(dbf16)); + static HWY_INLINE void Load2(DBF16 dbf16, + const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, + hn::Vec& raw1) { + const size_t N16 = hn::Lanes(dbf16); + raw0 = hn::LoadU(dbf16, packed.ptr + packed_ofs); + raw1 = hn::LoadU(dbf16, packed.ptr + packed_ofs + N16); } template - static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/, - const MatT* HWY_RESTRICT in, size_t in_ofs, - float* HWY_RESTRICT out, size_t num) { - const hn::Repartition dbf; + static HWY_INLINE void Load2(DF df, const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, + hn::Vec& raw1) { + const hn::Repartition dbf; + using VBF = hn::Vec; + const VBF packed0 = hn::LoadU(dbf, packed.ptr + packed_ofs); + raw0 = hn::PromoteLowerTo(df, packed0); + raw1 = hn::PromoteUpperTo(df, packed0); + } + + template + static HWY_INLINE void DecompressAndZeroPad( + DBF dbf, const PackedSpan& packed, const size_t packed_ofs, + BF16* HWY_RESTRICT raw, size_t num) { using VBF = hn::Vec; - using VF = hn::Vec; const size_t N16 = hn::Lanes(dbf); size_t i = 0; if (num >= N16) { for (i = 0; i <= num - N16; i += N16) { - VF in0, in1; - Decompress2(df, in, in_ofs + i, in0, in1); - hn::StoreU(in0, df, out + i); - hn::StoreU(in1, df, out + i + N16 / 2); + const VBF packed0 = hn::LoadU(dbf, packed.ptr + packed_ofs + i); + hn::StoreU(packed0, dbf, raw + i); } } const size_t remaining = num - i; - if (remaining != 0) { - const VBF in16 = hn::LoadN(dbf, in + in_ofs + i, remaining); - const VF in0 = hn::PromoteLowerTo(df, in16); - const VF in1 = hn::PromoteUpperTo(df, in16); - hn::StoreN(in0, df, out + i, remaining); - // Avoid wraparound, potentially store nothing. - const size_t remaining1 = remaining - HWY_MIN(remaining, N16 / 2); - hn::StoreN(in1, df, out + i + N16 / 2, remaining1); + HWY_DASSERT(remaining < N16); + if (HWY_UNLIKELY(remaining != 0)) { + const VBF packed0 = + hn::LoadN(dbf, packed.ptr + packed_ofs + i, remaining); + hn::StoreU(packed0, dbf, raw + i); } } - // Computes the dot product of an even-odd deinterleaved, f32 `vec_aligned` - // and a column- major matrix `in`. `vec_aligned` should be aligned and - // alternate even-indexed `hn::Lanes(df32)` elements followed by odd-indexed - // `hn::Lanes(df32)` elements. template - static HWY_INLINE float DotEO( - const DF df32, const hwy::bfloat16_t* HWY_RESTRICT in, size_t in_ofs, - const float* HWY_RESTRICT vec_aligned, size_t num) { - HWY_DASSERT(num >= (hn::Lanes(df32) * 2) && - (num % (hn::Lanes(df32) * 2)) == 0); - HWY_DASSERT((in_ofs % (hn::Lanes(df32) * 2)) == 0); - HWY_DASSERT(hn::IsAligned(df32, vec_aligned)); + static HWY_INLINE void DecompressAndZeroPad( + DF df, const PackedSpan& packed, const size_t packed_ofs, + float* HWY_RESTRICT raw, size_t num) { + const hn::Repartition dbf; + using VF = hn::Vec; + using VBF = hn::Vec; + const size_t NF = hn::Lanes(df); - const hn::Repartition dbf16; - using VF32 = decltype(Zero(df32)); - const size_t N = Lanes(dbf16); - - VF32 sum0 = Zero(df32); - VF32 sum1 = Zero(df32); - VF32 sum2 = Zero(df32); - VF32 sum3 = Zero(df32); - - for (size_t i = 0; i < num; /* i += 2 * N */) { - const auto interleaved0 = hn::LoadU(dbf16, in + in_ofs + i); - const VF32 ae0 = Load(df32, vec_aligned + i); - const VF32 ao0 = Load(df32, vec_aligned + i + (N / 2)); - sum0 = hn::MulAdd(ae0, hn::PromoteEvenTo(df32, interleaved0), sum0); - sum1 = hn::MulAdd(ao0, hn::PromoteOddTo(df32, interleaved0), sum1); - i += N; - - const auto interleaved1 = hn::LoadU(dbf16, in + in_ofs + i); - const VF32 ae1 = Load(df32, vec_aligned + i); - const VF32 ao1 = Load(df32, vec_aligned + i + (N / 2)); - sum2 = hn::MulAdd(ae1, hn::PromoteEvenTo(df32, interleaved1), sum2); - sum3 = hn::MulAdd(ao1, hn::PromoteOddTo(df32, interleaved1), sum3); - i += N; + size_t i = 0; + if (num >= 2 * NF) { + for (i = 0; i <= num - 2 * NF; i += 2 * NF) { + VF raw0, raw1; + Load2(df, packed, packed_ofs + i, raw0, raw1); + hn::StoreU(raw0, df, raw + i); + hn::StoreU(raw1, df, raw + i + NF); + } } - sum0 = hn::Add(sum0, sum1); - sum2 = hn::Add(sum2, sum3); - sum0 = hn::Add(sum0, sum2); - return hn::ReduceSum(df32, sum0); + const size_t remaining = num - i; + HWY_DASSERT(remaining < 2 * NF); + if (HWY_UNLIKELY(remaining != 0)) { + const VBF packed0 = + hn::LoadN(dbf, packed.ptr + packed_ofs + i, remaining); + const VF raw0 = hn::PromoteLowerTo(df, packed0); + const VF raw1 = hn::PromoteUpperTo(df, packed0); + // If at most one vector, the first store adds zero padding. Check before + // storing the second, because callers only pad to one vector. + hn::StoreU(raw0, df, raw + i); + if (remaining >= NF) hn::StoreU(raw1, df, raw + i + NF); + } } }; // Switching floating point: 8-bit, 2..3 mantissa bits. template <> struct CompressTraits { - using MatT = SfpStream; - static const char* Name() { return "sfp"; } - static constexpr bool kSupportsEvenOdd = true; + using Packed = SfpStream; - // Callers are responsible for scaling `in` such that its magnitudes do not - // exceed 1.875. See CompressedArray::scale(). + // Callers are responsible for scaling `raw` such that its magnitudes do not + // exceed `SfpStream::kMax`. See CompressedArray::scale(). template - static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in, + static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw, size_t num, CompressPerThread& tls, - size_t /*out_capacity*/, - MatT* HWY_RESTRICT out, size_t out_ofs) { - SfpCodec::Enc(df, in, num, out + out_ofs); + const PackedSpan& packed, + const size_t packed_ofs) { + SfpCodec::Enc(df, raw, num, packed.ptr + packed_ofs); if (COMPRESS_STATS) { - const hn::Repartition dbf; - auto distorted = hwy::AllocateAligned(num); - SfpCodec::Dec(dbf, out + out_ofs, num, distorted.get()); + const hn::Repartition dbf; + auto distorted = + hwy::AllocateAligned(hwy::RoundUpTo(num, hn::Lanes(dbf))); + SfpCodec::DecompressAndZeroPad(dbf, MakeConst(packed), packed_ofs, + distorted.get(), num); DistortionStats stats; for (size_t i = 0; i < num; ++i) { - stats.Notify(in[i], hwy::F32FromBF16(distorted[i])); + stats.Notify(raw[i], hwy::F32FromBF16(distorted[i])); } tls.stats.Notify(stats); } } - template // f32 or bf16 - static HWY_INLINE void Decompress2(D d, const MatT* HWY_RESTRICT in, - size_t in_ofs, hn::Vec& v0, - hn::Vec& v1) { + template // Caller checks this is f32 or bf16 + static HWY_INLINE void Load2(D d, const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, + hn::Vec& raw1) { const hn::Twice> d8; using V8 = hn::Vec; - const V8 packed = hn::LoadU(d8, &in->byte + in_ofs); - SfpCodec::Dec2(d, packed, v0, v1); + const V8 v8 = hn::LoadU(d8, &packed.ptr->byte + packed_ofs); + SfpCodec::Dec2(d, v8, raw0, raw1); } - template - static HWY_INLINE void Decompress(D d, size_t /*in_capacity*/, - const MatT* HWY_RESTRICT in, size_t in_ofs, - OutT* HWY_RESTRICT out, size_t num) { - SfpCodec::Dec(d, in + in_ofs, num, out); - } + // Store2 is not yet implemented. - // Computes the dot product of an even-odd deinterleaved, f32 or bf16 - // `vec_aligned` and a column-major matrix `in`. `vec_aligned` should be - // aligned and alternate even-indexed `hn::Lanes(df)` elements followed by - // odd-indexed `hn::Lanes(df)` elements. - template - static HWY_INLINE float DotEO(const DF df, const MatT* HWY_RESTRICT in, - size_t in_ofs, - const VecT* HWY_RESTRICT vec_aligned, - size_t num) { - HWY_DASSERT(num >= (hn::Lanes(df) * 2) && (num % (hn::Lanes(df) * 2)) == 0); - HWY_DASSERT((in_ofs % (hn::Lanes(df) * 2)) == 0); - HWY_DASSERT(hn::IsAligned(df, vec_aligned)); - - using VF = hn::Vec; - VF sum0 = hn::Zero(df); - VF sum1 = hn::Zero(df); - VF sum2 = hn::Zero(df); - VF sum3 = hn::Zero(df); - - SfpCodec::DotEO(df, in + in_ofs, num, vec_aligned, sum0, sum1, sum2, sum3); - - // Reduction tree: sum of all accumulators, then their lanes - sum0 = hn::Add(sum0, sum1); - sum2 = hn::Add(sum2, sum3); - sum0 = hn::Add(sum0, sum2); - return hn::ReduceSum(df, sum0); + template + static HWY_INLINE void DecompressAndZeroPad( + D d, const PackedSpan& packed, const size_t packed_ofs, + Raw* HWY_RESTRICT raw, const size_t num) { + SfpCodec::DecompressAndZeroPad(d, packed, packed_ofs, raw, num); } }; // Nonuniform quantization, 4.5 bits per element, two separate streams. template <> struct CompressTraits { - using MatT = NuqStream; - static const char* Name() { return "nuq"; } - static constexpr bool kSupportsEvenOdd = false; + using Packed = NuqStream; template - static HWY_INLINE void Compress(DF df, const float* in, size_t num, - CompressPerThread& tls, size_t out_capacity, - MatT* out, size_t out_ofs) { - NuqCodec::Enc(df, in, num, tls.buf, out_capacity, out, out_ofs); + static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw, + size_t num, CompressPerThread& tls, + const PackedSpan& packed, + const size_t packed_ofs) { + NuqCodec::Enc(df, raw, num, tls.buf, packed, packed_ofs); if (COMPRESS_STATS) { for (size_t i = 0; i < num; ++i) { - tls.stats.NotifyIn(static_cast(lroundf(in[i] * 100.0f + 500.0f))); + tls.stats.NotifyIn(static_cast(lroundf(raw[i] * 100.0f + 500.0f))); } - const hn::Repartition dbf; - auto distorted = hwy::AllocateAligned(num); - NuqCodec::Dec(dbf, out_capacity, out, out_ofs, distorted.get(), num); + const hn::Repartition dbf; + const size_t N16 = hn::Lanes(dbf); + auto distorted = hwy::AllocateAligned(hwy::RoundUpTo(num, N16)); + NuqCodec::DecompressAndZeroPad(dbf, MakeConst(packed), packed_ofs, + distorted.get(), num); DistortionStats stats; for (size_t i = 0; i < num; ++i) { - stats.Notify(in[i], hwy::F32FromBF16(distorted[i])); + stats.Notify(raw[i], hwy::F32FromBF16(distorted[i])); } tls.stats.Notify(stats); } } - template - static HWY_INLINE void Decompress(D d, size_t in_capacity, const MatT* in, - size_t in_ofs, OutT* out, size_t num) { - NuqCodec::Dec(d, in_capacity, in, in_ofs, out, num); + template // Caller checks this is f32 or bf16 + static HWY_INLINE void Load2(D d, const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, + hn::Vec& raw1) { + const hn::Twice> d8; + using V8 = hn::Vec; + NuqCodec::Dec2(d, packed, packed_ofs, raw0, raw1); + } + + // Store2 is not yet implemented. + + template + static HWY_INLINE void DecompressAndZeroPad( + D d, const PackedSpan& packed, const size_t packed_ofs, + Raw* raw, const size_t num) { + NuqCodec::DecompressAndZeroPad(d, packed, packed_ofs, raw, num); } }; -// Compresses `num` inputs to `out` starting at `out_ofs`. This can be used for -// compressing sub-regions of an array. -template -HWY_NOINLINE void Compress(const float* in, size_t num, - CompressWorkingSet& work, size_t out_capacity, - MatT* out, size_t out_ofs, hwy::ThreadPool& pool) { - HWY_DASSERT(out_ofs + num <= out_capacity); - work.tls.resize(pool.NumThreads()); +// Compresses `num` elements of `raw` to `packed` starting at `packed_ofs`, +// which is useful for compressing sub-regions of an array. +template +HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, + CompressWorkingSet& work, + const PackedSpan& packed, + const size_t packed_ofs, hwy::ThreadPool& pool) { + packed.BoundsCheck(packed_ofs, num); + work.tls.resize(pool.NumWorkers()); if (COMPRESS_STATS) { for (auto& tls : work.tls) { tls.stats.Reset(); } } - const double t0 = hwy::platform::Now(); + const bool want_bench = num > 1024 * 1024 || COMPRESS_STATS; + const double t0 = want_bench ? hwy::platform::Now() : 0.0; - using Traits = CompressTraits; + using Traits = CompressTraits; constexpr size_t kBatch = 8192; const size_t num_batches = hwy::DivCeil(num, kBatch); pool.Run(0, num_batches, [&](const uint32_t idx_batch, size_t thread) HWY_ATTR { const hn::ScalableTag df; - const size_t in_ofs = idx_batch * kBatch; + const size_t my_pos = idx_batch * kBatch; const size_t my_num = - idx_batch == num_batches - 1 ? (num - in_ofs) : kBatch; - Traits::Compress(df, in + in_ofs, my_num, work.tls[thread], - out_capacity, out, out_ofs + in_ofs); + idx_batch == num_batches - 1 ? (num - my_pos) : kBatch; + Traits::Compress(df, raw + my_pos, my_num, work.tls[thread], + packed, packed_ofs + my_pos); }); - const double t1 = hwy::platform::Now(); - const double mb = static_cast(num) * sizeof(in[0]) * 1E-6; - const double mbps = mb / (t1 - t0); - fprintf(stderr, "Compress %.1f MB/s\n", mbps); + if (want_bench) { // Avoids log spam in tests + const double t1 = hwy::platform::Now(); + const double mb = static_cast(num) * sizeof(raw[0]) * 1E-6; + const double mbps = mb / (t1 - t0); + fprintf(stderr, "Compress %.1f MB/s\n", mbps); + } if (COMPRESS_STATS) { for (size_t i = 1; i < work.tls.size(); ++i) { @@ -449,53 +439,182 @@ HWY_NOINLINE void Compress(const float* in, size_t num, } } -// Compresses an entire std::array into `out`, which is assumed to have exactly -// that much capacity. -template -HWY_INLINE void Compress(const std::array& in, - CompressWorkingSet& work, - CompressedArray& compressed, - hwy::ThreadPool& pool) { - Compress(in.data(), kCapacity, work, kCapacity, compressed.data(), 0, pool); +// Adapter that compresses into `CompressedArray`. `raw` must already be scaled +// to fit the value range, if `Packed` is `SfpStream`. +template +HWY_INLINE void CompressScaled(const float* HWY_RESTRICT raw, size_t num, + CompressWorkingSet& work, + CompressedArray& compressed, + hwy::ThreadPool& pool) { + Compress(raw, num, work, MakeSpan(compressed.data(), kCapacity), + /*packed_ofs=*/0, pool); } -// Decompresses `num` values from `compressed` starting at `compressed_ofs`. -template -HWY_NOINLINE void Decompress(const ArrayT& compressed, size_t compressed_ofs, - OutT* out, size_t num) { - HWY_DASSERT(compressed_ofs + num <= compressed.size()); - const hn::ScalableTag d; - using Traits = CompressTraits; - Traits::Decompress(d, compressed.size(), compressed.data(), compressed_ofs, - out, num); +// Stores two f32 vectors to f32 or bf16; avoids duplicating RMSNorm and +// RMSNormInplace for the two output types. +template > +void Compress2(DF df, VF raw0, VF raw1, const PackedSpan& packed, + const size_t packed_ofs) { + static_assert(hwy::IsSameEither()); + packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(df)); + using Traits = CompressTraits; + Traits::Store2(df, raw0, raw1, packed, packed_ofs); } -// As above, but with threading and benchmarking. -template -HWY_INLINE void Decompress(const CompressedArray& compressed, - size_t compressed_ofs, OutT* out, size_t num, - hwy::ThreadPool& pool) { - HWY_DASSERT(compressed_ofs + num <= compressed.size()); - const double t0 = hwy::platform::Now(); +// Decompresses from any type of `packed`, to two float or BF16 vectors. +template > +HWY_INLINE void Decompress2(DRaw d, const PackedSpan& packed, + const size_t packed_ofs, VRaw& raw0, VRaw& raw1) { + using TRaw = hn::TFromD; + static_assert(hwy::IsSameEither()); + packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(d)); + using Traits = CompressTraits>; + Traits::Load2(d, MakeConst(packed), packed_ofs, raw0, raw1); +} - using Traits = CompressTraits; - constexpr size_t kBatch = 8192; - const size_t num_batches = hwy::DivCeil(num, kBatch); - pool.Run( - 0, num_batches, [&](const uint32_t idx_batch, size_t thread) HWY_ATTR { - const hn::ScalableTag d; +// Decompresses from any type of `packed`, starting at (any) `packed_ofs`, to +// (any) `num` elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as +// required to round `num` up to one vector, if it is not already. The caller is +// responsible for scaling `raw` to the original range because `EmbedToken` +// also wants to scale the decompressed elements. +template > +HWY_NOINLINE void DecompressAndZeroPad(DRaw d, const PackedSpan& packed, + const size_t packed_ofs, TRaw* raw, + size_t num) { + static_assert(hwy::IsSameEither()); + using Traits = CompressTraits>; + packed.BoundsCheck(packed_ofs, num); + Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num); +} - const size_t ofs = idx_batch * kBatch; - const size_t batch = - idx_batch == num_batches - 1 ? (num - ofs) : kBatch; - Traits::Decompress(d, compressed.size(), compressed.data(), - compressed_ofs + ofs, out + ofs, batch); - }); +// Decompresses to the type specified by `D` from each of two arrays in groups +// of four vectors, passes them to `kernel.Update4`, zero-pads to a vector +// multiple, then calls `kernel.Update1` for the remaining vectors. Returns +// `kernel.Reduce`. +// +// This is useful for implementing dot products, and similar to +// `hwy/contrib/unroller`, but also supports compressed types with simpler +// remainder handling thanks to `DecompressAndZeroPad`. +// +// `w` can be any packed type, including NUQ, which requires a separate `w_ofs` +// rather than pointer arithmetic. `vec_aligned` can also be any type, but +// typically float or BF16. We omit a `v_ofs` because it is 0 in our use cases. +// `num`, the number of elements to process, need not be a vector multiple. +// +// `kernel` is const& so we can pass an rvalue argument, but can contain +// mutable state, though not vectors (see highway.h). We pass in the four +// loaded vectors plus eight *f32* state vectors, independent of `D`. +template +HWY_INLINE float DecompressAndCall(D d, const PackedSpan& w, + const size_t w_ofs, + const VecT* HWY_RESTRICT vec_aligned, + const size_t num, const Kernel& kernel) { + PROFILER_FUNC; - const double t1 = hwy::platform::Now(); - const double mb = num * sizeof(MatT) * 1E-6; - const double mbps = mb / (t1 - t0); - fprintf(stderr, "Decompress %.1f MB/s\n", mbps); + HWY_DASSERT(hn::IsAligned(hn::Repartition(), vec_aligned)); + const auto v_span = MakeSpan(vec_aligned, num); + + // Decompressed inputs + using V = hn::Vec; + V w0, w1, w2, w3, v0, v1, v2, v3; + + // State for Kernel + const hn::Repartition df; + using VF = hn::Vec; + VF sum0 = hn::Zero(df); + VF sum1 = hn::Zero(df); + VF sum2 = hn::Zero(df); + VF sum3 = hn::Zero(df); + VF comp0 = hn::Zero(df); + VF comp1 = hn::Zero(df); + VF comp2 = hn::Zero(df); + VF comp3 = hn::Zero(df); + + const size_t N = hn::Lanes(d); + size_t i = 0; + if (num >= 4 * N) { + for (; i <= num - 4 * N; i += 4 * N) { + Decompress2(d, w, w_ofs + i + 0 * N, w0, w1); + Decompress2(d, w, w_ofs + i + 2 * N, w2, w3); + Decompress2(d, v_span, i + 0 * N, v0, v1); + Decompress2(d, v_span, i + 2 * N, v2, v3); + + kernel.Update4(d, w0, w1, w2, w3, v0, v1, v2, v3, sum0, sum1, sum2, sum3, + comp0, comp1, comp2, comp3); + } + } + + size_t remaining = num - i; + HWY_DASSERT(remaining < 4 * N); + if (HWY_UNLIKELY(remaining != 0)) { + using T = hn::TFromD; + HWY_ALIGN T padded_w[4 * hn::MaxLanes(d)]; + HWY_ALIGN T padded_v[4 * hn::MaxLanes(d)]; + DecompressAndZeroPad(d, w, w_ofs + i, padded_w, remaining); + DecompressAndZeroPad(d, v_span, i, padded_v, remaining); + + // 1..4 whole vectors, possibly zero-padded. + for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) { + const V w0 = hn::Load(d, padded_w + padded_pos); + const V v0 = hn::Load(d, padded_v + padded_pos); + kernel.Update1(d, w0, v0, sum0, comp0); + } + } + + return kernel.Reduce(df, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3); +} + +// Same as above, but single input array. Used by RMSNorm. +template +HWY_INLINE float DecompressAndCall(D d, const VecT* HWY_RESTRICT vec_aligned, + const size_t num, const Kernel& kernel) { + PROFILER_FUNC; + + HWY_DASSERT(hn::IsAligned(hn::Repartition(), vec_aligned)); + const auto v_span = MakeSpan(vec_aligned, num); + + // Decompressed inputs + using V = hn::Vec; + V v0, v1, v2, v3; + + // State for Kernel + const hn::Repartition df; + using VF = hn::Vec; + VF sum0 = hn::Zero(d); + VF sum1 = hn::Zero(d); + VF sum2 = hn::Zero(d); + VF sum3 = hn::Zero(d); + VF comp0 = hn::Zero(d); + VF comp1 = hn::Zero(d); + VF comp2 = hn::Zero(d); + VF comp3 = hn::Zero(d); + + const size_t N = hn::Lanes(d); + size_t i = 0; + if (num >= 4 * N) { + for (; i <= num - 4 * N; i += 4 * N) { + Decompress2(d, v_span, i + 0 * N, v0, v1); + Decompress2(d, v_span, i + 2 * N, v2, v3); + + kernel.Update4(d, v0, v1, v2, v3, v0, v1, v2, v3, sum0, sum1, sum2, sum3, + comp0, comp1, comp2, comp3); + } + } + + size_t remaining = num - i; + HWY_DASSERT(remaining < 4 * N); + if (HWY_UNLIKELY(remaining != 0)) { + HWY_ALIGN float padded_v[4 * hn::MaxLanes(d)]; + DecompressAndZeroPad(d, v_span, i, padded_v, remaining); + + // 1..4 whole vectors, possibly zero-padded. + for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) { + const VF v0 = hn::Load(d, padded_v + padded_pos); + kernel.Update1(d, v0, v0, sum0, comp0); + } + } + + return kernel.Reduce(d, sum0, sum1, sum2, sum3, comp0, comp1, comp2, comp3); } // Functor called for each tensor, which compresses and stores them along with @@ -504,21 +623,22 @@ class Compressor { public: explicit Compressor(hwy::ThreadPool& pool) : pool_(pool) {} - template + template void operator()(const char* name, const float* weights, - CompressedArray& compressed) { + CompressedArray& compressed) { Insert(name, weights, kCapacity, work_, compressed.CompressedSize(), compressed.data(), 0, pool_); } - template + template void Insert(const char* name, const float* weights, size_t weights_count, - CompressWorkingSet& work, size_t out_capacity, MatT* out, - size_t out_ofs, hwy::ThreadPool& pool) { + CompressWorkingSet& work, size_t out_capacity, Packed* packed, + size_t packed_ofs, hwy::ThreadPool& pool) { fprintf(stderr, "Regenerating %s (%zuM), please wait\n", name, weights_count / (1000 * 1000)); - Compress(weights, weights_count, work_, weights_count, out, 0, pool_); - writer_.Add(CacheKey(name), out, out_capacity); + Compress(weights, weights_count, work_, + PackedSpan{packed, weights_count}, 0, pool_); + writer_.Add(CacheKey(name), packed, out_capacity); } void AddScales(const float* scales, size_t len) { diff --git a/compression/compress.h b/compression/compress.h index 2dd77dd..cfd512f 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -29,7 +29,6 @@ // IWYU pragma: begin_exports #include "compression/blob_store.h" #include "compression/io.h" -#include "compression/nuq.h" #include "compression/shared.h" // IWYU pragma: end_exports #include "compression/distortion.h" @@ -41,22 +40,6 @@ namespace gcpp { -static inline const char* TypeName(float) { return "f32"; } -static inline const char* TypeName(BF16) { return "b16"; } -static inline const char* TypeName(SfpStream) { return "sfp"; } -static inline const char* TypeName(NuqStream) { return "nuq"; } - -// Returns the number of `MatT` elements required to store `capacity` values, -// which must not be zero. -template -constexpr size_t CompressedArrayElements(size_t capacity) { - if constexpr (hwy::IsSame, NuqStream>()) { - return NuqStream::PackedEnd(capacity); - } else { - return capacity; - } -} - // Compressed representation of floating-point elements. The array length may // differ from the number of elements. Associated operations such as Dot are // implemented in SIMD code and are thus non-member functions. @@ -152,8 +135,8 @@ struct CompressStats { #endif // COMPRESS_STATS struct CompressPerThread { + NuqStream::ClusterBuf buf; CompressStats stats; - ClusterBuf buf; }; struct CompressWorkingSet { diff --git a/compression/compress_test.cc b/compression/compress_test.cc index 00e678c..52883d4 100644 --- a/compression/compress_test.cc +++ b/compression/compress_test.cc @@ -12,3 +12,198 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +// SFP uses ConcatEven/Odd which are not supported; skip SVE for faster tests. +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE) +#endif + +#include "compression/compress.h" + +#include +#include + +#include "compression/distortion.h" +#include "util/test_util.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/tests/hwy_gtest.h" +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "compression/compress_test.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "compression/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +// Calls Compress and Decompress2 and verifies the distortion/error. +template +struct TestDecompress2T { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const size_t N = hn::Lanes(d); + CompressWorkingSet work; + hwy::ThreadPool pool(0); + hwy::RandomState rng; + + const size_t num = 2 * N; + const size_t packed_num = CompressedArrayElements(num); + auto raw = hwy::AllocateAligned(num); // Compress requires f32 + auto packed = hwy::AllocateAligned(packed_num); + auto dec = hwy::AllocateAligned(num); + HWY_ASSERT(raw && packed && dec); + const auto packed_span = MakeSpan(packed.get(), packed_num); + + hwy::Stats in_stats; + for (size_t i = 0; i < num; ++i) { + raw[i] = static_cast(RandomGaussian(rng)); + in_stats.Notify(raw[i]); + } + // Short inputs fail VerifyGaussian. + + const size_t packed_ofs = 0; + Compress(raw.get(), num, work, packed_span, packed_ofs, pool); + hn::Vec raw0, raw1; + Decompress2(d, MakeConst(packed_span), packed_ofs, raw0, raw1); + hn::Store(raw0, d, dec.get()); + hn::Store(raw1, d, dec.get() + N); + + DistortionStats stats; + for (size_t i = 0; i < num; ++i) { + stats.Notify(raw[i], hwy::ConvertScalarTo(dec[i])); + } + + if constexpr (false) { + fprintf(stderr, "%s %s: %zu: %f %f %f %f\n", TypeName(), + TypeName(), num, stats.SumL1(), stats.GeomeanValueDivL1(), + stats.WeightedAverageL1(), stats.L1().Max()); + } + + constexpr bool kFromFloat = hwy::IsSame(); + constexpr bool kToFloat = hwy::IsSame(); + if constexpr (kFromFloat && kToFloat) { // Lossless + HWY_ASSERT(stats.NumExact() == num); + HWY_ASSERT(stats.SumL1() == 0.0f); + HWY_ASSERT(stats.L1().Max() == 0.0f); + } else if constexpr (hwy::IsSame() || + (kFromFloat && hwy::IsSame())) { + // Small roundoff error. BF16 to float is not lossless because the + // comparison is with float `raw`, prior to the Compress to BF16. + HWY_ASSERT(stats.L1().Max() <= 2E-3f); + HWY_ASSERT(IsInside(3E-4, 2E-3, stats.WeightedAverageL1())); + HWY_ASSERT(IsInside(600.0, 900.0, stats.GeomeanValueDivL1())); + } else if constexpr (hwy::IsSame()) { + HWY_ASSERT(stats.SumL1() <= 0.4f); + HWY_ASSERT(stats.L1().Max() <= 0.04f); + HWY_ASSERT(IsInside(0.01, 0.03, stats.WeightedAverageL1())); + HWY_ASSERT(IsInside(48.0, 72.0, stats.GeomeanValueDivL1())); + } else if constexpr (hwy::IsSame()) { + static_assert(NuqStream::kGroupSize == 256, "Update expected"); + HWY_ASSERT(stats.SumL1() <= 1.2f); + HWY_ASSERT(stats.L1().Max() <= 0.08f); + HWY_ASSERT(IsInside(0.02, 0.05, stats.WeightedAverageL1())); + HWY_ASSERT(IsInside(18.0, 62.0, stats.GeomeanValueDivL1())); + } else { + HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType"); + } + } +}; + +void TestAllDecompress2() { ForeachPackedAndRawType(); } + +// Calls Compress and DecompressAndZeroPad for all short lengths and verifies +// the distortion/error. +template +struct TestShortLengthsT { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const size_t N = hn::Lanes(d); + CompressWorkingSet work; + hwy::ThreadPool pool(0); + hwy::RandomState rng; + + for (size_t num = 1; num < 5 * hn::Lanes(d); ++num) { + const size_t packed_num = CompressedArrayElements(num); + + auto raw = hwy::AllocateAligned(num); // Compress requires f32 + auto packed = hwy::AllocateAligned(packed_num); + auto dec = hwy::AllocateAligned(hwy::RoundUpTo(num, N)); + HWY_ASSERT(raw && packed && dec); + const auto packed_span = MakeSpan(packed.get(), packed_num); + + hwy::Stats in_stats; + for (size_t i = 0; i < num; ++i) { + raw[i] = static_cast(RandomGaussian(rng)); + in_stats.Notify(raw[i]); + } + // Short inputs fail VerifyGaussian. + + const size_t packed_ofs = 0; + Compress(raw.get(), num, work, packed_span, packed_ofs, pool); + DecompressAndZeroPad(d, MakeConst(packed_span), packed_ofs, dec.get(), + num); + + DistortionStats stats; + for (size_t i = 0; i < num; ++i) { + stats.Notify(raw[i], hwy::ConvertScalarTo(dec[i])); + } + + if constexpr (false) { + fprintf(stderr, "%s %s: %zu: %f %f %f %f\n", TypeName(), + TypeName(), num, stats.SumL1(), stats.GeomeanValueDivL1(), + stats.WeightedAverageL1(), stats.L1().Max()); + } + + constexpr bool kFromFloat = hwy::IsSame(); + constexpr bool kToFloat = hwy::IsSame(); + if constexpr (kFromFloat && kToFloat) { // Lossless + HWY_ASSERT(stats.NumExact() == num); + HWY_ASSERT(stats.SumL1() == 0.0f); + HWY_ASSERT(stats.L1().Max() == 0.0f); + } else if (hwy::IsSame() || + (kFromFloat && hwy::IsSame())) { + // Small roundoff error. BF16 to float is not lossless because the + // comparison is with float `raw`, prior to the Compress to BF16. + HWY_ASSERT(stats.L1().Max() <= 4E-3f); + HWY_ASSERT(IsInside(1E-5, 3E-3, stats.WeightedAverageL1())); + HWY_ASSERT(IsInside(300.0, 2200.0, stats.GeomeanValueDivL1())); + } else if (hwy::IsSame()) { + HWY_ASSERT(stats.SumL1() <= 1.3f); + HWY_ASSERT(stats.L1().Max() <= 0.08f); + HWY_ASSERT(IsInside(7E-5, 0.05, stats.WeightedAverageL1())); + HWY_ASSERT(IsInside(28.0, 200.0, stats.GeomeanValueDivL1())); + } else if (hwy::IsSame()) { + static_assert(NuqStream::kGroupSize == 256, "Update expected"); + HWY_ASSERT(stats.SumL1() <= 4.6f); + HWY_ASSERT(stats.L1().Max() <= 0.14f); + HWY_ASSERT(IsInside(7E-5, 0.06, stats.WeightedAverageL1())); + HWY_ASSERT(IsInside(11.0, 180.0, stats.GeomeanValueDivL1())); + } else { + HWY_ABORT("Unhandled type requested by ForeachPackedAndRawType"); + } + } + } +}; + +void TestAllShortLengths() { ForeachPackedAndRawType(); } + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace gcpp { +HWY_BEFORE_TEST(CompressTest); +HWY_EXPORT_AND_TEST_P(CompressTest, TestAllDecompress2); +HWY_EXPORT_AND_TEST_P(CompressTest, TestAllShortLengths); +HWY_AFTER_TEST(); +} // namespace gcpp +#endif // HWY_ONCE diff --git a/compression/compress_weights.cc b/compression/compress_weights.cc index 3106db9..a36b35c 100644 --- a/compression/compress_weights.cc +++ b/compression/compress_weights.cc @@ -54,21 +54,6 @@ namespace gcpp { constexpr bool kDryRunFread = false; namespace { -float ScaleWeights(float* data, size_t len) { - float maxabs = 0.0; - for (size_t i = 0; i < len; ++i) { - maxabs = std::max(maxabs, std::abs(data[i])); - } - if (maxabs <= kMaxSFP) { - return 1.0f; - } - const float scale = maxabs / kMaxSFP; - const float inv_scale = 1.0f / scale; - for (size_t i = 0; i < len; ++i) { - data[i] *= inv_scale; - } - return scale; -} #define READ_WEIGHTS(name) \ do { \ diff --git a/compression/distortion_test.cc b/compression/distortion_test.cc index 00e026a..9350b5b 100644 --- a/compression/distortion_test.cc +++ b/compression/distortion_test.cc @@ -17,7 +17,7 @@ #include -#include "compression/shared.h" +#include "compression/shared.h" // SfpStream::kMax #include "util/test_util.h" #include "hwy/nanobenchmark.h" #include "hwy/tests/hwy_gtest.h" @@ -75,13 +75,15 @@ TEST(DistortionTest, TestDilution) { HWY_ASSERT(IsNear(0.001, stats.WeightedAverageL1())); // Now add a large difference: - stats.Notify(kMaxSFP - 0.0625f, kMaxSFP); // max magnitude, 3-bit mantissa + stats.Notify(SfpStream::kMax - 0.0625f, + SfpStream::kMax); // max magnitude, 3-bit mantissa // .. WeightedAverageL1 is closer to it. HWY_ASSERT(IsInside(0.020, 0.025, stats.WeightedAverageL1())); // Add a small and large difference: stats.Notify((1.75f - 0.125f) / 1024, 1.75f / 1024); // small, 2-bit mantissa - stats.Notify(-kMaxSFP + 0.0625f, -kMaxSFP); // larger negative + stats.Notify(-SfpStream::kMax + 0.0625f, + -SfpStream::kMax); // larger negative // .. SNR is still barely affected. HWY_ASSERT(IsInside(890.0, 900.0, stats.GeomeanValueDivL1())); // .. WeightedAverageL1 is higher after another large error. diff --git a/compression/nuq-inl.h b/compression/nuq-inl.h index 11e9204..f8fa467 100644 --- a/compression/nuq-inl.h +++ b/compression/nuq-inl.h @@ -19,11 +19,15 @@ #include #include +#include -#include "compression/nuq.h" #include "compression/shared.h" #include "hwy/base.h" +#if HWY_IS_MSAN +#include +#endif + #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_ // Actual per-target include guard. @@ -40,17 +44,24 @@ #include "compression/sfp-inl.h" #include "hwy/contrib/sort/vqsort-inl.h" -#ifndef HWY_IF_CONSTEXPR -#define HWY_IF_CONSTEXPR if -#endif - HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +static inline void MaybeCheckInitialized(const void* ptr, size_t size) { +#if HWY_IS_MSAN + __msan_check_mem_is_initialized(ptr, size); +#else + (void)ptr; + (void)size; +#endif +} + // For internal use by NuqCodec. class NuqClustering { + static constexpr size_t kGroupSize = NuqStream::kGroupSize; + // To go from sorted order back to the original order in O(1), we store the // original index in the lower bits of the float32 mantissa, which means they // are sorted alongside the value. @@ -88,11 +99,13 @@ class NuqClustering { explicit ClusterCost(const float* HWY_RESTRICT sorted) { double cumsum = 0.0; double cumsum2 = 0.0; - cumsum_[0] = cumsum2_[0] = 0.0; + dcumsum_[0] = 0.0; + cumsum_[0] = cumsum2_[0] = 0.0f; for (size_t i = 0; i < kGroupSize; ++i) { const float x = FloatPayload::Clear(sorted[i]); cumsum += x; cumsum2 += static_cast(x) * x; + dcumsum_[1 + i] = cumsum; cumsum_[1 + i] = static_cast(cumsum); cumsum2_[1 + i] = static_cast(cumsum2); } @@ -132,8 +145,10 @@ class NuqClustering { } // Returns cost (L2 norm) for a single cluster, used for backtracking. - float SumOfSorted(size_t first, size_t last) const { - return cumsum_[last + 1] - cumsum_[first]; + double SumOfSorted(size_t first, size_t last) const { + HWY_DASSERT(first < kGroupSize); + HWY_DASSERT(last < kGroupSize); + return dcumsum_[last + 1] - dcumsum_[first]; } // Returns vector of costs of clustering first..last + i with their means. @@ -199,6 +214,8 @@ class NuqClustering { float cumsum2_[kGroupSize + 1 + kMaxLanes]; float len_[kMaxLanes + kGroupSize + 1 + kMaxLanes]; // = vlen[i] float inv_len_[kMaxLanes + kGroupSize + 1 + kMaxLanes]; // = 1 / vlen[i] + + double dcumsum_[kGroupSize + 1]; // for SumOfSorted }; // Dynamic programming step: returns costs of clustering 0..last+i, where the @@ -206,18 +223,17 @@ class NuqClustering { // `first`, and `last`; vectorized across `last`. `first` may be greater than // `last`. `valid[i]` is `first <= last + i`. template , class MF = hn::Mask> - static HWY_INLINE VF ClusterDynProg(DF df, const AlignedMatrix& D, - const ClusterCost& cc, - const size_t idx_cluster, - const size_t first, const size_t last, - const MF valid) { + static HWY_INLINE VF + ClusterDynProg(DF df, const NuqStream::AlignedMatrix& costs, + const ClusterCost& cc, const size_t idx_cluster, + const size_t first, const size_t last, const MF valid) { HWY_DASSERT(idx_cluster != 0); HWY_DASSERT(0 != first && first < kGroupSize); HWY_DASSERT(last < kGroupSize); HWY_DASSERT(last % hn::Lanes(df) == 0); // Called in steps of N // Cost of clustering 0..first-1 with one fewer cluster than now. - const VF prev = hn::Set(df, D(idx_cluster - 1, first - 1)); + const VF prev = hn::Set(df, costs(idx_cluster - 1, first - 1)); // Eq2: add to that the cost of another cluster from first..last. return hn::Add(prev, cc.SumCosts(df, first, last, valid)); } @@ -237,7 +253,8 @@ class NuqClustering { // as implemented in FAISS, for our kGroupSize of 256. template static HWY_NOINLINE size_t ClusterExactL2(DF df, const float* HWY_RESTRICT x, - size_t num, ClusterBuf& buf, + size_t num, + NuqStream::ClusterBuf& buf, float* HWY_RESTRICT centers, uint16_t* HWY_RESTRICT indices) { HWY_DASSERT(num <= kGroupSize); @@ -268,31 +285,34 @@ class NuqClustering { ClusterCost cc(sorted_and_i); // ignores payload bits. // Reference: https://arxiv.org/abs/1701.07204 - // D[k-1][m] is the lowest cost of clustering x1..m into k clusters. - AlignedMatrix& D = buf.d; - // T[k][m] is the starting index within sorted_and_i[] of the k-th cluster. - AlignedMatrix& T = buf.t; + // costs[k-1][m] is the lowest cost of clustering x1..m into k clusters. + NuqStream::AlignedMatrix& costs = buf.costs; + // argmin[k][m] is the starting index within sorted_and_i[] of the k-th + // cluster. + NuqStream::AlignedMatrix& argmin = buf.argmin; - // Fill first row of `D` and `T`: single cluster, iterate over all `last`. + // Fill first row of `costs` and `argmin`: single cluster, iterate over all + // `last`. { const size_t cluster_idx = 0; const size_t first = 0; const VI vfirst = hn::Set(di, static_cast(first)); const MF all_valid = hn::FirstN(df, N); // first <= last is always true for (size_t last = 0; last < kGroupSize; last += N) { - const VF costs = cc.SumCosts(df, first, last, all_valid); - hn::Store(costs, df, &D(cluster_idx, last)); - hn::Store(vfirst, di, &T(cluster_idx, last)); + const VF vcosts = cc.SumCosts(df, first, last, all_valid); + hn::Store(vcosts, df, &costs(cluster_idx, last)); + hn::Store(vfirst, di, &argmin(cluster_idx, last)); } } + constexpr size_t kClusters = NuqStream::kClusters; for (size_t cluster_idx = 1; cluster_idx < kClusters; ++cluster_idx) { // For vectors of `last + i` with `i < N`: for (size_t last = 0; last < kGroupSize; last += N) { const VI vlast = hn::Iota(di, static_cast(last)); - const VF prev_cost = hn::LoadU(df, &D(cluster_idx - 1, last)); + const VF prev_cost = hn::LoadU(df, &costs(cluster_idx - 1, last)); VF min = prev_cost; - VI arg = hn::LoadU(di, &T(cluster_idx - 1, last)); + VI arg = hn::LoadU(di, &argmin(cluster_idx - 1, last)); // For each `first` (j), which is the start of the rightmost of at least // two clusters, hence never zero. `first` also continues past `last` // because the last `vlast` lane is `last + N - 1`. @@ -300,7 +320,7 @@ class NuqClustering { const VI vfirst = hn::Set(di, static_cast(first)); const MF valid = hn::RebindMask(df, hn::Le(vfirst, vlast)); const VF c = - ClusterDynProg(df, D, cc, cluster_idx, first, last, valid); + ClusterDynProg(df, costs, cc, cluster_idx, first, last, valid); // Retain the min cost and the `first` that caused it. const MF less = hn::And(valid, hn::Lt(c, min)); @@ -309,21 +329,21 @@ class NuqClustering { } HWY_DASSERT(hn::AllTrue(df, hn::Le(min, prev_cost))); - hn::Store(min, df, &D(cluster_idx, last)); - hn::Store(arg, di, &T(cluster_idx, last)); + hn::Store(min, df, &costs(cluster_idx, last)); + hn::Store(arg, di, &argmin(cluster_idx, last)); } } - // Backtrack to find centers. Clusters are [T(k, last), last]. + // Backtrack to find centers. Clusters are [argmin(k, last), last]. size_t last = kGroupSize - 1; size_t unused_clusters = 0; for (size_t k = kClusters - 1; k < kClusters; --k) { - const size_t start = static_cast(T(k, last)); + const size_t start = static_cast(argmin(k, last)); // Center = mean, O(1) thanks to cumulative sums. - const float sum = cc.SumOfSorted(start, last); + const double sum = cc.SumOfSorted(start, last); const int size = static_cast(last) - static_cast(start) + 1; HWY_DASSERT(0 < size && size <= static_cast(kGroupSize)); - centers[k] = sum / static_cast(size); + centers[k] = static_cast(sum / size); // We know the range inside sorted_and_i[]; translate to original indices, // which are stored inside each of the sorted_and_i mantissas. @@ -347,15 +367,34 @@ class NuqClustering { } if (HWY_IS_DEBUG_BUILD) { - // Centers are in ascending order. + // If centers are not in ascending order, print them. for (size_t i = unused_clusters + 1; i < kClusters; ++i) { - HWY_DASSERT(centers[i] >= centers[i - 1]); + if (centers[i] < centers[i - 1]) { + for (size_t i = 0; i < kClusters; ++i) { + fprintf(stderr, "%2zu: %.8f\n", i, centers[i]); + } + for (size_t i = 0; i < kGroupSize; ++i) { + fprintf(stderr, "%3zu: %.8f\n", i, + FloatPayload::Clear(sorted_and_i[i])); + } + for (size_t i = 0; i < num; ++i) { + fprintf(stderr, "%3zu: %.8f\n", i, x[i]); + } + HWY_ABORT("Centers not in ascending order at %zu; unused %zu\n", i, + unused_clusters); + } } } + + MaybeCheckInitialized(centers, kClusters * sizeof(centers[0])); return unused_clusters; } }; // NuqClustering +// Half-vector of u8 from u16/bf16. +template +using D8HFromD16 = hn::Half>; + // Bit-packing 4-bit values is trivial if we have 2 or 4 independent vectors: // simply shift+OR them together into a full vector of 8 or 16-bit lanes. // However, the order then depends on the vector length, which is unacceptable @@ -371,15 +410,15 @@ class NuqClustering { // operations which benefit from special-casing for target and vector length. class NibbleCodec { public: - // Packs four u16 vectors' lanes to nibbles within one vector, in order, and - // stores that vector to `out`. - template > - static HWY_INLINE void OrderedPackU16(D16 d16, V16 in0, V16 in1, V16 in2, - V16 in3, uint8_t* HWY_RESTRICT out) { - const hn::Repartition d8; + // Returns a byte vector whose nibbles are the lanes of four u16 vectors, in + // the same order. + template , + class D8 = hn::Repartition, class V8 = hn::Vec> + static HWY_INLINE V8 OrderedPackU16(D16 d16, V16 in0, V16 in1, V16 in2, + V16 in3) { + const D8 d8; const hn::Repartition d32; const hn::Repartition d64; - using V8 = hn::Vec; // Pairwise compaction of a single vector so nibbles are packed in-order. // v16 lanes hold a 4-bit value; OR together adjacent pairs into the lower @@ -393,14 +432,13 @@ class NibbleCodec { const V16 u8_1 = combine_u16_pair_to_8(in1); const V16 u8_2 = combine_u16_pair_to_8(in2); const V16 u8_3 = combine_u16_pair_to_8(in3); - V8 packed; if constexpr (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) { // 8-bit ConcatEven is efficient. Let digits denote eight u8 lanes // of u8_1/0: ?d?3 ?c?2 / ?b?1 ?a?0. 8-bit ConcatEven = d3c2 b1a0, and // again with the second x2_1 gives 7654 3210. const V8 x2_0 = hn::ConcatEven(d8, BitCast(d8, u8_1), BitCast(d8, u8_0)); const V8 x2_1 = hn::ConcatEven(d8, BitCast(d8, u8_3), BitCast(d8, u8_2)); - packed = hn::ConcatEven(d8, x2_1, x2_0); + return hn::ConcatEven(d8, x2_1, x2_0); } else { // To avoid expensive 8-bit ConcatEven, compact pairs of u32 into the // lower 16 bits in each u64, with other bits undefined. @@ -416,70 +454,23 @@ class NibbleCodec { // u16 of every u64. This is the same as above but with 16-bit Concat. const V16 x2_0 = hn::ConcatEven(d16, u16_1, u16_0); const V16 x2_1 = hn::ConcatEven(d16, u16_3, u16_2); - packed = hn::BitCast(d8, hn::ConcatEven(d16, x2_1, x2_0)); + return hn::BitCast(d8, hn::ConcatEven(d16, x2_1, x2_0)); } - hn::StoreU(packed, d8, out); } - // Unpacks `Lanes(d16)` nibbles to u16 lanes. The first comes from the low - // nibble of packed[0], then its high nibble, then the next low nibble, etc. - template > - static HWY_INLINE V16 OrderedUnpackU16(D16 d16, const uint8_t* packed) { - const hn::Repartition d8; + // Unpacks nibbles from the `kHalf` (0 or 1) half of a half-vector of bytes. + // Thus we use a quarter of a vector of bytes and expand nibbles 4x into u16, + // which fills a whole vector. Its first lane comes from the low nibble of the + // first byte, the second from its high nibble, then the next low nibble, etc. + template , + class D8H = D8HFromD16, class V8H = hn::Vec> + static HWY_INLINE V16 OrderedUnpackU16(D16 d16, const V8H packed) { + const hn::Twice d8; // full vector using V8 = hn::Vec; - const hn::CappedTag d_load; - // We replicate each byte 4x, so that its two nibbles propagate to both - // u16 lanes that they will initialize. The only performance-portable op to - // replicate bytes is TableLookupBytes, which shuffles 128-bit blocks - // independently. Thus each block receives 4 packed bytes, replicates them - // 4x, shifts/masks, and casts to 8 u16 lanes. - // - // Loading 16 bytes via LoadDup128 only works on AVX3; for smaller vectors, - // it may trigger asan errors from overrunning the end. We thus special-case - // vector lengths, handling any non-constexpr, and constexpr <= 512 bit. - V8 rep4; - if constexpr (HWY_HAVE_SCALABLE) { - // Non constexpr length: 4 per whole block equals size/4. - const size_t num_bytes = HWY_MAX(1, hn::Lanes(d8) / 4); - const V8 bytes = hn::LoadN(d8, packed, num_bytes); - // Replicate bytes 4x: lowest 4 = 0, next 4 = 1 etc. - const V8 idx = hn::ShiftRight<2>(hn::Iota(d8, 0)); - rep4 = hn::TableLookupLanes(bytes, hn::IndicesFromVec(d8, idx)); - } else if (hn::MaxLanes(d16) <= 8) { // <= 128-bit - const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed)); - alignas(16) static constexpr uint8_t kRep4[16] = { - HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3)}; - rep4 = hn::TableLookupBytes(bytes, hn::Load(d8, kRep4)); - } else if (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) { - // Plain load, can do 256..512-bit permute across blocks. - const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed)); - alignas(64) static constexpr uint8_t kRep4[64] = { - HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3), - HWY_REP4(4), HWY_REP4(5), HWY_REP4(6), HWY_REP4(7), - HWY_REP4(8), HWY_REP4(9), HWY_REP4(10), HWY_REP4(11), - HWY_REP4(12), HWY_REP4(13), HWY_REP4(14), HWY_REP4(15)}; - rep4 = hn::TableLookupLanes(bytes, hn::SetTableIndices(d8, kRep4)); - } else if (hn::MaxLanes(d16) == 16) { // 256-bit - const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed)); - // First copy to upper block for TableLookupBytes. This is slightly - // faster than 64-bit BroadcastLane. - const V8 bcast = hn::ConcatLowerLower(d8, bytes, bytes); - alignas(32) static constexpr uint8_t kRep4[32] = { - HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3), - HWY_REP4(4), HWY_REP4(5), HWY_REP4(6), HWY_REP4(7)}; - rep4 = hn::TableLookupBytes(bcast, hn::Load(d8, kRep4)); - } else if (hn::MaxLanes(d16) == 32) { // 512-bit - const V8 bytes = hn::LoadDup128(d8, packed); - alignas(64) static constexpr uint8_t kRep4[64] = { - HWY_REP4(0), HWY_REP4(1), HWY_REP4(2), HWY_REP4(3), - HWY_REP4(4), HWY_REP4(5), HWY_REP4(6), HWY_REP4(7), - HWY_REP4(8), HWY_REP4(9), HWY_REP4(10), HWY_REP4(11), - HWY_REP4(12), HWY_REP4(13), HWY_REP4(14), HWY_REP4(15)}; - rep4 = hn::TableLookupBytes(bytes, hn::Load(d8, kRep4)); - } else { - HWY_DASSERT(false); - } + // Replicate each byte 4x, so that its two nibbles propagate to both u16 + // lanes that they will initialize. + const V8 rep4 = Replicate4x(d8, hn::ResizeBitCast(d8, packed)); const V16 mask4 = hn::Set(d16, 0xF); const V16 u16 = BitCast(d16, rep4); @@ -490,10 +481,60 @@ class NibbleCodec { // zz z3 zz z2 | zz z1 zz z0 And (unpacked result) return hn::And(mask4, hn::OddEven(hn::ShiftRight<4>(u16), u16)); } + + private: + // Returns `bytes[0 + kHalf * N/2]` in lanes 0..3, `bytes[1 + kHalf * N/2]` in + // lanes 4..7, etc. We fuse `kHalf` into the tables, which avoids the caller + // having to pass in `UpperHalf(bytes)`. + template > + static HWY_INLINE V8 Replicate4x(D8 d8, V8 bytes) { + static_assert(kHalf <= 1); + const size_t N = hn::Lanes(d8); + constexpr size_t kMaxN = hn::MaxLanes(d8); + // For kHalf=1 and 512-bit vectors, kAdd would be 16, which is out of + // bounds for TableLookupBytes. We instead BroadcastBlock<1> there. + constexpr uint8_t kAdd = kMaxN < 64 ? kHalf * kMaxN / 4 : 0; + // The only performance-portable op to replicate bytes is TableLookupBytes, + // but this only works if vectors are 128-bit or we first BroadcastBlock, + // which only works for <= 512-bit vectors. For scalable vectors, we + // instead synthesize this table via Iota+ShiftRight. + alignas(64) static constexpr uint8_t kRep4[64] = { + HWY_REP4(kAdd + 0), HWY_REP4(kAdd + 1), HWY_REP4(kAdd + 2), + HWY_REP4(kAdd + 3), HWY_REP4(kAdd + 4), HWY_REP4(kAdd + 5), + HWY_REP4(kAdd + 6), HWY_REP4(kAdd + 7), HWY_REP4(kAdd + 8), + HWY_REP4(kAdd + 9), HWY_REP4(kAdd + 10), HWY_REP4(kAdd + 11), + HWY_REP4(kAdd + 12), HWY_REP4(kAdd + 13), HWY_REP4(kAdd + 14), + HWY_REP4(kAdd + 15)}; + + if constexpr (HWY_HAVE_SCALABLE) { + // Replicate bytes 4x: lowest 4 = 0, next 4 = 1 etc. This works for up to + // 1024-bit vectors: Iota is [128, 256), and [32, 64) after shifting. + // For larger vectors, this would overflow and we should instead add kAdd. + HWY_DASSERT(N <= 128); + const V8 iota = hn::Iota(d8, static_cast(kHalf * N)); + const V8 idx = hn::ShiftRight<2>(iota); + return hn::TableLookupLanes(bytes, hn::IndicesFromVec(d8, idx)); + } else if constexpr (kMaxN <= 16) { // <= 128-bit + // No BroadcastBlock, we anyway only have one block. + return hn::TableLookupBytes(bytes, hn::Load(d8, kRep4)); + } else if constexpr (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) { + // No BroadcastBlock, can directly permute across blocks. + return hn::TableLookupLanes(bytes, hn::SetTableIndices(d8, kRep4)); + } else { // 256..512-bit, no efficient TableLookupLanes + static_assert(kMaxN <= 64); // Else BroadcastBlock does not work. + // See kAdd comment above. + constexpr size_t kBlock = (kMaxN == 64 && kHalf == 1) ? 1 : 0; + bytes = hn::BroadcastBlock(bytes); + return hn::TableLookupBytes(bytes, hn::Load(d8, kRep4)); + } + } }; // Encode/decode functions. class NuqCodec { + static constexpr size_t kClusters = NuqStream::kClusters; + static constexpr size_t kGroupSize = NuqStream::kGroupSize; + // 256-bit vectors can hold 16 bf16, otherwise we require 2x128-bit. template static constexpr size_t NumTables(DU du) { @@ -508,308 +549,465 @@ class NuqCodec { hn::Vec* HWY_RESTRICT tbl1) { // Cap to the table size (kClusters) for decoding SFP - sufficient, and may // be faster than a large vector. - const hn::CappedTag d_table; + const hn::CappedTag d_table; // We ResizeCast tables to DU: if DU is bigger, table lookups will only // access lanes < kClusters. If DU is smaller (128-bit), we have 2 tables. HWY_DASSERT(hn::Lanes(du) >= hn::Lanes(d_table) || NumTables(du) == 2); - HWY_ALIGN hwy::bfloat16_t table[kClusters]; - SfpCodec::Dec(d_table, reinterpret_cast(centers), - kClusters, table); + HWY_ALIGN BF16 table[kClusters]; + SfpCodec::DecompressAndZeroPad( + d_table, + MakeSpan(reinterpret_cast(centers), kClusters), 0, + table, kClusters); // If we assume >= 128-bit vectors, we can use [Two]TableLookupLanes // instead of TableLookupBytes, which requires extra interleaving of lo/hi. HWY_DASSERT(hn::Lanes(du) >= 8); - HWY_IF_CONSTEXPR(NumTables(du) == 2) { + if constexpr (NumTables(du) == 2) { // Reduce cap for second half to avoid loading past the end of the table. - const hn::CappedTag d_table2; + const hn::CappedTag d_table2; *tbl1 = hn::ResizeBitCast(du, hn::LoadU(d_table2, table + kClusters / 2)); } return hn::ResizeBitCast(du, hn::Load(d_table, table)); } - // Unpacks per-weight indices and sets c0/c1 to the corresponding centers. - template - static HWY_INLINE void TableLookups(DU du, hn::Vec tbl0, hn::Vec tbl1, - const uint8_t* packed, hn::Vec& c0, - hn::Vec& c1) { - using V16 = hn::Vec; - const size_t N16 = hn::Lanes(du); - - const V16 idx0 = NibbleCodec::OrderedUnpackU16(du, packed); - const V16 idx1 = NibbleCodec::OrderedUnpackU16(du, packed + N16 / 2); + // Unpacks a half-vector of nibbles into two vectors of u16 indices and sets + // c0/c1 to the corresponding bf16 (stored in u16) centers from tbl0/tbl1. + template , class D8H = D8HFromD16, + class V8H = hn::Vec> + static HWY_INLINE void TableLookups(DU du, VU tbl0, VU tbl1, const V8H packed, + VU& c0, VU& c1) { + const VU idx0 = NibbleCodec::OrderedUnpackU16<0>(du, packed); + const VU idx1 = NibbleCodec::OrderedUnpackU16<1>(du, packed); const auto indices0 = hn::IndicesFromVec(du, idx0); const auto indices1 = hn::IndicesFromVec(du, idx1); - HWY_IF_CONSTEXPR(NumTables(du) == 1) { + if constexpr (NumTables(du) == 1) { (void)tbl1; c0 = hn::TableLookupLanes(tbl0, indices0); c1 = hn::TableLookupLanes(tbl0, indices1); } - HWY_IF_CONSTEXPR(NumTables(du) == 2) { // `else` is poorly formatted. + if constexpr (NumTables(du) == 2) { // `else` is poorly formatted. c0 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices0); c1 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices1); } } + // As above, but returns a single 16-bit output vector for f32 Dec2, thus + // packed is only a quarter-vector. + template , + class D8Q = hn::Half>, class V8Q = hn::Vec> + static HWY_INLINE VU TableLookups(DU du, VU tbl0, VU tbl1, const V8Q packed) { + const D8HFromD16 d8h; + // OrderedUnpackU16 expects a half-vector, but will only use the lower half + // of it. + const hn::Vec packed_h = hn::ZeroExtendVector(d8h, packed); + const VU idx0 = NibbleCodec::OrderedUnpackU16<0>(du, packed_h); + + const auto indices0 = hn::IndicesFromVec(du, idx0); + + if constexpr (NumTables(du) == 1) { + (void)tbl1; + return hn::TableLookupLanes(tbl0, indices0); + } + if constexpr (NumTables(du) == 2) { // `else` is poorly formatted. + return hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices0); + } + } + public: - // Encodes `num` floats starting from `in`. `out` points to compressed - // storage for `out_capacity` values and `out_ofs` indicates the destination - // offset within it, in units of float values, for parallel encoding by - // multiple threads. `num`, `out_capacity`, and `out_ofs` must all be - // multiples of `kGroupSize`. Returns the total number of unused clusters, - // which is expected to be zero. + // Encodes `num` floats from `raw`. `packed` points to compressed storage and + // `packed_ofs` indicates the destination offset within it, in units of float + // values, for parallel encoding by multiple threads. Returns the total + // number of unused clusters, which is typically zero. template - static HWY_INLINE size_t Enc(DF df, const float* const in, const size_t num, - ClusterBuf& buf, const size_t out_capacity, - NuqStream* const out, const size_t out_ofs) { + static HWY_INLINE size_t Enc(DF df, const float* HWY_RESTRICT raw, + const size_t num, NuqStream::ClusterBuf& buf, + const PackedSpan& packed, + size_t packed_ofs) { const hn::Repartition d16; + const hn::Repartition d8; using V16 = hn::Vec; - + using V8 = hn::Vec; const size_t N16 = hn::Lanes(d16); - HWY_ASSERT(kGroupSize >= 4 * N16); - HWY_ASSERT(out_ofs + num <= out_capacity); - buf.Resize(num); - HWY_ASSERT(num % kGroupSize == 0); - HWY_ASSERT(out_capacity % kGroupSize == 0); - HWY_ASSERT(out_ofs % kGroupSize == 0); - const size_t num_groups = num / kGroupSize; - const size_t ofs_groups = out_ofs / kGroupSize; + HWY_ASSERT(packed_ofs % kGroupSize == 0); + const size_t ofs_groups = packed_ofs / kGroupSize; + const size_t num_groups = hwy::DivCeil(num, kGroupSize); + buf.Resize(num_groups); size_t unused_clusters = 0; for (size_t g = 0; g < num_groups; ++g) { - const float* HWY_RESTRICT g_in = in + g * kGroupSize; + const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize); + const float* HWY_RESTRICT g_in = raw + g * kGroupSize; float* HWY_RESTRICT g_centers = buf.centers.get() + g * kClusters; uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize; - unused_clusters += NuqClustering::ClusterExactL2(df, g_in, kGroupSize, - buf, g_centers, g_idx); + unused_clusters += + NuqClustering::ClusterExactL2(df, g_in, g_num, buf, g_centers, g_idx); } - uint8_t* centers = &out->byte + ofs_groups * kClusters; + uint8_t* centers = &packed.ptr->byte + ofs_groups * kClusters; SfpCodec::Enc(df, buf.centers.get(), num_groups * kClusters, reinterpret_cast(centers)); - uint8_t* packed_start = &out->byte + NuqStream::PackedStart(out_capacity) + + uint8_t* packed_start = &packed.ptr->byte + + NuqStream::PackedStart(packed.num) + ofs_groups * kGroupSize / 2; + // All but the last group have no remainders. + HWY_DASSERT(kGroupSize % (4 * N16) == 0); HWY_UNROLL(1) - for (size_t g = 0; g < num_groups; ++g) { + for (size_t g = 0; g < num_groups - 1; ++g) { const uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize; uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; HWY_UNROLL(1) for (size_t i = 0; i < kGroupSize; i += 4 * N16) { - const V16 idx0 = hn::LoadU(d16, g_idx + i + N16 * 0); - const V16 idx1 = hn::LoadU(d16, g_idx + i + N16 * 1); - const V16 idx2 = hn::LoadU(d16, g_idx + i + N16 * 2); - const V16 idx3 = hn::LoadU(d16, g_idx + i + N16 * 3); - NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3, - g_packed + i / 2); + const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16); + const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16); + const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16); + const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16); + const V8 nibbles = + NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3); + hn::StoreU(nibbles, d8, g_packed + i / 2); + } + } + + // Last group may have remainders. + { + HWY_DASSERT(num_groups != 0); + const size_t g = num_groups - 1; + const size_t g_num = num - g * kGroupSize; + HWY_DASSERT(g_num <= kGroupSize); + const uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize; + uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; + + size_t i = 0; + if (g_num >= 4 * N16) { + HWY_UNROLL(1) + for (; i <= g_num - 4 * N16; i += 4 * N16) { + const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16); + const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16); + const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16); + const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16); + const V8 nibbles = + NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3); + hn::StoreU(nibbles, d8, g_packed + i / 2); + } + } + + const size_t remaining = g_num - i; + HWY_DASSERT(remaining < 4 * N16); + if (HWY_UNLIKELY(remaining != 0)) { + const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16); + const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16); + const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16); + const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16); + const V8 nibbles = + NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3); + // i is even, but remaining might not be. + hn::StoreN(nibbles, d8, g_packed + i / 2, hwy::DivCeil(remaining, 2)); } } return unused_clusters; } - // Decodes `num` values from the stream `in`, starting at the offset `in_ofs` - // (in units of values), to bf16 in `out`. `in_capacity`, `in_ofs` and `num` - // must all be multiples of `kGroupSize`. + // Decompresses to two bf16 vectors. `packed_ofs` must be a multiple of two + // vectors so that we only have to load one group's table. template - static HWY_INLINE void Dec(DBF dbf, const size_t in_capacity, - const NuqStream* const in, const size_t in_ofs, - hwy::bfloat16_t* const out, const size_t num) { + static HWY_INLINE void Dec2(DBF dbf, + const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, + hn::Vec& raw1) { const hn::RebindToUnsigned d16; + const D8HFromD16 d8h; using V16 = hn::Vec; + using V8H = hn::Vec; - const size_t N16 = hn::Lanes(d16); - HWY_DASSERT(kGroupSize >= 4 * N16); + const size_t within_group = packed_ofs % kGroupSize; + HWY_DASSERT(within_group % (2 * hn::Lanes(d16)) == 0); + const size_t ofs_in_groups = packed_ofs / kGroupSize; + const uint8_t* table = &packed.ptr->byte + ofs_in_groups * kClusters; + const uint8_t* indices = + &packed.ptr->byte + NuqStream::PackedStart(packed.num) + + hwy::DivCeil(ofs_in_groups * kGroupSize + within_group, 2); - HWY_DASSERT(in_ofs + num <= in_capacity); - HWY_DASSERT(in_capacity % kGroupSize == 0); - HWY_DASSERT(in_ofs % kGroupSize == 0); - HWY_DASSERT(num % kGroupSize == 0); - const size_t num_groups = num / kGroupSize; - const size_t ofs_groups = in_ofs / kGroupSize; - const uint8_t* tables = &in->byte + ofs_groups * kClusters; - const uint8_t* packed_start = &in->byte + - NuqStream::PackedStart(in_capacity) + - ofs_groups * kGroupSize / 2; + V16 tbl1 = Zero(d16); + const V16 tbl0 = LoadTable(d16, table, &tbl1); - HWY_UNROLL(1) - for (size_t g = 0; g < num_groups; ++g) { - const uint8_t* g_centers = tables + g * kClusters; - const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; - hwy::bfloat16_t* HWY_RESTRICT g_out = out + g * kGroupSize; + const V8H nibbles = hn::LoadU(d8h, indices); - V16 tbl1 = Zero(d16); - const V16 tbl0 = LoadTable(d16, g_centers, &tbl1); - - HWY_UNROLL(1) - for (size_t i = 0; i < kGroupSize; i += 2 * N16) { - V16 c0, c1; - TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1); - hn::StoreU(BitCast(dbf, c0), dbf, g_out + i + N16 * 0); - hn::StoreU(BitCast(dbf, c1), dbf, g_out + i + N16 * 1); - } - } + V16 c0, c1; + TableLookups(d16, tbl0, tbl1, nibbles, c0, c1); + raw0 = BitCast(dbf, c0); + raw1 = BitCast(dbf, c1); } - // Decodes `num` values from the stream `in`, starting at the offset - // `in_ofs` (in units of values), to f32 in `out`. `in_capacity`, - // `in_ofs` and `num` must all be multiples of `kGroupSize`. + // Decompresses to two f32 vectors. `packed_ofs` must be a multiple of two + // vectors so that we only have to load one group's table. template - static HWY_INLINE void Dec(DF df, const size_t in_capacity, - const NuqStream* const in, const size_t in_ofs, - float* const out, const size_t num) { - const hn::Repartition dbf; + static HWY_INLINE void Dec2(DF df, const PackedSpan& packed, + const size_t packed_ofs, hn::Vec& raw0, + hn::Vec& raw1) { + const hn::Repartition dbf; const hn::RebindToUnsigned d16; + const hn::Half> d8q; + using V8Q = hn::Vec; using V16 = hn::Vec; - using VF = hn::Vec; - - const size_t NF = hn::Lanes(df); - HWY_DASSERT(kGroupSize >= 4 * NF); - - HWY_DASSERT(in_ofs + num <= in_capacity); - HWY_DASSERT(in_capacity % kGroupSize == 0); - HWY_DASSERT(in_ofs % kGroupSize == 0); - HWY_DASSERT(num % kGroupSize == 0); - const size_t ofs_groups = in_ofs / kGroupSize; - const size_t num_groups = num / kGroupSize; - const uint8_t* tables = &in->byte + ofs_groups * kClusters; - const uint8_t* packed_start = &in->byte + - NuqStream::PackedStart(in_capacity) + - ofs_groups * kGroupSize / 2; - - HWY_UNROLL(1) - for (size_t g = 0; g < num_groups; ++g) { - const uint8_t* g_centers = tables + g * kClusters; - const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; - float* HWY_RESTRICT g_out = out + g * kGroupSize; - - V16 tbl1 = Zero(d16); - const V16 tbl0 = LoadTable(d16, g_centers, &tbl1); - - HWY_UNROLL(1) - for (size_t i = 0; i < kGroupSize; i += 4 * NF) { - V16 c0, c1; - TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1); - const VF f0 = hn::PromoteLowerTo(df, BitCast(dbf, c0)); - const VF f1 = hn::PromoteUpperTo(df, BitCast(dbf, c0)); - const VF f2 = hn::PromoteLowerTo(df, BitCast(dbf, c1)); - const VF f3 = hn::PromoteUpperTo(df, BitCast(dbf, c1)); - hn::StoreU(f0, df, g_out + i + NF * 0); - hn::StoreU(f1, df, g_out + i + NF * 1); - hn::StoreU(f2, df, g_out + i + NF * 2); - hn::StoreU(f3, df, g_out + i + NF * 3); - } - } - } - - // Accumulates into `sum0..3` dot products of decoded values with `num` bf16 - // from `vec_aligned`. DF is f32 because sum0..3 are also f32. `in_capacity`, - // `in_ofs` and `num` must all be multiples of `kGroupSize`. - template - static HWY_INLINE void Dot(DF df, const size_t in_capacity, - const NuqStream* const in, const size_t in_ofs, - const hwy::bfloat16_t* const vec_aligned, - const size_t num, hn::Vec& sum0, - hn::Vec& sum1, hn::Vec& sum2, - hn::Vec& sum3) { - const hn::Repartition dbf; - const hn::RebindToUnsigned d16; - using VBF = hn::Vec; - using V16 = hn::Vec; - const size_t N16 = hn::Lanes(d16); - HWY_DASSERT(kGroupSize >= 4 * N16); - - HWY_DASSERT(in_ofs + num <= in_capacity); - HWY_DASSERT(in_capacity % kGroupSize == 0); - HWY_DASSERT(in_ofs % kGroupSize == 0); - HWY_DASSERT(num % kGroupSize == 0); - const size_t ofs_groups = in_ofs / kGroupSize; - const size_t num_groups = num / kGroupSize; - const uint8_t* tables = &in->byte + ofs_groups * kClusters; - const uint8_t* packed_start = &in->byte + - NuqStream::PackedStart(in_capacity) + - ofs_groups * kGroupSize / 2; - - HWY_UNROLL(1) - for (size_t g = 0; g < num_groups; ++g) { - const uint8_t* g_centers = tables + g * kClusters; - const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; - const hwy::bfloat16_t* HWY_RESTRICT g_in = vec_aligned + g * kGroupSize; - - V16 tbl1 = Zero(d16); - const V16 tbl0 = LoadTable(d16, g_centers, &tbl1); - - HWY_UNROLL(1) - for (size_t i = 0; i < kGroupSize; i += 2 * N16) { - V16 c0, c1; - TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1); - const VBF in0 = hn::Load(dbf, g_in + i + N16 * 0); - const VBF in1 = hn::Load(dbf, g_in + i + N16 * 1); - sum0 = hn::ReorderWidenMulAccumulate(df, in0, BitCast(dbf, c0), sum0, - sum1); - sum2 = hn::ReorderWidenMulAccumulate(df, in1, BitCast(dbf, c1), sum2, - sum3); - } - } - } - - // Accumulates into `sum0..3` dot products of decoded values with `num` f32 - // from `vec_aligned`. `in_capacity`, `in_ofs` and `num` must all be - // multiples of `kGroupSize`. - template - static HWY_INLINE void Dot(DF df, const size_t in_capacity, - const NuqStream* const in, const size_t in_ofs, - const float* const vec_aligned, const size_t num, - hn::Vec& sum0, hn::Vec& sum1, - hn::Vec& sum2, hn::Vec& sum3) { - const hn::Repartition dbf; - const hn::RebindToUnsigned d16; using VF = hn::Vec; - using V16 = hn::Vec; - const size_t NF = hn::Lanes(df); - HWY_DASSERT(kGroupSize >= 4 * NF); - HWY_DASSERT(in_ofs + num <= in_capacity); - HWY_DASSERT(in_capacity % kGroupSize == 0); - HWY_DASSERT(in_ofs % kGroupSize == 0); - HWY_DASSERT(num % kGroupSize == 0); - const size_t ofs_groups = in_ofs / kGroupSize; - const size_t num_groups = num / kGroupSize; - const uint8_t* tables = &in->byte + ofs_groups * kClusters; - const uint8_t* packed_start = &in->byte + - NuqStream::PackedStart(in_capacity) + - ofs_groups * kGroupSize / 2; + const size_t within_group = packed_ofs % kGroupSize; + HWY_DASSERT(within_group % (2 * hn::Lanes(df)) == 0); + const size_t ofs_groups = packed_ofs / kGroupSize; + const uint8_t* table = &packed.ptr->byte + ofs_groups * kClusters; + const uint8_t* indices = + &packed.ptr->byte + NuqStream::PackedStart(packed.num) + + hwy::DivCeil(ofs_groups * kGroupSize + within_group, 2); + + V16 tbl1 = Zero(d16); + const V16 tbl0 = LoadTable(d16, table, &tbl1); + + // The single-vector TableLookups overload only calls OrderedUnpackU16<0>, + // which expects a quarter vector of bytes. + const V8Q nibbles = hn::LoadU(d8q, indices); + + const V16 c0 = TableLookups(d16, tbl0, tbl1, nibbles); + raw0 = hn::PromoteLowerTo(df, BitCast(dbf, c0)); + raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0)); + } + + // Decompresses from `packed`, starting at (any) `packed_ofs`, to (any) `num` + // elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as required to + // round `num` up to one vector, if it is not already. + template > + static HWY_INLINE void DecompressAndZeroPad( + D d, const PackedSpan& packed, size_t packed_ofs, + Raw* HWY_RESTRICT raw, size_t num) { + // If unaligned, load elements from the first group and update the args, + // from which we compute new tables/indices below. + if (size_t within_group = packed_ofs % kGroupSize; within_group != 0) { + const size_t ofs_in_groups = packed_ofs / kGroupSize; + const uint8_t* tables = &packed.ptr->byte + ofs_in_groups * kClusters; + const uint8_t* indices = + &packed.ptr->byte + NuqStream::PackedStart(packed.num) + + hwy::DivCeil(ofs_in_groups * kGroupSize + within_group, 2); + const size_t remaining = HWY_MIN(num, kGroupSize - within_group); + DecPartialGroup(d, tables, indices, raw, remaining); + packed_ofs += remaining; + raw += remaining; + num -= remaining; + if (num == 0) return; + } + + HWY_DASSERT(packed_ofs % kGroupSize == 0); + const size_t ofs_in_groups = packed_ofs / kGroupSize; + const uint8_t* tables = &packed.ptr->byte + ofs_in_groups * kClusters; + const uint8_t* indices = &packed.ptr->byte + + NuqStream::PackedStart(packed.num) + + hwy::DivCeil(ofs_in_groups * kGroupSize, 2); + + const size_t num_groups = hwy::DivCeil(num, kGroupSize); + HWY_UNROLL(1) + for (size_t g = 0; g < num_groups - 1; ++g) { + DecWholeGroup(d, tables + g * kClusters, indices + g * kGroupSize / 2, + raw + g * kGroupSize); + } + + const size_t g = num_groups - 1; + DecPartialGroup(d, tables + g * kClusters, indices + g * kGroupSize / 2, + raw + g * kGroupSize, num - g * kGroupSize); + } + + private: + template + static HWY_INLINE void DecWholeGroup(DBF dbf, + const uint8_t* HWY_RESTRICT table, + const uint8_t* HWY_RESTRICT indices, + BF16* HWY_RESTRICT raw_bf) { + const hn::RebindToUnsigned d16; + const D8HFromD16 d8h; + using V16 = hn::Vec; + using V8H = hn::Vec; + const size_t N16 = hn::Lanes(d16); + + V16 tbl1 = Zero(d16); + const V16 tbl0 = LoadTable(d16, table, &tbl1); HWY_UNROLL(1) - for (size_t g = 0; g < num_groups; ++g) { - const uint8_t* g_centers = tables + g * kClusters; - const uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; - const float* HWY_RESTRICT g_in = vec_aligned + g * kGroupSize; + for (size_t i = 0; i < kGroupSize; i += 2 * N16) { + const V8H nibbles = hn::LoadU(d8h, indices + i / 2); + V16 c0, c1; + TableLookups(d16, tbl0, tbl1, nibbles, c0, c1); + hn::StoreU(BitCast(dbf, c0), dbf, raw_bf + i + 0 * N16); + hn::StoreU(BitCast(dbf, c1), dbf, raw_bf + i + 1 * N16); + } + } - V16 tbl1 = Zero(d16); - const V16 tbl0 = LoadTable(d16, g_centers, &tbl1); + // Called for first and last group. + template + static HWY_INLINE void DecPartialGroup(DBF dbf, + const uint8_t* HWY_RESTRICT table, + const uint8_t* HWY_RESTRICT indices, + BF16* HWY_RESTRICT raw_bf, + size_t num) { + HWY_DASSERT(num <= kGroupSize); + const hn::RebindToUnsigned d16; + const D8HFromD16 d8h; + using V16 = hn::Vec; + using V8H = hn::Vec; + const size_t N16 = hn::Lanes(d16); + + V16 tbl1 = Zero(d16); + const V16 tbl0 = LoadTable(d16, table, &tbl1); + + size_t i = 0; + + if (num >= 2 * N16) { HWY_UNROLL(1) - for (size_t i = 0; i < kGroupSize; i += 4 * NF) { + for (; i <= num - 2 * N16; i += 2 * N16) { + const V8H nibbles = hn::LoadU(d8h, indices + i / 2); V16 c0, c1; - TableLookups(d16, tbl0, tbl1, g_packed + i / 2, c0, c1); - const VF in0 = hn::LoadU(df, g_in + i + NF * 0); - const VF in1 = hn::LoadU(df, g_in + i + NF * 1); - const VF in2 = hn::LoadU(df, g_in + i + NF * 2); - const VF in3 = hn::LoadU(df, g_in + i + NF * 3); + TableLookups(d16, tbl0, tbl1, nibbles, c0, c1); + hn::StoreU(BitCast(dbf, c0), dbf, raw_bf + i + 0 * N16); + hn::StoreU(BitCast(dbf, c1), dbf, raw_bf + i + 1 * N16); + } + } + + const size_t remaining = num - i; + HWY_DASSERT(remaining < 2 * N16); + if (HWY_UNLIKELY(remaining != 0)) { + // i is even, but remaining might not be. + const V8H nibbles = + hn::LoadN(d8h, indices + i / 2, hwy::DivCeil(remaining, 2)); + + V16 c0, c1; + TableLookups(d16, tbl0, tbl1, nibbles, c0, c1); + // Out of bounds `nibbles` are 0, but this does not yet guarantee + // c0/c1 are, because centers[0] might not be 0. + c0 = hn::IfThenElseZero(hn::FirstN(d16, remaining), c0); + hn::StoreU(BitCast(dbf, c0), dbf, raw_bf + i); + // Callers only pad to one vector, so check before storing the second. + if (remaining > N16) { + c1 = hn::IfThenElseZero(hn::FirstN(d16, remaining - N16), c1); + hn::StoreU(BitCast(dbf, c1), dbf, raw_bf + i + N16); + } + } + } + + template + static HWY_INLINE void DecWholeGroup(DF df, const uint8_t* HWY_RESTRICT table, + const uint8_t* HWY_RESTRICT indices, + float* HWY_RESTRICT raw_f) { + const hn::Repartition dbf; + const hn::RebindToUnsigned d16; + const D8HFromD16 d8h; + using V16 = hn::Vec; + using V8H = hn::Vec; + using VF = hn::Vec; + const size_t NF = hn::Lanes(df); + + V16 tbl1 = Zero(d16); + const V16 tbl0 = LoadTable(d16, table, &tbl1); + + HWY_UNROLL(1) + for (size_t i = 0; i < kGroupSize; i += 4 * NF) { + const V8H nibbles = hn::LoadU(d8h, indices + i / 2); + V16 c0, c1; + TableLookups(d16, tbl0, tbl1, nibbles, c0, c1); + const VF f0 = hn::PromoteLowerTo(df, BitCast(dbf, c0)); + const VF f1 = hn::PromoteUpperTo(df, BitCast(dbf, c0)); + const VF f2 = hn::PromoteLowerTo(df, BitCast(dbf, c1)); + const VF f3 = hn::PromoteUpperTo(df, BitCast(dbf, c1)); + hn::StoreU(f0, df, raw_f + i + 0 * NF); + hn::StoreU(f1, df, raw_f + i + 1 * NF); + hn::StoreU(f2, df, raw_f + i + 2 * NF); + hn::StoreU(f3, df, raw_f + i + 3 * NF); + } + } + + // Called for first and last group. + template + static HWY_INLINE void DecPartialGroup(DF df, + const uint8_t* HWY_RESTRICT table, + const uint8_t* HWY_RESTRICT indices, + float* HWY_RESTRICT raw_f, + const size_t num) { + HWY_DASSERT(num <= kGroupSize); + + const hn::Repartition dbf; + const hn::RebindToUnsigned d16; + const D8HFromD16 d8h; + using V16 = hn::Vec; + using V8H = hn::Vec; + using VF = hn::Vec; + const size_t NF = hn::Lanes(df); + + V16 tbl1 = Zero(d16); + const V16 tbl0 = LoadTable(d16, table, &tbl1); + + size_t i = 0; + + if (num >= 4 * NF) { + HWY_UNROLL(1) + for (; i <= num - 4 * NF; i += 4 * NF) { + const V8H nibbles = hn::LoadU(d8h, indices + i / 2); + V16 c0, c1; + TableLookups(d16, tbl0, tbl1, nibbles, c0, c1); const VF f0 = hn::PromoteLowerTo(df, BitCast(dbf, c0)); const VF f1 = hn::PromoteUpperTo(df, BitCast(dbf, c0)); const VF f2 = hn::PromoteLowerTo(df, BitCast(dbf, c1)); const VF f3 = hn::PromoteUpperTo(df, BitCast(dbf, c1)); - sum0 = hn::MulAdd(in0, f0, sum0); - sum1 = hn::MulAdd(in1, f1, sum1); - sum2 = hn::MulAdd(in2, f2, sum2); - sum3 = hn::MulAdd(in3, f3, sum3); + hn::StoreU(f0, df, raw_f + i + 0 * NF); + hn::StoreU(f1, df, raw_f + i + 1 * NF); + hn::StoreU(f2, df, raw_f + i + 2 * NF); + hn::StoreU(f3, df, raw_f + i + 3 * NF); + } + } + + const size_t remaining = num - i; + HWY_DASSERT(remaining < 4 * NF); + if (HWY_UNLIKELY(remaining != 0)) { + // i is even, but remaining might not be. + const V8H nibbles = + hn::LoadN(d8h, indices + i / 2, hwy::DivCeil(remaining, 2)); + + V16 c0, c1; + TableLookups(d16, tbl0, tbl1, nibbles, c0, c1); + const VF f0 = hn::PromoteLowerTo(df, BitCast(dbf, c0)); + const VF f1 = hn::PromoteUpperTo(df, BitCast(dbf, c0)); + const VF f2 = hn::PromoteLowerTo(df, BitCast(dbf, c1)); + const VF f3 = hn::PromoteUpperTo(df, BitCast(dbf, c1)); + // `raw_f` is only guaranteed to padded to NF, hence we cannot store all + // four vectors. We could conditionally store vectors either to `raw_f` + // or a buffer. However, we still have to mask because only `nibbles` + // are guaranteed to be 0, not c0/c1. Copying also involves branches, + // so we fully unroll the copy loop to avoid a buffer. We could also + // change the contract to pad to four vectors, but it would anyway be + // better to decompress to bf16. + if (remaining <= 1 * NF) { + const hn::Mask mask = hn::FirstN(df, remaining); + hn::StoreU(hn::IfThenElseZero(mask, f0), df, raw_f + i + 0 * NF); + return; + } + hn::StoreU(f0, df, raw_f + i + 0 * NF); + if (remaining <= 2 * NF) { + const hn::Mask mask = hn::FirstN(df, remaining - NF); + hn::StoreU(hn::IfThenElseZero(mask, f1), df, raw_f + i + 1 * NF); + return; + } + hn::StoreU(f1, df, raw_f + i + 1 * NF); + if (remaining <= 3 * NF) { + const hn::Mask mask = hn::FirstN(df, remaining - 2 * NF); + hn::StoreU(hn::IfThenElseZero(mask, f2), df, raw_f + i + 2 * NF); + return; + } + hn::StoreU(f2, df, raw_f + i + 2 * NF); + { + const hn::Mask mask = hn::FirstN(df, remaining - 3 * NF); + hn::StoreU(hn::IfThenElseZero(mask, f3), df, raw_f + i + 3 * NF); } } } diff --git a/compression/nuq.h b/compression/nuq.h deleted file mode 100644 index d7ae814..0000000 --- a/compression/nuq.h +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright 2023 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_H_ -#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_H_ - -// Non-uniform quantization: a compressed representation of f32 inputs that -// supports seeking at a granularity of kGroupSize, decoding to bf16/f32, and a -// fused decode/dot product with bf16/f32 vectors. - -#include -#include - -#include "hwy/aligned_allocator.h" -#include "hwy/base.h" // HWY_INLINE - -namespace gcpp { - -// 4-bit indices are a sweet spot in terms of quality per size. -static constexpr size_t kClusters = 16; - -// Number of weights that share a table. Larger = slower encode, higher error, -// smaller size (table amortized over more weights). This is the minimum -// granularity for seeking/decoding in the stream, and must be at least four -// times the number of bf16 elements per vector. -static constexpr size_t kGroupSize = 256; - -// Points to the *start* of a NUQ stream. Aligning the allocation (see -// aligned_allocator.h) may be speed up decoding but is not required. -// -// Layout: first one table of kClusters entries per group, in ascending order -// of group index, then two packed indices per byte. -// -// Indices are stored in-order to enable vector-length agnostic decode, because -// streams may be persisted to disk and used by other CPUs. -// -// To enable parallel encoding and decoding, Enc/Dec have `offset` parameters -// which refer to the stream, NOT the raw from/to pointers, which point directly -// to the source/destination. Offsets are in units of values, NOT compressed -// bytes within the stream. -#pragma pack(push, 1) -struct NuqStream { - // Returns offset of packed indices from the start of the stream. This matches - // the (padded) total table size because table entries are bytes. - static constexpr size_t PackedStart(size_t capacity) { - // Round up to avoid cache-line splits when loading indices. No effect on - // size as long as capacity / kGroupSize is a multiple of 4. - return hwy::RoundUpTo(hwy::DivCeil(capacity, kGroupSize) * kClusters, 64); - } - - // Returns number of NuqStream to allocate for the stream, which matches its - // size in bytes. `capacity` is already a multiple of `kGroupSize`. - static constexpr size_t PackedEnd(size_t capacity) { - return PackedStart(capacity) + hwy::DivCeil(capacity, 2); // 2x 4-bit/byte - } - - uint8_t byte; -}; -#pragma pack(pop) - -// Storage for dynamic programming. There are two matrices; we use separate -// allocations to avoid type punning. -template -class AlignedMatrix { - public: - AlignedMatrix() : mem_(hwy::AllocateAligned(kClusters * kGroupSize)) {} - - HWY_INLINE const T& operator()(size_t row, size_t col) const { - return mem_[row * kGroupSize + col]; - } - - HWY_INLINE T& operator()(size_t row, size_t col) { - return mem_[row * kGroupSize + col]; - } - - private: - hwy::AlignedFreeUniquePtr mem_; -}; - -// Reuse memory across calls to Enc to avoid per-call allocations. -struct ClusterBuf { - void Resize(size_t new_num) { - if (new_num < num) return; - - num = new_num; - const size_t num_groups = hwy::DivCeil(num, kGroupSize); - centers = hwy::AllocateAligned(num_groups * kClusters); - idx = hwy::AllocateAligned(hwy::RoundUpTo(num, kGroupSize)); - } - - AlignedMatrix d; - AlignedMatrix t; - - size_t num = 0; - hwy::AlignedFreeUniquePtr centers; - hwy::AlignedFreeUniquePtr idx; -}; - -} // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_H_ diff --git a/compression/nuq_test.cc b/compression/nuq_test.cc index 8c6175c..8cbce6c 100644 --- a/compression/nuq_test.cc +++ b/compression/nuq_test.cc @@ -18,8 +18,6 @@ #define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SVE) #endif -#include "compression/nuq.h" - #include #include #include @@ -28,9 +26,11 @@ #include #include "compression/distortion.h" +#include "compression/shared.h" #include "util/test_util.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" +#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/test_util.h" #include "hwy/timer.h" @@ -39,10 +39,9 @@ #define HWY_TARGET_INCLUDE "compression/nuq_test.cc" // NOLINT // clang-format on #include "hwy/foreach_target.h" // IWYU pragma: keep -// Other headers that include Highway must come after foreach_target.h -#include "compression/nuq-inl.h" #include "hwy/highway.h" -#include "hwy/tests/hwy_gtest.h" +// After highway.h +#include "compression/nuq-inl.h" #include "hwy/tests/test_util-inl.h" HWY_BEFORE_NAMESPACE(); @@ -50,6 +49,8 @@ namespace gcpp { namespace HWY_NAMESPACE { static constexpr size_t kTimingReps = hn::AdjustedReps(3); +static constexpr size_t kClusters = NuqStream::kClusters; +static constexpr size_t kGroupSize = NuqStream::kGroupSize; // All-equal inputs: only one cluster struct TestFlat { @@ -65,7 +66,7 @@ struct TestFlat { for (size_t i = 0; i < kGroupSize; ++i) { in[i] = 0.5f; } - ClusterBuf buf; + NuqStream::ClusterBuf buf; float centers[kClusters]; uint16_t indices[kGroupSize]; const size_t unused_clusters = NuqClustering::ClusterExactL2( @@ -107,7 +108,7 @@ struct TestPlateaus { std::mt19937 rng(rd()); std::shuffle(in.get(), in.get() + kGroupSize, rng); - ClusterBuf buf; + NuqStream::ClusterBuf buf; float centers[kClusters]; uint16_t indices[kGroupSize]; const size_t unused_clusters = NuqClustering::ClusterExactL2( @@ -154,7 +155,7 @@ struct TestRamp { std::mt19937 rng(rd()); std::shuffle(in.get(), in.get() + kGroupSize, rng); - ClusterBuf buf; + NuqStream::ClusterBuf buf; float centers[kClusters]; uint16_t indices[kGroupSize]; const size_t unused_clusters = NuqClustering::ClusterExactL2( @@ -199,7 +200,7 @@ struct TestNormal { } VerifyGaussian(in_stats); - ClusterBuf buf; + NuqStream::ClusterBuf buf; float centers[kClusters]; uint16_t indices[kGroupSize]; double elapsed = hwy::HighestValue(); @@ -239,7 +240,7 @@ struct TestOffset { template HWY_INLINE void operator()(T /*unused*/, D d) { const hn::Repartition df; - const size_t total = 10 * kGroupSize; + const size_t total = 10 * kGroupSize; // already padded const size_t kMidLen = 2 * kGroupSize; // length of middle piece auto in = hwy::AllocateAligned(total); // Enc() requires f32 @@ -247,6 +248,7 @@ struct TestOffset { auto dec2 = hwy::AllocateAligned(kMidLen); auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(total)); HWY_ASSERT(in && dec1 && dec2 && nuq); + const auto nuq_span = MakeSpan(nuq.get(), total); hwy::RandomState rng; for (size_t i = 0; i < total; ++i) { @@ -254,53 +256,72 @@ struct TestOffset { } // Encode + decode everything - ClusterBuf buf; - (void)NuqCodec::Enc(df, in.get(), total, buf, total, nuq.get(), 0); - NuqCodec::Dec(d, total, nuq.get(), 0, dec1.get(), total); + NuqStream::ClusterBuf buf; + (void)NuqCodec::Enc(df, in.get(), total, buf, nuq_span, 0); + NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), 0, dec1.get(), + total); // Overwrite middle with first inputs const size_t offset = 5 * kGroupSize; - (void)NuqCodec::Enc(df, in.get(), kMidLen, buf, total, nuq.get(), offset); + (void)NuqCodec::Enc(df, in.get(), kMidLen, buf, nuq_span, offset); // Decoded middle now matches previously decoded first - NuqCodec::Dec(d, total, nuq.get(), offset, dec2.get(), kMidLen); + NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), offset, dec2.get(), + kMidLen); for (size_t i = 0; i < kMidLen; ++i) { HWY_ASSERT(dec1[i] == dec2[i]); } } }; -void TestAllOffsetF32() { - const hn::ForGEVectors<128, TestOffset> test; - test(float()); -} - -void TestAllOffsetBF16() { - const hn::ForGEVectors<128, TestOffset> test; - test(hwy::bfloat16_t()); -} +void TestOffsetBF16() { hn::ForGEVectors<128, TestOffset>()(BF16()); } +void TestOffsetF32() { hn::ForGEVectors<128, TestOffset>()(float()); } struct TestNibble { template HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition d8; + const hn::Half d8h; using V = hn::Vec; - const size_t N = hn::Lanes(d); - const size_t num = 4 * N; - auto bytes = hwy::AllocateAligned(num / 2); - HWY_ASSERT(bytes); - const V v0 = hn::And(hn::Iota(d, 0), hn::Set(d, 15)); - const V v1 = hn::Set(d, 1); - const V v2 = hn::OddEven(v1, hn::Zero(d)); - const V v3 = hn::Reverse(d, v0); - NibbleCodec::OrderedPackU16(d, v0, v1, v2, v3, bytes.get()); - const V out0 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 0 * N / 2); - const V out1 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 1 * N / 2); - const V out2 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 2 * N / 2); - const V out3 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 3 * N / 2); - HWY_ASSERT_VEC_EQ(d, v0, out0); - HWY_ASSERT_VEC_EQ(d, v1, out1); - HWY_ASSERT_VEC_EQ(d, v2, out2); - HWY_ASSERT_VEC_EQ(d, v3, out3); + using V8 = hn::Vec; + using V8H = hn::Vec; + const V mask = hn::Set(d, 15); + + { + const V v0 = hn::And(hn::Iota(d, 0), mask); + const V v1 = hn::Set(d, 1); + const V v2 = hn::OddEven(v1, hn::Zero(d)); + const V v3 = hn::Reverse(d, v0); + const V8 nibbles = NibbleCodec::OrderedPackU16(d, v0, v1, v2, v3); + const V8H nibbles0 = hn::LowerHalf(d8h, nibbles); + const V8H nibbles1 = hn::UpperHalf(d8h, nibbles); + const V out0 = NibbleCodec::OrderedUnpackU16<0>(d, nibbles0); + const V out1 = NibbleCodec::OrderedUnpackU16<1>(d, nibbles0); + const V out2 = NibbleCodec::OrderedUnpackU16<0>(d, nibbles1); + const V out3 = NibbleCodec::OrderedUnpackU16<1>(d, nibbles1); + HWY_ASSERT_VEC_EQ(d, v0, out0); + HWY_ASSERT_VEC_EQ(d, v1, out1); + HWY_ASSERT_VEC_EQ(d, v2, out2); + HWY_ASSERT_VEC_EQ(d, v3, out3); + } + // Same, but with different values in each lane. + { + const V v0 = hn::And(hn::Iota(d, 0), mask); + const V v1 = hn::And(hn::Iota(d, 1), mask); + const V v2 = hn::And(hn::Iota(d, 2), mask); + const V v3 = hn::And(hn::Iota(d, 3), mask); + const V8 nibbles = NibbleCodec::OrderedPackU16(d, v0, v1, v2, v3); + const V8H nibbles0 = hn::LowerHalf(d8h, nibbles); + const V8H nibbles1 = hn::UpperHalf(d8h, nibbles); + const V out0 = NibbleCodec::OrderedUnpackU16<0>(d, nibbles0); + const V out1 = NibbleCodec::OrderedUnpackU16<1>(d, nibbles0); + const V out2 = NibbleCodec::OrderedUnpackU16<0>(d, nibbles1); + const V out3 = NibbleCodec::OrderedUnpackU16<1>(d, nibbles1); + HWY_ASSERT_VEC_EQ(d, v0, out0); + HWY_ASSERT_VEC_EQ(d, v1, out1); + HWY_ASSERT_VEC_EQ(d, v2, out2); + HWY_ASSERT_VEC_EQ(d, v3, out3); + } } }; @@ -309,15 +330,19 @@ void TestAllNibble() { test(uint16_t()); } -struct TestStream { +// Checks the distortion from an encode and decode round trip. Unlike +// `TestShortLengthsT` in compress_test, this covers large `num` and +// prints the enc/dec throughput. +struct TestEncDec { template HWY_INLINE void operator()(T /*unused*/, D d) { const hn::Repartition df; const size_t num = 4 * kGroupSize; auto in = hwy::AllocateAligned(num); // Enc() requires f32 - auto out = hwy::AllocateAligned(num); + auto out = hwy::AllocateAligned(num); // already padded auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(num)); HWY_ASSERT(in && out && nuq); + const auto nuq_span = MakeSpan(nuq.get(), num); hwy::RandomState rng; hwy::Stats in_stats; @@ -327,12 +352,12 @@ struct TestStream { } VerifyGaussian(in_stats); - ClusterBuf buf; + NuqStream::ClusterBuf buf; double elapsed = hwy::HighestValue(); for (size_t rep = 0; rep < kTimingReps; ++rep) { const double t0 = hwy::platform::Now(); const size_t unused_clusters = - NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0); + NuqCodec::Enc(df, in.get(), num, buf, nuq_span, 0); HWY_ASSERT(unused_clusters == 0); const double t1 = hwy::platform::Now(); elapsed = HWY_MIN(elapsed, t1 - t0); @@ -343,7 +368,7 @@ struct TestStream { elapsed = hwy::HighestValue(); for (size_t rep = 0; rep < kTimingReps; ++rep) { const double t0 = hwy::platform::Now(); - NuqCodec::Dec(d, num, nuq.get(), 0, out.get(), num); + NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), 0, out.get(), num); const double t1 = hwy::platform::Now(); elapsed = HWY_MIN(elapsed, t1 - t0); } @@ -367,129 +392,8 @@ struct TestStream { } }; -void TestAllStreamF32() { - const hn::ForGEVectors<128, TestStream> test; - test(float()); -} - -void TestAllStreamBF16() { - const hn::ForGEVectors<128, TestStream> test; - test(hwy::bfloat16_t()); -} - -struct TestDot { - template - HWY_INLINE void operator()(T /*unused*/, D d) { - const hn::Repartition df; - const size_t num = 4 * kGroupSize; - auto in = hwy::AllocateAligned(num); - auto dec = hwy::AllocateAligned(num); - auto vec = hwy::AllocateAligned(num); - auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(num)); - HWY_ASSERT(in && dec && vec && nuq); - - // Generate inputs and verify their distribution. - hwy::RandomState rng; - hwy::Stats in_stats; - for (size_t i = 0; i < num; ++i) { - in[i] = static_cast(RandomGaussian(rng)); - in_stats.Notify(in[i]); - } - for (size_t i = 0; i < num; ++i) { - const float r = static_cast(RandomGaussian(rng)); - in_stats.Notify(r); - vec[i] = hwy::ConvertScalarTo(r); - } - VerifyGaussian(in_stats); - - ClusterBuf buf; - const size_t unused_clusters = - NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0); - HWY_ASSERT(unused_clusters == 0); - - // Compute dot product without decompression. - float actual = 0.0f; - double elapsed = hwy::HighestValue(); - for (size_t rep = 0; rep < kTimingReps; ++rep) { - hn::Vec sum0 = hn::Zero(df); - hn::Vec sum1 = hn::Zero(df); - hn::Vec sum2 = hn::Zero(df); - hn::Vec sum3 = hn::Zero(df); - const double t0 = hwy::platform::Now(); - NuqCodec::Dot(df, num, nuq.get(), 0, vec.get(), num, sum0, sum1, sum2, - sum3); - const double t1 = hwy::platform::Now(); - elapsed = HWY_MIN(elapsed, t1 - t0); - sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3)); - actual = hn::ReduceSum(df, sum0); - } - - NuqCodec::Dec(df, num, nuq.get(), 0, dec.get(), num); - fprintf(stderr, "Vec %zu Dec %.2f MB/s\n", Lanes(d) * sizeof(T), - num * sizeof(in[0]) * 1E-6 / elapsed); - - // Exact and decompressed dot products for comparison. - float exact = 0.0f; // using original input - float expected = 0.0f; // using decoded NUQ - DistortionStats dec_stats; - hwy::Stats ratios; - for (size_t i = 0; i < num; ++i) { - dec_stats.Notify(in[i], dec[i]); - const float v1 = hwy::ConvertScalarTo(vec[i]); - exact += in[i] * v1; - expected += dec[i] * v1; - if (expected != 0.0f) { - ratios.Notify(exact / expected); - } - } - const bool isBF = sizeof(T) == 2; - const double dec_snr = dec_stats.GeomeanValueDivL1(); - const double dec_wl1 = dec_stats.WeightedAverageL1(); - const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean()); - // exact and actual fluctuate due to the combination of NUQ imprecision, - // and whether vec[i] is negative or positive, so this is quite loose. - const float final_ratio = HWY_MIN(exact / actual, actual / exact); - if (HWY_ONCE) { - fprintf(stderr, "ratios %s\n", ratios.ToString().c_str()); - fprintf(stderr, - "exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f " - "dot_snr %.2f dec_wl1 %.4f\n", - exact, expected, actual, final_ratio, dec_snr, dot_snr, dec_wl1); - } - // Final values are not too far apart. - HWY_ASSERT(gcpp::IsInside(0.88f, 1.0f, final_ratio)); - // Decompressed and uncompressed dot should match exactly. - HWY_ASSERT(gcpp::IsNear(expected, actual, 1E-4f)); - // Geomean of ratios for each i should be very close to one. - HWY_ASSERT(dot_snr >= (isBF ? 17.7 : 14.3)); - - // dec[] is close to in[], but we already check that in TestStream with the - // same input distribution. - HWY_ASSERT(gcpp::IsNear(13.1, dec_snr, 0.1)); - HWY_ASSERT(gcpp::IsNear(0.034, dec_wl1, 0.001)); - HWY_ASSERT(gcpp::IsNear(23.5, dec_stats.SumL1(), 0.1)); - HWY_ASSERT(dec_stats.NumSignFlip() < num / kClusters); - HWY_ASSERT_EQ(0, dec_stats.NumExact()); - HWY_ASSERT_EQ(0, dec_stats.NumRoundedToZero()); - HWY_ASSERT_EQ(0.0, dec_stats.SumL1Rounded()); - // Absolute decode errors are in [0, 0.11], and somewhat right-tailed. - HWY_ASSERT(gcpp::IsInside(0.0f, 2E-5f, dec_stats.L1().Min())); - HWY_ASSERT(gcpp::IsInside(0.09f, 0.11f, dec_stats.L1().Max())); - HWY_ASSERT(gcpp::IsInside(0.02, 0.03, dec_stats.L1().Mean())); - HWY_ASSERT(gcpp::IsInside(1.0, 1.1, dec_stats.L1().Skewness())); - HWY_ASSERT(gcpp::IsInside(4.0, 5.0, dec_stats.L1().Kurtosis())); - static_assert(kGroupSize == 256, "Update expected*"); - } -}; - -void TestAllDotF32() { - const hn::ForGEVectors<128, TestDot> test; - test(float()); -} -void TestAllDotBF16() { - const hn::ForGEVectors<128, TestDot> test; - test(hwy::bfloat16_t()); -} +void TestEncDecBF16() { hn::ForGEVectors<128, TestEncDec>()(BF16()); } +void TestEncDecF32() { hn::ForGEVectors<128, TestEncDec>()(float()); } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE @@ -497,23 +401,17 @@ void TestAllDotBF16() { HWY_AFTER_NAMESPACE(); #if HWY_ONCE - namespace gcpp { HWY_BEFORE_TEST(NuqTest); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllFlat); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllPlateaus); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNormal); -HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetF32); -HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetBF16); +HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetBF16); +HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetF32); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble); -HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamF32); -HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamBF16); -HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotF32); -HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotBF16); -#ifdef HWY_AFTER_TEST +HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecBF16); +HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecF32); HWY_AFTER_TEST(); -#endif } // namespace gcpp - -#endif +#endif // HWY_ONCE diff --git a/compression/sfp-inl.h b/compression/sfp-inl.h index 78c941f..1be84e9 100644 --- a/compression/sfp-inl.h +++ b/compression/sfp-inl.h @@ -52,9 +52,6 @@ HWY_INLINE hn::Mask SignedLt(DU du, hn::Vec a, hn::Vec b) { return SignedGt(du, b, a); } -// Saturated subtraction; returns 0 if the result would be negative. -static inline size_t SubOr0(size_t a, size_t b) { return a > b ? a - b : 0; } - // Encode/decode functions. class SfpCodec { public: @@ -260,9 +257,9 @@ class SfpCodec { } // Encodes `num` bf16 values from `in_bf` to `out_packed`. Their magnitude - // must be at most 1.875. + // must be at most SfpStream::kMax. template - static HWY_INLINE void Enc(DBF dbf, const hwy::bfloat16_t* HWY_RESTRICT in_bf, + static HWY_INLINE void Enc(DBF dbf, const BF16* HWY_RESTRICT in_bf, size_t num, SfpStream* HWY_RESTRICT out_packed) { const hn::Repartition d8; using V8 = hn::Vec; @@ -280,7 +277,7 @@ class SfpCodec { const size_t remaining = num - i; HWY_DASSERT(remaining < 2 * N16); if (remaining != 0) { - HWY_ALIGN hwy::bfloat16_t padded[2 * hn::MaxLanes(dbf)]; + HWY_ALIGN BF16 padded[2 * hn::MaxLanes(dbf)]; hwy::ZeroBytes(padded, sizeof(padded)); hwy::CopyBytes(in_bf + i, padded, remaining * sizeof(padded[0])); const V8 packed = Enc2B(dbf, padded); @@ -289,7 +286,7 @@ class SfpCodec { } // Encodes `num` f32 values from `in_f` to `packed`. Their magnitude - // must be at most 1.875. + // must be at most SfpStream::kMax. template static HWY_INLINE void Enc(DF df, const float* HWY_RESTRICT in_f, size_t num, SfpStream* HWY_RESTRICT out_packed) { @@ -317,148 +314,112 @@ class SfpCodec { } } - // Decodes `num` values from `in_packed` to `out_bf`. + template >> + static HWY_INLINE void Dec2(DBF16 dbf16, V8 packed, hn::Vec& raw0, + hn::Vec& raw1) { + Dec2B(dbf16, packed, raw0, raw1); + } + + template >>> + static HWY_INLINE void Dec2(DF df, V8 packed, hn::Vec& raw0, + hn::Vec& raw1) { + const hn::Rebind dbf; // half-vector + using VBF = hn::Vec; + VBF bf0, bf1; + Dec2B(dbf, packed, bf0, bf1); + raw0 = hn::PromoteTo(df, bf0); + raw1 = hn::PromoteTo(df, bf1); + } + + // Decompresses to (arbitrary) `num` BF16 elements in `raw_bf`, then appends + // `[0, hn::Lanes(dbf))` zeroes as required to round `num` up to one vector, + // if it is not already. DBF argument is provided by nuq-inl.h. template - static HWY_INLINE void Dec(DBF dbf, const SfpStream* HWY_RESTRICT in_packed, - size_t num, hwy::bfloat16_t* HWY_RESTRICT out_bf) { + static HWY_INLINE void DecompressAndZeroPad( + DBF dbf, const PackedSpan& packed, size_t packed_ofs, + BF16* HWY_RESTRICT raw_bf, size_t num) { const hn::Repartition d8; using V8 = hn::Vec; using VBF = hn::Vec; const size_t N16 = hn::Lanes(dbf); + const uint8_t* HWY_RESTRICT base = &packed.ptr->byte + packed_ofs; + size_t i = 0; if (num >= 2 * N16) { HWY_UNROLL(1) for (; i <= num - 2 * N16; i += 2 * N16) { - const V8 packed = hn::LoadU(d8, &in_packed->byte + i); + const V8 packed = hn::LoadU(d8, base + i); VBF bf0, bf1; Dec2B(dbf, packed, bf0, bf1); - hn::StoreU(bf0, dbf, out_bf + i); - hn::StoreU(bf1, dbf, out_bf + i + N16); + hn::StoreU(bf0, dbf, raw_bf + i); + hn::StoreU(bf1, dbf, raw_bf + i + N16); } } const size_t remaining = num - i; HWY_DASSERT(remaining < 2 * N16); if (remaining != 0) { - const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining); + const V8 packed = hn::LoadN(d8, base + i, remaining); VBF bf0, bf1; Dec2B(dbf, packed, bf0, bf1); - hn::StoreN(bf0, dbf, out_bf + i, remaining); - hn::StoreN(bf1, dbf, out_bf + i + N16, SubOr0(remaining, N16)); + // If at most one vector, the first store adds zero padding. Check before + // storing the second, because callers only pad to one vector. + hn::StoreU(bf0, dbf, raw_bf + i); + if (remaining > N16) hn::StoreU(bf1, dbf, raw_bf + i + N16); } } - // Decodes `num` values from `in_packed` to `out_f`. + // Decompresses to (arbitrary) `num` float elements in `raw_f`, then appends + // `[0, hn::Lanes(df))` zeroes as required to round `num` up to one vector, + // if it is not already. template - static HWY_INLINE void Dec(DF df, const SfpStream* HWY_RESTRICT in_packed, - size_t num, float* HWY_RESTRICT out_f) { + static HWY_INLINE void DecompressAndZeroPad( + DF df, const PackedSpan& packed, size_t packed_ofs, + float* HWY_RESTRICT raw_f, size_t num) { const hn::Repartition d8; using V8 = hn::Vec; using VF = hn::Vec; const size_t NF = hn::Lanes(df); + const uint8_t* HWY_RESTRICT base = &packed.ptr->byte + packed_ofs; + size_t i = 0; if (num >= 4 * NF) { HWY_UNROLL(1) for (; i <= num - 4 * NF; i += 4 * NF) { - const V8 packed = hn::LoadU(d8, &in_packed->byte + i); + const V8 packed = hn::LoadU(d8, base + i); VF f0, f1, f2, f3; Dec4F(df, packed, f0, f1, f2, f3); - hn::StoreU(f0, df, out_f + i + NF * 0); - hn::StoreU(f1, df, out_f + i + NF * 1); - hn::StoreU(f2, df, out_f + i + NF * 2); - hn::StoreU(f3, df, out_f + i + NF * 3); + hn::StoreU(f0, df, raw_f + i + NF * 0); + hn::StoreU(f1, df, raw_f + i + NF * 1); + hn::StoreU(f2, df, raw_f + i + NF * 2); + hn::StoreU(f3, df, raw_f + i + NF * 3); } } const size_t remaining = num - i; HWY_DASSERT(remaining < 4 * NF); - if (remaining != 0) { - const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining); + if (HWY_UNLIKELY(remaining != 0)) { + const V8 packed = hn::LoadN(d8, base + i, remaining); VF f0, f1, f2, f3; Dec4F(df, packed, f0, f1, f2, f3); - hn::StoreN(f0, df, out_f + i + 0 * NF, remaining); - hn::StoreN(f1, df, out_f + i + 1 * NF, SubOr0(remaining, 1 * NF)); - hn::StoreN(f2, df, out_f + i + 2 * NF, SubOr0(remaining, 2 * NF)); - hn::StoreN(f3, df, out_f + i + 3 * NF, SubOr0(remaining, 3 * NF)); + // We are only guaranteed one vector of padding, so cannot unconditionally + // store four vectors. `StoreN` would work, at the cost of saturated + // subtraction and creating masks. Because we know that `raw_f` is padded + // to at least one vector, we can instead store entire vectors and only + // make the address conditional, which potentially avoids branches. + // Separate per-vector storage may avoid conflicts. + HWY_ALIGN float buf[4 * hn::MaxLanes(df)]; + hn::StoreU(f0, df, raw_f + i); + hn::StoreU(f1, df, (remaining > 1 * NF ? (raw_f + i) : buf) + 1 * NF); + hn::StoreU(f2, df, (remaining > 2 * NF ? (raw_f + i) : buf) + 2 * NF); + hn::StoreU(f3, df, (remaining > 3 * NF ? (raw_f + i) : buf) + 3 * NF); } } - // Fused decode and dot product with even-odd bf16 into four f32 accumulators. - template - static HWY_INLINE void DotEO(DF df, const SfpStream* HWY_RESTRICT in_packed, - size_t num, - const hwy::bfloat16_t* HWY_RESTRICT vec_aligned, - hn::Vec& sum0, hn::Vec& sum1, - hn::Vec& sum2, hn::Vec& sum3) { - const hn::Repartition d8; - const hn::Repartition dbf; - using V8 = hn::Vec; - using VBF = hn::Vec; - const size_t N16 = hn::Lanes(dbf); - HWY_DASSERT(num % (2 * N16) == 0); // whole SFP vector -> 2x bf16 - - HWY_UNROLL(1) - for (size_t i = 0; i < num; i += 2 * N16) { - const V8 packed = hn::LoadU(d8, &in_packed->byte + i); - const VBF ve = hn::LoadU(dbf, vec_aligned + i); - const VBF vo = hn::LoadU(dbf, vec_aligned + i + N16); - VBF be, bo; - DecEvenOdd(dbf, packed, be, bo); - sum0 = hn::ReorderWidenMulAccumulate(df, be, ve, sum0, sum1); - sum2 = hn::ReorderWidenMulAccumulate(df, bo, vo, sum2, sum3); - } - } - - // Fused decode and dot product with even-odd f32 into four f32 accumulators. - template - static HWY_INLINE void DotEO(DF df, const SfpStream* HWY_RESTRICT in_packed, - size_t num, - const float* HWY_RESTRICT vec_aligned, - hn::Vec& sum0, hn::Vec& sum1, - hn::Vec& sum2, hn::Vec& sum3) { - const hn::Repartition d8; - using V8 = hn::Vec; - using VF = hn::Vec; - const size_t NF = hn::Lanes(df); - HWY_DASSERT(num % (4 * NF) == 0); // whole SFP vector -> 4x f32 - - HWY_UNROLL(1) - for (size_t i = 0; i < num; i += 4 * NF) { - const V8 packed = hn::LoadU(d8, &in_packed->byte + i); - const VF ve0 = hn::LoadU(df, vec_aligned + i + NF * 0); - const VF vo0 = hn::LoadU(df, vec_aligned + i + NF * 1); - const VF ve1 = hn::LoadU(df, vec_aligned + i + NF * 2); - const VF vo1 = hn::LoadU(df, vec_aligned + i + NF * 3); - VF fe0, fo0, fe1, fo1; - DecEvenOddF(df, packed, fe0, fo0, fe1, fo1); - sum0 = hn::MulAdd(fe0, ve0, sum0); - sum1 = hn::MulAdd(fo0, vo0, sum1); - sum2 = hn::MulAdd(fe1, ve1, sum2); - sum3 = hn::MulAdd(fo1, vo1, sum3); - } - } - - template >>> - static HWY_INLINE void Dec2(DF df, V8 packed, hn::Vec& f0, - hn::Vec& f1) { - const hn::Rebind dbf; - using VBF = hn::Vec; - VBF bf0, bf1; - Dec2B(dbf, packed, bf0, bf1); - f0 = hn::PromoteTo(df, bf0); - f1 = hn::PromoteTo(df, bf1); - } - - template >> - static HWY_INLINE void Dec2(DBF16 dbf16, V8 packed, hn::Vec& bf0, - hn::Vec& bf1) { - Dec2B(dbf16, packed, bf0, bf1); - } - private: // Wrappers to avoid code duplication across float/bf16 input types and // the main loop/remainder. @@ -479,7 +440,7 @@ class SfpCodec { template >> - static HWY_INLINE V8 Enc2B(DBF dbf, const hwy::bfloat16_t* HWY_RESTRICT in) { + static HWY_INLINE V8 Enc2B(DBF dbf, const BF16* HWY_RESTRICT in) { const hn::Repartition d16; const size_t N16 = hn::Lanes(d16); using V16 = hn::Vec; @@ -505,7 +466,7 @@ class SfpCodec { class V8 = hn::Vec>> static HWY_INLINE V8 Enc4F(DF df, const float* HWY_RESTRICT in) { const hn::Repartition d16; - const hn::Repartition dbf; + const hn::Repartition dbf; using VF = hn::Vec; using V16 = hn::Vec; const size_t NF = hn::Lanes(df); @@ -549,7 +510,7 @@ class SfpCodec { static HWY_INLINE void Dec4F(DF df, V8 packed, hn::Vec& f0, hn::Vec& f1, hn::Vec& f2, hn::Vec& f3) { - const hn::Repartition dbf; + const hn::Repartition dbf; using VBF = hn::Vec; VBF bf0, bf1; Dec2B(dbf, packed, bf0, bf1); @@ -559,6 +520,7 @@ class SfpCodec { f3 = hn::PromoteUpperTo(df, bf1); } + // TODO: currently unused, but keep for potential later MatMul packing. template >> static HWY_INLINE void DecEvenOdd(DBF dbf, V8 packed, hn::Vec& even, @@ -576,7 +538,7 @@ class SfpCodec { static HWY_INLINE void DecEvenOddF(DF df, V8 packed, hn::Vec& even0, hn::Vec& odd0, hn::Vec& even1, hn::Vec& odd1) { - const hn::Repartition dbf; + const hn::Repartition dbf; using VBF = hn::Vec; VBF even_bf, odd_bf; DecEvenOdd(dbf, packed, even_bf, odd_bf); diff --git a/compression/sfp_test.cc b/compression/sfp_test.cc index da4f220..f79e600 100644 --- a/compression/sfp_test.cc +++ b/compression/sfp_test.cc @@ -39,7 +39,6 @@ #include "hwy/highway.h" // After highway.h #include "compression/sfp-inl.h" -#include "ops/dot-inl.h" #include "hwy/tests/test_util-inl.h" HWY_BEFORE_NAMESPACE(); @@ -128,7 +127,7 @@ void TestAllFastDecode() { // Encode HWY_INLINE uint32_t SFP8FromF32(float f) { - HWY_ASSERT(-1.875f <= f && f <= 1.875f); + HWY_ASSERT(-SfpStream::kMax <= f && f <= SfpStream::kMax); constexpr uint32_t kMaskM = hwy::MantissaMask(); uint32_t binary32; @@ -182,7 +181,7 @@ struct TestDecEnc { template HWY_INLINE void operator()(T /*unused*/, D d) { const hn::RepartitionToWide d16; - const hn::Rebind dbf; + const hn::Rebind dbf; const hn::Repartition df; for (uint32_t encoded = 0; encoded < 256; ++encoded) { if (encoded == 0x80) continue; // -0 is reserved @@ -215,7 +214,7 @@ struct TestGolden { template HWY_INLINE void operator()(T /*unused*/, D d) { const hn::Repartition df; - const hn::Repartition dbf; + const hn::Repartition dbf; const hn::RebindToUnsigned d16; struct Golden { @@ -294,9 +293,53 @@ void TestAllGolden() { TestGolden()(uint8_t(), hn::ScalableTag()); } +// ------------------------------ Order + +// Store 8-bit iota, decode, encode, check iota == packed. This ensures +// Enc/Dec are preserving the order independent of vector length. +struct TestOrder { + template + HWY_INLINE void operator()(T /*unused*/, DBF dbf) { + const size_t N16 = hn::Lanes(dbf); + + for (size_t num = 1; num < 6 * N16; ++num) { + const size_t padded = hwy::RoundUpTo(num, N16); + + auto iota = hwy::AllocateAligned(num); + auto packed = hwy::AllocateAligned(num); + auto bf = hwy::AllocateAligned(padded); + HWY_ASSERT(iota && packed && bf); + for (size_t i = 0; i < num; ++i) { + // Clear sign bit so we can also check that bf is in ascending order. + iota[i].byte = i & 127; + } + + SfpCodec::DecompressAndZeroPad(dbf, MakeConstSpan(iota.get(), num), 0, + bf.get(), num); + for (size_t i = num; i < padded; ++i) { + if (hwy::ConvertScalarTo(bf[i]) != 0.0f) { + HWY_ABORT("num %zu padded %zu i %zu: not padded", num, padded, i); + } + } + + SfpCodec::Enc(dbf, bf.get(), num, packed.get()); + + for (size_t i = 0; i < num; ++i) { + if (iota[i].byte != packed[i].byte) { + HWY_ABORT("@%zu: %d %d\n", i, iota[i].byte, packed[i].byte); + } + } + } + } +}; + +void TestAllOrder() { hn::ForGEVectors<32, TestOrder>()(BF16()); } + // ------------------------------ Foreach bf16 input -// Generate all values, encode, decode back. +// Checks the distortion from an encode and decode round trip. Unlike +// `TestShortLengthsT` in compress_test, this covers large `num` and +// prints the enc/dec throughput. struct TestEncDec { template HWY_INLINE void operator()(T /*unused*/, DBF dbf) { @@ -309,14 +352,14 @@ struct TestEncDec { auto in = hwy::AllocateAligned(max); auto packed = hwy::AllocateAligned(max); - auto dec = hwy::AllocateAligned(max); + auto dec = hwy::AllocateAligned(max); // already padded HWY_ASSERT(in && packed && dec); size_t num = 0; for (size_t i = 0; i < max; ++i) { const uint16_t bits = i * kStep; const float f = hwy::F32FromBF16(hwy::BitCastScalar(bits)); // Keep if within range - if (hwy::ScalarIsFinite(f) && f <= 1.875f) { + if (hwy::ScalarIsFinite(f) && f <= SfpStream::kMax) { in[num] = hwy::BF16FromF32(f); in[num + 1] = hwy::BF16FromF32(-f); num += 2; @@ -329,7 +372,8 @@ struct TestEncDec { const double t0 = hwy::platform::Now(); SfpCodec::Enc(dbf, in.get(), num, packed.get()); const double t1 = hwy::platform::Now(); - SfpCodec::Dec(dbf, packed.get(), num, dec.get()); + SfpCodec::DecompressAndZeroPad(dbf, MakeConstSpan(packed.get(), num), 0, + dec.get(), num); const double t2 = hwy::platform::Now(); enc_elapsed = HWY_MIN(enc_elapsed, t1 - t0); dec_elapsed = HWY_MIN(dec_elapsed, t2 - t1); @@ -358,9 +402,10 @@ struct TestEncDec { stats.SumL1Rounded(), snr, wl1); } HWY_ASSERT(stats.Original().Count() == stats.L1().Count()); - // Inputs are in [-1.875, 1.875], symmetric, and heavy-tailed. - HWY_ASSERT(stats.Original().Min() == -1.875f); - HWY_ASSERT(stats.Original().Max() == 1.875f); + // Inputs are in [-SfpStream::kMax, SfpStream::kMax], symmetric, and + // heavy-tailed. + HWY_ASSERT(stats.Original().Min() == -SfpStream::kMax); + HWY_ASSERT(stats.Original().Max() == SfpStream::kMax); HWY_ASSERT(gcpp::IsInside(-1E-6, 1E-6, stats.Original().Mean())); HWY_ASSERT(gcpp::IsInside(-1E-6, 1E-6, stats.Original().Skewness())); HWY_ASSERT(gcpp::IsInside(80.0, 100.0, stats.Original().Kurtosis())); @@ -382,179 +427,7 @@ struct TestEncDec { } }; -void TestAllEncDec() { hn::ForGEVectors<32, TestEncDec>()(hwy::bfloat16_t()); } - -// ------------------------------ Order - -// Store 8-bit iota, decode, encode, check iota == packed. This ensures -// Enc/Dec are preserving the order independent of vector length. -struct TestOrder { - template - HWY_INLINE void operator()(T /*unused*/, DBF dbf) { - const hn::Repartition du8; - - const size_t num = 10 * hn::Lanes(du8) / 3; - - auto iota = hwy::AllocateAligned(num); - auto packed = hwy::AllocateAligned(num); - auto bf = hwy::AllocateAligned(num); - HWY_ASSERT(iota && packed && bf); - for (size_t i = 0; i < num; ++i) { - // Clear sign bit so we can also check that bf is in ascending order. - iota[i].byte = i & 127; - } - - SfpCodec::Dec(dbf, iota.get(), num, bf.get()); - SfpCodec::Enc(dbf, bf.get(), num, packed.get()); - - for (size_t i = 0; i < num; ++i) { - if (iota[i].byte != packed[i].byte) { - HWY_ABORT("@%zu: %d %d\n", i, iota[i].byte, packed[i].byte); - } - } - } -}; - -void TestAllOrder() { hn::ForGEVectors<32, TestOrder>()(hwy::bfloat16_t()); } - -// ------------------------------ Dot - -struct TestDot { - template - HWY_INLINE void operator()(T /*unused*/, D d) { - const hn::Repartition df; - const size_t num = 1024; // not too many for GeometricMean overflow. - const size_t N = hn::Lanes(d); - auto in = hwy::AllocateAligned(num); - auto dec = hwy::AllocateAligned(num); - auto vec = hwy::AllocateAligned(num); - auto vec_eo = hwy::AllocateAligned(num); - auto sfp = hwy::AllocateAligned(num); - HWY_ASSERT(in && dec && vec && vec_eo && sfp); - - // Generate inputs and verify their distribution. - hwy::RandomState rng; - hwy::Stats in_stats; - for (size_t i = 0; i < num; ++i) { - const float r = static_cast(RandomGaussian(rng)); - in_stats.Notify(r); - in[i] = hwy::ConvertScalarTo(r); - } - for (size_t i = 0; i < num; ++i) { - const float r = static_cast(RandomGaussian(rng)); - in_stats.Notify(r); - vec[i] = hwy::ConvertScalarTo(r); - } - VerifyGaussian(in_stats); - - // Convert vec to even/odd for DotEO - for (size_t i = 0; i < num; i += 2 * N) { - hn::Vec ve, vo; - hn::LoadInterleaved2(d, vec.get() + i, ve, vo); - hn::Store(ve, d, vec_eo.get() + i + 0); - hn::Store(vo, d, vec_eo.get() + i + N); - } - - SfpCodec::Enc(d, in.get(), num, sfp.get()); - - // Compute dot product without decompression. - float actual = 0.0f; - float actual_eo = 0.0f; - double elapsed = hwy::HighestValue(); - double elapsed_eo = hwy::HighestValue(); - for (size_t rep = 0; rep < 200; ++rep) { - { - const double t0 = hwy::platform::Now(); - actual = SimpleDot(df, sfp.get(), 0, vec.get(), num); - const double t1 = hwy::platform::Now(); - elapsed = HWY_MIN(elapsed, t1 - t0); - } - { - hn::Vec sum0 = hn::Zero(df); - hn::Vec sum1 = hn::Zero(df); - hn::Vec sum2 = hn::Zero(df); - hn::Vec sum3 = hn::Zero(df); - const double t0 = hwy::platform::Now(); - SfpCodec::DotEO(df, sfp.get(), num, vec_eo.get(), sum0, sum1, sum2, - sum3); - const double t1 = hwy::platform::Now(); - elapsed_eo = HWY_MIN(elapsed_eo, t1 - t0); - sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3)); - actual_eo = hn::ReduceSum(df, sum0); - } - } - - SfpCodec::Dec(d, sfp.get(), num, dec.get()); - fprintf(stderr, "Vec %zu Dot %zu-bit %.2f ; %.2f MB/s\n", - Lanes(d) * sizeof(T), sizeof(T) * 8, - num * sizeof(T) * 1E-6 / elapsed, - num * sizeof(T) * 1E-6 / elapsed_eo); - - // Exact and decompressed dot products for comparison. - float exact = 0.0f; // using original input - float expected = 0.0f; // using decoded SFP - DistortionStats dec_stats; - hwy::Stats ratios; - for (size_t i = 0; i < num; ++i) { - const float in1 = hwy::ConvertScalarTo(in[i]); - const float dec1 = hwy::ConvertScalarTo(dec[i]); - const float vec1 = hwy::ConvertScalarTo(vec[i]); - dec_stats.Notify(in1, dec1); - - exact += in1 * vec1; - expected += dec1 * vec1; - if (expected != 0.0f) { - ratios.Notify(exact / expected); - } - } - const bool isBF = sizeof(T) == 2; - const double dec_snr = dec_stats.GeomeanValueDivL1(); - const double dec_wl1 = dec_stats.WeightedAverageL1(); - const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean()); - // exact and actual fluctuate due to the combination of SFP imprecision, - // and whether vec[i] is negative or positive, so this is quite loose. - const float final_ratio = HWY_MIN(exact / actual, actual / exact); - if (HWY_ONCE) { - fprintf(stderr, "ratios %s\n", ratios.ToString().c_str()); - fprintf(stderr, - "exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f " - "dot_snr %.2f dec_wl1 %.5f\n", - exact, expected, actual, final_ratio, dec_snr, dot_snr, dec_wl1); - } - // Final values are not too far apart. - HWY_ASSERT(gcpp::IsInside(0.87f, 1.0f, final_ratio)); - // Decompressed and uncompressed dot should match exactly. - HWY_ASSERT(gcpp::IsNear(expected, actual, 1E-4f)); - // Even/odd dot should also match - HWY_ASSERT(gcpp::IsNear(actual, actual_eo, 1E-4f)); - // Geomean of ratios for each i should be very close to one. - HWY_ASSERT(dot_snr >= (isBF ? 70.0 : 1000.0)); - - // dec[] is close to in[]. We also check that in TestEncDec, but for much - // smaller input magnitudes. - HWY_ASSERT(gcpp::IsNear(isBF ? 51.0 : 64.0, dec_snr, 1.0)); - HWY_ASSERT(gcpp::IsNear(isBF ? 0.013 : 0.012, dec_wl1, 0.001)); - HWY_ASSERT(gcpp::IsNear(isBF ? 6.2 : 6.3, dec_stats.SumL1(), 0.1)); - HWY_ASSERT_EQ(0, dec_stats.NumSignFlip()); - HWY_ASSERT_EQ(0, dec_stats.NumRoundedToZero()); - HWY_ASSERT_EQ(0.0, dec_stats.SumL1Rounded()); - // Absolute decode errors are in [0, 5E-2], and somewhat right-tailed. - HWY_ASSERT(gcpp::IsInside(0.0f, 2E-6f, dec_stats.L1().Min())); - HWY_ASSERT(gcpp::IsInside(3E-2f, 5E-2f, dec_stats.L1().Max())); - HWY_ASSERT(gcpp::IsInside(4E-3, 7E-3, dec_stats.L1().Mean())); - HWY_ASSERT(gcpp::IsInside(1.8, 1.9, dec_stats.L1().Skewness())); - HWY_ASSERT(gcpp::IsInside(6.0, 7.0, dec_stats.L1().Kurtosis())); - } -}; - -void TestAllDotF32() { - const hn::ForGEVectors<128, TestDot> test; - test(float()); -} -void TestAllDotBF16() { - const hn::ForGEVectors<128, TestDot> test; - test(hwy::bfloat16_t()); -} +void TestAllEncDec() { hn::ForGEVectors<32, TestEncDec>()(BF16()); } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE @@ -562,7 +435,6 @@ void TestAllDotBF16() { HWY_AFTER_NAMESPACE(); #if HWY_ONCE - namespace gcpp { HWY_BEFORE_TEST(SfpTest); HWY_EXPORT_AND_TEST_P(SfpTest, PrintTables); @@ -570,13 +442,8 @@ HWY_EXPORT_AND_TEST_P(SfpTest, TestAllUnique); HWY_EXPORT_AND_TEST_P(SfpTest, TestAllFastDecode); HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDecEnc); HWY_EXPORT_AND_TEST_P(SfpTest, TestAllGolden); -HWY_EXPORT_AND_TEST_P(SfpTest, TestAllEncDec); HWY_EXPORT_AND_TEST_P(SfpTest, TestAllOrder); -HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDotF32); -HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDotBF16); -#ifdef HWY_AFTER_TEST +HWY_EXPORT_AND_TEST_P(SfpTest, TestAllEncDec); HWY_AFTER_TEST(); -#endif } // namespace gcpp - -#endif +#endif // HWY_ONCE diff --git a/compression/shared.h b/compression/shared.h index 5f8b173..166cd29 100644 --- a/compression/shared.h +++ b/compression/shared.h @@ -20,8 +20,12 @@ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_ #include +#include -#include "hwy/base.h" // hwy::bfloat16_t +#include + +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" // HWY_INLINE namespace gcpp { @@ -35,25 +39,172 @@ using BF16 = hwy::bfloat16_t; // - 24-bit dynamic range, with max exponent 2^0. // - 3 bit mantissa for values >= 2^-7, otherwise 2. // -// A pointer to this is the *start* of an SFP stream. Values are stored -// in-order to enable vector-length agnostic seeking, because streams may be -// written to disk for loading on other CPUs. +// A pointer to this is the *start* of an SFP stream. Aligning the allocation +// (see aligned_allocator.h) may speed up decoding but is not required. +// +// Layout: Values are stored in-order to enable vector-length agnostic seeking, +// because streams may be written to disk for loading on other CPUs. // // This is faster to decode than a straightforward implementation of eXmY, in // part because SFP does not require subnormals. Unlike OCP MX, it also does not // require side information (shared exponents). // // Although the representation could probably be shrunk to 6-7 bits, more -// savings can be had by non-uniform clustering - see nuq.h. +// savings can be had by non-uniform clustering - see NuqStream. #pragma pack(push, 1) struct SfpStream { + // Largest possible input magnitude: 1.111 * 2^0. This could be increased by + // shifting the value range (exponent bias). + static constexpr float kMax = 1.875f; + uint8_t byte; }; #pragma pack(pop) -// Largest possible input magnitude: 1.111 * 2^0. This could be increased by -// shifting the value range (exponent bias). -constexpr float kMaxSFP = 1.875f; +// Returns 1.0f if all magnitudes are <= SfpStream::kMax, otherwise scales them +// such that the largest magnitude is SfpStream::kMax, and returns the +// multiplier with which to restore the original values. This is only necessary +// before compressing to SfpStream. +// TODO: vectorize +static inline float ScaleWeights(float* HWY_RESTRICT raw, size_t num) { + float maxabs = 0.0; + for (size_t i = 0; i < num; ++i) { + maxabs = HWY_MAX(maxabs, hwy::ScalarAbs(raw[i])); + } + if (maxabs <= SfpStream::kMax) { + return 1.0f; + } + const float scale = maxabs / SfpStream::kMax; + const float inv_scale = static_cast(1.0 / static_cast(scale)); + for (size_t i = 0; i < num; ++i) { + // Clamp because kMax may still be exceeded. + const float magn = + HWY_MIN(SfpStream::kMax, hwy::ScalarAbs(raw[i] * inv_scale)); + raw[i] = hwy::ScalarCopySign(magn, raw[i]); + } + return scale; +} + +// Non-uniform quantization: a compressed representation of f32 inputs that +// supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or +// two vectors (for `Decompress2`), and decoding to bf16/f32. +// +// A pointer to this is the *start* of a NUQ stream. Aligning the allocation +// (see aligned_allocator.h) may be speed up decoding but is not required. +// +// Layout: first one table of kClusters entries per group, in ascending order +// of group index, then two packed indices per byte. Indices are stored +// in-order to enable vector-length agnostic decode, because streams may be +// persisted to disk and used by other CPUs. +// +// To enable parallel encoding and decoding, Enc/Dec have `offset` parameters +// which refer to the stream, NOT the raw from/to pointers, which point directly +// to the source/destination. Offsets are in units of values, NOT compressed +// bytes within the stream. +#pragma pack(push, 1) +struct NuqStream { + // 4-bit indices are a sweet spot in terms of quality per size. + static constexpr size_t kClusters = 16; + + // Number of weights that share a table. Larger = slower encode, higher error, + // smaller size (table amortized over more weights). + static constexpr size_t kGroupSize = 256; + + // Storage for dynamic programming. There are two matrices; we use separate + // allocations to avoid type punning. + template + class AlignedMatrix { + public: + AlignedMatrix() : mem_(hwy::AllocateAligned(kClusters * kGroupSize)) {} + + HWY_INLINE const T& operator()(size_t row, size_t col) const { + return mem_[row * kGroupSize + col]; + } + + HWY_INLINE T& operator()(size_t row, size_t col) { + return mem_[row * kGroupSize + col]; + } + + private: + hwy::AlignedFreeUniquePtr mem_; + }; + + // Reuse memory across calls to Enc to avoid per-call allocations. + struct ClusterBuf { + // Move-only (stored inside vector in CompressWorkingSet). + ClusterBuf() = default; + ClusterBuf(const ClusterBuf&) = delete; + ClusterBuf& operator=(const ClusterBuf&) = delete; + ClusterBuf(ClusterBuf&&) = default; + ClusterBuf& operator=(ClusterBuf&&) = default; + + void Resize(size_t new_num_groups) { + if (new_num_groups < num_groups) return; + + num_groups = new_num_groups; + centers = hwy::AllocateAligned(num_groups * kClusters); + idx = hwy::AllocateAligned(num_groups * kGroupSize); + } + + // Independent of num_groups. + AlignedMatrix costs; + AlignedMatrix argmin; + + size_t num_groups = 0; + hwy::AlignedFreeUniquePtr centers; + hwy::AlignedFreeUniquePtr idx; + }; + + // Returns offset of packed indices from the start of the stream. This matches + // the (padded) total table size because table entries are bytes. + static constexpr size_t PackedStart(size_t capacity) { + // Round up to avoid cache-line splits when loading indices. No effect on + // size as long as capacity / kGroupSize is a multiple of 4. + return hwy::RoundUpTo(hwy::DivCeil(capacity, kGroupSize) * kClusters, 64); + } + + // Returns number of NuqStream to allocate for the stream, which matches its + // size in bytes. + static constexpr size_t PackedEnd(size_t capacity) { + return PackedStart(capacity) + hwy::DivCeil(capacity, 2); // 2x 4-bit/byte + } + + uint8_t byte; +}; +#pragma pack(pop) + +template +const char* TypeName() { + using Packed = hwy::RemoveCvRef; + if constexpr (hwy::IsSame()) { + return "f32"; + } else if constexpr (hwy::IsSame()) { + return "b16"; + } else if constexpr (hwy::IsSame()) { + return "sfp"; + } else if constexpr (hwy::IsSame()) { + return "nuq"; + } else { + HWY_DASSERT(false); + return "unknown"; + } +} + +template +constexpr bool IsCompressed() { + return hwy::IsSameEither, SfpStream, NuqStream>(); +} + +// Returns the number of `MatT` elements required to store `capacity` values, +// which must not be zero. +template +constexpr size_t CompressedArrayElements(size_t capacity) { + if constexpr (hwy::IsSame, NuqStream>()) { + return NuqStream::PackedEnd(capacity); + } else { + return capacity; + } +} // Non-owning view of packed elements. Shortens argument lists. // @@ -63,13 +214,19 @@ constexpr float kMaxSFP = 1.875f; // reusing `hwy::Span`. template struct PackedSpan { - void BoundsCheck(size_t packed_ofs, size_t num) const { - HWY_DASSERT(packed_ofs + num <= size); - (void)size; + // Ensures callers can read or write `num_accessible` elements starting at + // `packed_ofs`. + void BoundsCheck(size_t packed_ofs, size_t num_accessible) const { + // For NUQ, there can be fewer Packed than the number of elements, hence + // check the compressed count and ensure we have that many. + const size_t required = + CompressedArrayElements(packed_ofs + num_accessible); + HWY_DASSERT(num >= required); + (void)required; } Packed* HWY_RESTRICT ptr; - size_t size; // for BoundsCheck and nuq-inl.h HWY_ASSERT. + size_t num; // for BoundsCheck and nuq-inl.h HWY_ASSERT. }; // Avoids spelling out the template parameter in every call. @@ -87,7 +244,7 @@ HWY_INLINE PackedSpan MakeConstSpan(Packed* ptr, size_t size) { // `RMSNormInplace` and compression tests. template HWY_INLINE PackedSpan MakeConst(PackedSpan packed) { - return {packed.ptr, packed.size}; + return {packed.ptr, packed.num}; } } // namespace gcpp diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h new file mode 100644 index 0000000..1f591ff --- /dev/null +++ b/compression/test_util-inl.h @@ -0,0 +1,68 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard for headers. +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_ + +// IWYU pragma: begin_exports +#include "compression/compress.h" +#include "compression/distortion.h" +// IWYU pragma: end_exports + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_ + +// Include guard for (potentially) SIMD code. +#if defined(THIRD_PARTY_GEMMA_CPP_COMPRESS_TEST_UTIL_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) // NOLINT +#ifdef THIRD_PARTY_GEMMA_CPP_COMPRESS_TEST_UTIL_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_COMPRESS_TEST_UTIL_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_COMPRESS_TEST_UTIL_TOGGLE +#endif + +#include "hwy/highway.h" +// After highway.h +#include "compression/compress-inl.h" +#include "hwy/tests/test_util-inl.h" // IWYU pragma: export + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +// `Packed` is the type passed to `TestT`. +template class TestT> +void ForeachRawType() { + const hn::ForGEVectors<128, TestT> test; + // The argument selects the type to decode to: BF16 or float. + test(BF16()); + test(float()); +} + +template