Fix bias lambda bug

This commit is contained in:
Ed Addario 2025-08-20 17:26:37 +01:00
parent 52da4a4f8c
commit 3f0118d602
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 15 additions and 20 deletions

View File

@ -782,52 +782,47 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
} }
if (rs == 0) { continue; } if (rs == 0) { continue; }
const size_t got = ggml_quantize_chunk(typ, f32_sample.data(), qbuf.data(), 0, rs, n_per_row, value); // Quantize sample rows and dequantize back
(void)got; (void)ggml_quantize_chunk(typ, f32_sample.data(), qbuf.data(), 0, rs, n_per_row, value);
traits->to_float(qbuf.data(), deq.data(), rs * n_per_row); traits->to_float(qbuf.data(), deq.data(), rs * n_per_row);
// Compute error proxy per sampled row // Compute error proxy per sampled slice
double slice_err = 0.0;
for (int64_t s = 0; s < rs; ++s) { for (int64_t s = 0; s < rs; ++s) {
const float * xs = f32_sample.data() + s * n_per_row; const float * xs = f32_sample.data() + s * n_per_row;
const float * ys = deq.data() + s * n_per_row; const float * ys = deq.data() + s * n_per_row;
double mse_w = 0.0; double mse_w = 0.0;
double bias = 0.0;
double bias_sum = 0.0; double bias_sum = 0.0;
if (value) { if (value) {
for (int64_t j = 0; j < n_per_row; ++j) { for (int64_t j = 0; j < n_per_row; ++j) {
const float e = ys[j] - xs[j]; const float e = ys[j] - xs[j];
mse_w += e * e * value[j]; mse_w += e * e * value[j];
if (activation) { if (activation) { bias_sum += e * activation[j]; }
bias_sum += e * activation[j];
}
} }
} else { } else {
for (int64_t j = 0; j < n_per_row; ++j) { for (int64_t j = 0; j < n_per_row; ++j) {
const float e = ys[j] - xs[j]; const float e = ys[j] - xs[j];
mse_w += e * e; mse_w += e * e;
if (activation) { if (activation) { bias_sum += e * activation[j]; }
bias_sum += e * activation[j];
}
} }
} }
if (activation) { bias = std::abs(bias_sum); }
// Normalize by n_per_row to get a per-row average scale // Normalize by n_per_row to get a per-row average scale
double row_err = mse_w / std::max<int64_t>(1, n_per_row); double row_err = mse_w / std::max<int64_t>(1, n_per_row);
if (bias_lambda != 0.0) { if (activation && bias_lambda != 0.0) {
row_err += bias_lambda * (bias / std::max<int64_t>(1, n_per_row)); // bias_sum ~= sum_j ( (w_q - w_fp)[j] * E[a_j] )
const double bias = std::abs(bias_sum) / std::max<int64_t>(1, n_per_row);
row_err += bias_lambda * bias;
} }
total_err += row_err; slice_err += row_err;
} }
// Scale for the rows we didn't sample in this expert: multiply by stride-ish factor // Scale the slice contribution by the sampling factor
const auto scale_rows = (double)rows_per_expert / std::max(1.0, (double)rs); const auto scale_rows = (double)rows_per_expert / std::max(1.0, (double)rs);
total_err *= scale_rows; total_err += slice_err * scale_rows;
} }
return std::isfinite(total_err) ? total_err : 1e35; return std::isfinite(total_err) ? total_err : 1e35;
@ -1002,7 +997,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
if (delta_bytes == 0) { continue; } if (delta_bytes == 0) { continue; }
double err = (double)cur.error - (double)nxt.error; double err = (double)cur.error - (double)nxt.error;
err = std::max(err, 0.0); // do not penalize due to sampling noise err = std::max(err, 0.0);
double ratio = err / (double)(delta_bytes * 8ull); double ratio = err / (double)(delta_bytes * 8ull);
if (ratio > best.ratio + eps || (std::abs(ratio - best.ratio) <= eps && delta_bytes < best.delta_bytes)) { if (ratio > best.ratio + eps || (std::abs(ratio - best.ratio) <= eps && delta_bytes < best.delta_bytes)) {