Automatically determine if bias error is significant

This commit is contained in:
Ed Addario 2025-10-11 10:04:48 +01:00
parent c93131cef6
commit 5b0d3f6d5a
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 48 additions and 4 deletions

View File

@ -637,6 +637,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
float bpw;
size_t bytes;
double error;
double mse = 0.0;
double proj = 0.0;
};
struct tensor_info {
@ -1340,9 +1342,11 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const ggml_type tensor_types = compatible_candidates[i];
const auto bpw = (float)tensor_bpw(tensor, tensor_types);
const size_t bytes = tensor_bytes(tensor, tensor_types);
double mse = 0.0;
double proj = 0.0;
const auto err = estimate_error(tensor, tensor_types, f32_sample, rows_sample, values, activations,
tl_quantized_buffer, tl_dequantized_buffer, tensor_lambda, slice_lambda);
eval_candidates[i] = candidate_types{ tensor_types, bpw, bytes, err };
tl_quantized_buffer, tl_dequantized_buffer, tensor_lambda, slice_lambda, &mse, &proj);
eval_candidates[i] = candidate_types{ tensor_types, bpw, bytes, err, mse, proj };
}
});
}
@ -1354,8 +1358,48 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
check_signal_handler(all);
}
for (auto &c : eval_candidates) {
if (c.bytes > 0) { info.candidate.push_back(c); }
// Check if biasing is needed
bool bias_needed = false;
if (!lambdas.empty()) {
int min_mse = -1;
int min_bias = -1;
{
double best_mse = std::numeric_limits<double>::infinity();
double best_err = std::numeric_limits<double>::infinity();
for (int i = 0; i < (int)eval_candidates.size(); ++i) {
const auto & c = eval_candidates[i];
if (c.bytes == 0) { continue; }
if (c.mse < best_mse) {
best_mse = c.mse;
min_mse = i;
}
if (c.error < best_err) {
best_err = c.error;
min_bias = i;
}
}
}
if (min_mse != min_bias) {
bias_needed = true;
} else {
double max_rel_bias = 0.0;
for (const auto & c : eval_candidates) {
if (c.bytes == 0) { continue; }
const double mse = std::max(c.mse, epsilon);
const double bias_term = std::max(0.0, c.error - c.mse);
const double rel = bias_term / mse;
max_rel_bias = std::max(rel, max_rel_bias);
}
bias_needed = max_rel_bias >= 0.5; // >= 50% of MSE?
}
}
for (auto & c : eval_candidates) {
if (c.bytes == 0) { continue; }
const double final_err = bias_needed ? c.error : c.mse;
info.candidate.push_back(candidate_types{ c.type, c.bpw, c.bytes, final_err, c.mse, c.proj });
}
if (info.candidate.empty()) {