diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 790003b5c9..b3f10856d6 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -631,6 +631,7 @@ static std::unordered_map target_bpw_type( const std::map & mapped, const std::unordered_map> * values_data, const std::unordered_map> * activations_data, + const std::unordered_map> * statistics_data, const llama_model_quantize_params * params, int nthread ) { @@ -651,14 +652,15 @@ static std::unordered_map target_bpw_type( } } signal_guard; - // Error and bias projection per GGML_TYPE per tensor - struct candidate_types { + // GGML_TYPE scores + struct type_scores { ggml_type type = GGML_TYPE_COUNT; float bpw = 0.0f; size_t bytes = 0; double error = 0.0; double mse = 0.0; double proj = 0.0; + double wce = 0.0; }; // Tensor quantization type choice @@ -694,11 +696,19 @@ static std::unordered_map target_bpw_type( #endif }; - constexpr double epsilon = 1e-12; - constexpr double infinity = std::numeric_limits::infinity(); - constexpr uint32_t file_magic = 0x4d534531; // MSE1 - constexpr uint64_t arbitrary_magic = 0xeabada55cafed00d; + constexpr double EPSILON = 1e-12; + constexpr double INFINITE = std::numeric_limits::infinity(); + constexpr uint32_t MSE_MAGIC = 0x4d534531; // MSE1 + constexpr uint32_t WCE_MAGIC = 0x57434531; // WCE1 + constexpr uint64_t HASH_MAGIC = 0xeabada55cafed00d; const char * func = __func__; + const bool wce = params->use_wce; + const bool valid_wce = wce && activations_data && statistics_data != nullptr; + const uint32_t file_magic = valid_wce ? WCE_MAGIC : MSE_MAGIC; + + if (wce && !valid_wce) { + LLAMA_LOG_WARN("%s: WCE optimization requested but no activation or statistics data provided; using default MSE optimization.\n", func); + } // Tensor size in bytes for a given type auto tensor_bytes = [](const ggml_tensor * gt, const ggml_type gq) -> size_t { @@ -908,8 +918,28 @@ static std::unordered_map target_bpw_type( } }; + // Quality metrics + struct quant_error { + double error = INFINITE; + double mse = 0.0; + double proj = 0.0; + double wce = 0.0; + }; + + // Pre-calculated stats for MSE + struct mse_cache { + std::vector bias_denominator; + std::vector row_sq_norm; + }; + + // Pre-calculated stats for WCE + struct wce_cache { + std::vector row_sq_norm; + }; + // Estimate error for a given type using a sampled subset of rows - auto estimate_error = [&](const ggml_tensor * t, + auto compute_quant_error = [&]( + const ggml_tensor * t, const ggml_type quant_type, const std::vector & f32_sample, const std::vector & rows_sample, @@ -917,89 +947,79 @@ static std::unordered_map target_bpw_type( const float * activations_sample, std::vector & quantized_buffer, std::vector & dequantized_buffer, - float tensor_bias_lambda, - const float * slice_bias_lambda, - double * out_mse = nullptr, - double * out_proj = nullptr) -> double + float tensor_bias, + const float * slice_bias, + const wce_cache * ref_wce = nullptr, + const mse_cache * ref_mse = nullptr + ) -> quant_error { const int64_t n_per_row = t->ne[0]; - const int64_t nrows = t->ne[1]; const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1; const size_t sample_elems = f32_sample.size(); - const size_t sample_rows = n_per_row > 0 ? sample_elems / (size_t)n_per_row : 0; + const size_t sample_rows = n_per_row > 0 ? sample_elems / (size_t)n_per_row : 0; + quant_error qe; if (sample_rows == 0) { - if (out_mse) { *out_mse = 0.0; } - if (out_proj) { *out_proj = 0.0; } - return 0.0; - } - - size_t expected_rows = 0; - for (int64_t s = 0; s < ne2; ++s) { - expected_rows += (size_t)rows_sample[s]; - } - - if (expected_rows != sample_rows) { - if (out_mse) { *out_mse = infinity; } - if (out_proj) { *out_proj = 0.0; } - return infinity; + qe.error = 0.0; + return qe; } const size_t row_sz = ggml_row_size(quant_type, n_per_row); - const size_t buf_sz = row_sz * sample_rows; - - if (quantized_buffer.size() < buf_sz) { quantized_buffer.resize(buf_sz); } + if (quantized_buffer.size() < row_sz * sample_rows) { quantized_buffer.resize(row_sz * sample_rows); } if (dequantized_buffer.size() < sample_elems) { dequantized_buffer.resize(sample_elems); } - const bool has_values = values_sample != nullptr; - const bool has_activations = activations_sample != nullptr; + const bool has_vals = values_sample != nullptr; + const bool has_acts = activations_sample != nullptr; + const bool do_wce = valid_wce && has_acts && has_vals; - // Bias denominators per slice - std::vector bias_denom(ne2, 0.0); - if (has_activations) { - for (int64_t s = 0; s < ne2; ++s) { - const float * v = has_values ? values_sample + s * n_per_row : nullptr; - const float * a = activations_sample + s * n_per_row; - double denom = 0.0; - for (int64_t j = 0; j < n_per_row; ++j) { - const double w = v ? std::max(0.0f, v[j]) : 1.0; - const double aj = a[j]; - denom += w * aj * aj; - } + // Sampled stats for MSE + std::vector local_bias_denom; + std::vector local_row_sq_norm; + const std::vector * ptr_bias_denom = nullptr; + const std::vector * ptr_row_sq_norm = nullptr; - bias_denom[s] = denom; - } - } - - // Row squared norms (weighted if values present) - std::vector row_sq_norm(sample_rows, 0.0); - { - size_t off = 0; - size_t ridx = 0; - for (int64_t s = 0; s < ne2; ++s) { - const int64_t rs = rows_sample[s]; - if (rs == 0) { continue; } - - const float * v = has_values ? values_sample + s * n_per_row : nullptr; - for (int64_t r = 0; r < rs; ++r, ++ridx) { - const float * x = f32_sample.data() + off; - double sum = 0.0; - if (v) { + // Setup reference stats pointers for MSE + if (!do_wce) { + if (ref_mse) { + ptr_bias_denom = & ref_mse->bias_denominator; + ptr_row_sq_norm = & ref_mse->row_sq_norm; + } else { + local_bias_denom.assign(ne2, 0.0); + if (has_acts) { + for (int64_t s = 0; s < ne2; ++s) { + const float * v = has_vals ? values_sample + s * n_per_row : nullptr; + const float * a = activations_sample + s * n_per_row; + double denom = 0.0; for (int64_t j = 0; j < n_per_row; ++j) { - const double w = std::max(0.0f, v[j]); - const double xx = x[j]; - sum += w * xx * xx; - } - } else { - for (int64_t j = 0; j < n_per_row; ++j) { - const double xx = x[j]; - sum += xx * xx; + const double w = v ? std::max(0.0f, v[j]) : 1.0; + const double aj = a[j]; + denom += w * aj * aj; } + + local_bias_denom[s] = denom; } - - row_sq_norm[ridx] = sum; - off += (size_t)n_per_row; } + + ptr_bias_denom = & local_bias_denom; + local_row_sq_norm.reserve(sample_rows); + size_t off = 0; + for (int64_t s = 0; s < ne2; ++s) { + const int64_t rs = rows_sample[s]; + const float * v = has_vals ? values_sample + s * n_per_row : nullptr; + for (int64_t r = 0; r < rs; ++r) { + const float * x = f32_sample.data() + off; + double sum = 0.0; + for (int64_t j = 0; j < n_per_row; ++j) { + double xx = x[j]; + sum += (v ? std::max(0.0f, v[j]) : 1.0) * xx * xx; + } + + local_row_sq_norm.push_back(sum); + off += (size_t)n_per_row; + } + } + + ptr_row_sq_norm = & local_row_sq_norm; } } @@ -1039,6 +1059,105 @@ static std::unordered_map target_bpw_type( return std::accumulate(v.begin() + k, v.end() - k, 0.0) / std::max(1.0, (double)(n - 2 * k)); }; + // Compute Error Metrics: Entropy-Modulated Weighted Cosine Error (WCE) - Experimental + if (do_wce) { + float h_norm = 1.0f; + if (statistics_data) { + const std::string name = ggml_get_name(t); + const std::string key = remap_imatrix(name, mapped); + if (auto it = statistics_data->find(key); it != statistics_data->end() && !it->second.empty()) { + h_norm = it->second.size() > 3 ? it->second[1] : 1.0f; + } + } + + double total_cos_error = 0.0; + size_t off = 0; + size_t sample_idx = 0; + + const std::vector * cached_norm_x = ref_wce && !ref_wce->row_sq_norm.empty() ? & ref_wce->row_sq_norm : nullptr; + + for (int64_t s = 0; s < ne2; ++s) { + const int64_t rs = rows_sample[s]; + if (rs == 0) { continue; } + + const float * v = values_sample + s * n_per_row; + double slice_sum = 0.0; + + for (int64_t r = 0; r < rs; ++r, ++sample_idx) { + const float * wx = f32_sample.data() + off; + const float * wy = dequantized_buffer.data() + off; + + double dot = 0.0; + double ny = 0.0; + double nx = 0.0; + const bool calc_nx = !cached_norm_x; + + // SIMD-friendly loops + if (v) { + if (calc_nx) { + for (int64_t j = 0; j < n_per_row; ++j) { + const double w = std::max(0.0f, v[j]); + const double xj = wx[j]; + const double yj = wy[j]; + const double yw = yj * w; + dot += xj * yw; + ny += yj * yw; + nx += xj * xj * w; + } + } else { + nx = (* cached_norm_x)[sample_idx]; + for (int64_t j = 0; j < n_per_row; ++j) { + const double w = std::max(0.0f, v[j]); + const double yj = wy[j]; + const double yw = yj * w; + dot += (double) wx[j] * yw; + ny += yj * yw; + } + } + } else { + if (calc_nx) { + for (int64_t j = 0; j < n_per_row; ++j) { + const double xj = wx[j]; + const double yj = wy[j]; + dot += xj * yj; + ny += yj * yj; + nx += xj * xj; + } + } else { + nx = (* cached_norm_x)[sample_idx]; + for (int64_t j = 0; j < n_per_row; ++j) { + const double xj = wx[j]; + const double yj = wy[j]; + dot += xj * yj; + ny += yj * yj; + } + } + } + + // Cosine Distance + double cos_sim; + const double norm_prod = nx * ny; + + if (norm_prod <= EPSILON) { cos_sim = nx <= EPSILON && ny <= EPSILON ? 1.0 : 0.0; } + else { cos_sim = dot / std::sqrt(norm_prod); } + + if (cos_sim > 1.0) { cos_sim = 1.0; } + else if (cos_sim < -1.0) { cos_sim = -1.0; } + + slice_sum += 1.0 - cos_sim; + off += (size_t) n_per_row; + } + + const double nrows = t->ne[1]; + total_cos_error += slice_sum / (double) rs * (double) nrows; + } + + const double penalty = 2.0 - std::clamp((double) h_norm, 0.0, 1.0); + qe.wce = total_cos_error * penalty; + qe.error = qe.wce; + return qe; + } + // Compute Error Metrics: Weighted MSE Optimization - Default size_t off = 0; size_t row_idx = 0; @@ -1258,6 +1377,71 @@ static std::unordered_map target_bpw_type( prepare_broadcast(val_ptr, val_sz, val_vec); prepare_broadcast(act_ptr, act_sz, act_vec); + + // Precompute WCE reference stats (row_sq_norm) to avoid recalculation per candidate + wce_cache ref_wce; + mse_cache ref_mse; + size_t total_rows_sampled = 0; + for (int64_t r : rows_sample) { total_rows_sampled += r; } + + if (valid_wce && !val_vec.empty() && !act_vec.empty()) { + ref_wce.row_sq_norm.reserve(total_rows_sampled); + + size_t off = 0; + for (int64_t s = 0; s < ne2; ++s) { + const int64_t rs = rows_sample[s]; + if (rs == 0) { continue; } + const float * v = val_vec.data() + s * n_per_row; + + for (int64_t r = 0; r < rs; ++r) { + const float * wx = f32_sample.data() + off; + double norm_x = 0.0; + for (int64_t j = 0; j < n_per_row; ++j) { + const double w = v ? std::max(0.0f, v[j]) : 1.0; + norm_x += (double)wx[j] * wx[j] * w; + } + ref_wce.row_sq_norm.push_back(norm_x); + off += n_per_row; + } + } + } else { + // Precompute MSE reference stats (row_sq_norm and bias_denominator) to avoid recalculation per candidate + ref_mse.row_sq_norm.reserve(total_rows_sampled); + ref_mse.bias_denominator.assign(ne2, 0.0); + const bool has_acts = !act_vec.empty(); + const bool has_vals = !val_vec.empty(); + + // Bias Denominators + if (has_acts) { + for (int64_t s = 0; s < ne2; ++s) { + const float * v = has_vals ? val_vec.data() + s * n_per_row : nullptr; + const float * a = act_vec.data() + s * n_per_row; + double denom = 0.0; + for (int64_t j = 0; j < n_per_row; ++j) { + const double w = v ? std::max(0.0f, v[j]) : 1.0; + const double aj = a[j]; + denom += w * aj * aj; + } + ref_mse.bias_denominator[s] = denom; + } + } + + // Row Squared Norms + size_t off = 0; + for (int64_t s = 0; s < ne2; ++s) { + const int64_t rs = rows_sample[s]; + const float * v = has_vals ? val_vec.data() + s * n_per_row : nullptr; + for (int64_t r = 0; r < rs; ++r) { + const float * x = f32_sample.data() + off; + double sum = 0.0; + for (int64_t j = 0; j < n_per_row; ++j) { + double xx = x[j]; + sum += (v ? std::max(0.0f, v[j]) : 1.0) * xx * xx; + } + ref_mse.row_sq_norm.push_back(sum); + off += (size_t)n_per_row; + } + } } // Build candidates @@ -1328,6 +1512,7 @@ static std::unordered_map target_bpw_type( ch.w = tw; ch.n_elements = ggml_nelements(tensor); bool bias_needed = false; + if (!valid_wce && !slice_lambdas.empty()) { // Determine if bias correction is required double best_mse = INFINITE; double max_rel_bias = 0.0; @@ -1731,6 +1916,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } const std::unordered_map> * values_data = nullptr; const std::unordered_map> * activations_data = nullptr; + const std::unordered_map> * statistics_data = nullptr; if (params->imatrix) { values_data = static_cast>*>(params->imatrix); if (values_data) { @@ -1761,6 +1947,12 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } } } + if (params->statistics) { + statistics_data = static_cast>*>(params->statistics); + if (statistics_data) { + LLAMA_LOG_INFO(" and %d statistics", int(statistics_data->size())); + } + } LLAMA_LOG_INFO("\n"); gguf_context_ptr ctx_out { gguf_init_empty() }; @@ -1899,11 +2091,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } else { LLAMA_LOG_INFO("%s: imatrix does not have activations, process may be less accurate\n", __func__); } + if (params->statistics) { + LLAMA_LOG_INFO("%s: imatrix has statistics\n", __func__); + } if (params->ignore_tensor_importance) { LLAMA_LOG_INFO("%s: distributing budget equitably across all tensors\n", __func__); } else { LLAMA_LOG_INFO("%s: assigning more budget to important tensors\n", __func__); } + if (params->use_wce) { + LLAMA_LOG_INFO("%s: using experimental Entropy-Modulated Weighted Cosine Error (WCE) approximation optimization\n", __func__); + } else { + LLAMA_LOG_INFO("%s: using weighted Mean Squared Error (MSE) optimization\n", __func__); + } if (params->target_size >= 0) { LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve file size %.2f MiB\n", __func__, (double)params->target_size / 1024.0 / 1024.0); @@ -1911,7 +2111,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve %.4f bpw\n", __func__, params->target_bpw); } - bpw_overrides = target_bpw_type(ml, model, tensors, mapped, values_data, activations_data, params, nthread); + // get quantization type overrides targeting a given bits per weight budget + bpw_overrides = target_bpw_type(ml, model, tensors, mapped, values_data, activations_data, statistics_data, params, nthread); } else { LLAMA_LOG_WARN("%s: --target-bpw/--target-size require an imatrix but none was provided, ignoring\n", __func__); } @@ -2170,6 +2371,7 @@ llama_model_quantize_params llama_model_quantize_default_params() { /*.keep_split =*/ false, /*.imatrix =*/ nullptr, /*.activations =*/ nullptr, + /*.statistics =*/ nullptr, /*.kv_overrides =*/ nullptr, /*.tensor_type =*/ nullptr, /*.prune_layers =*/ nullptr, @@ -2177,7 +2379,8 @@ llama_model_quantize_params llama_model_quantize_default_params() { /*.target_size =*/ -1, /*.save_state =*/ false, /*.state_file =*/ nullptr, - /*.ignore_tensor_importance =*/ false + /*.ignore_tensor_importance =*/ false, + /*.use_wce =*/ false }; return result;