Precompute error denominator in estimate_erro()

This commit is contained in:
Ed Addario 2025-08-21 16:25:31 +01:00
parent 887490c5ec
commit 9e11f82e8f
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 121 additions and 33 deletions

View File

@ -598,8 +598,8 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float *
// Returns per-tensor overrides of quantization types to meet target BPW with the lowest ppl
// sample_rows_per_expert: Larger values will result in more accurate error estimates, but will take longer to compute
// bias_lambda: Affects the weight of the bias term in the MSE error function. 0.0 means no bias, 1.0 means equal weight
// for bias and error, 2.0 means twice as much weight for bias
// bias_lambda: Affects the weight of the bias term in the weigthed MSE error function. 0.0 means no bias (standard MSE),
// 1.0 means equal weight for bias and error, 2.0 means twice as much weight for bias
static std::unordered_map<std::string, ggml_type> target_bpw_type(
llama_model_loader & ml,
std::vector<no_init<uint8_t>> & buffer,
@ -658,7 +658,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
GGML_TYPE_IQ3_S,
GGML_TYPE_IQ4_XS,
GGML_TYPE_IQ4_NL,
// Add higher-precision fallbacks for IQ mixes to improve ppl if bpw budget allows it
// TODO: add higher-precision fallbacks for IQ mixes to improve ppl if bpw budget allows it?
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
GGML_TYPE_Q5_K,
@ -770,7 +770,68 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
if (qbuf.size() < need_q) { qbuf.resize(need_q); }
if (deq.size() < nels) { deq.resize(nels); }
// Quantize sampled rows slice-by-slice
// Precompute denominators:
// - x2_per_row: sum_j w[j]*x[j]^2 if w present else sum_j x[j]^2
// - bden_per_slice: sum_j w[j]*a[j]^2 if w & a present; sum_j a[j]^2 if only a present; 0 otherwise
std::vector x2_per_row(total_sampled_rows, 0.0);
std::vector bden_per_slice(ne2, 0.0);
const bool has_w = (values_sample != nullptr);
const bool has_a = (activations_sample != nullptr);
// Precompute bden per slice (depends only on w,a)
if (has_a) {
for (int64_t s = 0; s < ne2; ++s) {
const float * wv = has_w ? values_sample + s * n_per_row : nullptr;
const float * act = activations_sample + s * n_per_row;
double bden = 0.0;
if (has_w) {
for (int64_t j = 0; j < n_per_row; ++j) {
const double a = act[j];
bden += (double) wv[j] * a * a;
}
} else {
for (int64_t j = 0; j < n_per_row; ++j) {
const double a = act[j];
bden += a * a;
}
}
bden_per_slice[s] = bden;
}
}
// Precompute x2 per sampled row
{
size_t off = 0;
size_t row_idx = 0;
for (int64_t s = 0; s < ne2; ++s) {
const int64_t rs = sample_rows_per_slice[s];
if (rs == 0) { continue; }
const float * wv = has_w ? values_sample + s * n_per_row : nullptr;
for (int64_t r = 0; r < rs; ++r, ++row_idx) {
const float * x = f32_sample.data() + off;
double x2 = 0.0;
if (has_w) {
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = wv[j];
const double xx = x[j];
x2 += w * xx * xx;
}
} else {
for (int64_t j = 0; j < n_per_row; ++j) {
const double xx = x[j];
x2 += xx * xx;
}
}
x2_per_row[row_idx] = x2;
off += (size_t)n_per_row;
}
}
}
// Quantize sampled rows slice-by-slice into qbuf
size_t qoff = 0;
size_t foff = 0;
for (int64_t slice = 0; slice < ne2; ++slice) {
@ -784,43 +845,50 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
foff += (size_t)rs * (size_t)n_per_row;
}
// Dequantize into deq
if (typ == GGML_TYPE_F16) {
ggml_fp16_to_fp32_row((const ggml_fp16_t *)qbuf.data(), deq.data(), (int)nels);
} else if (typ == GGML_TYPE_BF16) {
ggml_bf16_to_fp32_row((const ggml_bf16_t *)qbuf.data(), deq.data(), (int)nels);
} else {
// Dequantize into deq (row-wise if needed to avoid int overflow)
{
const ggml_type_traits * traits = ggml_get_type_traits(typ);
if (!traits || !traits->to_float) {
LLAMA_LOG_WARN("%s: unsupported quantization type %s\n", __func__, ggml_type_name(typ));
return 1e35;
}
if (typ == GGML_TYPE_F16) {
ggml_fp16_to_fp32_row((const ggml_fp16_t *)qbuf.data(), deq.data(), (int)nels);
} else if (typ == GGML_TYPE_BF16) {
ggml_bf16_to_fp32_row((const ggml_bf16_t *)qbuf.data(), deq.data(), (int)nels);
} else {
if (!traits || !traits->to_float) {
LLAMA_LOG_WARN("%s: unsupported quantization type %s\n", __func__, ggml_type_name(typ));
return 1e35;
}
traits->to_float(qbuf.data(), deq.data(), (int) nels);
size_t done = 0;
while (done < nels) {
const size_t chunk = std::min((size_t)n_per_row, nels - done);
traits->to_float(qbuf.data() + done / n_per_row * row_sz, deq.data() + done, (int)chunk);
done += chunk;
}
}
}
// Compute error
const double eps = 1e-12;
size_t off = 0;
size_t row_idx = 0;
double total_err = 0.0;
for (int64_t slice = 0; slice < ne2; ++slice) {
const int64_t rs = sample_rows_per_slice[slice];
if (rs == 0) { continue; }
const float * wv = values_sample ? values_sample + slice * n_per_row : nullptr;
const float * act = activations_sample ? activations_sample + slice * n_per_row : nullptr;
const float * wv = has_w ? values_sample + slice * n_per_row : nullptr;
const float * act = has_a ? activations_sample + slice * n_per_row : nullptr;
const double bden = has_a ? bden_per_slice[slice] : 0.0;
double slice_err = 0.0;
for (int64_t r = 0; r < rs; ++r) {
for (int64_t r = 0; r < rs; ++r, ++row_idx) {
const float * x = f32_sample.data() + off;
const float * y = deq.data() + off;
double mse_w = 0.0;
double x2_w = 0.0;
double bnum = 0.0;
double bden = 0.0;
if (wv && act) {
for (int64_t j = 0; j < n_per_row; ++j) {
@ -828,52 +896,49 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const double e = y[j] - x[j];
const double a = act[j];
mse_w += w * e * e;
x2_w += w * x[j] * x[j];
bnum += w * e * a; // weighted bias
bden += w * a * a; // weighted norm
bnum += w * e * a;
}
} else if (wv) {
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = wv[j];
const double e = y[j] - x[j];
mse_w += w * e * e;
x2_w += w * x[j] * x[j];
}
} else if (act) {
for (int64_t j = 0; j < n_per_row; ++j) {
const double e = y[j] - x[j];
const double a = act[j];
mse_w += e * e;
x2_w += x[j] * x[j];
bnum += e * a;
bden += a * a;
}
} else {
for (int64_t j = 0; j < n_per_row; ++j) {
const double e = y[j] - x[j];
mse_w += e * e;
x2_w += x[j] * x[j];
}
}
double row_err = mse_w / (x2_w + eps);
// corrected normalization: divide the full numerator by x2
double numer = mse_w;
if (act && bias_lambda != 0.0) {
// penalize squared projection of error onto activations
row_err += bias_lambda * (bnum * bnum) / (bden + eps);
const double proj = bnum * bnum / (bden + eps);
numer += bias_lambda * proj;
}
const double denom = x2_per_row[row_idx] + eps;
const double row_err = numer / denom;
slice_err += row_err;
off += (size_t)n_per_row;
}
// scale to full rows in this slice (nrows)
// scale to full rows (nrows)
const double scale_rows = (double)nrows / std::max(1.0, (double)rs);
total_err += slice_err * scale_rows;
}
return std::isfinite(total_err) ? total_err : 1e35;
};
};
std::vector<tensor_info> all;
all.reserve(tensors.size());
@ -1067,6 +1132,29 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
info.candidate.push_back(candidate_types{ t->type, bpw, ggml_nbytes(t), 0.0 });
}
// Remove dominated candidates: if A has >= bytes and >= error than B (and > in at least one), drop A.
{
std::vector<candidate_types> pruned;
pruned.reserve(info.candidate.size());
// Sort by bytes asc, error asc
std::sort(info.candidate.begin(), info.candidate.end(), [](const candidate_types &a, const candidate_types &b) {
if (a.bytes != b.bytes) { return a.bytes < b.bytes; }
return a.error < b.error;
});
double best_err = std::numeric_limits<double>::infinity();
size_t last_bytes = std::numeric_limits<size_t>::max();
for (const auto &c : info.candidate) {
if (c.error < best_err || c.bytes > last_bytes) {
pruned.push_back(c);
best_err = std::min(best_err, (double)c.error);
last_bytes = c.bytes;
}
}
info.candidate.swap(pruned);
}
std::sort(info.candidate.begin(), info.candidate.end(), [](const candidate_types & a, const candidate_types & b) {
if (a.bpw != b.bpw) { return a.bpw < b.bpw; }
if (a.error != b.error) { return a.error < b.error; }