Base interleaved handling for 4.5-bit NUQ, specifically Enc, DecompressAndZeroPad, and Dec2. Includes tests.

PiperOrigin-RevId: 721821577
This commit is contained in:
Phil Culliton 2025-01-31 10:34:57 -08:00 committed by Copybara-Service
parent 23dac72463
commit 8a6edff319
3 changed files with 179 additions and 80 deletions

View File

@ -81,9 +81,13 @@ struct TestDecompress2T {
} }
if constexpr (false) { if constexpr (false) {
fprintf(stderr, "%s %s: %zu: %f %f %f %f\n", TypeName<Packed>(), fprintf(stderr,
TypeName<T>(), num, stats.SumL1(), stats.GeomeanValueDivL1(), "TypeName<Packed>() %s TypeName<T>() %s: num %zu: stats.SumL1() "
stats.WeightedAverageL1(), stats.L1().Max()); "%f stats.GeomeanValueDivL1() %f stats.WeightedAverageL1() %f "
"stats.L1().Max() %f\n",
TypeName<Packed>(), TypeName<T>(), num, stats.SumL1(),
stats.GeomeanValueDivL1(), stats.WeightedAverageL1(),
stats.L1().Max());
} }
constexpr bool kFromFloat = hwy::IsSame<Packed, float>(); constexpr bool kFromFloat = hwy::IsSame<Packed, float>();

View File

