diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 4c0ec3063a..08e1c97185 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -885,9 +885,8 @@ static std::unordered_map target_bpw_type( if (n < 50) { return std::accumulate(v.begin(), v.end(), 0.0); } int64_t k = (int64_t) std::floor(0.02 * (double)n); // trim 2% on each side - k = std::clamp(k, 0, n / 32); // but no more than ~3% - std::nth_element(v.begin(), v.begin() + k, v.end()); - std::nth_element(v.begin() + k, v.begin() + (n - k), v.end()); + k = std::clamp(k, 0, std::min(n / 32, n / 2 - 1)); // but no more than ~3% or n/2 if small + std::sort(v.begin(), v.end()); return std::accumulate(v.begin() + k, v.begin() + (n - k), 0.0); };