Expected Output Error MSE

This commit is contained in:
Ed Addario 2026-03-01 09:22:15 +00:00
parent 06d3b50b03
commit 6773bd59ad
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 37 additions and 130 deletions

View File

@ -588,7 +588,7 @@ static void signal_handler(int) {
bpw_stop.store(true, std::memory_order_relaxed);
}
// Returns tensor type overrides that meet a global bpw target
// Returns tensor type overrides that meet a global file size or bpw target
static std::unordered_map<std::string, ggml_type> target_bpw_type(
llama_model_loader & ml,
const llama_model & model,
@ -640,7 +640,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
size_t bytes = 0;
double error = 0.0;
double mse = 0.0;
double proj = 0.0;
double wce = 0.0;
};
@ -906,13 +905,11 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
struct quant_error {
double error = INFINITE;
double mse = 0.0;
double proj = 0.0;
double wce = 0.0;
};
// Pre-calculated stats for MSE
struct mse_cache {
std::vector<double> bias_denominator;
std::vector<double> row_sq_norm;
};
@ -931,8 +928,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const float * activations_sample,
std::vector<uint8_t> & quantized_buffer,
std::vector<float> & dequantized_buffer,
float tensor_bias,
const float * slice_bias,
const wce_cache * ref_wce = nullptr,
const mse_cache * ref_mse = nullptr
) -> quant_error
@ -958,44 +953,34 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const bool do_wce = valid_wce && has_acts && has_vals;
// Sampled stats for MSE
std::vector<double> local_bias_denom;
std::vector<double> local_row_sq_norm;
const std::vector<double> * ptr_bias_denom = nullptr;
const std::vector<double> * ptr_row_sq_norm = nullptr;
// Setup reference stats pointers for MSE
if (!do_wce) {
if (ref_mse) {
ptr_bias_denom = & ref_mse->bias_denominator;
ptr_row_sq_norm = & ref_mse->row_sq_norm;
} else {
local_bias_denom.assign(ne2, 0.0);
if (has_acts) {
for (int64_t s = 0; s < ne2; ++s) {
const float * v = has_vals ? values_sample + s * n_per_row : nullptr;
const float * a = activations_sample + s * n_per_row;
double denom = 0.0;
if (v) {
for (int64_t j = 0; j < n_per_row; ++j) { denom += std::max(0.0f, v[j]) * a[j] * a[j]; }
} else {
for (int64_t j = 0; j < n_per_row; ++j) { denom += a[j] * a[j]; }
}
local_bias_denom[s] = denom;
}
}
ptr_bias_denom = & local_bias_denom;
local_row_sq_norm.reserve(sample_rows);
size_t off = 0;
for (int64_t s = 0; s < ne2; ++s) {
const int64_t rs = rows_sample[s];
const float * v = has_vals ? values_sample + s * n_per_row : nullptr;
const float * val = has_vals ? values_sample + s * n_per_row : nullptr;
const float * act = has_acts ? activations_sample + s * n_per_row : nullptr;
for (int64_t r = 0; r < rs; ++r) {
const float * x = f32_sample.data() + off;
double sum = 0.0;
if (v) {
for (int64_t j = 0; j < n_per_row; ++j) { sum += std::max(0.0f, v[j]) * x[j] * x[j]; }
double bias_sum = 0.0;
if (val && act) {
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = std::max(0.0f, val[j]);
const double act_j = act[j];
sum += w * x[j] * x[j];
bias_sum += act_j * x[j];
}
sum += bias_sum * bias_sum;
} else if (val) {
for (int64_t j = 0; j < n_per_row; ++j) { sum += std::max(0.0f, val[j]) * x[j] * x[j]; }
} else {
for (int64_t j = 0; j < n_per_row; ++j) { sum += x[j] * x[j]; }
}
@ -1115,12 +1100,10 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
return qe;
}
// Weighted Mean Squared Error (MSE) - Default
// Expected Output Error MSE (EOE-MSE)
size_t off = 0;
size_t row_idx = 0;
double total_wmse = 0.0;
double total_proj = 0.0;
double total_bias = 0.0;
for (int64_t s = 0; s < ne2; ++s) {
const int64_t rs = rows_sample[s];
@ -1128,12 +1111,9 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const float * val = has_vals ? values_sample + s * n_per_row : nullptr;
const float * act = has_acts ? activations_sample + s * n_per_row : nullptr;
const double denom_bias = has_acts ? (* ptr_bias_denom)[s] : 0.0;
std::vector<double> slice_mse_norm;
slice_mse_norm.reserve(rs);
std::vector<double> slice_proj_norm;
if (act) { slice_proj_norm.reserve(rs); }
for (int64_t r = 0; r < rs; ++r, ++row_idx) {
const float * x = f32_sample.data() + off;
@ -1144,23 +1124,18 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
if (val && act) {
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = std::max(0.0f, val[j]);
const double a = act[j];
const double e = (double)y[j] - (double)x[j];
const double we = w * e;
w_err += we * e;
bias_num += we * act[j];
w_err += w * e * e;
bias_num += a * e;
}
w_err += bias_num * bias_num;
} else if (val) {
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = std::max(0.0f, val[j]);
const double e = (double)y[j] - (double)x[j];
w_err += w * e * e;
}
} else if (act) {
for (int64_t j = 0; j < n_per_row; ++j) {
const double e = (double)y[j] - (double)x[j];
w_err += e * e;
bias_num += e * act[j];
}
} else {
for (int64_t j = 0; j < n_per_row; ++j) {
const double e = (double)y[j] - (double)x[j];
@ -1172,61 +1147,20 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const double m_norm = rsn > EPSILON ? w_err / rsn : 0.0;
slice_mse_norm.push_back(std::isfinite(m_norm) ? m_norm : INFINITE);
if (act) {
double p_norm = 0.0;
if (denom_bias > 0.0) {
const double proj = bias_num * bias_num / (denom_bias + EPSILON);
p_norm = std::isfinite(proj) ? proj : 0.0;
}
slice_proj_norm.push_back(p_norm);
}
off += (size_t)n_per_row;
}
const int64_t nrows = t->ne[1];
const double slice_mean_mse = trimmed_mean(slice_mse_norm) * (double)nrows;
const double slice_mean_proj = act ? trimmed_mean(slice_proj_norm) * (double)nrows : 0.0;
total_wmse += slice_mean_mse;
total_proj += slice_mean_proj;
const double lambda = slice_bias ? (double)std::max(0.0f, slice_bias[s]) : (double)tensor_bias;
total_bias += lambda * slice_mean_proj;
}
qe.mse = total_wmse;
qe.proj = total_proj;
qe.error = total_wmse + total_bias;
qe.error = total_wmse;
return qe;
};
// Lambda per slice or 0.0 if no activations
auto estimate_lambda = [&](const float * values, const float * activations, const int64_t n_per_row, const int64_t ne2) -> std::vector<float> {
if (!activations) { return {}; }
const int64_t ns = std::max<int64_t>(1, ne2);
std::vector<float> lambdas(ns, 0.0f);
for (int64_t s = 0; s < ns; ++s) {
const float * v = values ? values + s * n_per_row : nullptr;
const float * a = activations + s * n_per_row;
double s1 = 0.0;
double s2 = 0.0;
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = v ? std::max(0.0f, v[j]) : 1.0;
const double aw2 = w * a[j] * a[j];
s1 += aw2;
s2 += aw2 * aw2;
}
if (s1 > 0.0) {
const double c = std::max(0.0, s2 / (s1 * s1 + EPSILON) - 1.0 / (double)n_per_row);
lambdas[s] = (float)std::clamp(12.0 * (c / (c + 1.0)), 0.0, 16.0);
}
}
return lambdas;
};
std::unordered_map<std::string, type_choice> bpw_data;
if (params->state_file && !checkpoint_file.empty()) { bpw_data = load_state(); } // ToDo: rethink this condition
@ -1336,7 +1270,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
auto [val_ptr, val_sz] = get_side_data(values_data);
auto [act_ptr, act_sz] = get_side_data(activations_data);
// Cache WCE stats once per tensor to avoid repeated map lookups/regex inside compute_quant_error
// Cache WCE stats per tensor
std::vector<float> val_storage;
std::vector<float> act_storage;
const float * val_vec_ptr = nullptr;
@ -1390,36 +1324,31 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
} else {
// Precompute MSE reference stats
ref_mse.row_sq_norm.reserve(total_rows_sampled);
ref_mse.bias_denominator.assign(ne2, 0.0);
const bool has_acts = act_vec_ptr != nullptr;
const bool has_vals = val_vec_ptr != nullptr;
if (has_acts) {
for (int64_t s = 0; s < ne2; ++s) {
const float * v = has_vals ? val_vec_ptr + s * n_per_row : nullptr;
const float * a = act_vec_ptr + s * n_per_row;
double denom = 0.0;
if (v) {
for (int64_t j = 0; j < n_per_row; ++j) { denom += std::max(0.0f, v[j]) * a[j] * a[j]; }
} else {
for (int64_t j = 0; j < n_per_row; ++j) { denom += a[j] * a[j]; }
}
ref_mse.bias_denominator[s] = denom;
}
}
size_t off = 0;
for (int64_t s = 0; s < ne2; ++s) {
const int64_t rs = rows_sample[s];
const float * v = has_vals ? val_vec_ptr + s * n_per_row : nullptr;
const float * val = has_vals ? val_vec_ptr + s * n_per_row : nullptr;
const float * act = has_acts ? act_vec_ptr + s * n_per_row : nullptr;
for (int64_t r = 0; r < rs; ++r) {
const float * x = f32_sample.data() + off;
double sum = 0.0;
if (v) {
for (int64_t j = 0; j < n_per_row; ++j) { sum += std::max(0.0f, v[j]) * x[j] * x[j]; }
}
else {
double bias_sum = 0.0;
if (val && act) {
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = std::max(0.0f, val[j]);
const double act_j = act[j];
sum += w * x[j] * x[j];
bias_sum += act_j * x[j];
}
sum += bias_sum * bias_sum;
} else if (val) {
for (int64_t j = 0; j < n_per_row; ++j) {
sum += std::max(0.0f, val[j]) * x[j] * x[j];
}
} else {
for (int64_t j = 0; j < n_per_row; ++j) { sum += x[j] * x[j]; }
}
@ -1446,14 +1375,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
std::sort(valid_types.begin(), valid_types.end());
valid_types.erase(std::unique(valid_types.begin(), valid_types.end()), valid_types.end());
float tensor_lambda = 0.0f;
std::vector<float> slice_lambdas = estimate_lambda(val_vec_ptr, act_vec_ptr, n_per_row, ne2);
if (!slice_lambdas.empty()) {
double sum = 0;
for(float l : slice_lambdas) { sum += l; }
tensor_lambda = (float)(sum / slice_lambdas.size());
}
// Evaluate candidates
std::vector<type_scores> evaluations;
evaluations.reserve(valid_types.size());
@ -1487,8 +1408,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
act_vec_ptr,
q_buf,
dq_buf,
tensor_lambda,
slice_lambdas.empty() ? nullptr : slice_lambdas.data(),
ptr_ref_wce,
ptr_ref_mse
);
@ -1499,7 +1418,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
candidate.bytes = tensor_bytes(tensor, vt);
candidate.error = qe.error * scaling_factor;
candidate.mse = qe.mse;
candidate.proj = qe.proj;
candidate.wce = qe.wce;
evaluations.push_back(candidate);
}
@ -1508,17 +1426,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
ch.w = tw;
ch.n_elements = ggml_nelements(tensor);
bool bias_needed = false;
if (!valid_wce && !slice_lambdas.empty()) {
double best_mse = INFINITE;
double max_rel_bias = 0.0;
for (const auto& c : evaluations) {
if (c.bytes == 0) { continue; }
best_mse = std::min(best_mse, c.mse);
if (c.mse > EPSILON) { max_rel_bias = std::max(max_rel_bias, std::max(0.0, c.error - c.mse) / c.mse); }
}
bias_needed = max_rel_bias >= 0.5;
}
for (const auto & ev : evaluations) {
if (ev.bytes == 0) { continue; }