Fix trimming logic

This commit is contained in:
Ed Addario 2025-10-06 21:40:37 +01:00
parent 84ada44894
commit 044fa783c7
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 16 additions and 11 deletions

View File

@ -849,8 +849,12 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
}; };
auto delete_bpw_state = [&] { auto delete_bpw_state = [&] {
std::ifstream ifs(checkpoint_file);
if (ifs.good()) {
LLAMA_LOG_INFO("%s: deleting %s\n", func, checkpoint_file.c_str()); LLAMA_LOG_INFO("%s: deleting %s\n", func, checkpoint_file.c_str());
std::remove(checkpoint_file.c_str()); std::remove(checkpoint_file.c_str());
}
}; };
auto check_signal_handler = [&](const std::vector<tensor_info> & all_vec) { auto check_signal_handler = [&](const std::vector<tensor_info> & all_vec) {
@ -988,14 +992,16 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
} }
// Compute error per slice with trimmed aggregation // Compute error per slice with trimmed aggregation
auto trimmed_sum = [](std::vector<double> & v) -> double { auto trimmed_mean = [](std::vector<double> & v) -> double {
const int64_t n = (int64_t)v.size(); const int64_t n = (int64_t)v.size();
if (n == 0) { return 0.0; } if (n == 0) { return 0.0; }
if (n < 50) { return std::accumulate(v.begin(), v.end(), 0.0); } // use all samples for small datasets double sum = std::accumulate(v.begin(), v.end(), 0.0);
if (n < 50) { return sum / (double)n; } // too few elements to trim
int64_t k = (int64_t) std::floor(0.025 * (double)n); // trim 2.5% from each tail of the distribution int64_t k = (int64_t) std::floor(0.025 * (double)n); // trim 5% (2.5% each side)
std::sort(v.begin(), v.end()); std::sort(v.begin(), v.end());
return std::accumulate(v.begin() + k, v.begin() + (n - k), 0.0); const auto num = (double)(n - 2 * k);
sum = std::accumulate(v.begin() + k, v.begin() + (n - k), 0.0);
return sum / std::max(1.0, num);
}; };
size_t off = 0; size_t off = 0;
@ -1044,9 +1050,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
off += (size_t)n_per_row; off += (size_t)n_per_row;
} }
const double scale_rows = (double)nrows / std::max(1.0, (double)rs); const double slice_mse = trimmed_mean(row_mse_norm) * (double)nrows;
const double slice_mse = trimmed_sum(row_mse_norm) * scale_rows; const double slice_proj = a ? trimmed_mean(row_proj_norm) * (double)nrows : 0.0;
const double slice_proj = a ? trimmed_sum(row_proj_norm) * scale_rows : 0.0;
total_mse += slice_mse; total_mse += slice_mse;
total_proj += slice_proj; total_proj += slice_proj;