@ -21,6 +21,8 @@
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
#include <cstdio>
#include "compression/shared.h" #include "compression/shared.h"
#include "util/basics.h" #include "util/basics.h"
#include "hwy/base.h" #include "hwy/base.h"
@ -529,6 +531,14 @@ class NuqCodec {
return (!HWY_HAVE_SCALABLE && du.MaxBytes() >= 32) ? 1 : 2; return (!HWY_HAVE_SCALABLE && du.MaxBytes() >= 32) ? 1 : 2;
} }
// Offset (in bytes) of a group's table for packed_ofs (in elements) within a
// set of groups.
static constexpr size_t TableByteOffset(size_t packed_ofs) {
const size_t kBytesPerGroup =
(kClusters * sizeof(SfpStream)) + kGroupSize / 2;
return (packed_ofs / kGroupSize) * kBytesPerGroup;
}
// Unpacks `centers` from SFP into bf16 and loads them into one or two vectors // Unpacks `centers` from SFP into bf16 and loads them into one or two vectors
// for use by [Two]TableLookups. Returns as u16 because TableLookupLanes might // for use by [Two]TableLookups. Returns as u16 because TableLookupLanes might
// not be available for bf16. // not be available for bf16.
@ -606,10 +616,11 @@ class NuqCodec {
} }
public: public:
// Encodes `num` floats from `raw`. `packed` points to compressed storage and // Encodes `num` floats from `raw` into `packed`. `packed` points to
// `packed_ofs` indicates the destination offset within it, in units of float // compressed storage and `packed_ofs` indicates the destination offset within
// values, for parallel encoding by multiple threads. Returns the total // it, in number of elements. Tables are interleaved with indices (clustered
// number of unused clusters, which is typically zero. // elements) to allow for easier unpacking. Returns the total number of
// unused clusters, which is typically zero.
template <class DF, HWY_IF_F32_D(DF)> template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE size_t Enc(DF df, const float* HWY_RESTRICT raw, static HWY_INLINE size_t Enc(DF df, const float* HWY_RESTRICT raw,
const size_t num, NuqStream::ClusterBuf& buf, const size_t num, NuqStream::ClusterBuf& buf,
@ -622,71 +633,45 @@ class NuqCodec {
const size_t N16 = hn::Lanes(d16); const size_t N16 = hn::Lanes(d16);
HWY_ASSERT(packed_ofs % kGroupSize == 0); HWY_ASSERT(packed_ofs % kGroupSize == 0);
const size_t ofs_groups = packed_ofs / kGroupSize;
const size_t num_groups = hwy::DivCeil(num, kGroupSize); const size_t num_groups = hwy::DivCeil(num, kGroupSize);
buf.Resize(num_groups); // TODO: dynamic resize should be removed; it is no longer necessary as
// interleaved encoding uses only a single buffer of the same size.
buf.Resize(1);
size_t unused_clusters = 0; size_t unused_clusters = 0;
size_t current_offset = packed_ofs;
for (size_t g = 0; g < num_groups; ++g) { for (size_t g = 0; g < num_groups; ++g) {
const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize); const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize);
const float* HWY_RESTRICT g_in = raw + g * kGroupSize; const float* HWY_RESTRICT g_in = raw + g * kGroupSize;
float* HWY_RESTRICT g_centers = buf.centers.get() + g * kClusters;
uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize; float* HWY_RESTRICT g_centers = buf.centers.get();
uint16_t* HWY_RESTRICT g_idx = buf.idx.get();
unused_clusters += unused_clusters +=
NuqClustering::ClusterExactL2(df, g_in, g_num, buf, g_centers, g_idx); NuqClustering::ClusterExactL2(df, g_in, g_num, buf, g_centers, g_idx);
}
uint8_t* centers = &packed.ptr->byte + ofs_groups * kClusters; uint8_t* centers = &packed.ptr->byte + TableByteOffset(current_offset);
SfpCodec::Enc(df, buf.centers.get(), num_groups * kClusters, SfpCodec::Enc(df, buf.centers.get(), kClusters,
reinterpret_cast<SfpStream*>(centers)); reinterpret_cast<SfpStream*>(centers));
uint8_t* packed_start = &packed.ptr->byte + uint8_t* packed_start = centers + kClusters;
NuqStream::PackedStart(packed.num) +
ofs_groups * kGroupSize / 2;
// All but the last group have no remainders. current_offset += g_num;
HWY_DASSERT(kGroupSize % (4 * N16) == 0);
HWY_UNROLL(1)
for (size_t g = 0; g < num_groups - 1; ++g) {
const uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize;
uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2;
size_t i = 0;
HWY_UNROLL(1) HWY_UNROLL(1)
for (size_t i = 0; i < kGroupSize; i += 4 * N16) { for (; i < g_num; i += 4 * N16) {
const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16); const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16);
const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16); const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16);
const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16); const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16);
const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16); const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16);
const V8 nibbles = const V8 nibbles =
NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3); NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3);
hn::StoreU(nibbles, d8, g_packed + i / 2); hn::StoreU(nibbles, d8, packed_start + i / 2);
}
}
// Last group may have remainders.
{
HWY_DASSERT(num_groups != 0);
const size_t g = num_groups - 1;
const size_t g_num = num - g * kGroupSize;
HWY_DASSERT(g_num <= kGroupSize);
const uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize;
uint8_t* HWY_RESTRICT g_packed = packed_start + g * kGroupSize / 2;
size_t i = 0;
if (g_num >= 4 * N16) {
HWY_UNROLL(1)
for (; i <= g_num - 4 * N16; i += 4 * N16) {
const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16);
const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16);
const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16);
const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16);
const V8 nibbles =
NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3);
hn::StoreU(nibbles, d8, g_packed + i / 2);
}
} }
const size_t remaining = g_num - i; const size_t remaining = g_num - i;
HWY_DASSERT(remaining < 4 * N16);
if (HWY_UNLIKELY(remaining != 0)) { if (HWY_UNLIKELY(remaining != 0)) {
const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16); const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16);
const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16); const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16);
@ -695,10 +680,10 @@ class NuqCodec {
const V8 nibbles = const V8 nibbles =
NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3); NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3);
// i is even, but remaining might not be. // i is even, but remaining might not be.
hn::StoreN(nibbles, d8, g_packed + i / 2, hwy::DivCeil(remaining, 2)); hn::StoreN(nibbles, d8, packed_start + i / 2,
hwy::DivCeil(remaining, 2));
} }
} }
return unused_clusters; return unused_clusters;
} }
@ -716,11 +701,8 @@ class NuqCodec {
const size_t within_group = packed_ofs % kGroupSize; const size_t within_group = packed_ofs % kGroupSize;
HWY_DASSERT(within_group % (2 * hn::Lanes(d16)) == 0); HWY_DASSERT(within_group % (2 * hn::Lanes(d16)) == 0);
const size_t ofs_in_groups = packed_ofs / kGroupSize; const uint8_t* table = &packed.ptr->byte + TableByteOffset(packed_ofs);
const uint8_t* table = &packed.ptr->byte + ofs_in_groups * kClusters; const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2);
const uint8_t* indices =
&packed.ptr->byte + NuqStream::PackedStart(packed.num) +
hwy::DivCeil(ofs_in_groups * kGroupSize + within_group, 2);
V16 tbl1 = Zero(d16); V16 tbl1 = Zero(d16);
const V16 tbl0 = LoadTable(d16, table, &tbl1); const V16 tbl0 = LoadTable(d16, table, &tbl1);
@ -747,11 +729,8 @@ class NuqCodec {
const size_t within_group = packed_ofs % kGroupSize; const size_t within_group = packed_ofs % kGroupSize;
HWY_DASSERT(within_group % (2 * hn::Lanes(df)) == 0); HWY_DASSERT(within_group % (2 * hn::Lanes(df)) == 0);
const size_t ofs_groups = packed_ofs / kGroupSize; const uint8_t* table = &packed.ptr->byte + TableByteOffset(packed_ofs);
const uint8_t* table = &packed.ptr->byte + ofs_groups * kClusters; const uint8_t* indices = table + kClusters + hwy::DivCeil(within_group, 2);
const uint8_t* indices =
&packed.ptr->byte + NuqStream::PackedStart(packed.num) +
hwy::DivCeil(ofs_groups * kGroupSize + within_group, 2);
V16 tbl1 = Zero(d16); V16 tbl1 = Zero(d16);
const V16 tbl0 = LoadTable(d16, table, &tbl1); const V16 tbl0 = LoadTable(d16, table, &tbl1);
@ -760,51 +739,53 @@ class NuqCodec {
// which expects a quarter vector of bytes. // which expects a quarter vector of bytes.
const V8Q nibbles = hn::LoadU(d8q, indices); const V8Q nibbles = hn::LoadU(d8q, indices);
// TODO(janwas): From janwas: on AVX-512 I imagine we can get a
// bit more speed for this function by changing LoadTable to return floats,
// then we could have a single lookup here instead of PromoteUpperTo which
// is not cheap.
const V16 c0 = TableLookups(d16, tbl0, tbl1, nibbles); const V16 c0 = TableLookups(d16, tbl0, tbl1, nibbles);
raw0 = hn::PromoteLowerTo(df, BitCast(dbf, c0)); raw0 = hn::PromoteLowerTo(df, BitCast(dbf, c0));
raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0)); raw1 = hn::PromoteUpperTo(df, BitCast(dbf, c0));
} }
// Decompresses from `packed`, starting at (any) `packed_ofs`, to (any) `num`
// elements in `raw`, then appends `[0, hn::Lanes(d))` zeroes as required to
// round `num` up to one vector, if it is not already.
template <class D, typename Raw = hn::TFromD<D>> template <class D, typename Raw = hn::TFromD<D>>
static HWY_INLINE void DecompressAndZeroPad( static HWY_INLINE void DecompressAndZeroPad(
D d, const PackedSpan<const NuqStream>& packed, size_t packed_ofs, D d, const PackedSpan<const NuqStream>& packed, size_t packed_ofs,
Raw* HWY_RESTRICT raw, size_t num) { Raw* HWY_RESTRICT raw, size_t num) {
// If unaligned, load elements from the first group and update the args, // If unaligned, load elements from the first group and update the args,
// from which we compute new tables/indices below. // from which we compute new tables/indices below.
size_t current_offset = packed_ofs;
if (size_t within_group = packed_ofs % kGroupSize; within_group != 0) { if (size_t within_group = packed_ofs % kGroupSize; within_group != 0) {
const size_t ofs_in_groups = packed_ofs / kGroupSize; const uint8_t* tables =
const uint8_t* tables = &packed.ptr->byte + ofs_in_groups * kClusters; &packed.ptr->byte + TableByteOffset(current_offset);
const uint8_t* indices = const uint8_t* indices = tables + kClusters + within_group / 2;
&packed.ptr->byte + NuqStream::PackedStart(packed.num) +
hwy::DivCeil(ofs_in_groups * kGroupSize + within_group, 2);
const size_t remaining = HWY_MIN(num, kGroupSize - within_group); const size_t remaining = HWY_MIN(num, kGroupSize - within_group);
DecPartialGroup(d, tables, indices, raw, remaining); DecPartialGroup(d, tables, indices, raw, remaining);
packed_ofs += remaining; packed_ofs += remaining;
current_offset += remaining;
raw += remaining; raw += remaining;
num -= remaining; num -= remaining;
if (num == 0) return; if (num == 0) return;
} }
HWY_DASSERT(packed_ofs % kGroupSize == 0); HWY_DASSERT(packed_ofs % kGroupSize == 0);
const size_t ofs_in_groups = packed_ofs / kGroupSize;
const uint8_t* tables = &packed.ptr->byte + ofs_in_groups * kClusters;
const uint8_t* indices = &packed.ptr->byte +
NuqStream::PackedStart(packed.num) +
hwy::DivCeil(ofs_in_groups * kGroupSize, 2);
const size_t num_groups = hwy::DivCeil(num, kGroupSize); const size_t num_groups = hwy::DivCeil(num, kGroupSize);
HWY_UNROLL(1) HWY_UNROLL(1)
for (size_t g = 0; g < num_groups - 1; ++g) { for (size_t g = 0; g < num_groups - 1; ++g) {
DecWholeGroup(d, tables + g * kClusters, indices + g * kGroupSize / 2, const uint8_t* tables =
raw + g * kGroupSize); &packed.ptr->byte + TableByteOffset(current_offset);
const uint8_t* indices = tables + kClusters;
DecWholeGroup(d, tables, indices, raw + g * kGroupSize);
current_offset += kGroupSize;
} }
const size_t g = num_groups - 1; const size_t g = num_groups - 1;
DecPartialGroup(d, tables + g * kClusters, indices + g * kGroupSize / 2, const uint8_t* tables = &packed.ptr->byte + TableByteOffset(current_offset);
raw + g * kGroupSize, num - g * kGroupSize); const uint8_t* indices = tables + kClusters;
DecPartialGroup(d, tables, indices, raw + g * kGroupSize,
num - g * kGroupSize);
} }
private: private:
@ -955,6 +936,7 @@ class NuqCodec {
} }
const size_t remaining = num - i; const size_t remaining = num - i;
HWY_DASSERT(remaining < 4 * NF); HWY_DASSERT(remaining < 4 * NF);
if (HWY_UNLIKELY(remaining != 0)) { if (HWY_UNLIKELY(remaining != 0)) {
// i is even, but remaining might not be. // i is even, but remaining might not be.

View File

@ -277,6 +277,115 @@ struct TestOffset {
void TestOffsetBF16() { hn::ForGEVectors<128, TestOffset>()(BF16()); } void TestOffsetBF16() { hn::ForGEVectors<128, TestOffset>()(BF16()); }
void TestOffsetF32() { hn::ForGEVectors<128, TestOffset>()(float()); } void TestOffsetF32() { hn::ForGEVectors<128, TestOffset>()(float()); }
// Can encode and decode sub-regions. Tests unaligned offsets - i.e. offsets
// within groups / that are not a multiple of the group size.
struct TestUnalignedOffset {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::Repartition<float, D> df;
const size_t total = 10 * kGroupSize; // already padded
constexpr size_t kNumUnalignedOffsets = 4;
const std::array<size_t, kNumUnalignedOffsets> unaligned_offsets = {
4, kGroupSize + 100, 2 * kGroupSize + 100, 3 * kGroupSize + 100};
const std::array<size_t, kNumUnalignedOffsets> num = {4, 16, 32, 64};
for (int i = 0; i < kNumUnalignedOffsets; ++i) {
const size_t unaligned_offset = unaligned_offsets[i];
const size_t num_decompressed = num[i];
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec1 = hwy::AllocateAligned<T>(total);
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
auto dec2 = hwy::AllocateAligned<T>(num_decompressed);
HWY_ASSERT(in && dec1 && dec2 && nuq);
const auto nuq_span = MakeSpan(nuq.get(), total);
hwy::RandomState rng;
for (size_t i = 0; i < total; ++i) {
in[i] = static_cast<float>(RandomGaussian(rng));
}
// Encode + decode everything
NuqStream::ClusterBuf buf;
(void)NuqCodec::Enc(df, in.get(), total, buf, nuq_span, 0);
NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), 0, dec1.get(),
total);
NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), unaligned_offset,
dec2.get(), num_decompressed);
for (size_t i = 0; i < num_decompressed; ++i) {
T expected = hwy::ConvertScalarTo<T>(dec1[unaligned_offset + i]);
T actual = hwy::ConvertScalarTo<T>(dec2[i]);
HWY_ASSERT_EQ(expected, actual);
}
}
}
};
void TestUnalignedOffsetBF16() {
hn::ForGEVectors<128, TestUnalignedOffset>()(BF16());
}
void TestUnalignedOffsetF32() {
hn::ForGEVectors<128, TestUnalignedOffset>()(float());
}
// Can encode and decode sub-regions.
// Uses Dec2 to decode all elements in the packed buffer, then
// compares against DecompressAndZeroPad.
struct TestDec2 {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::Repartition<float, D> df;
// incl. partial group to test partial group handling
const size_t total = 2 * kGroupSize + (kGroupSize / 2);
const size_t kMidLen = 2 * kGroupSize; // length of middle piece
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec0 = hwy::AllocateAligned<T>(total);
auto dec1 = hwy::AllocateAligned<T>(total);
auto dec2 = hwy::AllocateAligned<T>(kMidLen);
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
HWY_ASSERT(in && dec0 && dec1 && dec2 && nuq);
const auto nuq_span = MakeSpan(nuq.get(), total);
hwy::RandomState rng;
for (size_t i = 0; i < total; ++i) {
in[i] = static_cast<float>(RandomGaussian(rng));
}
// Non-interleaved encode + decode for comparison
NuqStream::ClusterBuf buf0;
(void)NuqCodec::Enc(df, in.get(), total, buf0, nuq_span, 0);
NuqCodec::DecompressAndZeroPad(d, MakeConst(nuq_span), 0, dec0.get(),
total);
// Encode + decode everything
NuqStream::ClusterBuf buf;
(void)NuqCodec::Enc(df, in.get(), total, buf, nuq_span, 0);
using V = hn::Vec<decltype(d)>;
const size_t N = Lanes(d);
for (size_t i = 0; i < total; i += 2 * N) {
V f0, f1;
NuqCodec::Dec2(d, MakeConst(nuq_span), i, f0, f1);
hn::StoreU(f0, d, dec1.get() + i + 0 * N);
hn::StoreU(f1, d, dec1.get() + i + 1 * N);
}
for (size_t i = 0; i < total; ++i) {
HWY_ASSERT(dec0[i] == dec1[i]);
}
}
};
void TestDec2BF16() { hn::ForGEVectors<128, TestDec2>()(BF16()); }
void TestDec2F32() { hn::ForGEVectors<128, TestDec2>()(float()); }
struct TestNibble { struct TestNibble {
template <typename T, class D> template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) { HWY_INLINE void operator()(T /*unused*/, D d) {
@ -409,6 +518,10 @@ HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNormal); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNormal);
HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetBF16); HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetBF16);
HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetF32); HWY_EXPORT_AND_TEST_P(NuqTest, TestOffsetF32);
HWY_EXPORT_AND_TEST_P(NuqTest, TestDec2BF16);
HWY_EXPORT_AND_TEST_P(NuqTest, TestDec2F32);
HWY_EXPORT_AND_TEST_P(NuqTest, TestUnalignedOffsetBF16);
HWY_EXPORT_AND_TEST_P(NuqTest, TestUnalignedOffsetF32);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble);
HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecBF16); HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecBF16);
HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecF32); HWY_EXPORT_AND_TEST_P(NuqTest, TestEncDecF32);