Minor general refactoring

This commit is contained in:
Ed Addario 2025-09-21 16:45:09 +01:00
parent 0d5f18303e
commit 814f6b66be
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 5 additions and 10 deletions

View File

@ -860,7 +860,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const int64_t n = (int64_t)v.size(); const int64_t n = (int64_t)v.size();
if (n == 0) { return 0.0; } if (n == 0) { return 0.0; }
if (n < 50) { return std::accumulate(v.begin(), v.end(), 0.0); } if (n < 50) { return std::accumulate(v.begin(), v.end(), 0.0); }
int64_t k = (int64_t) std::floor(0.02 * (double) n); // trim 2% on each side
int64_t k = (int64_t) std::floor(0.02 * (double)n); // trim 2% on each side
k = std::clamp<int64_t>(k, 0, n / 32); // but no more than ~3% k = std::clamp<int64_t>(k, 0, n / 32); // but no more than ~3%
std::nth_element(v.begin(), v.begin() + k, v.end()); std::nth_element(v.begin(), v.begin() + k, v.end());
std::nth_element(v.begin() + k, v.begin() + (n - k), v.end()); std::nth_element(v.begin() + k, v.begin() + (n - k), v.end());
@ -1190,7 +1191,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
// Pareto by bytes -> error // Pareto by bytes -> error
std::vector<candidate_types> pareto; std::vector<candidate_types> pareto;
pareto.reserve(candidates.size()); pareto.reserve(candidates.size());
double best_err = std::numeric_limits<double>::infinity(); double best_err = infinity;
size_t last_b = std::numeric_limits<size_t>::max(); size_t last_b = std::numeric_limits<size_t>::max();
for (const auto & c : candidates) { for (const auto & c : candidates) {
if (c.bytes != last_b) { if (c.bytes != last_b) {
@ -1273,12 +1274,10 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
if (budget_bytes <= min_bytes) { if (budget_bytes <= min_bytes) {
for (auto & ti : all) { ti.choice = 0; } for (auto & ti : all) { ti.choice = 0; }
return emit_overrides(); return emit_overrides();
} }
if (budget_bytes >= max_bytes) { if (budget_bytes >= max_bytes) {
for (auto & ti : all) { ti.choice = (int) ti.candidate.size() - 1; } for (auto & ti : all) { ti.choice = (int) ti.candidate.size() - 1; }
return emit_overrides(); return emit_overrides();
} }
@ -1327,14 +1326,10 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
int expand = 0; int expand = 0;
while (true) { while (true) {
lagrange_penalty(mu_hi, choice_hi, bytes_hi, err_hi); lagrange_penalty(mu_hi, choice_hi, bytes_hi, err_hi);
if (bytes_hi <= budget_bytes) { if (bytes_hi <= budget_bytes) { break; }
break;
}
mu_hi *= 2.0; mu_hi *= 2.0;
if (++expand > 60) { if (++expand > 60) { break; } // safety cap
break;
}
} }
} }