Add precise_lambda()

This commit is contained in:
Ed Addario 2025-08-28 16:06:42 +01:00
parent 8df1d00ae4
commit 66aff8fa1e
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 102 additions and 0 deletions

View File

@ -921,6 +921,108 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
// Clamp to a reasonable range
return (float)std::clamp(scale, 0.5, 2.0);
};
// Returns an adaptive lambda for this tensor using a small probe set
// bias_lambda adjusts the trade-off between systematic bias (introduced by blockwise scaling) and MSE
// larger value favours quantisation types that produce smaller bias even if the MSE is slightly larger
auto precise_lambda = [&](const ggml_tensor * t,
const std::vector<float> & f32_sample,
const std::vector<int64_t> & sample_rows_per_slice,
const float * values,
const float * activations,
const std::vector<ggml_type> & compatible_candidates) -> float
{
// No activations => no projection term
if (!activations) { return 0.0f; }
// pick a tiny probe set: try to spread around mid-range types
std::vector<ggml_type> probes;
probes.reserve(3);
auto push_if = [&](const ggml_type tiny) {
if (std::find(compatible_candidates.begin(), compatible_candidates.end(), tiny) != compatible_candidates.end()) {
probes.push_back(tiny);
}
};
// Prefer family-consistent probes; fall back to whatever exists
push_if(GGML_TYPE_Q4_K);
push_if(GGML_TYPE_Q3_K);
push_if(GGML_TYPE_Q5_K);
if (probes.empty() && !compatible_candidates.empty()) {
probes.push_back(compatible_candidates[compatible_candidates.size() / 2]);
}
if (probes.size() == 1 && compatible_candidates.size() >= 2) {
probes.push_back(compatible_candidates.front());
}
if (probes.empty()) { return 0.0f; }
// Scratch buffers (reused)
const int64_t n_per_row = t->ne[0];
const size_t total_sampled_rows = f32_sample.size() / n_per_row;
size_t max_row_sz = 0;
for (auto pt : probes) {
max_row_sz = std::max(max_row_sz, ggml_row_size(pt, n_per_row));
}
std::vector<uint8_t> quantized_buffer(max_row_sz * total_sampled_rows);
std::vector<float> dequantized_buffer(f32_sample.size());
std::vector<double> ratios;
ratios.reserve(probes.size());
for (const auto pt : probes) {
// err at lambda=0 => pure weighted MSE part
double err0 = estimate_error(t, pt, f32_sample, sample_rows_per_slice, values, activations, quantized_buffer, dequantized_buffer, 0.0f);
// err at lambda=1 => weighted MSE + projection penalty
const double err1 = estimate_error(t, pt, f32_sample, sample_rows_per_slice, values, activations, quantized_buffer, dequantized_buffer, 1.0f);
const double p = std::max(0.0, err1 - err0); // projection term contribution
const double m = std::max(0.0, err0); // MSE term contribution
if (p > epsilon && std::isfinite(m) && std::isfinite(p)) {
ratios.push_back(m / p);
}
}
if (ratios.empty()) { return 0.0f; }
std::nth_element(ratios.begin(), ratios.begin() + ratios.size() / 2, ratios.end());
double lambda = ratios[ratios.size() / 2];
// activations directional scale
const float scale = directional_scale(values, activations, n_per_row);
lambda *= scale;
// clamp to safe range
lambda = std::clamp(lambda, 0.0, 8.0);
return (float)lambda;
};
auto fast_lambda = [&](const float * values, const float * activations, const int64_t n_per_row) {
if (!activations) { return 0.0f; }
double s = 0.0;
double s2 = 0.0;
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = values ? std::max(0.0f, values[j]) : 1.0;
const double aw = std::sqrt(w) * activations[j];
const double aw2 = aw * aw;
s += aw2;
s2 += aw2 * aw2;
}
if (s2 <= 0.0) { return 0.0f; }
const auto d = (double)n_per_row;
//const double p = s * s / (d * s2 + epsilon);
//const double lambda = 8.0 * std::clamp(1.0 - p, 0.0, 1.0);
// Map p in (0,1] to lambda in [0,8] decreasing
double base = 1.0 - s * s / (d * s2 + epsilon);
base = std::clamp(base, 0.0, 1.0);
// activations directional scale
const double scale = directional_scale(values, activations, n_per_row);
// clamp to safe range
const double lambda = std::clamp(base * scale, 0.0, 1.0) * 8.0;
return (float)lambda;
};
std::vector<tensor_info> all;
all.reserve(tensors.size());
for (const auto * tw : tensors) {