Fix nuq Enc() to handle groups < kGroupSize.

Also remove no longer required dynamic allocation.

PiperOrigin-RevId: 725203824
This commit is contained in:
Jan Wassenberg 2025-02-10 07:17:10 -08:00 committed by Copybara-Service
parent 5563d94811
commit 953c877658
2 changed files with 18 additions and 28 deletions

View File

@ -632,11 +632,12 @@ class NuqCodec {
const size_t N16 = hn::Lanes(d16);
HWY_ASSERT(packed_ofs % kGroupSize == 0);
HWY_ASSERT(kGroupSize % (4 * N16) == 0);
const size_t num_groups = hwy::DivCeil(num, kGroupSize);
// TODO: dynamic resize should be removed; it is no longer necessary as
// interleaved encoding uses only a single buffer of the same size.
buf.Resize(1);
HWY_ALIGN float g_centers[kClusters];
// Zero-initialize in case of remainders (g_num != kGroupSize).
HWY_ALIGN uint16_t g_idx[kGroupSize] = {};
size_t unused_clusters = 0;
size_t current_offset = packed_ofs;
@ -644,34 +645,35 @@ class NuqCodec {
const size_t g_num = HWY_MIN(num - g * kGroupSize, kGroupSize);
const float* HWY_RESTRICT g_in = raw + g * kGroupSize;
float* HWY_RESTRICT g_centers = buf.centers.get();
uint16_t* HWY_RESTRICT g_idx = buf.idx.get();
unused_clusters +=
NuqClustering::ClusterExactL2(df, g_in, g_num, buf, g_centers, g_idx);
uint8_t* centers = &packed.ptr->byte + TableByteOffset(current_offset);
SfpCodec::Enc(df, buf.centers.get(), kClusters,
SfpCodec::Enc(df, g_centers, kClusters,
reinterpret_cast<SfpStream*>(centers));
uint8_t* packed_start = centers + kClusters;
current_offset += g_num;
size_t i = 0;
HWY_UNROLL(1)
for (; i < g_num; i += 4 * N16) {
const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16);
const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16);
const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16);
const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16);
const V8 nibbles =
NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3);
hn::StoreU(nibbles, d8, packed_start + i / 2);
if (g_num >= 4 * N16) {
HWY_UNROLL(1)
for (; i <= g_num - 4 * N16; i += 4 * N16) {
const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16);
const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16);
const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16);
const V16 idx3 = hn::LoadU(d16, g_idx + i + 3 * N16);
const V8 nibbles =
NibbleCodec::OrderedPackU16(d16, idx0, idx1, idx2, idx3);
hn::StoreU(nibbles, d8, packed_start + i / 2);
}
}
const size_t remaining = g_num - i;
if (HWY_UNLIKELY(remaining != 0)) {
// Safe to load all 4 vectors: g_idx zero-initialized and its size is
// a multiple of 4 vectors.
const V16 idx0 = hn::LoadU(d16, g_idx + i + 0 * N16);
const V16 idx1 = hn::LoadU(d16, g_idx + i + 1 * N16);
const V16 idx2 = hn::LoadU(d16, g_idx + i + 2 * N16);

View File

@ -140,21 +140,9 @@ struct NuqStream {
ClusterBuf(ClusterBuf&&) = default;
ClusterBuf& operator=(ClusterBuf&&) = default;
void Resize(size_t new_num_groups) {
if (new_num_groups < num_groups) return;
num_groups = new_num_groups;
centers = hwy::AllocateAligned<float>(num_groups * kClusters);
idx = hwy::AllocateAligned<uint16_t>(num_groups * kGroupSize);
}
// Independent of num_groups.
AlignedMatrix<float> costs;
AlignedMatrix<int32_t> argmin;
size_t num_groups = 0;
hwy::AlignedFreeUniquePtr<float[]> centers;
hwy::AlignedFreeUniquePtr<uint16_t[]> idx;
};
// Returns offset of packed indices from the start of the stream. This matches