Improve error estimation using weighted MSE

This commit is contained in:
Ed Addario 2025-08-20 23:27:20 +01:00
parent b0b33b7ccb
commit 35ad0fc4ad
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 36 additions and 24 deletions

View File

@ -783,14 +783,26 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
f32_offset += rs * n_per_row;
}
traits->to_float(qbuf.data(), deq.data(), f32_sample.size());
if (typ == GGML_TYPE_F16) {
const auto *const src = (const ggml_fp16_t *)qbuf.data();
for (size_t r = 0; r < total_sampled_rows; ++r) {
ggml_fp16_to_fp32_row(src + r * n_per_row, deq.data() + r * n_per_row, n_per_row);
}
} else if (typ == GGML_TYPE_BF16) {
const auto *const src = (const ggml_bf16_t *)qbuf.data();
for (size_t r = 0; r < total_sampled_rows; ++r) {
ggml_bf16_to_fp32_row(src + r * n_per_row, deq.data() + r * n_per_row, n_per_row);
}
} else {
traits->to_float(qbuf.data(), deq.data(), f32_sample.size());
}
double total_err = 0.0;
size_t sample_offset = 0;
for (int64_t slice = 0; slice < ne2; ++slice) {
const float * value_slice = values_sample.empty() ? nullptr : values_sample.data() + slice * n_per_row;
const float * activation_slice = activations_sample.empty() ? nullptr : activations_sample.data() + slice * n_per_row;
const float * wv = values_sample.empty() ? nullptr : values_sample.data() + slice * n_per_row;
const float * act = activations_sample.empty() ? nullptr : activations_sample.data() + slice * n_per_row;
const int64_t rs = sample_rows_per_slice[slice];
double slice_err = 0.0;
@ -799,37 +811,37 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const float * ys = deq.data() + sample_offset;
double mse_w = 0.0;
double bias_sum = 0.0;
double x2_w = 0.0;
double bias_num = 0.0;
double bias_den = 0.0;
if (value_slice) {
for (int64_t j = 0; j < n_per_row; ++j) {
const float e = ys[j] - xs[j];
mse_w += e * e * value_slice[j];
if (activation_slice) { bias_sum += e * activation_slice[j]; }
}
} else {
for (int64_t j = 0; j < n_per_row; ++j) {
const float e = ys[j] - xs[j];
mse_w += e * e;
if (activation_slice) { bias_sum += e * activation_slice[j]; }
for (int64_t j = 0; j < n_per_row; ++j) {
const double e = ys[j] - xs[j];
const double w = wv ? wv[j] : 1.0;
mse_w += w * e * e;
x2_w += w * xs[j] * xs[j];
if (act) {
const double a = act[j];
bias_num += e * a;
bias_den += a * a;
}
}
// Normalize by n_per_row to get a per-row average scale
double row_err = mse_w / std::max<int64_t>(1, n_per_row);
if (activation_slice && bias_lambda != 0.0) {
// bias_sum ~= sum_j ( (w_q - w_fp)[j] * E[a_j] )
const double bias = std::abs(bias_sum) / std::max<int64_t>(1, n_per_row);
row_err += bias_lambda * bias;
const double eps = 1e-30;
double row_err = mse_w / (x2_w + eps);
if (act && bias_lambda != 0.0) {
const double bias_norm = bias_num * bias_num / (bias_den + eps);
row_err += bias_lambda * bias_norm;
}
slice_err += row_err;
sample_offset += n_per_row;
}
// Scale the slice contribution by the sampling factor
const double rows_per_expert = (double) nrows;
const auto scale_rows = rows_per_expert / std::max(1.0, (double) rs);
const auto rows_per_expert = nrows;
const double scale_rows = (double)rows_per_expert / std::max(1.0, (double)rs);
total_err += slice_err * scale_rows;
}