Fix trimming logic
This commit is contained in:
parent
84ada44894
commit
044fa783c7
|
|
@ -849,8 +849,12 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
};
|
};
|
||||||
|
|
||||||
auto delete_bpw_state = [&] {
|
auto delete_bpw_state = [&] {
|
||||||
|
std::ifstream ifs(checkpoint_file);
|
||||||
|
if (ifs.good()) {
|
||||||
LLAMA_LOG_INFO("%s: deleting %s\n", func, checkpoint_file.c_str());
|
LLAMA_LOG_INFO("%s: deleting %s\n", func, checkpoint_file.c_str());
|
||||||
std::remove(checkpoint_file.c_str());
|
std::remove(checkpoint_file.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
auto check_signal_handler = [&](const std::vector<tensor_info> & all_vec) {
|
auto check_signal_handler = [&](const std::vector<tensor_info> & all_vec) {
|
||||||
|
|
@ -988,14 +992,16 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute error per slice with trimmed aggregation
|
// Compute error per slice with trimmed aggregation
|
||||||
auto trimmed_sum = [](std::vector<double> & v) -> double {
|
auto trimmed_mean = [](std::vector<double> & v) -> double {
|
||||||
const int64_t n = (int64_t)v.size();
|
const int64_t n = (int64_t)v.size();
|
||||||
if (n == 0) { return 0.0; }
|
if (n == 0) { return 0.0; }
|
||||||
if (n < 50) { return std::accumulate(v.begin(), v.end(), 0.0); } // use all samples for small datasets
|
double sum = std::accumulate(v.begin(), v.end(), 0.0);
|
||||||
|
if (n < 50) { return sum / (double)n; } // too few elements to trim
|
||||||
int64_t k = (int64_t) std::floor(0.025 * (double)n); // trim 2.5% from each tail of the distribution
|
int64_t k = (int64_t) std::floor(0.025 * (double)n); // trim 5% (2.5% each side)
|
||||||
std::sort(v.begin(), v.end());
|
std::sort(v.begin(), v.end());
|
||||||
return std::accumulate(v.begin() + k, v.begin() + (n - k), 0.0);
|
const auto num = (double)(n - 2 * k);
|
||||||
|
sum = std::accumulate(v.begin() + k, v.begin() + (n - k), 0.0);
|
||||||
|
return sum / std::max(1.0, num);
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t off = 0;
|
size_t off = 0;
|
||||||
|
|
@ -1044,9 +1050,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
off += (size_t)n_per_row;
|
off += (size_t)n_per_row;
|
||||||
}
|
}
|
||||||
|
|
||||||
const double scale_rows = (double)nrows / std::max(1.0, (double)rs);
|
const double slice_mse = trimmed_mean(row_mse_norm) * (double)nrows;
|
||||||
const double slice_mse = trimmed_sum(row_mse_norm) * scale_rows;
|
const double slice_proj = a ? trimmed_mean(row_proj_norm) * (double)nrows : 0.0;
|
||||||
const double slice_proj = a ? trimmed_sum(row_proj_norm) * scale_rows : 0.0;
|
|
||||||
|
|
||||||
total_mse += slice_mse;
|
total_mse += slice_mse;
|
||||||
total_proj += slice_proj;
|
total_proj += slice_proj;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue