Improve pareto efficient candidate selection

This commit is contained in:
Ed Addario 2025-08-22 09:14:14 +01:00
parent 47cdbe2155
commit 01c927fb94
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 14 additions and 35 deletions

View File

@ -1106,56 +1106,35 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
info.candidate.push_back(candidate_types{ t->type, bpw, ggml_nbytes(t), 0.0 });
}
// Keep only the Paretooptimal candidates: if A has >= bytes and >= error than B, drop A.
// Keep only the paretooptimal candidates: if A has >= bytes and >= error than B, drop A.
{
std::vector<candidate_types> pruned;
pruned.reserve(info.candidate.size());
// Sort by bytes asc, error asc
std::sort(info.candidate.begin(), info.candidate.end(), [](const candidate_types &a, const candidate_types &b) {
// Sort by bytes ascending, error ascending
std::sort(info.candidate.begin(), info.candidate.end(), [](const candidate_types & a, const candidate_types & b) {
if (a.bytes != b.bytes) { return a.bytes < b.bytes; }
return a.error < b.error;
});
double best_err = std::numeric_limits<double>::infinity();
size_t last_bytes = std::numeric_limits<size_t>::max();
for (const auto &c : info.candidate) {
if (c.error < best_err || c.bytes > last_bytes) {
pruned.push_back(c);
best_err = std::min(best_err, (double)c.error);
for (const auto & c : info.candidate) {
// Only keep the best error seen so far at strictly larger byte sizes
if (c.bytes != last_bytes) {
// first time we see this byte size
last_bytes = c.bytes;
if (c.error < best_err) {
pruned.push_back(c);
best_err = c.error;
}
} else {
// same bytes: we already sorted by error; skip
}
}
info.candidate.swap(pruned);
}
// Collapse candidates with identical storage size (bytes)
{
std::vector<candidate_types> unique;
unique.reserve(info.candidate.size());
// Sort by bpw asc, error asc, bytes asc
std::sort(info.candidate.begin(), info.candidate.end(), [](const candidate_types & a, const candidate_types & b) {
if (a.bpw != b.bpw) { return a.bpw < b.bpw; }
if (a.error != b.error) { return a.error < b.error; }
return a.bytes < b.bytes;
});
for (size_t i = 0; i < info.candidate.size();) {
size_t j = i + 1;
candidate_types best = info.candidate[i];
// group same-byte entries, keep the one with the lowest error
while (j < info.candidate.size() && info.candidate[j].bytes == info.candidate[i].bytes) {
if (info.candidate[j].error < best.error) {
best = info.candidate[j];
}
++j;
}
unique.push_back(best);
i = j;
}
info.candidate.swap(unique);
}
// Initialize choice at the smallest bpw candidate
info.choice = 0;
info.min_bpw = info.candidate.front().bpw;