From 66aff8fa1ee1d34c7faaa0ff658a730a9554ef36 Mon Sep 17 00:00:00 2001 From: Ed Addario Date: Thu, 28 Aug 2025 16:06:42 +0100 Subject: [PATCH] Add precise_lambda() --- src/llama-quant.cpp | 102 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index a9621eab8e..662760fbe9 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -921,6 +921,108 @@ static std::unordered_map target_bpw_type( // Clamp to a reasonable range return (float)std::clamp(scale, 0.5, 2.0); }; + + // Returns an adaptive lambda for this tensor using a small probe set + // bias_lambda adjusts the trade-off between systematic bias (introduced by block‑wise scaling) and MSE + // larger value favours quantisation types that produce smaller bias even if the MSE is slightly larger + auto precise_lambda = [&](const ggml_tensor * t, + const std::vector & f32_sample, + const std::vector & sample_rows_per_slice, + const float * values, + const float * activations, + const std::vector & compatible_candidates) -> float + { + // No activations => no projection term + if (!activations) { return 0.0f; } + + // pick a tiny probe set: try to spread around mid-range types + std::vector probes; + probes.reserve(3); + auto push_if = [&](const ggml_type tiny) { + if (std::find(compatible_candidates.begin(), compatible_candidates.end(), tiny) != compatible_candidates.end()) { + probes.push_back(tiny); + } + }; + + // Prefer family-consistent probes; fall back to whatever exists + push_if(GGML_TYPE_Q4_K); + push_if(GGML_TYPE_Q3_K); + push_if(GGML_TYPE_Q5_K); + if (probes.empty() && !compatible_candidates.empty()) { + probes.push_back(compatible_candidates[compatible_candidates.size() / 2]); + } + if (probes.size() == 1 && compatible_candidates.size() >= 2) { + probes.push_back(compatible_candidates.front()); + } + if (probes.empty()) { return 0.0f; } + + // Scratch buffers (reused) + const int64_t n_per_row = t->ne[0]; + const size_t total_sampled_rows = f32_sample.size() / n_per_row; + size_t max_row_sz = 0; + for (auto pt : probes) { + max_row_sz = std::max(max_row_sz, ggml_row_size(pt, n_per_row)); + } + std::vector quantized_buffer(max_row_sz * total_sampled_rows); + std::vector dequantized_buffer(f32_sample.size()); + + std::vector ratios; + ratios.reserve(probes.size()); + + for (const auto pt : probes) { + // err at lambda=0 => pure weighted MSE part + double err0 = estimate_error(t, pt, f32_sample, sample_rows_per_slice, values, activations, quantized_buffer, dequantized_buffer, 0.0f); + // err at lambda=1 => weighted MSE + projection penalty + const double err1 = estimate_error(t, pt, f32_sample, sample_rows_per_slice, values, activations, quantized_buffer, dequantized_buffer, 1.0f); + + const double p = std::max(0.0, err1 - err0); // projection term contribution + const double m = std::max(0.0, err0); // MSE term contribution + if (p > epsilon && std::isfinite(m) && std::isfinite(p)) { + ratios.push_back(m / p); + } + } + + if (ratios.empty()) { return 0.0f; } + + std::nth_element(ratios.begin(), ratios.begin() + ratios.size() / 2, ratios.end()); + double lambda = ratios[ratios.size() / 2]; + + // activations directional scale + const float scale = directional_scale(values, activations, n_per_row); + lambda *= scale; + + // clamp to safe range + lambda = std::clamp(lambda, 0.0, 8.0); + return (float)lambda; + }; + + auto fast_lambda = [&](const float * values, const float * activations, const int64_t n_per_row) { + if (!activations) { return 0.0f; } + double s = 0.0; + double s2 = 0.0; + for (int64_t j = 0; j < n_per_row; ++j) { + const double w = values ? std::max(0.0f, values[j]) : 1.0; + const double aw = std::sqrt(w) * activations[j]; + const double aw2 = aw * aw; + s += aw2; + s2 += aw2 * aw2; + } + if (s2 <= 0.0) { return 0.0f; } + const auto d = (double)n_per_row; + //const double p = s * s / (d * s2 + epsilon); + //const double lambda = 8.0 * std::clamp(1.0 - p, 0.0, 1.0); + // Map p in (0,1] to lambda in [0,8] decreasing + double base = 1.0 - s * s / (d * s2 + epsilon); + base = std::clamp(base, 0.0, 1.0); + + // activations directional scale + const double scale = directional_scale(values, activations, n_per_row); + // clamp to safe range + const double lambda = std::clamp(base * scale, 0.0, 1.0) * 8.0; + + return (float)lambda; + }; + std::vector all; all.reserve(tensors.size()); for (const auto * tw : tensors) {