From 9a1656eb975fa9f1024a8de029e22a762e49719b Mon Sep 17 00:00:00 2001 From: Ed Addario Date: Sun, 21 Sep 2025 16:21:35 +0100 Subject: [PATCH] Refactor pareto optimise and convexify --- src/llama-quant.cpp | 86 ++++++++++++++++++++++----------------------- 1 file changed, 42 insertions(+), 44 deletions(-) diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index b3e4b3cbf7..751a26c63a 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -1179,55 +1179,53 @@ static std::unordered_map target_bpw_type( } // Keep only the pareto‑optimal candidates and enforce convexity in (bytes, error) curve - { - auto & candidates = info.candidate; - if (!candidates.empty()) { - std::sort(candidates.begin(), candidates.end(), [](const candidate_types & a, const candidate_types & b) { - if (a.bytes != b.bytes) { return a.bytes < b.bytes; } + auto pareto_convex = [](std::vector & candidates) { + if (candidates.empty()) return; - return a.error < b.error; - }); + std::sort(candidates.begin(), candidates.end(), [](const candidate_types & a, const candidate_types & b) { + if (a.bytes != b.bytes) { return a.bytes < b.bytes; } + return a.error < b.error; + }); - std::vector pareto; - pareto.reserve(candidates.size()); - double best_err = infinity; - size_t last_bytes = std::numeric_limits::max(); - for (const auto & c : candidates) { - if (c.bytes != last_bytes) { - last_bytes = c.bytes; - if (c.error < best_err) { - best_err = c.error; - pareto.push_back(c); - } + // Pareto by bytes -> error + std::vector pareto; + pareto.reserve(candidates.size()); + double best_err = std::numeric_limits::infinity(); + size_t last_b = std::numeric_limits::max(); + for (const auto & c : candidates) { + if (c.bytes != last_b) { + last_b = c.bytes; + if (c.error < best_err) { + best_err = c.error; + pareto.push_back(c); } } - - candidates.swap(pareto); - - if (candidates.size() >= 3) { - std::vector hull; - hull.reserve(candidates.size()); - auto slope = [](const candidate_types & a, const candidate_types & b) { - const double dx = b.bytes - a.bytes; - - return dx <= 0.0 ? infinity : (b.error - a.error) / dx; - }; - - for (const auto & p : candidates) { - while (hull.size() >= 2) { - double s1 = slope(hull[hull.size() - 2], hull[hull.size() - 1]); - double s2 = slope(hull[hull.size() - 1], p); - if (s2 + epsilon < s1) { hull.pop_back(); } - else { break; } - } - - hull.push_back(p); - } - - candidates.swap(hull); - } } - } + + candidates.swap(pareto); + if (candidates.size() < 3) { return; } // need at least 3 points to do convex hull + + // Convex hull (lower envelope) + auto slope = [](const candidate_types & a, const candidate_types & b) { + const double dx = b.bytes - a.bytes; + return dx <= 0.0 ? infinity : (b.error - a.error) / dx; + }; + + std::vector hull; hull.reserve(candidates.size()); + for (const auto & p : candidates) { + while (hull.size() >= 2) { + const double s1 = slope(hull[hull.size() - 2], hull[hull.size() - 1]); + const double s2 = slope(hull[hull.size() - 1], p); + if (s2 + epsilon < s1) hull.pop_back(); + else { break; } + } + + hull.push_back(p); + } + candidates.swap(hull); + }; + + pareto_convex(info.candidate); // Initialize choice at the smallest bpw candidate info.choice = 0;