From 044fa783c7e5e87bddf667fbe7396628e827b455 Mon Sep 17 00:00:00 2001 From: Ed Addario Date: Mon, 6 Oct 2025 21:40:37 +0100 Subject: [PATCH] Fix trimming logic --- src/llama-quant.cpp | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index eb5c9124b5..aeb1542607 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -849,8 +849,12 @@ static std::unordered_map target_bpw_type( }; auto delete_bpw_state = [&] { - LLAMA_LOG_INFO("%s: deleting %s\n", func, checkpoint_file.c_str()); - std::remove(checkpoint_file.c_str()); + std::ifstream ifs(checkpoint_file); + if (ifs.good()) { + LLAMA_LOG_INFO("%s: deleting %s\n", func, checkpoint_file.c_str()); + std::remove(checkpoint_file.c_str()); + } + }; auto check_signal_handler = [&](const std::vector & all_vec) { @@ -988,14 +992,16 @@ static std::unordered_map target_bpw_type( } // Compute error per slice with trimmed aggregation - auto trimmed_sum = [](std::vector & v) -> double { + auto trimmed_mean = [](std::vector & v) -> double { const int64_t n = (int64_t)v.size(); if (n == 0) { return 0.0; } - if (n < 50) { return std::accumulate(v.begin(), v.end(), 0.0); } // use all samples for small datasets - - int64_t k = (int64_t) std::floor(0.025 * (double)n); // trim 2.5% from each tail of the distribution + double sum = std::accumulate(v.begin(), v.end(), 0.0); + if (n < 50) { return sum / (double)n; } // too few elements to trim + int64_t k = (int64_t) std::floor(0.025 * (double)n); // trim 5% (2.5% each side) std::sort(v.begin(), v.end()); - return std::accumulate(v.begin() + k, v.begin() + (n - k), 0.0); + const auto num = (double)(n - 2 * k); + sum = std::accumulate(v.begin() + k, v.begin() + (n - k), 0.0); + return sum / std::max(1.0, num); }; size_t off = 0; @@ -1028,7 +1034,7 @@ static std::unordered_map target_bpw_type( } const double denom_x = row_sq_norm[ridx]; - const double m_norm = w_mse / (denom_x + epsilon); + const double m_norm = w_mse / (denom_x + epsilon); row_mse_norm.push_back(std::isfinite(m_norm) ? m_norm : infinity); if (a) { @@ -1044,9 +1050,8 @@ static std::unordered_map target_bpw_type( off += (size_t)n_per_row; } - const double scale_rows = (double)nrows / std::max(1.0, (double)rs); - const double slice_mse = trimmed_sum(row_mse_norm) * scale_rows; - const double slice_proj = a ? trimmed_sum(row_proj_norm) * scale_rows : 0.0; + const double slice_mse = trimmed_mean(row_mse_norm) * (double)nrows; + const double slice_proj = a ? trimmed_mean(row_proj_norm) * (double)nrows : 0.0; total_mse += slice_mse; total_proj += slice_proj;