Improve error estimation using weighted MSE
This commit is contained in:
parent
b0b33b7ccb
commit
35ad0fc4ad
|
|
@ -783,14 +783,26 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
f32_offset += rs * n_per_row;
|
||||
}
|
||||
|
||||
traits->to_float(qbuf.data(), deq.data(), f32_sample.size());
|
||||
if (typ == GGML_TYPE_F16) {
|
||||
const auto *const src = (const ggml_fp16_t *)qbuf.data();
|
||||
for (size_t r = 0; r < total_sampled_rows; ++r) {
|
||||
ggml_fp16_to_fp32_row(src + r * n_per_row, deq.data() + r * n_per_row, n_per_row);
|
||||
}
|
||||
} else if (typ == GGML_TYPE_BF16) {
|
||||
const auto *const src = (const ggml_bf16_t *)qbuf.data();
|
||||
for (size_t r = 0; r < total_sampled_rows; ++r) {
|
||||
ggml_bf16_to_fp32_row(src + r * n_per_row, deq.data() + r * n_per_row, n_per_row);
|
||||
}
|
||||
} else {
|
||||
traits->to_float(qbuf.data(), deq.data(), f32_sample.size());
|
||||
}
|
||||
|
||||
double total_err = 0.0;
|
||||
size_t sample_offset = 0;
|
||||
|
||||
for (int64_t slice = 0; slice < ne2; ++slice) {
|
||||
const float * value_slice = values_sample.empty() ? nullptr : values_sample.data() + slice * n_per_row;
|
||||
const float * activation_slice = activations_sample.empty() ? nullptr : activations_sample.data() + slice * n_per_row;
|
||||
const float * wv = values_sample.empty() ? nullptr : values_sample.data() + slice * n_per_row;
|
||||
const float * act = activations_sample.empty() ? nullptr : activations_sample.data() + slice * n_per_row;
|
||||
const int64_t rs = sample_rows_per_slice[slice];
|
||||
|
||||
double slice_err = 0.0;
|
||||
|
|
@ -799,37 +811,37 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
const float * ys = deq.data() + sample_offset;
|
||||
|
||||
double mse_w = 0.0;
|
||||
double bias_sum = 0.0;
|
||||
double x2_w = 0.0;
|
||||
double bias_num = 0.0;
|
||||
double bias_den = 0.0;
|
||||
|
||||
if (value_slice) {
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const float e = ys[j] - xs[j];
|
||||
mse_w += e * e * value_slice[j];
|
||||
if (activation_slice) { bias_sum += e * activation_slice[j]; }
|
||||
}
|
||||
} else {
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const float e = ys[j] - xs[j];
|
||||
mse_w += e * e;
|
||||
if (activation_slice) { bias_sum += e * activation_slice[j]; }
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double e = ys[j] - xs[j];
|
||||
const double w = wv ? wv[j] : 1.0;
|
||||
mse_w += w * e * e;
|
||||
x2_w += w * xs[j] * xs[j];
|
||||
|
||||
if (act) {
|
||||
const double a = act[j];
|
||||
bias_num += e * a;
|
||||
bias_den += a * a;
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize by n_per_row to get a per-row average scale
|
||||
double row_err = mse_w / std::max<int64_t>(1, n_per_row);
|
||||
if (activation_slice && bias_lambda != 0.0) {
|
||||
// bias_sum ~= sum_j ( (w_q - w_fp)[j] * E[a_j] )
|
||||
const double bias = std::abs(bias_sum) / std::max<int64_t>(1, n_per_row);
|
||||
row_err += bias_lambda * bias;
|
||||
const double eps = 1e-30;
|
||||
double row_err = mse_w / (x2_w + eps);
|
||||
|
||||
if (act && bias_lambda != 0.0) {
|
||||
const double bias_norm = bias_num * bias_num / (bias_den + eps);
|
||||
row_err += bias_lambda * bias_norm;
|
||||
}
|
||||
|
||||
slice_err += row_err;
|
||||
sample_offset += n_per_row;
|
||||
}
|
||||
|
||||
// Scale the slice contribution by the sampling factor
|
||||
const double rows_per_expert = (double) nrows;
|
||||
const auto scale_rows = rows_per_expert / std::max(1.0, (double) rs);
|
||||
const auto rows_per_expert = nrows;
|
||||
const double scale_rows = (double)rows_per_expert / std::max(1.0, (double)rs);
|
||||
total_err += slice_err * scale_rows;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue