Add experimental entropy-modulated weighted cosine error (WCE)
This commit is contained in:
parent
3ba6798d45
commit
1c23a6fbd2
|
|
@ -631,6 +631,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
const std::map<int, std::string> & mapped,
|
||||
const std::unordered_map<std::string, std::vector<float>> * values_data,
|
||||
const std::unordered_map<std::string, std::vector<float>> * activations_data,
|
||||
const std::unordered_map<std::string, std::vector<float>> * statistics_data,
|
||||
const llama_model_quantize_params * params,
|
||||
int nthread
|
||||
) {
|
||||
|
|
@ -651,14 +652,15 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
}
|
||||
} signal_guard;
|
||||
|
||||
// Error and bias projection per GGML_TYPE per tensor
|
||||
struct candidate_types {
|
||||
// GGML_TYPE scores
|
||||
struct type_scores {
|
||||
ggml_type type = GGML_TYPE_COUNT;
|
||||
float bpw = 0.0f;
|
||||
size_t bytes = 0;
|
||||
double error = 0.0;
|
||||
double mse = 0.0;
|
||||
double proj = 0.0;
|
||||
double wce = 0.0;
|
||||
};
|
||||
|
||||
// Tensor quantization type choice
|
||||
|
|
@ -694,11 +696,19 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
#endif
|
||||
};
|
||||
|
||||
constexpr double epsilon = 1e-12;
|
||||
constexpr double infinity = std::numeric_limits<double>::infinity();
|
||||
constexpr uint32_t file_magic = 0x4d534531; // MSE1
|
||||
constexpr uint64_t arbitrary_magic = 0xeabada55cafed00d;
|
||||
constexpr double EPSILON = 1e-12;
|
||||
constexpr double INFINITE = std::numeric_limits<double>::infinity();
|
||||
constexpr uint32_t MSE_MAGIC = 0x4d534531; // MSE1
|
||||
constexpr uint32_t WCE_MAGIC = 0x57434531; // WCE1
|
||||
constexpr uint64_t HASH_MAGIC = 0xeabada55cafed00d;
|
||||
const char * func = __func__;
|
||||
const bool wce = params->use_wce;
|
||||
const bool valid_wce = wce && activations_data && statistics_data != nullptr;
|
||||
const uint32_t file_magic = valid_wce ? WCE_MAGIC : MSE_MAGIC;
|
||||
|
||||
if (wce && !valid_wce) {
|
||||
LLAMA_LOG_WARN("%s: WCE optimization requested but no activation or statistics data provided; using default MSE optimization.\n", func);
|
||||
}
|
||||
|
||||
// Tensor size in bytes for a given type
|
||||
auto tensor_bytes = [](const ggml_tensor * gt, const ggml_type gq) -> size_t {
|
||||
|
|
@ -908,8 +918,28 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
}
|
||||
};
|
||||
|
||||
// Quality metrics
|
||||
struct quant_error {
|
||||
double error = INFINITE;
|
||||
double mse = 0.0;
|
||||
double proj = 0.0;
|
||||
double wce = 0.0;
|
||||
};
|
||||
|
||||
// Pre-calculated stats for MSE
|
||||
struct mse_cache {
|
||||
std::vector<double> bias_denominator;
|
||||
std::vector<double> row_sq_norm;
|
||||
};
|
||||
|
||||
// Pre-calculated stats for WCE
|
||||
struct wce_cache {
|
||||
std::vector<double> row_sq_norm;
|
||||
};
|
||||
|
||||
// Estimate error for a given type using a sampled subset of rows
|
||||
auto estimate_error = [&](const ggml_tensor * t,
|
||||
auto compute_quant_error = [&](
|
||||
const ggml_tensor * t,
|
||||
const ggml_type quant_type,
|
||||
const std::vector<float> & f32_sample,
|
||||
const std::vector<int64_t> & rows_sample,
|
||||
|
|
@ -917,89 +947,79 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
const float * activations_sample,
|
||||
std::vector<uint8_t> & quantized_buffer,
|
||||
std::vector<float> & dequantized_buffer,
|
||||
float tensor_bias_lambda,
|
||||
const float * slice_bias_lambda,
|
||||
double * out_mse = nullptr,
|
||||
double * out_proj = nullptr) -> double
|
||||
float tensor_bias,
|
||||
const float * slice_bias,
|
||||
const wce_cache * ref_wce = nullptr,
|
||||
const mse_cache * ref_mse = nullptr
|
||||
) -> quant_error
|
||||
{
|
||||
const int64_t n_per_row = t->ne[0];
|
||||
const int64_t nrows = t->ne[1];
|
||||
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
|
||||
const size_t sample_elems = f32_sample.size();
|
||||
const size_t sample_rows = n_per_row > 0 ? sample_elems / (size_t)n_per_row : 0;
|
||||
const size_t sample_rows = n_per_row > 0 ? sample_elems / (size_t)n_per_row : 0;
|
||||
|
||||
quant_error qe;
|
||||
if (sample_rows == 0) {
|
||||
if (out_mse) { *out_mse = 0.0; }
|
||||
if (out_proj) { *out_proj = 0.0; }
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
size_t expected_rows = 0;
|
||||
for (int64_t s = 0; s < ne2; ++s) {
|
||||
expected_rows += (size_t)rows_sample[s];
|
||||
}
|
||||
|
||||
if (expected_rows != sample_rows) {
|
||||
if (out_mse) { *out_mse = infinity; }
|
||||
if (out_proj) { *out_proj = 0.0; }
|
||||
return infinity;
|
||||
qe.error = 0.0;
|
||||
return qe;
|
||||
}
|
||||
|
||||
const size_t row_sz = ggml_row_size(quant_type, n_per_row);
|
||||
const size_t buf_sz = row_sz * sample_rows;
|
||||
|
||||
if (quantized_buffer.size() < buf_sz) { quantized_buffer.resize(buf_sz); }
|
||||
if (quantized_buffer.size() < row_sz * sample_rows) { quantized_buffer.resize(row_sz * sample_rows); }
|
||||
if (dequantized_buffer.size() < sample_elems) { dequantized_buffer.resize(sample_elems); }
|
||||
|
||||
const bool has_values = values_sample != nullptr;
|
||||
const bool has_activations = activations_sample != nullptr;
|
||||
const bool has_vals = values_sample != nullptr;
|
||||
const bool has_acts = activations_sample != nullptr;
|
||||
const bool do_wce = valid_wce && has_acts && has_vals;
|
||||
|
||||
// Bias denominators per slice
|
||||
std::vector<double> bias_denom(ne2, 0.0);
|
||||
if (has_activations) {
|
||||
for (int64_t s = 0; s < ne2; ++s) {
|
||||
const float * v = has_values ? values_sample + s * n_per_row : nullptr;
|
||||
const float * a = activations_sample + s * n_per_row;
|
||||
double denom = 0.0;
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double w = v ? std::max(0.0f, v[j]) : 1.0;
|
||||
const double aj = a[j];
|
||||
denom += w * aj * aj;
|
||||
}
|
||||
// Sampled stats for MSE
|
||||
std::vector<double> local_bias_denom;
|
||||
std::vector<double> local_row_sq_norm;
|
||||
const std::vector<double> * ptr_bias_denom = nullptr;
|
||||
const std::vector<double> * ptr_row_sq_norm = nullptr;
|
||||
|
||||
bias_denom[s] = denom;
|
||||
}
|
||||
}
|
||||
|
||||
// Row squared norms (weighted if values present)
|
||||
std::vector<double> row_sq_norm(sample_rows, 0.0);
|
||||
{
|
||||
size_t off = 0;
|
||||
size_t ridx = 0;
|
||||
for (int64_t s = 0; s < ne2; ++s) {
|
||||
const int64_t rs = rows_sample[s];
|
||||
if (rs == 0) { continue; }
|
||||
|
||||
const float * v = has_values ? values_sample + s * n_per_row : nullptr;
|
||||
for (int64_t r = 0; r < rs; ++r, ++ridx) {
|
||||
const float * x = f32_sample.data() + off;
|
||||
double sum = 0.0;
|
||||
if (v) {
|
||||
// Setup reference stats pointers for MSE
|
||||
if (!do_wce) {
|
||||
if (ref_mse) {
|
||||
ptr_bias_denom = & ref_mse->bias_denominator;
|
||||
ptr_row_sq_norm = & ref_mse->row_sq_norm;
|
||||
} else {
|
||||
local_bias_denom.assign(ne2, 0.0);
|
||||
if (has_acts) {
|
||||
for (int64_t s = 0; s < ne2; ++s) {
|
||||
const float * v = has_vals ? values_sample + s * n_per_row : nullptr;
|
||||
const float * a = activations_sample + s * n_per_row;
|
||||
double denom = 0.0;
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double w = std::max(0.0f, v[j]);
|
||||
const double xx = x[j];
|
||||
sum += w * xx * xx;
|
||||
}
|
||||
} else {
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double xx = x[j];
|
||||
sum += xx * xx;
|
||||
const double w = v ? std::max(0.0f, v[j]) : 1.0;
|
||||
const double aj = a[j];
|
||||
denom += w * aj * aj;
|
||||
}
|
||||
|
||||
local_bias_denom[s] = denom;
|
||||
}
|
||||
|
||||
row_sq_norm[ridx] = sum;
|
||||
off += (size_t)n_per_row;
|
||||
}
|
||||
|
||||
ptr_bias_denom = & local_bias_denom;
|
||||
local_row_sq_norm.reserve(sample_rows);
|
||||
size_t off = 0;
|
||||
for (int64_t s = 0; s < ne2; ++s) {
|
||||
const int64_t rs = rows_sample[s];
|
||||
const float * v = has_vals ? values_sample + s * n_per_row : nullptr;
|
||||
for (int64_t r = 0; r < rs; ++r) {
|
||||
const float * x = f32_sample.data() + off;
|
||||
double sum = 0.0;
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
double xx = x[j];
|
||||
sum += (v ? std::max(0.0f, v[j]) : 1.0) * xx * xx;
|
||||
}
|
||||
|
||||
local_row_sq_norm.push_back(sum);
|
||||
off += (size_t)n_per_row;
|
||||
}
|
||||
}
|
||||
|
||||
ptr_row_sq_norm = & local_row_sq_norm;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1039,6 +1059,105 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
return std::accumulate(v.begin() + k, v.end() - k, 0.0) / std::max(1.0, (double)(n - 2 * k));
|
||||
};
|
||||
|
||||
// Compute Error Metrics: Entropy-Modulated Weighted Cosine Error (WCE) - Experimental
|
||||
if (do_wce) {
|
||||
float h_norm = 1.0f;
|
||||
if (statistics_data) {
|
||||
const std::string name = ggml_get_name(t);
|
||||
const std::string key = remap_imatrix(name, mapped);
|
||||
if (auto it = statistics_data->find(key); it != statistics_data->end() && !it->second.empty()) {
|
||||
h_norm = it->second.size() > 3 ? it->second[1] : 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
double total_cos_error = 0.0;
|
||||
size_t off = 0;
|
||||
size_t sample_idx = 0;
|
||||
|
||||
const std::vector<double> * cached_norm_x = ref_wce && !ref_wce->row_sq_norm.empty() ? & ref_wce->row_sq_norm : nullptr;
|
||||
|
||||
for (int64_t s = 0; s < ne2; ++s) {
|
||||
const int64_t rs = rows_sample[s];
|
||||
if (rs == 0) { continue; }
|
||||
|
||||
const float * v = values_sample + s * n_per_row;
|
||||
double slice_sum = 0.0;
|
||||
|
||||
for (int64_t r = 0; r < rs; ++r, ++sample_idx) {
|
||||
const float * wx = f32_sample.data() + off;
|
||||
const float * wy = dequantized_buffer.data() + off;
|
||||
|
||||
double dot = 0.0;
|
||||
double ny = 0.0;
|
||||
double nx = 0.0;
|
||||
const bool calc_nx = !cached_norm_x;
|
||||
|
||||
// SIMD-friendly loops
|
||||
if (v) {
|
||||
if (calc_nx) {
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double w = std::max(0.0f, v[j]);
|
||||
const double xj = wx[j];
|
||||
const double yj = wy[j];
|
||||
const double yw = yj * w;
|
||||
dot += xj * yw;
|
||||
ny += yj * yw;
|
||||
nx += xj * xj * w;
|
||||
}
|
||||
} else {
|
||||
nx = (* cached_norm_x)[sample_idx];
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double w = std::max(0.0f, v[j]);
|
||||
const double yj = wy[j];
|
||||
const double yw = yj * w;
|
||||
dot += (double) wx[j] * yw;
|
||||
ny += yj * yw;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (calc_nx) {
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double xj = wx[j];
|
||||
const double yj = wy[j];
|
||||
dot += xj * yj;
|
||||
ny += yj * yj;
|
||||
nx += xj * xj;
|
||||
}
|
||||
} else {
|
||||
nx = (* cached_norm_x)[sample_idx];
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double xj = wx[j];
|
||||
const double yj = wy[j];
|
||||
dot += xj * yj;
|
||||
ny += yj * yj;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cosine Distance
|
||||
double cos_sim;
|
||||
const double norm_prod = nx * ny;
|
||||
|
||||
if (norm_prod <= EPSILON) { cos_sim = nx <= EPSILON && ny <= EPSILON ? 1.0 : 0.0; }
|
||||
else { cos_sim = dot / std::sqrt(norm_prod); }
|
||||
|
||||
if (cos_sim > 1.0) { cos_sim = 1.0; }
|
||||
else if (cos_sim < -1.0) { cos_sim = -1.0; }
|
||||
|
||||
slice_sum += 1.0 - cos_sim;
|
||||
off += (size_t) n_per_row;
|
||||
}
|
||||
|
||||
const double nrows = t->ne[1];
|
||||
total_cos_error += slice_sum / (double) rs * (double) nrows;
|
||||
}
|
||||
|
||||
const double penalty = 2.0 - std::clamp((double) h_norm, 0.0, 1.0);
|
||||
qe.wce = total_cos_error * penalty;
|
||||
qe.error = qe.wce;
|
||||
return qe;
|
||||
}
|
||||
|
||||
// Compute Error Metrics: Weighted MSE Optimization - Default
|
||||
size_t off = 0;
|
||||
size_t row_idx = 0;
|
||||
|
|
@ -1258,6 +1377,71 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
|
||||
prepare_broadcast(val_ptr, val_sz, val_vec);
|
||||
prepare_broadcast(act_ptr, act_sz, act_vec);
|
||||
|
||||
// Precompute WCE reference stats (row_sq_norm) to avoid recalculation per candidate
|
||||
wce_cache ref_wce;
|
||||
mse_cache ref_mse;
|
||||
size_t total_rows_sampled = 0;
|
||||
for (int64_t r : rows_sample) { total_rows_sampled += r; }
|
||||
|
||||
if (valid_wce && !val_vec.empty() && !act_vec.empty()) {
|
||||
ref_wce.row_sq_norm.reserve(total_rows_sampled);
|
||||
|
||||
size_t off = 0;
|
||||
for (int64_t s = 0; s < ne2; ++s) {
|
||||
const int64_t rs = rows_sample[s];
|
||||
if (rs == 0) { continue; }
|
||||
const float * v = val_vec.data() + s * n_per_row;
|
||||
|
||||
for (int64_t r = 0; r < rs; ++r) {
|
||||
const float * wx = f32_sample.data() + off;
|
||||
double norm_x = 0.0;
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double w = v ? std::max(0.0f, v[j]) : 1.0;
|
||||
norm_x += (double)wx[j] * wx[j] * w;
|
||||
}
|
||||
ref_wce.row_sq_norm.push_back(norm_x);
|
||||
off += n_per_row;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Precompute MSE reference stats (row_sq_norm and bias_denominator) to avoid recalculation per candidate
|
||||
ref_mse.row_sq_norm.reserve(total_rows_sampled);
|
||||
ref_mse.bias_denominator.assign(ne2, 0.0);
|
||||
const bool has_acts = !act_vec.empty();
|
||||
const bool has_vals = !val_vec.empty();
|
||||
|
||||
// Bias Denominators
|
||||
if (has_acts) {
|
||||
for (int64_t s = 0; s < ne2; ++s) {
|
||||
const float * v = has_vals ? val_vec.data() + s * n_per_row : nullptr;
|
||||
const float * a = act_vec.data() + s * n_per_row;
|
||||
double denom = 0.0;
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double w = v ? std::max(0.0f, v[j]) : 1.0;
|
||||
const double aj = a[j];
|
||||
denom += w * aj * aj;
|
||||
}
|
||||
ref_mse.bias_denominator[s] = denom;
|
||||
}
|
||||
}
|
||||
|
||||
// Row Squared Norms
|
||||
size_t off = 0;
|
||||
for (int64_t s = 0; s < ne2; ++s) {
|
||||
const int64_t rs = rows_sample[s];
|
||||
const float * v = has_vals ? val_vec.data() + s * n_per_row : nullptr;
|
||||
for (int64_t r = 0; r < rs; ++r) {
|
||||
const float * x = f32_sample.data() + off;
|
||||
double sum = 0.0;
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
double xx = x[j];
|
||||
sum += (v ? std::max(0.0f, v[j]) : 1.0) * xx * xx;
|
||||
}
|
||||
ref_mse.row_sq_norm.push_back(sum);
|
||||
off += (size_t)n_per_row;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build candidates
|
||||
|
|
@ -1328,6 +1512,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
ch.w = tw;
|
||||
ch.n_elements = ggml_nelements(tensor);
|
||||
bool bias_needed = false;
|
||||
if (!valid_wce && !slice_lambdas.empty()) {
|
||||
// Determine if bias correction is required
|
||||
double best_mse = INFINITE;
|
||||
double max_rel_bias = 0.0;
|
||||
|
|
@ -1731,6 +1916,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
}
|
||||
const std::unordered_map<std::string, std::vector<float>> * values_data = nullptr;
|
||||
const std::unordered_map<std::string, std::vector<float>> * activations_data = nullptr;
|
||||
const std::unordered_map<std::string, std::vector<float>> * statistics_data = nullptr;
|
||||
if (params->imatrix) {
|
||||
values_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix);
|
||||
if (values_data) {
|
||||
|
|
@ -1761,6 +1947,12 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
}
|
||||
}
|
||||
}
|
||||
if (params->statistics) {
|
||||
statistics_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->statistics);
|
||||
if (statistics_data) {
|
||||
LLAMA_LOG_INFO(" and %d statistics", int(statistics_data->size()));
|
||||
}
|
||||
}
|
||||
LLAMA_LOG_INFO("\n");
|
||||
|
||||
gguf_context_ptr ctx_out { gguf_init_empty() };
|
||||
|
|
@ -1899,11 +2091,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
} else {
|
||||
LLAMA_LOG_INFO("%s: imatrix does not have activations, process may be less accurate\n", __func__);
|
||||
}
|
||||
if (params->statistics) {
|
||||
LLAMA_LOG_INFO("%s: imatrix has statistics\n", __func__);
|
||||
}
|
||||
if (params->ignore_tensor_importance) {
|
||||
LLAMA_LOG_INFO("%s: distributing budget equitably across all tensors\n", __func__);
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: assigning more budget to important tensors\n", __func__);
|
||||
}
|
||||
if (params->use_wce) {
|
||||
LLAMA_LOG_INFO("%s: using experimental Entropy-Modulated Weighted Cosine Error (WCE) approximation optimization\n", __func__);
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: using weighted Mean Squared Error (MSE) optimization\n", __func__);
|
||||
}
|
||||
if (params->target_size >= 0) {
|
||||
LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve file size %.2f MiB\n",
|
||||
__func__, (double)params->target_size / 1024.0 / 1024.0);
|
||||
|
|
@ -1911,7 +2111,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve %.4f bpw\n", __func__, params->target_bpw);
|
||||
}
|
||||
|
||||
bpw_overrides = target_bpw_type(ml, model, tensors, mapped, values_data, activations_data, params, nthread);
|
||||
// get quantization type overrides targeting a given bits per weight budget
|
||||
bpw_overrides = target_bpw_type(ml, model, tensors, mapped, values_data, activations_data, statistics_data, params, nthread);
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: --target-bpw/--target-size require an imatrix but none was provided, ignoring\n", __func__);
|
||||
}
|
||||
|
|
@ -2170,6 +2371,7 @@ llama_model_quantize_params llama_model_quantize_default_params() {
|
|||
/*.keep_split =*/ false,
|
||||
/*.imatrix =*/ nullptr,
|
||||
/*.activations =*/ nullptr,
|
||||
/*.statistics =*/ nullptr,
|
||||
/*.kv_overrides =*/ nullptr,
|
||||
/*.tensor_type =*/ nullptr,
|
||||
/*.prune_layers =*/ nullptr,
|
||||
|
|
@ -2177,7 +2379,8 @@ llama_model_quantize_params llama_model_quantize_default_params() {
|
|||
/*.target_size =*/ -1,
|
||||
/*.save_state =*/ false,
|
||||
/*.state_file =*/ nullptr,
|
||||
/*.ignore_tensor_importance =*/ false
|
||||
/*.ignore_tensor_importance =*/ false,
|
||||
/*.use_wce =*/ false
|
||||
};
|
||||
|
||||
return result;
|
||||
|
|
|
|||
Loading…
Reference in New Issue