diff --git a/compression/compress_test.cc b/compression/compress_test.cc index 52883d4..0beba1a 100644 --- a/compression/compress_test.cc +++ b/compression/compress_test.cc @@ -81,9 +81,13 @@ struct TestDecompress2T { } if constexpr (false) { - fprintf(stderr, "%s %s: %zu: %f %f %f %f\n", TypeName(), - TypeName(), num, stats.SumL1(), stats.GeomeanValueDivL1(), - stats.WeightedAverageL1(), stats.L1().Max()); + fprintf(stderr, + "TypeName() %s TypeName() %s: num %zu: stats.SumL1() " + "%f stats.GeomeanValueDivL1() %f stats.WeightedAverageL1() %f " + "stats.L1().Max() %f\n", + TypeName(), TypeName(), num, stats.SumL1(), + stats.GeomeanValueDivL1(), stats.WeightedAverageL1(), + stats.L1().Max()); } constexpr bool kFromFloat = hwy::IsSame(); diff --git a/compression/nuq-inl.h b/compression/nuq-inl.h index 63c4255..d12a630 100644 --- a/compression/nuq-inl.h +++ b/compression/nuq-inl.h @@ -21,6 +21,8 @@ #include #include +#include + #include "compression/shared.h" #include "util/basics.h" #include "hwy/base.h" @@ -529,6 +531,14 @@ class NuqCodec { return (!HWY_HAVE_SCALABLE && du.MaxBytes() >= 32) ? 1 : 2; } + // Offset (in bytes) of a group's table for packed_ofs (in elements) within a + // set of groups. + static constexpr size_t TableByteOffset(size_t packed_ofs) { + const size_t kBytesPerGroup = + (kClusters * sizeof(SfpStream)) + kGroupSize / 2; + return (packed_ofs / kGroupSize) * kBytesPerGroup; + } + // Unpacks `centers` from SFP into bf16 and loads them into one or two vectors // for use by [Two]TableLookups. Returns as u16 because TableLookupLanes might // not be available for bf16. @@ -606,10 +616,11 @@ class NuqCodec { } public: - // Encodes `num` floats from `raw`. `packed` points to compressed storage and - // `packed_ofs` indicates the destination offset within it, in units of float - // values, for parallel encoding by multiple threads. Returns the total - // number of unused clusters, which is typically zero. + // Encodes `num` floats from `raw` into `packed`. `packed` points to + // compressed storage and `packed_ofs` indicates the destination offset within + // it, in number of elements. Tables are interleaved with indices (clustered + // elements) to allow for easier unpacking. Returns the total number of + // unused clusters, which is typically zero. template static HWY_INLINE size_t Enc(DF df, const float* HWY_RESTRICT raw, const size_t num, NuqStream::ClusterBuf& buf, @@ -622,71 +633,45 @@ class NuqCodec { const size_t N16 = hn::Lanes(d16); HWY_ASSERT(packed_ofs % kGroupSize == 0); - const size_t ofs_groups = packed_ofs / kGroupSize; + const size_t num_groups = hwy::DivCeil(num, kGroupSize); - buf.Resize(num_groups); + // TODO: dynamic resize should be removed; it is no longer necessary as + // interleaved encoding uses only a single buffer of the same size. + buf.Resize(1); size_t unused_clusters = 0; + size_t current_offset = packed_ofs; for (size_t g = 0; g < num_groups; ++g) { const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize); const float* HWY_RESTRICT g_in = raw + g * kGroupSize; - float* HWY_RESTRICT g_centers = buf.centers.get() + g * kClusters; - uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize; + + float* HWY_RESTRICT g_centers = buf.centers.get(); + uint16_t* HWY_RESTRICT g_idx = buf.idx.get(); + unused_clusters += NuqClustering::ClusterExactL2(df, g_in, g_num, buf, g_centers, g_idx); - } - uint8_t* centers = &packed.ptr->byte + ofs_groups * kClusters; - SfpCodec::Enc(df, buf.centers.get(), num_groups * kClusters, - reinterpret_cast(centers)); - uint8_t* packed_start = &packed.ptr->byte + - NuqStream::PackedStart(packed.num) + - ofs_groups * kGroupSize / 2; + uint8_t* centers = &packed.ptr->byte + TableByteOffset(current_offset); + SfpCodec::Enc(df, buf.centers.get(), kClusters, + reinterpret_cast(centers)); + uint8_t* packed_start = centers + kClusters; - // All but the last group have no remainders. - HWY_DASSERT(kGroupSize % (4 * N16) == 0); - HWY_UNROLL(1) - for (size_t g = 0; g < num_groups - 1; ++g) { - const uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize; - uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; + current_offset += g_num; + size_t i = 0; HWY_UNROLL(1) - for (size_t i = 0; i < kGroupSize; i += 4 * N16) { + for (; i < g_num; i += 4 * N16) { const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16); const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16); const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16); const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16); const V8 nibbles = NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3); - hn::StoreU(nibbles, d8, g_packed + i / 2); - } - } - - // Last group may have remainders. - { - HWY_DASSERT(num_groups != 0); - const size_t g = num_groups - 1; - const size_t g_num = num - g * kGroupSize; - HWY_DASSERT(g_num <= kGroupSize); - const uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize; - uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2; - - size_t i = 0; - if (g_num >= 4 * N16) { - HWY_UNROLL(1) - for (; i <= g_num - 4 * N16; i += 4 * N16) { - const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16); - const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16); - const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16); - const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16); - const V8 nibbles = - NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3); - hn::StoreU(nibbles, d8, g_packed + i / 2); - } + hn::StoreU(nibbles, d8, packed_start + i / 2); } const size_t remaining = g_num - i; - HWY_DASSERT(remaining < 4 * N16); + if (HWY_UNLIKELY(remaining != 0)) { const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16); const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16); @@ -695,10 +680,10 @@ class NuqCodec { const V8 nibbles = NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3); // i is even, but remaining might not be. - hn::StoreN(nibbles, d8, g_packed + i / 2, hwy::DivCeil(remaining, 2)); + hn::StoreN(nibbles, d8, packed_start + i / 2, + hwy::DivCeil(remaining, 2)); } } - return unused_clusters; } @@ -716,11 +701,8 @@ class NuqCodec { const size_t within_group = packed_ofs % kGroupSize; HWY_DASSERT(within_group % (2 * hn::Lanes(d16)) == 0); - const size_t ofs_in_groups = packed_ofs / kGroupSize; - const uint8_t* table = &packed.ptr->byte + ofs_in_groups * kClusters; - const uint8_t* indices = - &packed.ptr->byte + NuqStream::PackedStart(packed.num) + - hwy::DivCeil(ofs_in_groups * kGroupSize + within_group, 2); + const uint8_t* table = &packed.ptr->byte + TableByteOffset(packed_ofs); + const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2); V16 tbl1 = Zero(d16); const V16 tbl0 = LoadTable(d16, table, &tbl1); @@ -747,11 +729,8 @@ class NuqCodec { const size_t within_group = packed_ofs % kGroupSize; HWY_DASSERT(within_group % (2 * hn::Lanes(df)) == 0); - const size_t ofs_groups = packed_ofs / kGroupSize; - const uint8_t* table = &packed.ptr->byte + ofs_groups * kClusters; - const uint8_t* indices = - &packed.ptr->byte + NuqStream::PackedStart(packed.num) + - hwy::DivCeil(ofs_groups * kGroupSize + within_group, 2); + const uint8_t* table = &packed.ptr->byte + TableByteOffset(packed_ofs); + const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2); V16 tbl1 = Zero(d16); const V16 tbl0 = LoadTable(d16, table, &tbl1); @@ -760,51 +739,53 @@ class NuqCodec { // which expects a quarter vector of bytes. const V8Q nibbles = hn::LoadU(d8q, indices); + // TODO(janwas): From janwas: on AVX-512 I imagine we can get a + // bit more speed for this function by changing LoadTable to return floats, + // then we could have a single lookup here instead of PromoteUpperTo which + // is not cheap. const V16 c0 = TableLookups(d16, tbl0, tbl1, nibbles); raw0 = hn::PromoteLowerTo(df, BitCast(dbf, c0)); raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0)); } - // Decompresses from `packed`, starting at (any) `packed_ofs`, to (any) `num` - // elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as required to - // round `num` up to one vector, if it is not already. template > static HWY_INLINE void DecompressAndZeroPad( D d, const PackedSpan& packed, size_t packed_ofs, Raw* HWY_RESTRICT raw, size_t num) { // If unaligned, load elements from the first group and update the args, // from which we compute new tables/indices below. + size_t current_offset = packed_ofs; if (size_t within_group = packed_ofs % kGroupSize; within_group != 0) { - const size_t ofs_in_groups = packed_ofs / kGroupSize; - const uint8_t* tables = &packed.ptr->byte + ofs_in_groups * kClusters; - const uint8_t* indices = - &packed.ptr->byte + NuqStream::PackedStart(packed.num) + - hwy::DivCeil(ofs_in_groups * kGroupSize + within_group, 2); + const uint8_t* tables = + &packed.ptr->byte + TableByteOffset(current_offset); + const uint8_t* indices = tables + kClusters + within_group / 2; const size_t remaining = HWY_MIN(num, kGroupSize - within_group); + DecPartialGroup(d, tables, indices, raw, remaining); packed_ofs += remaining; + current_offset += remaining; raw += remaining; num -= remaining; if (num == 0) return; } HWY_DASSERT(packed_ofs % kGroupSize == 0); - const size_t ofs_in_groups = packed_ofs / kGroupSize; - const uint8_t* tables = &packed.ptr->byte + ofs_in_groups * kClusters; - const uint8_t* indices = &packed.ptr->byte + - NuqStream::PackedStart(packed.num) + - hwy::DivCeil(ofs_in_groups * kGroupSize, 2); const size_t num_groups = hwy::DivCeil(num, kGroupSize); HWY_UNROLL(1) for (size_t g = 0; g < num_groups - 1; ++g) { - DecWholeGroup(d, tables + g * kClusters, indices + g * kGroupSize / 2, - raw + g * kGroupSize); + const uint8_t* tables = + &packed.ptr->byte + TableByteOffset(current_offset); + const uint8_t* indices = tables + kClusters; + DecWholeGroup(d, tables, indices, raw + g * kGroupSize); + current_offset += kGroupSize; } const size_t g = num_groups - 1; - DecPartialGroup(d, tables + g * kClusters, indices + g * kGroupSize / 2, - raw + g * kGroupSize, num - g * kGroupSize); + const uint8_t* tables = &packed.ptr->byte + TableByteOffset(current_offset); + const uint8_t* indices = tables + kClusters; + DecPartialGroup(d, tables, indices, raw + g * kGroupSize, + num - g * kGroupSize); } private: @@ -955,6 +936,7 @@ class NuqCodec { } const size_t remaining = num - i; + HWY_DASSERT(remaining < 4 * NF); if (HWY_UNLIKELY(remaining != 0)) { // i is even, but remaining might not be. diff --git a/compression/nuq_test.cc b/compression/nuq_test.cc index 8cbce6c..6dd5982 100644 --- a/compression/nuq_test.cc +++ b/compression/nuq_test.cc @@ -277,6 +277,115 @@ struct TestOffset { void TestOffsetBF16() { hn::ForGEVectors<128, TestOffset>()(BF16()); } void TestOffsetF32() { hn::ForGEVectors<128, TestOffset>()(float()); } +// Can encode and decode sub-regions. Tests unaligned offsets - i.e. offsets +// within groups / that are not a multiple of the group size. +struct TestUnalignedOffset { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + const size_t total = 10 * kGroupSize; // already padded + + constexpr size_t kNumUnalignedOffsets = 4; + const std::array unaligned_offsets = { + 4, kGroupSize + 100, 2 * kGroupSize + 100, 3 * kGroupSize + 100}; + const std::array num = {4, 16, 32, 64}; + + for (int i = 0; i < kNumUnalignedOffsets; ++i) { + const size_t unaligned_offset = unaligned_offsets[i]; + const size_t num_decompressed = num[i]; + + auto in = hwy::AllocateAligned(total); // Enc() requires f32 + auto dec1 = hwy::AllocateAligned(total); + auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(total)); + auto dec2 = hwy::AllocateAligned(num_decompressed); + HWY_ASSERT(in && dec1 && dec2 && nuq); + const auto nuq_span = MakeSpan(nuq.get(), total); + + hwy::RandomState rng; + for (size_t i = 0; i < total; ++i) { + in[i] = static_cast(RandomGaussian(rng)); + } + + // Encode + decode everything + NuqStream::ClusterBuf buf; + (void)NuqCodec::Enc(df, in.get(), total, buf, nuq_span, 0); + NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), 0, dec1.get(), + total); + + NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), unaligned_offset, + dec2.get(), num_decompressed); + + for (size_t i = 0; i < num_decompressed; ++i) { + T expected = hwy::ConvertScalarTo(dec1[unaligned_offset + i]); + T actual = hwy::ConvertScalarTo(dec2[i]); + + HWY_ASSERT_EQ(expected, actual); + } + } + } +}; + +void TestUnalignedOffsetBF16() { + hn::ForGEVectors<128, TestUnalignedOffset>()(BF16()); +} +void TestUnalignedOffsetF32() { + hn::ForGEVectors<128, TestUnalignedOffset>()(float()); +} + +// Can encode and decode sub-regions. +// Uses Dec2 to decode all elements in the packed buffer, then +// compares against DecompressAndZeroPad. +struct TestDec2 { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + const hn::Repartition df; + // incl. partial group to test partial group handling + const size_t total = 2 * kGroupSize + (kGroupSize / 2); + const size_t kMidLen = 2 * kGroupSize; // length of middle piece + + auto in = hwy::AllocateAligned(total); // Enc() requires f32 + auto dec0 = hwy::AllocateAligned(total); + auto dec1 = hwy::AllocateAligned(total); + auto dec2 = hwy::AllocateAligned(kMidLen); + auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(total)); + HWY_ASSERT(in && dec0 && dec1 && dec2 && nuq); + const auto nuq_span = MakeSpan(nuq.get(), total); + + hwy::RandomState rng; + for (size_t i = 0; i < total; ++i) { + in[i] = static_cast(RandomGaussian(rng)); + } + + // Non-interleaved encode + decode for comparison + NuqStream::ClusterBuf buf0; + (void)NuqCodec::Enc(df, in.get(), total, buf0, nuq_span, 0); + NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), 0, dec0.get(), + total); + + // Encode + decode everything + NuqStream::ClusterBuf buf; + (void)NuqCodec::Enc(df, in.get(), total, buf, nuq_span, 0); + + using V = hn::Vec; + const size_t N = Lanes(d); + + for (size_t i = 0; i < total; i += 2 * N) { + V f0, f1; + NuqCodec::Dec2(d, MakeConst(nuq_span), i, f0, f1); + + hn::StoreU(f0, d, dec1.get() + i + 0 * N); + hn::StoreU(f1, d, dec1.get() + i + 1 * N); + } + + for (size_t i = 0; i < total; ++i) { + HWY_ASSERT(dec0[i] == dec1[i]); + } + } +}; + +void TestDec2BF16() { hn::ForGEVectors<128, TestDec2>()(BF16()); } +void TestDec2F32() { hn::ForGEVectors<128, TestDec2>()(float()); } + struct TestNibble { template HWY_INLINE void operator()(T /*unused*/, D d) { @@ -409,6 +518,10 @@ HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNormal); HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetBF16); HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetF32); +HWY_EXPORT_AND_TEST_P(NuqTest, TestDec2BF16); +HWY_EXPORT_AND_TEST_P(NuqTest, TestDec2F32); +HWY_EXPORT_AND_TEST_P(NuqTest, TestUnalignedOffsetBF16); +HWY_EXPORT_AND_TEST_P(NuqTest, TestUnalignedOffsetF32); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble); HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecBF16); HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecF32);