Add experimental entropy-modulated weighted cosine error (WCE)

This commit is contained in:
Ed Addario 2026-01-21 18:28:37 +00:00
parent 3ba6798d45
commit 1c23a6fbd2
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 278 additions and 75 deletions

View File

@ -631,6 +631,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const std::map<int, std::string> & mapped,
const std::unordered_map<std::string, std::vector<float>> * values_data,
const std::unordered_map<std::string, std::vector<float>> * activations_data,
const std::unordered_map<std::string, std::vector<float>> * statistics_data,
const llama_model_quantize_params * params,
int nthread
) {
@ -651,14 +652,15 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
}
} signal_guard;
// Error and bias projection per GGML_TYPE per tensor
struct candidate_types {
// GGML_TYPE scores
struct type_scores {
ggml_type type = GGML_TYPE_COUNT;
float bpw = 0.0f;
size_t bytes = 0;
double error = 0.0;
double mse = 0.0;
double proj = 0.0;
double wce = 0.0;
};
// Tensor quantization type choice
@ -694,11 +696,19 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
#endif
};
constexpr double epsilon = 1e-12;
constexpr double infinity = std::numeric_limits<double>::infinity();
constexpr uint32_t file_magic = 0x4d534531; // MSE1
constexpr uint64_t arbitrary_magic = 0xeabada55cafed00d;
constexpr double EPSILON = 1e-12;
constexpr double INFINITE = std::numeric_limits<double>::infinity();
constexpr uint32_t MSE_MAGIC = 0x4d534531; // MSE1
constexpr uint32_t WCE_MAGIC = 0x57434531; // WCE1
constexpr uint64_t HASH_MAGIC = 0xeabada55cafed00d;
const char * func = __func__;
const bool wce = params->use_wce;
const bool valid_wce = wce && activations_data && statistics_data != nullptr;
const uint32_t file_magic = valid_wce ? WCE_MAGIC : MSE_MAGIC;
if (wce && !valid_wce) {
LLAMA_LOG_WARN("%s: WCE optimization requested but no activation or statistics data provided; using default MSE optimization.\n", func);
}
// Tensor size in bytes for a given type
auto tensor_bytes = [](const ggml_tensor * gt, const ggml_type gq) -> size_t {
@ -908,8 +918,28 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
}
};
// Quality metrics
struct quant_error {
double error = INFINITE;
double mse = 0.0;
double proj = 0.0;
double wce = 0.0;
};
// Pre-calculated stats for MSE
struct mse_cache {
std::vector<double> bias_denominator;
std::vector<double> row_sq_norm;
};
// Pre-calculated stats for WCE
struct wce_cache {
std::vector<double> row_sq_norm;
};
// Estimate error for a given type using a sampled subset of rows
auto estimate_error = [&](const ggml_tensor * t,
auto compute_quant_error = [&](
const ggml_tensor * t,
const ggml_type quant_type,
const std::vector<float> & f32_sample,
const std::vector<int64_t> & rows_sample,
@ -917,89 +947,79 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const float * activations_sample,
std::vector<uint8_t> & quantized_buffer,
std::vector<float> & dequantized_buffer,
float tensor_bias_lambda,
const float * slice_bias_lambda,
double * out_mse = nullptr,
double * out_proj = nullptr) -> double
float tensor_bias,
const float * slice_bias,
const wce_cache * ref_wce = nullptr,
const mse_cache * ref_mse = nullptr
) -> quant_error
{
const int64_t n_per_row = t->ne[0];
const int64_t nrows = t->ne[1];
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
const size_t sample_elems = f32_sample.size();
const size_t sample_rows = n_per_row > 0 ? sample_elems / (size_t)n_per_row : 0;
const size_t sample_rows = n_per_row > 0 ? sample_elems / (size_t)n_per_row : 0;
quant_error qe;
if (sample_rows == 0) {
if (out_mse) { *out_mse = 0.0; }
if (out_proj) { *out_proj = 0.0; }
return 0.0;
}
size_t expected_rows = 0;
for (int64_t s = 0; s < ne2; ++s) {
expected_rows += (size_t)rows_sample[s];
}
if (expected_rows != sample_rows) {
if (out_mse) { *out_mse = infinity; }
if (out_proj) { *out_proj = 0.0; }
return infinity;
qe.error = 0.0;
return qe;
}
const size_t row_sz = ggml_row_size(quant_type, n_per_row);
const size_t buf_sz = row_sz * sample_rows;
if (quantized_buffer.size() < buf_sz) { quantized_buffer.resize(buf_sz); }
if (quantized_buffer.size() < row_sz * sample_rows) { quantized_buffer.resize(row_sz * sample_rows); }
if (dequantized_buffer.size() < sample_elems) { dequantized_buffer.resize(sample_elems); }
const bool has_values = values_sample != nullptr;
const bool has_activations = activations_sample != nullptr;
const bool has_vals = values_sample != nullptr;
const bool has_acts = activations_sample != nullptr;
const bool do_wce = valid_wce && has_acts && has_vals;
// Bias denominators per slice
std::vector<double> bias_denom(ne2, 0.0);
if (has_activations) {
for (int64_t s = 0; s < ne2; ++s) {
const float * v = has_values ? values_sample + s * n_per_row : nullptr;
const float * a = activations_sample + s * n_per_row;
double denom = 0.0;
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = v ? std::max(0.0f, v[j]) : 1.0;
const double aj = a[j];
denom += w * aj * aj;
}
// Sampled stats for MSE
std::vector<double> local_bias_denom;
std::vector<double> local_row_sq_norm;
const std::vector<double> * ptr_bias_denom = nullptr;
const std::vector<double> * ptr_row_sq_norm = nullptr;
bias_denom[s] = denom;
}
}
// Row squared norms (weighted if values present)
std::vector<double> row_sq_norm(sample_rows, 0.0);
{
size_t off = 0;
size_t ridx = 0;
for (int64_t s = 0; s < ne2; ++s) {
const int64_t rs = rows_sample[s];
if (rs == 0) { continue; }
const float * v = has_values ? values_sample + s * n_per_row : nullptr;
for (int64_t r = 0; r < rs; ++r, ++ridx) {
const float * x = f32_sample.data() + off;
double sum = 0.0;
if (v) {
// Setup reference stats pointers for MSE
if (!do_wce) {
if (ref_mse) {
ptr_bias_denom = & ref_mse->bias_denominator;
ptr_row_sq_norm = & ref_mse->row_sq_norm;
} else {
local_bias_denom.assign(ne2, 0.0);
if (has_acts) {
for (int64_t s = 0; s < ne2; ++s) {
const float * v = has_vals ? values_sample + s * n_per_row : nullptr;
const float * a = activations_sample + s * n_per_row;
double denom = 0.0;
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = std::max(0.0f, v[j]);
const double xx = x[j];
sum += w * xx * xx;
}
} else {
for (int64_t j = 0; j < n_per_row; ++j) {
const double xx = x[j];
sum += xx * xx;
const double w = v ? std::max(0.0f, v[j]) : 1.0;
const double aj = a[j];
denom += w * aj * aj;
}
local_bias_denom[s] = denom;
}
row_sq_norm[ridx] = sum;
off += (size_t)n_per_row;
}
ptr_bias_denom = & local_bias_denom;
local_row_sq_norm.reserve(sample_rows);
size_t off = 0;
for (int64_t s = 0; s < ne2; ++s) {
const int64_t rs = rows_sample[s];
const float * v = has_vals ? values_sample + s * n_per_row : nullptr;
for (int64_t r = 0; r < rs; ++r) {
const float * x = f32_sample.data() + off;
double sum = 0.0;
for (int64_t j = 0; j < n_per_row; ++j) {
double xx = x[j];
sum += (v ? std::max(0.0f, v[j]) : 1.0) * xx * xx;
}
local_row_sq_norm.push_back(sum);
off += (size_t)n_per_row;
}
}
ptr_row_sq_norm = & local_row_sq_norm;
}
}
@ -1039,6 +1059,105 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
return std::accumulate(v.begin() + k, v.end() - k, 0.0) / std::max(1.0, (double)(n - 2 * k));
};
// Compute Error Metrics: Entropy-Modulated Weighted Cosine Error (WCE) - Experimental
if (do_wce) {
float h_norm = 1.0f;
if (statistics_data) {
const std::string name = ggml_get_name(t);
const std::string key = remap_imatrix(name, mapped);
if (auto it = statistics_data->find(key); it != statistics_data->end() && !it->second.empty()) {
h_norm = it->second.size() > 3 ? it->second[1] : 1.0f;
}
}
double total_cos_error = 0.0;
size_t off = 0;
size_t sample_idx = 0;
const std::vector<double> * cached_norm_x = ref_wce && !ref_wce->row_sq_norm.empty() ? & ref_wce->row_sq_norm : nullptr;
for (int64_t s = 0; s < ne2; ++s) {
const int64_t rs = rows_sample[s];
if (rs == 0) { continue; }
const float * v = values_sample + s * n_per_row;
double slice_sum = 0.0;
for (int64_t r = 0; r < rs; ++r, ++sample_idx) {
const float * wx = f32_sample.data() + off;
const float * wy = dequantized_buffer.data() + off;
double dot = 0.0;
double ny = 0.0;
double nx = 0.0;
const bool calc_nx = !cached_norm_x;
// SIMD-friendly loops
if (v) {
if (calc_nx) {
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = std::max(0.0f, v[j]);
const double xj = wx[j];
const double yj = wy[j];
const double yw = yj * w;
dot += xj * yw;
ny += yj * yw;
nx += xj * xj * w;
}
} else {
nx = (* cached_norm_x)[sample_idx];
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = std::max(0.0f, v[j]);
const double yj = wy[j];
const double yw = yj * w;
dot += (double) wx[j] * yw;
ny += yj * yw;
}
}
} else {
if (calc_nx) {
for (int64_t j = 0; j < n_per_row; ++j) {
const double xj = wx[j];
const double yj = wy[j];
dot += xj * yj;
ny += yj * yj;
nx += xj * xj;
}
} else {
nx = (* cached_norm_x)[sample_idx];
for (int64_t j = 0; j < n_per_row; ++j) {
const double xj = wx[j];
const double yj = wy[j];
dot += xj * yj;
ny += yj * yj;
}
}
}
// Cosine Distance
double cos_sim;
const double norm_prod = nx * ny;
if (norm_prod <= EPSILON) { cos_sim = nx <= EPSILON && ny <= EPSILON ? 1.0 : 0.0; }
else { cos_sim = dot / std::sqrt(norm_prod); }
if (cos_sim > 1.0) { cos_sim = 1.0; }
else if (cos_sim < -1.0) { cos_sim = -1.0; }
slice_sum += 1.0 - cos_sim;
off += (size_t) n_per_row;
}
const double nrows = t->ne[1];
total_cos_error += slice_sum / (double) rs * (double) nrows;
}
const double penalty = 2.0 - std::clamp((double) h_norm, 0.0, 1.0);
qe.wce = total_cos_error * penalty;
qe.error = qe.wce;
return qe;
}
// Compute Error Metrics: Weighted MSE Optimization - Default
size_t off = 0;
size_t row_idx = 0;
@ -1258,6 +1377,71 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
prepare_broadcast(val_ptr, val_sz, val_vec);
prepare_broadcast(act_ptr, act_sz, act_vec);
// Precompute WCE reference stats (row_sq_norm) to avoid recalculation per candidate
wce_cache ref_wce;
mse_cache ref_mse;
size_t total_rows_sampled = 0;
for (int64_t r : rows_sample) { total_rows_sampled += r; }
if (valid_wce && !val_vec.empty() && !act_vec.empty()) {
ref_wce.row_sq_norm.reserve(total_rows_sampled);
size_t off = 0;
for (int64_t s = 0; s < ne2; ++s) {
const int64_t rs = rows_sample[s];
if (rs == 0) { continue; }
const float * v = val_vec.data() + s * n_per_row;
for (int64_t r = 0; r < rs; ++r) {
const float * wx = f32_sample.data() + off;
double norm_x = 0.0;
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = v ? std::max(0.0f, v[j]) : 1.0;
norm_x += (double)wx[j] * wx[j] * w;
}
ref_wce.row_sq_norm.push_back(norm_x);
off += n_per_row;
}
}
} else {
// Precompute MSE reference stats (row_sq_norm and bias_denominator) to avoid recalculation per candidate
ref_mse.row_sq_norm.reserve(total_rows_sampled);
ref_mse.bias_denominator.assign(ne2, 0.0);
const bool has_acts = !act_vec.empty();
const bool has_vals = !val_vec.empty();
// Bias Denominators
if (has_acts) {
for (int64_t s = 0; s < ne2; ++s) {
const float * v = has_vals ? val_vec.data() + s * n_per_row : nullptr;
const float * a = act_vec.data() + s * n_per_row;
double denom = 0.0;
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = v ? std::max(0.0f, v[j]) : 1.0;
const double aj = a[j];
denom += w * aj * aj;
}
ref_mse.bias_denominator[s] = denom;
}
}
// Row Squared Norms
size_t off = 0;
for (int64_t s = 0; s < ne2; ++s) {
const int64_t rs = rows_sample[s];
const float * v = has_vals ? val_vec.data() + s * n_per_row : nullptr;
for (int64_t r = 0; r < rs; ++r) {
const float * x = f32_sample.data() + off;
double sum = 0.0;
for (int64_t j = 0; j < n_per_row; ++j) {
double xx = x[j];
sum += (v ? std::max(0.0f, v[j]) : 1.0) * xx * xx;
}
ref_mse.row_sq_norm.push_back(sum);
off += (size_t)n_per_row;
}
}
}
// Build candidates
@ -1328,6 +1512,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
ch.w = tw;
ch.n_elements = ggml_nelements(tensor);
bool bias_needed = false;
if (!valid_wce && !slice_lambdas.empty()) {
// Determine if bias correction is required
double best_mse = INFINITE;
double max_rel_bias = 0.0;
@ -1731,6 +1916,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
}
const std::unordered_map<std::string, std::vector<float>> * values_data = nullptr;
const std::unordered_map<std::string, std::vector<float>> * activations_data = nullptr;
const std::unordered_map<std::string, std::vector<float>> * statistics_data = nullptr;
if (params->imatrix) {
values_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix);
if (values_data) {
@ -1761,6 +1947,12 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
}
}
}
if (params->statistics) {
statistics_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->statistics);
if (statistics_data) {
LLAMA_LOG_INFO(" and %d statistics", int(statistics_data->size()));
}
}
LLAMA_LOG_INFO("\n");
gguf_context_ptr ctx_out { gguf_init_empty() };
@ -1899,11 +2091,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
} else {
LLAMA_LOG_INFO("%s: imatrix does not have activations, process may be less accurate\n", __func__);
}
if (params->statistics) {
LLAMA_LOG_INFO("%s: imatrix has statistics\n", __func__);
}
if (params->ignore_tensor_importance) {
LLAMA_LOG_INFO("%s: distributing budget equitably across all tensors\n", __func__);
} else {
LLAMA_LOG_INFO("%s: assigning more budget to important tensors\n", __func__);
}
if (params->use_wce) {
LLAMA_LOG_INFO("%s: using experimental Entropy-Modulated Weighted Cosine Error (WCE) approximation optimization\n", __func__);
} else {
LLAMA_LOG_INFO("%s: using weighted Mean Squared Error (MSE) optimization\n", __func__);
}
if (params->target_size >= 0) {
LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve file size %.2f MiB\n",
__func__, (double)params->target_size / 1024.0 / 1024.0);
@ -1911,7 +2111,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve %.4f bpw\n", __func__, params->target_bpw);
}
bpw_overrides = target_bpw_type(ml, model, tensors, mapped, values_data, activations_data, params, nthread);
// get quantization type overrides targeting a given bits per weight budget
bpw_overrides = target_bpw_type(ml, model, tensors, mapped, values_data, activations_data, statistics_data, params, nthread);
} else {
LLAMA_LOG_WARN("%s: --target-bpw/--target-size require an imatrix but none was provided, ignoring\n", __func__);
}
@ -2170,6 +2371,7 @@ llama_model_quantize_params llama_model_quantize_default_params() {
/*.keep_split =*/ false,
/*.imatrix =*/ nullptr,
/*.activations =*/ nullptr,
/*.statistics =*/ nullptr,
/*.kv_overrides =*/ nullptr,
/*.tensor_type =*/ nullptr,
/*.prune_layers =*/ nullptr,
@ -2177,7 +2379,8 @@ llama_model_quantize_params llama_model_quantize_default_params() {
/*.target_size =*/ -1,
/*.save_state =*/ false,
/*.state_file =*/ nullptr,
/*.ignore_tensor_importance =*/ false
/*.ignore_tensor_importance =*/ false,
/*.use_wce =*/ false
};
return result;