mirror of https://github.com/google/gemma.cpp.git
1.22x NUQ compress speedup, fix out of bounds access, improve numerics
Also clarify the cost computation and move toward non-group-multiple num. PiperOrigin-RevId: 670544245
This commit is contained in:
parent
437e0eb9af
commit
aa11ddf5fc
|
|
@ -35,9 +35,10 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/sfp-inl.h"
|
||||
#include "hwy/contrib/sort/vqsort-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
#ifndef HWY_IF_CONSTEXPR
|
||||
#define HWY_IF_CONSTEXPR if
|
||||
|
|
@ -77,92 +78,156 @@ class NuqClustering {
|
|||
|
||||
// Cumulative sums for O(1) mean and interval sums.
|
||||
class ClusterCost {
|
||||
// Ensures it is safe to load a vector from the last element.
|
||||
static constexpr size_t kMaxLanes = hn::MaxLanes(hn::ScalableTag<float>());
|
||||
|
||||
// Initialization value for table elements where `valid` is false.
|
||||
static constexpr float kSentinel = -1.0f;
|
||||
|
||||
public:
|
||||
explicit ClusterCost(const float* sorted) {
|
||||
explicit ClusterCost(const float* HWY_RESTRICT sorted) {
|
||||
double cumsum = 0.0;
|
||||
double cumsum2 = 0.0;
|
||||
cumsum_[0] = cumsum2_[0] = 0.0;
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
const float x = FloatPayload::Clear(sorted[i]);
|
||||
cumsum_[1 + i] = x + cumsum_[i];
|
||||
cumsum2_[1 + i] = x * x + cumsum2_[i];
|
||||
cumsum += x;
|
||||
cumsum2 += static_cast<double>(x) * x;
|
||||
cumsum_[1 + i] = static_cast<float>(cumsum);
|
||||
cumsum2_[1 + i] = static_cast<float>(cumsum2);
|
||||
}
|
||||
|
||||
inv_len_[0] = 0.0f; // unused
|
||||
for (size_t i = 0; i <= kGroupSize; ++i) {
|
||||
inv_len_[i] = 1.0f / static_cast<float>(i);
|
||||
const hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
const VF k1 = hn::Set(df, 1.0f);
|
||||
const size_t N = hn::Lanes(df);
|
||||
HWY_DASSERT(kGroupSize % N == 0);
|
||||
|
||||
// Precomputed length and reciprocal.
|
||||
for (size_t len = 0; len < kGroupSize; len += N) {
|
||||
const VF vlen = hn::Iota(df, static_cast<int32_t>(len));
|
||||
hn::StoreU(vlen, df, len_ + kMaxLanes + len);
|
||||
hn::StoreU(hn::Div(k1, vlen), df, inv_len_ + kMaxLanes + len);
|
||||
}
|
||||
// len = kGroupSize is legitimate, e.g., for all-equal weights.
|
||||
len_[kMaxLanes + kGroupSize] = static_cast<float>(kGroupSize);
|
||||
inv_len_[kMaxLanes + kGroupSize] = 1.0f / static_cast<float>(kGroupSize);
|
||||
// len = 0 can happen, but valid is false for that lane.
|
||||
len_[kMaxLanes + 0] = kSentinel;
|
||||
inv_len_[kMaxLanes + 0] = kSentinel;
|
||||
|
||||
// Ensure it is safe to load a vector from the last element.
|
||||
for (size_t i = 0; i < kMaxLanes; ++i) {
|
||||
constexpr size_t kEnd = kGroupSize + 1;
|
||||
cumsum_[kEnd + i] = cumsum_[kGroupSize];
|
||||
cumsum2_[kEnd + i] = cumsum2_[kGroupSize];
|
||||
len_[kMaxLanes + kEnd + i] = len_[kMaxLanes + kGroupSize];
|
||||
inv_len_[kMaxLanes + kEnd + i] = inv_len_[kMaxLanes + kGroupSize];
|
||||
}
|
||||
// For inv_len_ we also prepend MaxLanes in case first > last.
|
||||
for (size_t i = 0; i < kMaxLanes; ++i) {
|
||||
len_[i] = kSentinel;
|
||||
inv_len_[i] = kSentinel;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns cost (L2 norm) for a single cluster, used for backtracking.
|
||||
float SumOfSorted(size_t first, size_t last) const {
|
||||
return cumsum_[last + 1] - cumsum_[first];
|
||||
}
|
||||
|
||||
// Returns cost of clustering first..last with their mean, for a vector of
|
||||
// last. O(1) thanks to cumulative sums, which works for Lp-norms with p >
|
||||
// 1; we choose p=2 for simplicity (fewer terms).
|
||||
template <class DF>
|
||||
hn::Vec<DF> operator()(DF df, size_t first, size_t last) const {
|
||||
// Callers are responsible for ignoring lanes where last < first.
|
||||
// Returns vector of costs of clustering first..last + i with their means.
|
||||
// O(1) thanks to cumulative sums, which works for Lp-norms with p > 1; we
|
||||
// choose p=2 for simplicity (fewer terms). Caller ignores lanes where
|
||||
// `!valid[i]`, i.e. `first > last + i`.
|
||||
template <class DF, class MF = hn::Mask<DF>, class VF = hn::Vec<DF>>
|
||||
VF SumCosts(DF df, size_t first, size_t last, MF valid) const {
|
||||
HWY_DASSERT(first < kGroupSize);
|
||||
HWY_DASSERT(last < kGroupSize);
|
||||
const int len = static_cast<int>(last) - static_cast<int>(first) + 1;
|
||||
const hn::Vec<DF> vlen = hn::Iota(df, static_cast<float>(len));
|
||||
|
||||
const hn::Vec<DF> u_lo = hn::Set(df, cumsum_[first]);
|
||||
const hn::Vec<DF> u_lo2 = hn::Set(df, cumsum2_[first]);
|
||||
const hn::Vec<DF> hi = hn::LoadU(df, cumsum_ + last + 1);
|
||||
const hn::Vec<DF> hi2 = hn::LoadU(df, cumsum2_ + last + 1);
|
||||
const hn::Vec<DF> sum = hn::Sub(hi, u_lo);
|
||||
const hn::Vec<DF> sum2 = hn::Sub(hi2, u_lo2);
|
||||
VF inv_len;
|
||||
const VF vlen = Lengths(df, first, last, valid, inv_len);
|
||||
|
||||
// Compute mean: table lookup is faster than division.
|
||||
const hn::Vec<DF> mu = hn::Mul(sum, hn::LoadU(df, inv_len_ + len));
|
||||
const VF u_lo = hn::Set(df, cumsum_[first]);
|
||||
const VF u_lo2 = hn::Set(df, cumsum2_[first]);
|
||||
const VF hi = hn::LoadU(df, cumsum_ + last + 1);
|
||||
const VF hi2 = hn::LoadU(df, cumsum2_ + last + 1);
|
||||
const VF sum = hn::Sub(hi, u_lo);
|
||||
const VF sum2 = hn::Sub(hi2, u_lo2);
|
||||
|
||||
// (x - mu)^2 = sum2 - 2mu*sum + mu^2
|
||||
const hn::Vec<DF> mu2 = hn::Mul(mu, mu);
|
||||
const hn::Vec<DF> two_mu = hn::Add(mu, mu);
|
||||
return hn::NegMulAdd(two_mu, sum, hn::MulAdd(vlen, mu2, sum2));
|
||||
// Sum of L2 over i in [first, last] = (x[i] - mu)^2. `sum` and `sum2` are
|
||||
// the cumulative sums of x and x^2, so expand to `sum x^2 + sum x * -2 *
|
||||
// mu + sum mu^2`. The last term is the sum of a constant, hence `mu^2 *
|
||||
// len`. Thus we have: `sum2 + mu * (-2 * sum + mu * len)`. We avoid a
|
||||
// (-)2 constant by adding.
|
||||
const VF mu = hn::Mul(sum, inv_len); // mean := sum[i] / len[i]
|
||||
const VF two_sum = hn::Add(sum, sum);
|
||||
const VF l2 = hn::MulAdd(mu, hn::MulSub(mu, vlen, two_sum), sum2);
|
||||
// mu can have some roundoff error. To avoid multiple redundant clusters,
|
||||
// clamp to zero.
|
||||
return hn::ZeroIfNegative(l2);
|
||||
}
|
||||
|
||||
private:
|
||||
// Returns precomputed lengths of [first, last + i] and their reciprocals.
|
||||
template <class DF, class VF = hn::Vec<DF>, class MF = hn::Mask<DF>>
|
||||
VF Lengths(DF df, size_t first, size_t last, MF valid, VF& inv_len) const {
|
||||
const int len = static_cast<int>(last) - static_cast<int>(first) + 1;
|
||||
HWY_DASSERT(kMaxLanes + len >= 0);
|
||||
HWY_DASSERT(len <= static_cast<int>(kGroupSize));
|
||||
// last + i are contiguous, hence single loads instead of gather.
|
||||
const VF vlen = hn::LoadU(df, len_ + kMaxLanes + len);
|
||||
inv_len = hn::LoadU(df, inv_len_ + kMaxLanes + len);
|
||||
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
// Sanity check: no valid lanes are sentinels, all invalid are.
|
||||
const VF sentinel = hn::Set(df, kSentinel);
|
||||
const MF bad = hn::Eq(vlen, sentinel);
|
||||
const MF inv_bad = hn::Eq(inv_len, sentinel);
|
||||
HWY_DASSERT(hn::AllFalse(df, hn::And(valid, bad)));
|
||||
HWY_DASSERT(hn::AllFalse(df, hn::And(valid, inv_bad)));
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Or(valid, bad)));
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Or(valid, inv_bad)));
|
||||
}
|
||||
|
||||
return vlen;
|
||||
}
|
||||
|
||||
// Float has enough precision for our relatively small kGroupSize (256).
|
||||
float cumsum_[kGroupSize + 1];
|
||||
float cumsum2_[kGroupSize + 1];
|
||||
float inv_len_[kGroupSize + 1];
|
||||
// Element i = sums of [0..i-1].
|
||||
float cumsum_[kGroupSize + 1 + kMaxLanes];
|
||||
float cumsum2_[kGroupSize + 1 + kMaxLanes];
|
||||
float len_[kMaxLanes + kGroupSize + 1 + kMaxLanes]; // = vlen[i]
|
||||
float inv_len_[kMaxLanes + kGroupSize + 1 + kMaxLanes]; // = 1 / vlen[i]
|
||||
};
|
||||
|
||||
// Cost of clustering 0..last, where the rightmost cluster is j..last. This is
|
||||
// called in a loop over j, and we return the vector of costs for a batch of
|
||||
// last = [last, last + N).
|
||||
template <class DF>
|
||||
static HWY_INLINE hn::Vec<DF> ClusterDynProg(
|
||||
DF df, const AlignedMatrix<float>& D, const ClusterCost& cc,
|
||||
const size_t num_clusters, const size_t last, const size_t j) {
|
||||
// Dynamic programming step: returns costs of clustering 0..last+i, where the
|
||||
// rightmost clusters start at `first`. Called for each `idx_cluster`,
|
||||
// `first`, and `last`; vectorized across `last`. `first` may be greater than
|
||||
// `last`. `valid[i]` is `first <= last + i`.
|
||||
template <class DF, class VF = hn::Vec<DF>, class MF = hn::Mask<DF>>
|
||||
static HWY_INLINE VF ClusterDynProg(DF df, const AlignedMatrix<float>& D,
|
||||
const ClusterCost& cc,
|
||||
const size_t idx_cluster,
|
||||
const size_t first, const size_t last,
|
||||
const MF valid) {
|
||||
HWY_DASSERT(idx_cluster != 0);
|
||||
HWY_DASSERT(0 != first && first < kGroupSize);
|
||||
HWY_DASSERT(last < kGroupSize);
|
||||
HWY_DASSERT(0 != j && j < kGroupSize);
|
||||
HWY_DASSERT(last % hn::Lanes(df) == 0); // Called in steps of N
|
||||
|
||||
const hn::RebindToSigned<decltype(df)> di;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
using VI = hn::Vec<decltype(di)>;
|
||||
using MI = hn::Mask<decltype(di)>;
|
||||
|
||||
const VI vlast = hn::Iota(di, static_cast<int32_t>(last));
|
||||
|
||||
// We have a non-empty rightmost cluster if j <= last <==> j-1 < last.
|
||||
const MI valid = hn::Lt(hn::Set(di, static_cast<int32_t>(j) - 1), vlast);
|
||||
// If not valid, return an arbitrary high cost, which will not be the min.
|
||||
const VF max = hn::Set(df, 1E38f);
|
||||
// Cost of clustering 0..j-1 with one fewer cluster than now.
|
||||
const VF vd = hn::Set(df, D(num_clusters - 1, j - 1));
|
||||
// Eq2: add to that the cost of another cluster from j..last.
|
||||
return hn::MaskedAddOr(max, RebindMask(df, valid), vd, cc(df, j, last));
|
||||
// Cost of clustering 0..first-1 with one fewer cluster than now.
|
||||
const VF prev = hn::Set(df, D(idx_cluster - 1, first - 1));
|
||||
// Eq2: add to that the cost of another cluster from first..last.
|
||||
return hn::Add(prev, cc.SumCosts(df, first, last, valid));
|
||||
}
|
||||
|
||||
public:
|
||||
// Clusters `kGroupSize` values in `x`, which need not be sorted already nor
|
||||
// aligned, by choosing and filling `centers` (size `kClusters`, ascending
|
||||
// order, not necessarily equal to one of the `x`). Fills `indices` with the
|
||||
// index of the cluster to which each `x` belongs (16-bit for bit-packing).
|
||||
// `buf` is per-thread.
|
||||
// Clusters `num <= kGroupSize` values in `x`, which need not be sorted
|
||||
// already nor aligned, by choosing and filling `centers` (size `kClusters`,
|
||||
// ascending order, not necessarily equal to one of the `x`). Fills `indices`
|
||||
// with the index of the cluster to which each `x` belongs (16-bit for
|
||||
// bit-packing). `buf` is per-thread.
|
||||
//
|
||||
// Returns the number of unused clusters, i.e., the starting index within
|
||||
// `centers`; prior centers are zero-initialized.
|
||||
|
|
@ -171,22 +236,36 @@ class NuqClustering {
|
|||
// that this is about 5 times as fast as the O(kClusters * kGroupSize) SMAWK
|
||||
// as implemented in FAISS, for our kGroupSize of 256.
|
||||
template <class DF>
|
||||
static HWY_NOINLINE size_t ClusterExactL2(DF df, const float* x,
|
||||
ClusterBuf& buf,
|
||||
static HWY_NOINLINE size_t ClusterExactL2(DF df, const float* HWY_RESTRICT x,
|
||||
size_t num, ClusterBuf& buf,
|
||||
float* HWY_RESTRICT centers,
|
||||
uint16_t* HWY_RESTRICT indices) {
|
||||
HWY_DASSERT(num <= kGroupSize);
|
||||
const hn::RebindToSigned<decltype(df)> di;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
using MF = hn::Mask<decltype(df)>;
|
||||
using VI = hn::Vec<decltype(di)>;
|
||||
const VI k1 = hn::Set(di, 1);
|
||||
const size_t N = hn::Lanes(df);
|
||||
HWY_DASSERT(kGroupSize % N == 0);
|
||||
|
||||
HWY_ALIGN float sorted_and_i[kGroupSize];
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
sorted_and_i[i] = FloatPayload::Set(x[i], i);
|
||||
}
|
||||
if (num != kGroupSize) {
|
||||
// Initialize the rest of the group. Use an existing value so we do not
|
||||
// waste a cluster on a sentinel value. Arbitrarily choose the largest.
|
||||
float max = -1E38f;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
max = HWY_MAX(max, x[i]);
|
||||
}
|
||||
for (size_t i = num; i < kGroupSize; ++i) {
|
||||
sorted_and_i[i] = FloatPayload::Set(max, i);
|
||||
}
|
||||
}
|
||||
hn::VQSortStatic(sorted_and_i, kGroupSize, hwy::SortAscending());
|
||||
ClusterCost cc(sorted_and_i);
|
||||
ClusterCost cc(sorted_and_i); // ignores payload bits.
|
||||
|
||||
// Reference: https://arxiv.org/abs/1701.07204
|
||||
// D[k-1][m] is the lowest cost of clustering x1..m into k clusters.
|
||||
|
|
@ -194,29 +273,44 @@ class NuqClustering {
|
|||
// T[k][m] is the starting index within sorted_and_i[] of the k-th cluster.
|
||||
AlignedMatrix<int32_t>& T = buf.t;
|
||||
|
||||
// Initialize the first rows for a single cluster.
|
||||
// Fill first row of `D` and `T`: single cluster, iterate over all `last`.
|
||||
{
|
||||
const size_t cluster_idx = 0;
|
||||
const size_t first = 0;
|
||||
const VI vfirst = hn::Set(di, static_cast<int32_t>(first));
|
||||
const MF all_valid = hn::FirstN(df, N); // first <= last is always true
|
||||
for (size_t last = 0; last < kGroupSize; last += N) {
|
||||
hn::Store(cc(df, 0, last), df, &D(0, last)); // Cost of 0..last
|
||||
hn::Store(Zero(di), di, &T(0, last)); // Cluster index = 0
|
||||
const VF costs = cc.SumCosts(df, first, last, all_valid);
|
||||
hn::Store(costs, df, &D(cluster_idx, last));
|
||||
hn::Store(vfirst, di, &T(cluster_idx, last));
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t num_clusters = 1; num_clusters < kClusters; ++num_clusters) {
|
||||
// For each batch starting at `last`, one per lane:
|
||||
for (size_t cluster_idx = 1; cluster_idx < kClusters; ++cluster_idx) {
|
||||
// For vectors of `last + i` with `i < N`:
|
||||
for (size_t last = 0; last < kGroupSize; last += N) {
|
||||
VF min = hn::LoadU(df, &D(0, last));
|
||||
VI arg = hn::Zero(di);
|
||||
// For each j (start of rightmost cluster):
|
||||
VI vj = k1;
|
||||
for (size_t j = 1; j < last + N; ++j, vj = hn::Add(vj, k1)) {
|
||||
const VF c = ClusterDynProg(df, D, cc, num_clusters, last, j);
|
||||
const VI vlast = hn::Iota(di, static_cast<int32_t>(last));
|
||||
const VF prev_cost = hn::LoadU(df, &D(cluster_idx - 1, last));
|
||||
VF min = prev_cost;
|
||||
VI arg = hn::LoadU(di, &T(cluster_idx - 1, last));
|
||||
// For each `first` (j), which is the start of the rightmost of at least
|
||||
// two clusters, hence never zero. `first` also continues past `last`
|
||||
// because the last `vlast` lane is `last + N - 1`.
|
||||
for (size_t first = 1; first < last + N; ++first) {
|
||||
const VI vfirst = hn::Set(di, static_cast<int32_t>(first));
|
||||
const MF valid = hn::RebindMask(df, hn::Le(vfirst, vlast));
|
||||
const VF c =
|
||||
ClusterDynProg(df, D, cc, cluster_idx, first, last, valid);
|
||||
|
||||
// Retain the min cost and the j index that caused it.
|
||||
const auto less = hn::Lt(c, min);
|
||||
// Retain the min cost and the `first` that caused it.
|
||||
const MF less = hn::And(valid, hn::Lt(c, min));
|
||||
min = hn::IfThenElse(less, c, min);
|
||||
arg = hn::IfThenElse(RebindMask(di, less), vj, arg);
|
||||
arg = hn::IfThenElse(RebindMask(di, less), vfirst, arg);
|
||||
}
|
||||
hn::Store(min, df, &D(num_clusters, last));
|
||||
hn::Store(arg, di, &T(num_clusters, last));
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Le(min, prev_cost)));
|
||||
|
||||
hn::Store(min, df, &D(cluster_idx, last));
|
||||
hn::Store(arg, di, &T(cluster_idx, last));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -490,8 +584,8 @@ class NuqCodec {
|
|||
const float* HWY_RESTRICT g_in = in + g * kGroupSize;
|
||||
float* HWY_RESTRICT g_centers = buf.centers.get() + g * kClusters;
|
||||
uint16_t* HWY_RESTRICT g_idx = buf.idx.get() + g * kGroupSize;
|
||||
unused_clusters +=
|
||||
NuqClustering::ClusterExactL2(df, g_in, buf, g_centers, g_idx);
|
||||
unused_clusters += NuqClustering::ClusterExactL2(df, g_in, kGroupSize,
|
||||
buf, g_centers, g_idx);
|
||||
}
|
||||
|
||||
uint8_t* centers = &out->byte + ofs_groups * kClusters;
|
||||
|
|
|
|||
|
|
@ -66,8 +66,8 @@ struct TestFlat {
|
|||
ClusterBuf buf;
|
||||
float centers[kClusters];
|
||||
uint16_t indices[kGroupSize];
|
||||
const size_t unused_clusters =
|
||||
NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices);
|
||||
const size_t unused_clusters = NuqClustering::ClusterExactL2(
|
||||
df, in.get(), kGroupSize, buf, centers, indices);
|
||||
HWY_ASSERT(unused_clusters == kClusters - 1);
|
||||
|
||||
for (size_t i = 0; i < unused_clusters; ++i) {
|
||||
|
|
@ -108,8 +108,8 @@ struct TestPlateaus {
|
|||
ClusterBuf buf;
|
||||
float centers[kClusters];
|
||||
uint16_t indices[kGroupSize];
|
||||
const size_t unused_clusters =
|
||||
NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices);
|
||||
const size_t unused_clusters = NuqClustering::ClusterExactL2(
|
||||
df, in.get(), kGroupSize, buf, centers, indices);
|
||||
HWY_ASSERT(unused_clusters == 0);
|
||||
|
||||
DistortionStats stats;
|
||||
|
|
@ -155,8 +155,8 @@ struct TestRamp {
|
|||
ClusterBuf buf;
|
||||
float centers[kClusters];
|
||||
uint16_t indices[kGroupSize];
|
||||
const size_t unused_clusters =
|
||||
NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices);
|
||||
const size_t unused_clusters = NuqClustering::ClusterExactL2(
|
||||
df, in.get(), kGroupSize, buf, centers, indices);
|
||||
HWY_ASSERT(unused_clusters == 0);
|
||||
|
||||
DistortionStats stats;
|
||||
|
|
@ -203,8 +203,8 @@ struct TestNormal {
|
|||
double elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < 100; ++rep) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
const size_t unused_clusters =
|
||||
NuqClustering::ClusterExactL2(df, in.get(), buf, centers, indices);
|
||||
const size_t unused_clusters = NuqClustering::ClusterExactL2(
|
||||
df, in.get(), kGroupSize, buf, centers, indices);
|
||||
HWY_ASSERT(unused_clusters == 0);
|
||||
const double t1 = hwy::platform::Now();
|
||||
elapsed = HWY_MIN(elapsed, t1 - t0);
|
||||
|
|
|
|||
Loading…
Reference in New Issue