From 5b0d3f6d5ad46596e0f30c967c00e2dc2b93d8da Mon Sep 17 00:00:00 2001 From: Ed Addario Date: Sat, 11 Oct 2025 10:04:48 +0100 Subject: [PATCH] Automatically determine if bias error is significant --- src/llama-quant.cpp | 52 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 4ad5124d1a..07a88f0fd6 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -637,6 +637,8 @@ static std::unordered_map target_bpw_type( float bpw; size_t bytes; double error; + double mse = 0.0; + double proj = 0.0; }; struct tensor_info { @@ -1340,9 +1342,11 @@ static std::unordered_map target_bpw_type( const ggml_type tensor_types = compatible_candidates[i]; const auto bpw = (float)tensor_bpw(tensor, tensor_types); const size_t bytes = tensor_bytes(tensor, tensor_types); + double mse = 0.0; + double proj = 0.0; const auto err = estimate_error(tensor, tensor_types, f32_sample, rows_sample, values, activations, - tl_quantized_buffer, tl_dequantized_buffer, tensor_lambda, slice_lambda); - eval_candidates[i] = candidate_types{ tensor_types, bpw, bytes, err }; + tl_quantized_buffer, tl_dequantized_buffer, tensor_lambda, slice_lambda, &mse, &proj); + eval_candidates[i] = candidate_types{ tensor_types, bpw, bytes, err, mse, proj }; } }); } @@ -1354,8 +1358,48 @@ static std::unordered_map target_bpw_type( check_signal_handler(all); } - for (auto &c : eval_candidates) { - if (c.bytes > 0) { info.candidate.push_back(c); } + // Check if biasing is needed + bool bias_needed = false; + if (!lambdas.empty()) { + int min_mse = -1; + int min_bias = -1; + { + double best_mse = std::numeric_limits::infinity(); + double best_err = std::numeric_limits::infinity(); + for (int i = 0; i < (int)eval_candidates.size(); ++i) { + const auto & c = eval_candidates[i]; + if (c.bytes == 0) { continue; } + if (c.mse < best_mse) { + best_mse = c.mse; + min_mse = i; + } + if (c.error < best_err) { + best_err = c.error; + min_bias = i; + } + } + } + + if (min_mse != min_bias) { + bias_needed = true; + } else { + double max_rel_bias = 0.0; + for (const auto & c : eval_candidates) { + if (c.bytes == 0) { continue; } + const double mse = std::max(c.mse, epsilon); + const double bias_term = std::max(0.0, c.error - c.mse); + const double rel = bias_term / mse; + max_rel_bias = std::max(rel, max_rel_bias); + } + + bias_needed = max_rel_bias >= 0.5; // >= 50% of MSE? + } + } + + for (auto & c : eval_candidates) { + if (c.bytes == 0) { continue; } + const double final_err = bias_needed ? c.error : c.mse; + info.candidate.push_back(candidate_types{ c.type, c.bpw, c.bytes, final_err, c.mse, c.proj }); } if (info.candidate.empty()) {