From 3f0118d6029450955c43cd84109bdfc36a8cecd3 Mon Sep 17 00:00:00 2001 From: Ed Addario Date: Wed, 20 Aug 2025 17:26:37 +0100 Subject: [PATCH] Fix bias lambda bug --- src/llama-quant.cpp | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 786adfe547..44cf9e30e3 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -782,52 +782,47 @@ static std::unordered_map target_bpw_type( } if (rs == 0) { continue; } - const size_t got = ggml_quantize_chunk(typ, f32_sample.data(), qbuf.data(), 0, rs, n_per_row, value); - (void)got; - + // Quantize sample rows and dequantize back + (void)ggml_quantize_chunk(typ, f32_sample.data(), qbuf.data(), 0, rs, n_per_row, value); traits->to_float(qbuf.data(), deq.data(), rs * n_per_row); - // Compute error proxy per sampled row + // Compute error proxy per sampled slice + double slice_err = 0.0; for (int64_t s = 0; s < rs; ++s) { const float * xs = f32_sample.data() + s * n_per_row; const float * ys = deq.data() + s * n_per_row; double mse_w = 0.0; - double bias = 0.0; double bias_sum = 0.0; if (value) { for (int64_t j = 0; j < n_per_row; ++j) { const float e = ys[j] - xs[j]; mse_w += e * e * value[j]; - if (activation) { - bias_sum += e * activation[j]; - } + if (activation) { bias_sum += e * activation[j]; } } } else { for (int64_t j = 0; j < n_per_row; ++j) { const float e = ys[j] - xs[j]; mse_w += e * e; - if (activation) { - bias_sum += e * activation[j]; - } + if (activation) { bias_sum += e * activation[j]; } } } - if (activation) { bias = std::abs(bias_sum); } - // Normalize by n_per_row to get a per-row average scale double row_err = mse_w / std::max(1, n_per_row); - if (bias_lambda != 0.0) { - row_err += bias_lambda * (bias / std::max(1, n_per_row)); + if (activation && bias_lambda != 0.0) { + // bias_sum ~= sum_j ( (w_q - w_fp)[j] * E[a_j] ) + const double bias = std::abs(bias_sum) / std::max(1, n_per_row); + row_err += bias_lambda * bias; } - total_err += row_err; + slice_err += row_err; } - // Scale for the rows we didn't sample in this expert: multiply by stride-ish factor - const auto scale_rows = (double)rows_per_expert / std::max(1.0, (double)rs); - total_err *= scale_rows; + // Scale the slice contribution by the sampling factor + const auto scale_rows = (double)rows_per_expert / std::max(1.0, (double)rs); + total_err += slice_err * scale_rows; } return std::isfinite(total_err) ? total_err : 1e35; @@ -1002,7 +997,7 @@ static std::unordered_map target_bpw_type( if (delta_bytes == 0) { continue; } double err = (double)cur.error - (double)nxt.error; - err = std::max(err, 0.0); // do not penalize due to sampling noise + err = std::max(err, 0.0); double ratio = err / (double)(delta_bytes * 8ull); if (ratio > best.ratio + eps || (std::abs(ratio - best.ratio) <= eps && delta_bytes < best.delta_bytes)) {