From a7ee915e19d9acd7a1187ba7d8d772d3a52a8f0d Mon Sep 17 00:00:00 2001 From: Ed Addario Date: Sun, 21 Sep 2025 16:20:06 +0100 Subject: [PATCH] Refactor trimmed_sum() --- src/llama-quant.cpp | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 63779ded48..67de29df87 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -920,26 +920,15 @@ static std::unordered_map target_bpw_type( // Trimmed sum to avoid outlier rows dominating the results auto trimmed_sum = [&](std::vector & v) -> double { - if (v.empty()) { return 0.0; } - const int64_t n = (int64_t)v.size(); - if (n < 50) { - double s = 0.0; - for (const double z : v) { s += z; } - - return s; - } + if (n == 0) { return 0.0; } + if (n < 50) { return std::accumulate(v.begin(), v.end(), 0.0); } int64_t k = (int64_t)std::floor(0.02 * (double)n); // trim 2% each side - k = std::max(0, std::min(k, n / 32)); // cap at ~3.125% + k = std::clamp(k, 0, n / 32); // cap at ~3.125% std::nth_element(v.begin(), v.begin() + k, v.end()); std::nth_element(v.begin() + k, v.begin() + (n - k), v.end()); - double s = 0.0; - for (int64_t i = k; i < n - k; ++i) { - s += v[i]; - } - - return s; + return std::accumulate(v.begin() + k, v.begin() + (n - k), 0.0); }; const double scale_rows = (double)nrows / std::max(1.0, (double)rs);