diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index cae908803b..1677b242d9 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -725,7 +725,9 @@ static std::unordered_map target_bpw_type( const float * activations_sample, std::vector & quantized_buffer, std::vector & dequantized_buffer, - float bias_lambda) -> double + float bias_lambda, + double * out_mse = nullptr, + double * out_proj = nullptr) -> double { const int64_t n_per_row = t->ne[0]; const int64_t nrows = t->ne[1]; @@ -733,13 +735,23 @@ static std::unordered_map target_bpw_type( const size_t sample_element_count = f32_sample.size(); const size_t sample_row_count = n_per_row > 0 ? sample_element_count / (size_t)n_per_row : 0; - if (sample_row_count == 0) { return 0.0; } + if (sample_row_count == 0) { + if (out_mse) { *out_mse = 0.0; } + if (out_proj) { *out_proj = 0.0; } + + return 0.0; + } size_t expected_rows = 0; for (int64_t s = 0; s < ne2; ++s) { expected_rows += (size_t)sample_rows_per_slice[s]; } - if (expected_rows != sample_row_count) { return infinity; } + if (expected_rows != sample_row_count) { + if (out_mse) { *out_mse = infinity; } + if (out_proj) { *out_proj = 0.0; } + + return infinity; + } const size_t row_sz = ggml_row_size(quant_type, n_per_row); const size_t buffer_sz = row_sz * sample_row_count; @@ -750,7 +762,7 @@ static std::unordered_map target_bpw_type( const bool has_values = values_sample != nullptr; const bool has_activations = activations_sample != nullptr; - // Bias denominators per slice (only needed if we have activations) + // Bias denominators per slice std::vector bias_denominator_per_slice(ne2, 0.0); if (has_activations) { for (int64_t s = 0; s < ne2; ++s) { @@ -815,7 +827,6 @@ static std::unordered_map target_bpw_type( // quantized_buffer -> dequantized_buffer { const ggml_type_traits * traits = ggml_get_type_traits(quant_type); - const bool is_fp16 = quant_type == GGML_TYPE_F16; const bool is_bf16 = quant_type == GGML_TYPE_BF16; if (!is_fp16 && !is_bf16 && traits && traits->to_float) { @@ -825,12 +836,19 @@ static std::unordered_map target_bpw_type( uint8_t * src = quantized_buffer.data() + r * row_sz; float * dst = dequantized_buffer.data() + r * (size_t) n_per_row; if (is_fp16) { - ggml_fp16_to_fp32_row((const ggml_fp16_t *) src, dst, (int)n_per_row); - } else if (is_bf16) { - ggml_bf16_to_fp32_row((const ggml_bf16_t *) src, dst, (int)n_per_row); - } else { - if (!traits || !traits->to_float) { return infinity; } - traits->to_float(src, dst, (int)n_per_row); + ggml_fp16_to_fp32_row((const ggml_fp16_t *) src, dst, (int) n_per_row); + } + else if (is_bf16) { + ggml_bf16_to_fp32_row((const ggml_bf16_t *) src, dst, (int) n_per_row); + } + else { + if (!traits || !traits->to_float) { + if (out_mse) { *out_mse = infinity; } + if (out_proj) { *out_proj = 0.0; } + + return infinity; + } + traits->to_float(src, dst, (int) n_per_row); } } } @@ -839,8 +857,8 @@ static std::unordered_map target_bpw_type( // Compute error size_t offset = 0; size_t row_idx = 0; - double total_err = 0.0; - + double total_mse = 0.0; + double total_proj = 0.0; for (int64_t slice = 0; slice < ne2; ++slice) { const int64_t rs = sample_rows_per_slice[slice]; if (rs == 0) { continue; } @@ -848,7 +866,11 @@ static std::unordered_map target_bpw_type( const float * values = has_values ? values_sample + slice * n_per_row : nullptr; const float * activations = has_activations ? activations_sample + slice * n_per_row : nullptr; const double bias_denom = has_activations ? bias_denominator_per_slice[slice] : 0.0; - double slice_err = 0.0; + std::vector row_mse_norm; + std::vector row_proj_norm; + row_mse_norm.reserve(rs); + if (activations) { row_proj_norm.reserve(rs); } + for (int64_t r = 0; r < rs; ++r, ++row_idx) { const float * x = f32_sample.data() + offset; const float * y = dequantized_buffer.data() + offset; @@ -868,13 +890,6 @@ static std::unordered_map target_bpw_type( const double e = y[j] - x[j]; weighted_mse += w * e * e; } - } else if (activations) { - for (int64_t j = 0; j < n_per_row; ++j) { - const double e = y[j] - x[j]; - const double a = activations[j]; - weighted_mse += e * e; - bias_num += e * a; - } } else { for (int64_t j = 0; j < n_per_row; ++j) { const double e = y[j] - x[j]; @@ -882,28 +897,64 @@ static std::unordered_map target_bpw_type( } } - double err_num = weighted_mse; - if (activations && bias_lambda != 0.0f) { + const double denom_x = row_sq_norm[row_idx]; + double m_norm = weighted_mse / (denom_x + epsilon); + row_mse_norm.push_back(std::isfinite(m_norm) ? m_norm : infinity); + + if (activations) { + double p_norm = 0.0; if (bias_denom > 0.0) { const double proj = bias_num * bias_num / (bias_denom + epsilon); - err_num += bias_lambda * proj; + p_norm = std::isfinite(proj) ? proj : 0.0; } + row_proj_norm.push_back(p_norm); } - - const double denom = row_sq_norm[row_idx] + epsilon; - slice_err += err_num / denom; offset += (size_t)n_per_row; } + // Trimmed sum to avoid outlier rows dominating the results + auto trimmed_sum = [&](std::vector & v) -> double { + if (v.empty()) { return 0.0; } + const int64_t n = (int64_t)v.size(); + if (n < 50) { + double s = 0.0; + for (const double z : v) { s += z; } + return s; + } + + int64_t k = (int64_t) std::floor(0.02 * (double)n); // trim 2% on each side + k = std::max(0, std::min(k, n / 32)); // but not more than 3.125% + std::nth_element(v.begin(), v.begin() + k, v.end()); + std::nth_element(v.begin() + k, v.begin() + (n - k), v.end()); + double s = 0.0; + for (int64_t i = k; i < n - k; ++i) { + s += v[i]; + } + + return s; + }; + const double scale_rows = (double)nrows / std::max(1.0, (double)rs); - total_err += slice_err * scale_rows; - if (!std::isfinite(total_err)) { return infinity; } + + total_mse += trimmed_sum(row_mse_norm) * scale_rows; + if (activations) { total_proj += trimmed_sum(row_proj_norm) * scale_rows; } + + if (!std::isfinite(total_mse) || !std::isfinite(total_proj)) { + if (out_mse) { *out_mse = infinity; } + if (out_proj) { *out_proj = 0.0; } + + return infinity; + } } + if (out_mse) { *out_mse = total_mse; } + if (out_proj) { *out_proj = total_proj; } + + const double total_err = total_mse + bias_lambda * total_proj; return std::isfinite(total_err) ? total_err : infinity; }; - // Higher precision but much longer to compute + // Higher precision but longer to compute auto precise_lambda = [&](const ggml_tensor * t, const std::vector & f32_sample, const std::vector & sample_rows_per_slice, @@ -936,22 +987,17 @@ static std::unordered_map target_bpw_type( const int64_t n_per_row = t->ne[0]; const size_t total_sampled_rows = f32_sample.size() / n_per_row; size_t max_row_sz = 0; - for (auto pt : probes) { - max_row_sz = std::max(max_row_sz, ggml_row_size(pt, n_per_row)); - } + for (auto pt : probes) max_row_sz = std::max(max_row_sz, ggml_row_size(pt, n_per_row)); std::vector quantized_buffer(max_row_sz * total_sampled_rows); std::vector dequantized_buffer(f32_sample.size()); + std::vector ratios; ratios.reserve(probes.size()); for (const auto pt : probes) { - // err at lambda=0 => pure weighted MSE part - double err0 = estimate_error(t, pt, f32_sample, sample_rows_per_slice, values, activations, quantized_buffer, dequantized_buffer, 0.0f); - // err at lambda=1 => weighted MSE + projection penalty - const double err1 = estimate_error(t, pt, f32_sample, sample_rows_per_slice, values, activations, quantized_buffer, dequantized_buffer, 1.0f); - - const double p = std::max(0.0, err1 - err0); // projection term contribution - const double m = std::max(0.0, err0); // MSE term contribution + double m = 0.0; + double p = 0.0; + (void)estimate_error(t, pt, f32_sample, sample_rows_per_slice, values, activations, quantized_buffer, dequantized_buffer, 0.0f, &m, &p); if (p > epsilon && std::isfinite(m) && std::isfinite(p)) { ratios.push_back(m / p); }