diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index a369d50ffe..955e6c12fe 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -1257,6 +1257,32 @@ static std::unordered_map target_bpw_type( info.candidate.swap(pruned); } + // Enforce convexity in (bytes, error) curve + { + const auto & c = info.candidate; + if (c.size() >= 3) { + std::vector convex; + convex.reserve(c.size()); + auto slope = [](const candidate_types & a, const candidate_types & b) -> double { + const double dx = (double)b.bytes - (double)a.bytes; + if (dx <= 0.0) { return infinity; } + + return ((double)b.error - (double)a.error) / dx; + }; + + for (const auto & p : c) { + while (convex.size() >= 2) { + double s1 = slope(convex[convex.size() - 2], convex[convex.size() - 1]); + double s2 = slope(convex[convex.size() - 1], p); + if (s2 + epsilon < s1) { convex.pop_back(); } + else { break; } + } + convex.push_back(p); + } + info.candidate.swap(convex); + } + } + // Initialize choice at the smallest bpw candidate info.choice = 0; info.min_bpw = info.candidate.front().bpw;