Refactor trimmed_sum()
This commit is contained in:
parent
b09662f86a
commit
a7ee915e19
|
|
@ -920,26 +920,15 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
|
||||
// Trimmed sum to avoid outlier rows dominating the results
|
||||
auto trimmed_sum = [&](std::vector<double> & v) -> double {
|
||||
if (v.empty()) { return 0.0; }
|
||||
|
||||
const int64_t n = (int64_t)v.size();
|
||||
if (n < 50) {
|
||||
double s = 0.0;
|
||||
for (const double z : v) { s += z; }
|
||||
|
||||
return s;
|
||||
}
|
||||
if (n == 0) { return 0.0; }
|
||||
if (n < 50) { return std::accumulate(v.begin(), v.end(), 0.0); }
|
||||
|
||||
int64_t k = (int64_t)std::floor(0.02 * (double)n); // trim 2% each side
|
||||
k = std::max<int64_t>(0, std::min<int64_t>(k, n / 32)); // cap at ~3.125%
|
||||
k = std::clamp<int64_t>(k, 0, n / 32); // cap at ~3.125%
|
||||
std::nth_element(v.begin(), v.begin() + k, v.end());
|
||||
std::nth_element(v.begin() + k, v.begin() + (n - k), v.end());
|
||||
double s = 0.0;
|
||||
for (int64_t i = k; i < n - k; ++i) {
|
||||
s += v[i];
|
||||
}
|
||||
|
||||
return s;
|
||||
return std::accumulate(v.begin() + k, v.begin() + (n - k), 0.0);
|
||||
};
|
||||
|
||||
const double scale_rows = (double)nrows / std::max(1.0, (double)rs);
|
||||
|
|
|
|||
Loading…
Reference in New Issue