diff --git a/compression/nuq-inl.h b/compression/nuq-inl.h index c03ce30..14e3849 100644 --- a/compression/nuq-inl.h +++ b/compression/nuq-inl.h @@ -35,9 +35,10 @@ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE #endif +#include "hwy/highway.h" +// After highway.h #include "compression/sfp-inl.h" #include "hwy/contrib/sort/vqsort-inl.h" -#include "hwy/highway.h" #ifndef HWY_IF_CONSTEXPR #define HWY_IF_CONSTEXPR if @@ -77,92 +78,156 @@ class NuqClustering { // Cumulative sums for O(1) mean and interval sums. class ClusterCost { + // Ensures it is safe to load a vector from the last element. + static constexpr size_t kMaxLanes = hn::MaxLanes(hn::ScalableTag()); + + // Initialization value for table elements where `valid` is false. + static constexpr float kSentinel = -1.0f; + public: - explicit ClusterCost(const float* sorted) { + explicit ClusterCost(const float* HWY_RESTRICT sorted) { + double cumsum = 0.0; + double cumsum2 = 0.0; cumsum_[0] = cumsum2_[0] = 0.0; for (size_t i = 0; i < kGroupSize; ++i) { const float x = FloatPayload::Clear(sorted[i]); - cumsum_[1 + i] = x + cumsum_[i]; - cumsum2_[1 + i] = x * x + cumsum2_[i]; + cumsum += x; + cumsum2 += static_cast(x) * x; + cumsum_[1 + i] = static_cast(cumsum); + cumsum2_[1 + i] = static_cast(cumsum2); } - inv_len_[0] = 0.0f; // unused - for (size_t i = 0; i <= kGroupSize; ++i) { - inv_len_[i] = 1.0f / static_cast(i); + const hn::ScalableTag df; + using VF = hn::Vec; + const VF k1 = hn::Set(df, 1.0f); + const size_t N = hn::Lanes(df); + HWY_DASSERT(kGroupSize % N == 0); + + // Precomputed length and reciprocal. + for (size_t len = 0; len < kGroupSize; len += N) { + const VF vlen = hn::Iota(df, static_cast(len)); + hn::StoreU(vlen, df, len_ + kMaxLanes + len); + hn::StoreU(hn::Div(k1, vlen), df, inv_len_ + kMaxLanes + len); + } + // len = kGroupSize is legitimate, e.g., for all-equal weights. + len_[kMaxLanes + kGroupSize] = static_cast(kGroupSize); + inv_len_[kMaxLanes + kGroupSize] = 1.0f / static_cast(kGroupSize); + // len = 0 can happen, but valid is false for that lane. + len_[kMaxLanes + 0] = kSentinel; + inv_len_[kMaxLanes + 0] = kSentinel; + + // Ensure it is safe to load a vector from the last element. + for (size_t i = 0; i < kMaxLanes; ++i) { + constexpr size_t kEnd = kGroupSize + 1; + cumsum_[kEnd + i] = cumsum_[kGroupSize]; + cumsum2_[kEnd + i] = cumsum2_[kGroupSize]; + len_[kMaxLanes + kEnd + i] = len_[kMaxLanes + kGroupSize]; + inv_len_[kMaxLanes + kEnd + i] = inv_len_[kMaxLanes + kGroupSize]; + } + // For inv_len_ we also prepend MaxLanes in case first > last. + for (size_t i = 0; i < kMaxLanes; ++i) { + len_[i] = kSentinel; + inv_len_[i] = kSentinel; } } + // Returns cost (L2 norm) for a single cluster, used for backtracking. float SumOfSorted(size_t first, size_t last) const { return cumsum_[last + 1] - cumsum_[first]; } - // Returns cost of clustering first..last with their mean, for a vector of - // last. O(1) thanks to cumulative sums, which works for Lp-norms with p > - // 1; we choose p=2 for simplicity (fewer terms). - template - hn::Vec operator()(DF df, size_t first, size_t last) const { - // Callers are responsible for ignoring lanes where last < first. + // Returns vector of costs of clustering first..last + i with their means. + // O(1) thanks to cumulative sums, which works for Lp-norms with p > 1; we + // choose p=2 for simplicity (fewer terms). Caller ignores lanes where + // `!valid[i]`, i.e. `first > last + i`. + template , class VF = hn::Vec> + VF SumCosts(DF df, size_t first, size_t last, MF valid) const { HWY_DASSERT(first < kGroupSize); HWY_DASSERT(last < kGroupSize); - const int len = static_cast(last) - static_cast(first) + 1; - const hn::Vec vlen = hn::Iota(df, static_cast(len)); - const hn::Vec u_lo = hn::Set(df, cumsum_[first]); - const hn::Vec u_lo2 = hn::Set(df, cumsum2_[first]); - const hn::Vec hi = hn::LoadU(df, cumsum_ + last + 1); - const hn::Vec hi2 = hn::LoadU(df, cumsum2_ + last + 1); - const hn::Vec sum = hn::Sub(hi, u_lo); - const hn::Vec sum2 = hn::Sub(hi2, u_lo2); + VF inv_len; + const VF vlen = Lengths(df, first, last, valid, inv_len); - // Compute mean: table lookup is faster than division. - const hn::Vec mu = hn::Mul(sum, hn::LoadU(df, inv_len_ + len)); + const VF u_lo = hn::Set(df, cumsum_[first]); + const VF u_lo2 = hn::Set(df, cumsum2_[first]); + const VF hi = hn::LoadU(df, cumsum_ + last + 1); + const VF hi2 = hn::LoadU(df, cumsum2_ + last + 1); + const VF sum = hn::Sub(hi, u_lo); + const VF sum2 = hn::Sub(hi2, u_lo2); - // (x - mu)^2 = sum2 - 2mu*sum + mu^2 - const hn::Vec mu2 = hn::Mul(mu, mu); - const hn::Vec two_mu = hn::Add(mu, mu); - return hn::NegMulAdd(two_mu, sum, hn::MulAdd(vlen, mu2, sum2)); + // Sum of L2 over i in [first, last] = (x[i] - mu)^2. `sum` and `sum2` are + // the cumulative sums of x and x^2, so expand to `sum x^2 + sum x * -2 * + // mu + sum mu^2`. The last term is the sum of a constant, hence `mu^2 * + // len`. Thus we have: `sum2 + mu * (-2 * sum + mu * len)`. We avoid a + // (-)2 constant by adding. + const VF mu = hn::Mul(sum, inv_len); // mean := sum[i] / len[i] + const VF two_sum = hn::Add(sum, sum); + const VF l2 = hn::MulAdd(mu, hn::MulSub(mu, vlen, two_sum), sum2); + // mu can have some roundoff error. To avoid multiple redundant clusters, + // clamp to zero. + return hn::ZeroIfNegative(l2); } private: + // Returns precomputed lengths of [first, last + i] and their reciprocals. + template , class MF = hn::Mask> + VF Lengths(DF df, size_t first, size_t last, MF valid, VF& inv_len) const { + const int len = static_cast(last) - static_cast(first) + 1; + HWY_DASSERT(kMaxLanes + len >= 0); + HWY_DASSERT(len <= static_cast(kGroupSize)); + // last + i are contiguous, hence single loads instead of gather. + const VF vlen = hn::LoadU(df, len_ + kMaxLanes + len); + inv_len = hn::LoadU(df, inv_len_ + kMaxLanes + len); + + if constexpr (HWY_IS_DEBUG_BUILD) { + // Sanity check: no valid lanes are sentinels, all invalid are. + const VF sentinel = hn::Set(df, kSentinel); + const MF bad = hn::Eq(vlen, sentinel); + const MF inv_bad = hn::Eq(inv_len, sentinel); + HWY_DASSERT(hn::AllFalse(df, hn::And(valid, bad))); + HWY_DASSERT(hn::AllFalse(df, hn::And(valid, inv_bad))); + HWY_DASSERT(hn::AllTrue(df, hn::Or(valid, bad))); + HWY_DASSERT(hn::AllTrue(df, hn::Or(valid, inv_bad))); + } + + return vlen; + } + // Float has enough precision for our relatively small kGroupSize (256). - float cumsum_[kGroupSize + 1]; - float cumsum2_[kGroupSize + 1]; - float inv_len_[kGroupSize + 1]; + // Element i = sums of [0..i-1]. + float cumsum_[kGroupSize + 1 + kMaxLanes]; + float cumsum2_[kGroupSize + 1 + kMaxLanes]; + float len_[kMaxLanes + kGroupSize + 1 + kMaxLanes]; // = vlen[i] + float inv_len_[kMaxLanes + kGroupSize + 1 + kMaxLanes]; // = 1 / vlen[i] }; - // Cost of clustering 0..last, where the rightmost cluster is j..last. This is - // called in a loop over j, and we return the vector of costs for a batch of - // last = [last, last + N). - template - static HWY_INLINE hn::Vec ClusterDynProg( - DF df, const AlignedMatrix& D, const ClusterCost& cc, - const size_t num_clusters, const size_t last, const size_t j) { + // Dynamic programming step: returns costs of clustering 0..last+i, where the + // rightmost clusters start at `first`. Called for each `idx_cluster`, + // `first`, and `last`; vectorized across `last`. `first` may be greater than + // `last`. `valid[i]` is `first <= last + i`. + template , class MF = hn::Mask> + static HWY_INLINE VF ClusterDynProg(DF df, const AlignedMatrix& D, + const ClusterCost& cc, + const size_t idx_cluster, + const size_t first, const size_t last, + const MF valid) { + HWY_DASSERT(idx_cluster != 0); + HWY_DASSERT(0 != first && first < kGroupSize); HWY_DASSERT(last < kGroupSize); - HWY_DASSERT(0 != j && j < kGroupSize); + HWY_DASSERT(last % hn::Lanes(df) == 0); // Called in steps of N - const hn::RebindToSigned di; - using VF = hn::Vec; - using VI = hn::Vec; - using MI = hn::Mask; - - const VI vlast = hn::Iota(di, static_cast(last)); - - // We have a non-empty rightmost cluster if j <= last <==> j-1 < last. - const MI valid = hn::Lt(hn::Set(di, static_cast(j) - 1), vlast); - // If not valid, return an arbitrary high cost, which will not be the min. - const VF max = hn::Set(df, 1E38f); - // Cost of clustering 0..j-1 with one fewer cluster than now. - const VF vd = hn::Set(df, D(num_clusters - 1, j - 1)); - // Eq2: add to that the cost of another cluster from j..last. - return hn::MaskedAddOr(max, RebindMask(df, valid), vd, cc(df, j, last)); + // Cost of clustering 0..first-1 with one fewer cluster than now. + const VF prev = hn::Set(df, D(idx_cluster - 1, first - 1)); + // Eq2: add to that the cost of another cluster from first..last. + return hn::Add(prev, cc.SumCosts(df, first, last, valid)); } public: - // Clusters `kGroupSize` values in `x`, which need not be sorted already nor - // aligned, by choosing and filling `centers` (size `kClusters`, ascending - // order, not necessarily equal to one of the `x`). Fills `indices` with the - // index of the cluster to which each `x` belongs (16-bit for bit-packing). - // `buf` is per-thread. + // Clusters `num <= kGroupSize` values in `x`, which need not be sorted + // already nor aligned, by choosing and filling `centers` (size `kClusters`, + // ascending order, not necessarily equal to one of the `x`). Fills `indices` + // with the index of the cluster to which each `x` belongs (16-bit for + // bit-packing). `buf` is per-thread. // // Returns the number of unused clusters, i.e., the starting index within // `centers`; prior centers are zero-initialized. @@ -171,22 +236,36 @@ class NuqClustering { // that this is about 5 times as fast as the O(kClusters * kGroupSize) SMAWK // as implemented in FAISS, for our kGroupSize of 256. template - static HWY_NOINLINE size_t ClusterExactL2(DF df, const float* x, - ClusterBuf& buf, + static HWY_NOINLINE size_t ClusterExactL2(DF df, const float* HWY_RESTRICT x, + size_t num, ClusterBuf& buf, float* HWY_RESTRICT centers, uint16_t* HWY_RESTRICT indices) { + HWY_DASSERT(num <= kGroupSize); const hn::RebindToSigned di; using VF = hn::Vec; + using MF = hn::Mask; using VI = hn::Vec; const VI k1 = hn::Set(di, 1); const size_t N = hn::Lanes(df); + HWY_DASSERT(kGroupSize % N == 0); HWY_ALIGN float sorted_and_i[kGroupSize]; - for (size_t i = 0; i < kGroupSize; ++i) { + for (size_t i = 0; i < num; ++i) { sorted_and_i[i] = FloatPayload::Set(x[i], i); } + if (num != kGroupSize) { + // Initialize the rest of the group. Use an existing value so we do not + // waste a cluster on a sentinel value. Arbitrarily choose the largest. + float max = -1E38f; + for (size_t i = 0; i < num; ++i) { + max = HWY_MAX(max, x[i]); + } + for (size_t i = num; i < kGroupSize; ++i) { + sorted_and_i[i] = FloatPayload::Set(max, i); + } + } hn::VQSortStatic(sorted_and_i, kGroupSize, hwy::SortAscending()); - ClusterCost cc(sorted_and_i); + ClusterCost cc(sorted_and_i); // ignores payload bits. // Reference: https://arxiv.org/abs/1701.07204 // D[k-1][m] is the lowest cost of clustering x1..m into k clusters. @@ -194,29 +273,44 @@ class NuqClustering { // T[k][m] is the starting index within sorted_and_i[] of the k-th cluster. AlignedMatrix& T = buf.t; - // Initialize the first rows for a single cluster. - for (size_t last = 0; last < kGroupSize; last += N) { - hn::Store(cc(df, 0, last), df, &D(0, last)); // Cost of 0..last - hn::Store(Zero(di), di, &T(0, last)); // Cluster index = 0 + // Fill first row of `D` and `T`: single cluster, iterate over all `last`. + { + const size_t cluster_idx = 0; + const size_t first = 0; + const VI vfirst = hn::Set(di, static_cast(first)); + const MF all_valid = hn::FirstN(df, N); // first <= last is always true + for (size_t last = 0; last < kGroupSize; last += N) { + const VF costs = cc.SumCosts(df, first, last, all_valid); + hn::Store(costs, df, &D(cluster_idx, last)); + hn::Store(vfirst, di, &T(cluster_idx, last)); + } } - for (size_t num_clusters = 1; num_clusters < kClusters; ++num_clusters) { - // For each batch starting at `last`, one per lane: + for (size_t cluster_idx = 1; cluster_idx < kClusters; ++cluster_idx) { + // For vectors of `last + i` with `i < N`: for (size_t last = 0; last < kGroupSize; last += N) { - VF min = hn::LoadU(df, &D(0, last)); - VI arg = hn::Zero(di); - // For each j (start of rightmost cluster): - VI vj = k1; - for (size_t j = 1; j < last + N; ++j, vj = hn::Add(vj, k1)) { - const VF c = ClusterDynProg(df, D, cc, num_clusters, last, j); + const VI vlast = hn::Iota(di, static_cast(last)); + const VF prev_cost = hn::LoadU(df, &D(cluster_idx - 1, last)); + VF min = prev_cost; + VI arg = hn::LoadU(di, &T(cluster_idx - 1, last)); + // For each `first` (j), which is the start of the rightmost of at least + // two clusters, hence never zero. `first` also continues past `last` + // because the last `vlast` lane is `last + N - 1`. + for (size_t first = 1; first < last + N; ++first) { + const VI vfirst = hn::Set(di, static_cast(first)); + const MF valid = hn::RebindMask(df, hn::Le(vfirst, vlast)); + const VF c = + ClusterDynProg(df, D, cc, cluster_idx, first, last, valid); - // Retain the min cost and the j index that caused it. - const auto less = hn::Lt(c, min); + // Retain the min cost and the `first` that caused it. + const MF less = hn::And(valid, hn::Lt(c, min)); min = hn::IfThenElse(less, c, min); - arg = hn::IfThenElse(RebindMask(di, less), vj, arg); + arg = hn::IfThenElse(RebindMask(di, less), vfirst, arg); } - hn::Store(min, df, &D(num_clusters, last)); - hn::Store(arg, di, &T(num_clusters, last)); + HWY_DASSERT(hn::AllTrue(df, hn::Le(min, prev_cost))); + + hn::Store(min, df, &D(cluster_idx, last)); + hn::Store(arg, di, &T(cluster_idx, last)); } } @@ -490,8 +584,8 @@ class NuqCodec { const float* HWY_RESTRICT g_in = in + g * kGroupSize; float* HWY_RESTRICT g_centers = buf.centers.get() + g * kClusters; uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize; - unused_clusters += - NuqClustering::ClusterExactL2(df, g_in, buf, g_centers, g_idx); + unused_clusters += NuqClustering::ClusterExactL2(df, g_in, kGroupSize, + buf, g_centers, g_idx); } uint8_t* centers = &out->byte + ofs_groups * kClusters; diff --git a/compression/nuq_test.cc b/compression/nuq_test.cc index dbb0980..70c9119 100644 --- a/compression/nuq_test.cc +++ b/compression/nuq_test.cc @@ -66,8 +66,8 @@ struct TestFlat { ClusterBuf buf; float centers[kClusters]; uint16_t indices[kGroupSize]; - const size_t unused_clusters = - NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices); + const size_t unused_clusters = NuqClustering::ClusterExactL2( + df, in.get(), kGroupSize, buf, centers, indices); HWY_ASSERT(unused_clusters == kClusters - 1); for (size_t i = 0; i < unused_clusters; ++i) { @@ -108,8 +108,8 @@ struct TestPlateaus { ClusterBuf buf; float centers[kClusters]; uint16_t indices[kGroupSize]; - const size_t unused_clusters = - NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices); + const size_t unused_clusters = NuqClustering::ClusterExactL2( + df, in.get(), kGroupSize, buf, centers, indices); HWY_ASSERT(unused_clusters == 0); DistortionStats stats; @@ -155,8 +155,8 @@ struct TestRamp { ClusterBuf buf; float centers[kClusters]; uint16_t indices[kGroupSize]; - const size_t unused_clusters = - NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices); + const size_t unused_clusters = NuqClustering::ClusterExactL2( + df, in.get(), kGroupSize, buf, centers, indices); HWY_ASSERT(unused_clusters == 0); DistortionStats stats; @@ -203,8 +203,8 @@ struct TestNormal { double elapsed = hwy::HighestValue(); for (size_t rep = 0; rep < 100; ++rep) { const double t0 = hwy::platform::Now(); - const size_t unused_clusters = - NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices); + const size_t unused_clusters = NuqClustering::ClusterExactL2( + df, in.get(), kGroupSize, buf, centers, indices); HWY_ASSERT(unused_clusters == 0); const double t1 = hwy::platform::Now(); elapsed = HWY_MIN(elapsed, t1 - t0);