Calculate bpw over all tensors

This commit is contained in:
Ed Addario 2025-09-27 17:28:39 +01:00
parent 3d75b14c0f
commit e49e241d37
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 32 additions and 13 deletions

View File

@ -1219,6 +1219,18 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
if (all.empty()) { return {}; }
// Compute total elements across all tensors and bytes for non-quantizable tensors
size_t nq_elements = 0;
size_t nq_bytes = 0;
for (const auto & it : ml.weights_map) {
const ggml_tensor * tensor = it.second.tensor;
const std::string name = it.first;
nq_elements += (size_t)ggml_nelements(tensor);
if (!is_quantizable(name, model.arch, params)) {
nq_bytes += ggml_nbytes(tensor);
}
}
auto total_bytes = [&]() -> size_t {
size_t tb = 0;
for (const auto & ti : all) {
@ -1228,19 +1240,20 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
return tb;
};
size_t total_elems = 0;
size_t q_elements = 0;
size_t min_bytes = 0;
size_t max_bytes = 0;
for (const auto & ti : all) {
total_elems += (size_t)ti.n_elements;
q_elements += (size_t)ti.n_elements;
min_bytes += ti.candidate.front().bytes; // smallest candidate per tensor
max_bytes += ti.candidate.back().bytes; // largest candidate per tensor
}
if (total_elems == 0) { return {}; }
if (q_elements == 0) { return {}; }
const double target_bpw = params->target_bpw;
size_t budget_bytes = std::llround(target_bpw * (double)total_elems / 8.0); // convert bpw to bytes
size_t target_total_bytes = std::llround(target_bpw * (double)nq_elements / 8.0);
size_t budget_bytes = target_total_bytes >= nq_bytes ? target_total_bytes - nq_bytes : min_bytes;
auto emit_overrides = [&]() -> std::unordered_map<std::string, ggml_type> {
std::unordered_map<std::string, ggml_type> overrides;
@ -1374,29 +1387,35 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
int best_i = -1;
int best_j = -1;
double best_ratio = -1.0;
size_t best_delta = 0;
double best_gain = -1.0;
for (int i = 0; i < (int)all.size(); ++i) {
const auto &ti = all[i];
int j = ti.choice + 1;
// skip same-bytes entries
while (j < (int)ti.candidate.size() && ti.candidate[j].bytes == ti.candidate[ti.choice].bytes) { ++j; }
if (j >= (int)ti.candidate.size()) { continue; }
size_t delta = ti.candidate[j].bytes - ti.candidate[ti.choice].bytes;
if (cur_bytes + delta > budget_bytes) { continue; }
size_t delta_bytes = ti.candidate[j].bytes - ti.candidate[ti.choice].bytes;
if (cur_bytes + delta_bytes > budget_bytes) { continue; }
double err_gain = std::max(0.0, ti.candidate[ti.choice].error - ti.candidate[j].error);
double ratio = err_gain / (double)(delta * 8); // error reduction per bit
if (ratio > best_ratio + epsilon || (std::abs(ratio - best_ratio) <= epsilon && delta < best_delta)) {
if (err_gain < epsilon) { continue; } // no real improvement
double ratio = err_gain / (double)delta_bytes; // error reduction per byte
// For tie-breaking, prioritize the largest absolute error improvement.
if (ratio > best_ratio + epsilon || (std::abs(ratio - best_ratio) <= epsilon && err_gain > best_gain)) {
best_ratio = ratio;
best_delta = delta;
best_gain = err_gain;
best_i = i;
best_j = j;
}
}
if (best_i < 0) { break; }
if (best_i < 0) { break; } // no more upgrades within budget found
size_t upgrade_cost = all[best_i].candidate[best_j].bytes - all[best_i].candidate[all[best_i].choice].bytes;
all[best_i].choice = best_j;
cur_bytes += best_delta;
cur_bytes += upgrade_cost;
}
}