mirror of https://github.com/google/gemma.cpp.git
Fix underflow in NUQ ClusterCost()
PiperOrigin-RevId: 628137904
This commit is contained in:
parent
9e0ac5de34
commit
e8f59bb411
|
|
@ -104,9 +104,8 @@ class NuqClustering {
|
|||
// Callers are responsible for ignoring lanes where last < first.
|
||||
HWY_DASSERT(first < kGroupSize);
|
||||
HWY_DASSERT(last < kGroupSize);
|
||||
const size_t len = last - first + 1;
|
||||
const hn::Vec<DF> vlen =
|
||||
hn::Iota(df, static_cast<float>(static_cast<int>(len)));
|
||||
const int len = static_cast<int>(last) - static_cast<int>(first) + 1;
|
||||
const hn::Vec<DF> vlen = hn::Iota(df, static_cast<float>(len));
|
||||
|
||||
const hn::Vec<DF> u_lo = hn::Set(df, cumsum_[first]);
|
||||
const hn::Vec<DF> u_lo2 = hn::Set(df, cumsum2_[first]);
|
||||
|
|
@ -204,7 +203,7 @@ class NuqClustering {
|
|||
for (size_t num_clusters = 1; num_clusters < kClusters; ++num_clusters) {
|
||||
// For each batch starting at `last`, one per lane:
|
||||
for (size_t last = 0; last < kGroupSize; last += N) {
|
||||
VF min = cc(df, 0, last);
|
||||
VF min = hn::LoadU(df, &D(0, last));
|
||||
VI arg = hn::Zero(di);
|
||||
// For each j (start of rightmost cluster):
|
||||
VI vj = k1;
|
||||
|
|
|
|||
Loading…
Reference in New Issue