Change error estimate to use normalised weighted MSE

This commit is contained in:
Ed Addario 2025-08-21 10:46:37 +01:00
parent 5ef493ea1a
commit 95b2ab2800
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 134 additions and 70 deletions

View File

@ -9,6 +9,7 @@
#include <cinttypes>
#include <fstream>
#include <mutex>
#include <random>
#include <regex>
#include <thread>
#include <unordered_map>
@ -661,8 +662,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,
GGML_TYPE_Q8_0
GGML_TYPE_Q6_K
};
auto name_tn = LLM_TN(model.arch);
@ -752,103 +752,125 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const ggml_type typ,
const std::vector<float> & f32_sample,
const std::vector<int64_t> & sample_rows_per_slice,
const std::vector<float> & values_sample,
const std::vector<float> & activations_sample) -> double
const float * values_sample,
const float * activations_sample,
std::vector<uint8_t> & qbuf,
std::vector<float> & deq) -> double
{
const int64_t n_per_row = t->ne[0];
const int64_t nrows = t->ne[1];
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
const ggml_type_traits * traits = ggml_get_type_traits(typ);
if (!traits || !traits->to_float) {
// Cannot dequantize candidate -> assign very high error
return 1e35f;
}
const int64_t nrows = t->ne[1];
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
const size_t total_sampled_rows = f32_sample.size() / n_per_row;
if (total_sampled_rows == 0) { return 0.0; }
const size_t qbuf_size = ggml_row_size(typ, n_per_row) * total_sampled_rows;
std::vector<uint8_t> qbuf(qbuf_size);
std::vector<float> deq(f32_sample.size());
const size_t row_sz = ggml_row_size(typ, n_per_row);
const size_t need_q = row_sz * total_sampled_rows;
if (qbuf.size() < need_q) { qbuf.resize(need_q); }
if (deq.size() < f32_sample.size()) { deq.resize(f32_sample.size()); }
// Quantize all sampled rows at once and dequantize back
size_t qbuf_offset = 0;
size_t f32_offset = 0;
// Quantize sampled rows slice-by-slice
size_t qoff = 0;
size_t foff = 0;
for (int64_t slice = 0; slice < ne2; ++slice) {
const int64_t rs = sample_rows_per_slice[slice];
if (rs == 0) { continue; }
const float * value = values_sample.empty() ? nullptr : values_sample.data() + slice * n_per_row;
(void)ggml_quantize_chunk(typ, f32_sample.data() + f32_offset, qbuf.data() + qbuf_offset, 0, rs, n_per_row, value);
qbuf_offset += ggml_row_size(typ, n_per_row) * rs;
f32_offset += rs * n_per_row;
const float * value = values_sample ? values_sample + slice * n_per_row : nullptr;
(void)ggml_quantize_chunk(typ, f32_sample.data() + foff, qbuf.data() + qoff, 0, rs, n_per_row, value);
qoff += row_sz * rs;
foff += (size_t)rs * n_per_row;
}
// Dequantize to deq
if (typ == GGML_TYPE_F16) {
const auto *const src = (const ggml_fp16_t *)qbuf.data();
for (size_t r = 0; r < total_sampled_rows; ++r) {
ggml_fp16_to_fp32_row(src + r * n_per_row, deq.data() + r * n_per_row, n_per_row);
}
ggml_fp16_to_fp32_row((const ggml_fp16_t *)qbuf.data(), deq.data(), (int)f32_sample.size());
} else if (typ == GGML_TYPE_BF16) {
const auto *const src = (const ggml_bf16_t *)qbuf.data();
for (size_t r = 0; r < total_sampled_rows; ++r) {
ggml_bf16_to_fp32_row(src + r * n_per_row, deq.data() + r * n_per_row, n_per_row);
}
ggml_bf16_to_fp32_row((const ggml_bf16_t *)qbuf.data(), deq.data(), (int)f32_sample.size());
} else {
traits->to_float(qbuf.data(), deq.data(), f32_sample.size());
const ggml_type_traits * traits = ggml_get_type_traits(typ);
if (!traits || !traits->to_float) {
// no dequantizer available
return 1e35;
}
traits->to_float(qbuf.data(), deq.data(), (int) f32_sample.size());
}
// Compute error
size_t off = 0;
double total_err = 0.0;
size_t sample_offset = 0;
const double eps = 1e-12;
for (int64_t slice = 0; slice < ne2; ++slice) {
const float * wv = values_sample.empty() ? nullptr : values_sample.data() + slice * n_per_row;
const float * act = activations_sample.empty() ? nullptr : activations_sample.data() + slice * n_per_row;
const int64_t rs = sample_rows_per_slice[slice];
if (rs == 0) { continue; }
const float * wv = values_sample ? values_sample + slice * n_per_row : nullptr;
const float * act = activations_sample ? activations_sample + slice * n_per_row : nullptr;
double slice_err = 0.0;
for (int64_t s = 0; s < rs; ++s) {
const float * xs = f32_sample.data() + sample_offset;
const float * ys = deq.data() + sample_offset;
for (int64_t r = 0; r < rs; ++r) {
const float * x = f32_sample.data() + off;
const float * y = deq.data() + off;
double mse_w = 0.0;
double x2_w = 0.0;
double bias_num = 0.0;
double bias_den = 0.0;
double bnum = 0.0;
double bden = 0.0;
for (int64_t j = 0; j < n_per_row; ++j) {
const double e = ys[j] - xs[j];
const double w = wv ? wv[j] : 1.0;
mse_w += w * e * e;
x2_w += w * xs[j] * xs[j];
if (act) {
if (wv && act) {
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = wv[j];
const double e = y[j] - x[j];
const double a = act[j];
bias_num += e * a;
bias_den += a * a;
mse_w += w * e * e;
x2_w += w * x[j] * x[j];
bnum += e * a;
bden += a * a;
}
} else if (wv) {
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = wv[j];
const double e = y[j] - x[j];
mse_w += w * e * e;
x2_w += w * x[j] * x[j];
}
} else if (act) {
for (int64_t j = 0; j < n_per_row; ++j) {
const double e = y[j] - x[j];
const double a = act[j];
mse_w += e * e;
x2_w += x[j] * x[j];
bnum += e * a;
bden += a * a;
}
} else {
for (int64_t j = 0; j < n_per_row; ++j) {
const double e = y[j] - x[j];
mse_w += e * e;
x2_w += x[j] * x[j];
}
}
const double eps = 1e-30;
double row_err = mse_w / (x2_w + eps);
if (act && bias_lambda != 0.0) {
const double bias_norm = bias_num * bias_num / (bias_den + eps);
row_err += bias_lambda * bias_norm;
row_err += bias_lambda * (bnum * bnum) / (bden + eps);
}
slice_err += row_err;
sample_offset += n_per_row;
off += (size_t)n_per_row;
}
const auto rows_per_expert = nrows;
const double scale_rows = (double)rows_per_expert / std::max(1.0, (double)rs);
// scale back up to the full number of rows in this slice
const double scale_rows = (double)nrows / std::max(1.0, (double)rs);
total_err += slice_err * scale_rows;
}
return std::isfinite(total_err) ? total_err : 1e35;
};
};
std::vector<tensor_info> all;
all.reserve(tensors.size());
@ -887,38 +909,70 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const int64_t n_per_row = t->ne[0];
const int64_t nrows_total = t->ne[1];
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
const int64_t rows_per_expert = nrows_total;
const int64_t sample_rows_max = std::max<int64_t>(1, std::min<int64_t>(rows_per_expert, sample_rows_per_expert));
const int64_t stride = std::max<int64_t>(1, rows_per_expert / sample_rows_max);
const int64_t sample_rows_max = std::max<int64_t>(1, std::min<int64_t>(nrows_total, sample_rows_per_expert));
const int64_t stride = std::max<int64_t>(1, nrows_total / sample_rows_max);
std::vector<float> f32_sample;
std::vector<float> values_sample;
std::vector<float> activations_sample;
std::vector<int64_t> sample_rows_per_slice(ne2);
std::mt19937 rng(std::random_device{}());
for (int64_t slice = 0; slice < ne2; ++slice) {
int64_t current_sampled_rows = 0;
for (int64_t r = 0; r < rows_per_expert && current_sampled_rows < sample_rows_max; r += stride) {
const float * src_row = f32_data + slice * (n_per_row * rows_per_expert) + r * n_per_row;
int64_t offset = 0;
if (stride > 1) {
std::uniform_int_distribution<int64_t> dist(0, stride - 1);
offset = dist(rng);
}
for (int64_t r = offset; r < nrows_total && current_sampled_rows < sample_rows_max; r += stride) {
const float * src_row = f32_data + slice * (n_per_row * nrows_total) + r * n_per_row;
f32_sample.insert(f32_sample.end(), src_row, src_row + n_per_row);
current_sampled_rows++;
}
sample_rows_per_slice[slice] = current_sampled_rows;
}
auto copy_or_broadcast = [&](const float *src, size_t src_sz, std::vector<float> &dst) {
const size_t want = (size_t)ne2 * (size_t)n_per_row;
dst.clear();
if (!src || src_sz == 0) { return; }
if (src_sz == want) {
dst.resize(want);
std::memcpy(dst.data(), src, want * sizeof(float));
} else if (src_sz == (size_t)n_per_row) {
dst.resize(want);
for (int64_t s = 0; s < ne2; ++s) {
std::memcpy(dst.data() + s * n_per_row, src, n_per_row * sizeof(float));
}
} else {
// Mismatch safer to skip using it for this tensor
LLAMA_LOG_WARN("%s: side data size mismatch for %s: got %zu, expected %zu or %zu; ignoring\n",
__func__, name.c_str(), src_sz, (size_t)n_per_row, want);
}
};
if (values_all) {
values_sample.resize(ne2 * n_per_row);
std::memcpy(values_sample.data(), values_all, ne2 * n_per_row * sizeof(float));
// get size from the map (not just the raw pointer)
auto itv = values_data->find(remap_imatrix(name, mapped));
const size_t sz = itv == values_data->end() ? 0 : itv->second.size();
copy_or_broadcast(values_all, sz, values_sample);
}
if (activations_all) {
activations_sample.resize(ne2 * n_per_row);
std::memcpy(activations_sample.data(), activations_all, ne2 * n_per_row * sizeof(float));
auto ita = activations_data->find(remap_imatrix(name, mapped));
const size_t sz = ita == activations_data->end() ? 0 : ita->second.size();
copy_or_broadcast(activations_all, sz, activations_sample);
}
tensor_info info;
info.w = tw;
info.n_elements = nelem;
// Prepare scratch buffers sized for the largest candidate row size
size_t total_sampled_rows = f32_sample.size() / n_per_row;
// Build list of candidate types first (compatible ones)
std::vector<ggml_type> quant_candidates;
if (is_iq(params->ftype)) {
quant_candidates.assign(std::begin(iq_candidates), std::end(iq_candidates));
@ -926,18 +980,28 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
quant_candidates.assign(std::begin(k_candidates), std::end(k_candidates));
}
// Build per-tensor candidate list
// Compute maximum row size among compatible candidates (to size qbuf once)
size_t max_row_sz = 0;
std::vector<ggml_type> compatible_candidates;
compatible_candidates.reserve(quant_candidates.size());
for (ggml_type ts_type : quant_candidates) {
if (is_iq(ts_type) && !values_all) { continue; }
ggml_type tt = make_compatible(t, ts_type);
if (!is_compatible(t, tt)) { continue; }
compatible_candidates.push_back(tt);
max_row_sz = std::max(max_row_sz, ggml_row_size(tt, n_per_row));
}
// Compute bpw and bytes
std::vector<uint8_t> qbuf(max_row_sz * total_sampled_rows);
std::vector<float> deq(f32_sample.size());
// Now evaluate candidates
for (ggml_type tt : compatible_candidates) {
auto bpw = (float)tensor_bpw(t, tt);
size_t bytes = total_bytes(t, tt);
// Estimate error using the pre-sampled data
auto err = (float)estimate_error(t, tt, f32_sample, sample_rows_per_slice, values_sample, activations_sample);
const float *vals_ptr = values_sample.empty() ? nullptr : values_sample.data();
const float *acts_ptr = activations_sample.empty() ? nullptr : activations_sample.data();
float err = (float)estimate_error(t, tt, f32_sample, sample_rows_per_slice, vals_ptr, acts_ptr, qbuf, deq);
info.candidate.push_back(candidate_types{ tt, bpw, bytes, err });
}