Fix underflow in NUQ ClusterCost()

PiperOrigin-RevId: 628137904
This commit is contained in:
Paul Chang 2024-04-25 11:28:09 -07:00 committed by Copybara-Service
parent 9e0ac5de34
commit e8f59bb411
1 changed files with 3 additions and 4 deletions

View File

@ -104,9 +104,8 @@ class NuqClustering {
// Callers are responsible for ignoring lanes where last < first.
HWY_DASSERT(first < kGroupSize);
HWY_DASSERT(last < kGroupSize);
const size_t len = last - first + 1;
const hn::Vec<DF> vlen =
hn::Iota(df, static_cast<float>(static_cast<int>(len)));
const int len = static_cast<int>(last) - static_cast<int>(first) + 1;
const hn::Vec<DF> vlen = hn::Iota(df, static_cast<float>(len));
const hn::Vec<DF> u_lo = hn::Set(df, cumsum_[first]);
const hn::Vec<DF> u_lo2 = hn::Set(df, cumsum2_[first]);
@ -204,7 +203,7 @@ class NuqClustering {
for (size_t num_clusters = 1; num_clusters < kClusters; ++num_clusters) {
// For each batch starting at `last`, one per lane:
for (size_t last = 0; last < kGroupSize; last += N) {
VF min = cc(df, 0, last);
VF min = hn::LoadU(df, &D(0, last));
VI arg = hn::Zero(di);
// For each j (start of rightmost cluster):
VI vj = k1;