mirror of https://github.com/google/gemma.cpp.git
Fix underflow in NUQ ClusterCost()
PiperOrigin-RevId: 628137904
This commit is contained in:
parent
9e0ac5de34
commit
e8f59bb411
|
|
@ -104,9 +104,8 @@ class NuqClustering {
|
||||||
// Callers are responsible for ignoring lanes where last < first.
|
// Callers are responsible for ignoring lanes where last < first.
|
||||||
HWY_DASSERT(first < kGroupSize);
|
HWY_DASSERT(first < kGroupSize);
|
||||||
HWY_DASSERT(last < kGroupSize);
|
HWY_DASSERT(last < kGroupSize);
|
||||||
const size_t len = last - first + 1;
|
const int len = static_cast<int>(last) - static_cast<int>(first) + 1;
|
||||||
const hn::Vec<DF> vlen =
|
const hn::Vec<DF> vlen = hn::Iota(df, static_cast<float>(len));
|
||||||
hn::Iota(df, static_cast<float>(static_cast<int>(len)));
|
|
||||||
|
|
||||||
const hn::Vec<DF> u_lo = hn::Set(df, cumsum_[first]);
|
const hn::Vec<DF> u_lo = hn::Set(df, cumsum_[first]);
|
||||||
const hn::Vec<DF> u_lo2 = hn::Set(df, cumsum2_[first]);
|
const hn::Vec<DF> u_lo2 = hn::Set(df, cumsum2_[first]);
|
||||||
|
|
@ -204,7 +203,7 @@ class NuqClustering {
|
||||||
for (size_t num_clusters = 1; num_clusters < kClusters; ++num_clusters) {
|
for (size_t num_clusters = 1; num_clusters < kClusters; ++num_clusters) {
|
||||||
// For each batch starting at `last`, one per lane:
|
// For each batch starting at `last`, one per lane:
|
||||||
for (size_t last = 0; last < kGroupSize; last += N) {
|
for (size_t last = 0; last < kGroupSize; last += N) {
|
||||||
VF min = cc(df, 0, last);
|
VF min = hn::LoadU(df, &D(0, last));
|
||||||
VI arg = hn::Zero(di);
|
VI arg = hn::Zero(di);
|
||||||
// For each j (start of rightmost cluster):
|
// For each j (start of rightmost cluster):
|
||||||
VI vj = k1;
|
VI vj = k1;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue