diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 67de29df87..b3e4b3cbf7 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -737,12 +737,12 @@ static std::unordered_map target_bpw_type( const int64_t n_per_row = t->ne[0]; const int64_t nrows = t->ne[1]; const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1; - const size_t sample_element_count = f32_sample.size(); - const size_t sample_row_count = n_per_row > 0 ? sample_element_count / (size_t)n_per_row : 0; - if (sample_row_count == 0) { + const size_t sample_elems = f32_sample.size(); + const size_t sample_rows = n_per_row > 0 ? sample_elems / (size_t)n_per_row : 0; + + if (sample_rows == 0) { if (out_mse) { *out_mse = 0.0; } if (out_proj) { *out_proj = 0.0; } - return 0.0; } @@ -751,105 +751,102 @@ static std::unordered_map target_bpw_type( expected_rows += (size_t)rows_sample[s]; } - if (expected_rows != sample_row_count) { + if (expected_rows != sample_rows) { if (out_mse) { *out_mse = infinity; } if (out_proj) { *out_proj = 0.0; } - return infinity; } const size_t row_sz = ggml_row_size(quant_type, n_per_row); - const size_t buffer_sz = row_sz * sample_row_count; + const size_t buf_sz = row_sz * sample_rows; - if (quantized_buffer.size() < buffer_sz) { quantized_buffer.resize(buffer_sz); } - if (dequantized_buffer.size() < sample_element_count) { dequantized_buffer.resize(sample_element_count); } + if (quantized_buffer.size() < buf_sz) { quantized_buffer.resize(buf_sz); } + if (dequantized_buffer.size() < sample_elems) { dequantized_buffer.resize(sample_elems); } const bool has_values = values_sample != nullptr; const bool has_activations = activations_sample != nullptr; // Bias denominators per slice - std::vector bias_denominator_per_slice(ne2, 0.0); + std::vector bias_denom(ne2, 0.0); if (has_activations) { for (int64_t s = 0; s < ne2; ++s) { - const float * values = has_values ? values_sample + s * n_per_row : nullptr; - const float * activations = activations_sample + s * n_per_row; + const float * v = has_values ? values_sample + s * n_per_row : nullptr; + const float * a = activations_sample + s * n_per_row; double denom = 0.0; for (int64_t j = 0; j < n_per_row; ++j) { - const double w = values ? std::max(0.0f, values[j]) : 1.0; - const double a = activations[j]; - denom += w * a * a; + const double w = v ? std::max(0.0f, v[j]) : 1.0; + const double aj = a[j]; + denom += w * aj * aj; } - bias_denominator_per_slice[s] = denom; + bias_denom[s] = denom; } } - // Weighted per-row squared norms - std::vector row_sq_norm(sample_row_count, 0.0); + // Row squared norms (weighted if values present) + std::vector row_sq_norm(sample_rows, 0.0); { - size_t offset = 0; - size_t row_idx = 0; + size_t off = 0; + size_t ridx = 0; for (int64_t s = 0; s < ne2; ++s) { const int64_t rs = rows_sample[s]; if (rs == 0) { continue; } - const float * values = has_values ? values_sample + s * n_per_row : nullptr; - for (int64_t r = 0; r < rs; ++r, ++row_idx) { - const float * x = f32_sample.data() + offset; - double rsn = 0.0; - if (values) { + const float * v = has_values ? values_sample + s * n_per_row : nullptr; + for (int64_t r = 0; r < rs; ++r, ++ridx) { + const float * x = f32_sample.data() + off; + double sum = 0.0; + if (v) { for (int64_t j = 0; j < n_per_row; ++j) { - const double w = std::max(0.0f, values[j]); + const double w = std::max(0.0f, v[j]); const double xx = x[j]; - rsn += w * xx * xx; + sum += w * xx * xx; } } else { for (int64_t j = 0; j < n_per_row; ++j) { const double xx = x[j]; - rsn += xx * xx; + sum += xx * xx; } } - row_sq_norm[row_idx] = rsn; - offset += (size_t)n_per_row; + + row_sq_norm[ridx] = sum; + off += (size_t)n_per_row; } } } - // Quantize sampled rows per slice -> quantized_buffer + // Quantize per slice into quantized_buffer { - size_t q_offset = 0; - size_t f_offset = 0; - for (int64_t slice = 0; slice < ne2; ++slice) { - const int64_t rs = rows_sample[slice]; + size_t qoff = 0; + size_t foff = 0; + for (int64_t s = 0; s < ne2; ++s) { + const int64_t rs = rows_sample[s]; if (rs == 0) { continue; } - const float * value = has_values ? values_sample + slice * n_per_row : nullptr; - (void)ggml_quantize_chunk(quant_type, f32_sample.data() + f_offset, quantized_buffer.data() + q_offset, 0, rs, n_per_row, value); - q_offset += row_sz * (size_t)rs; - f_offset += (size_t)rs * (size_t)n_per_row; + const float * v = has_values ? values_sample + s * n_per_row : nullptr; + (void)ggml_quantize_chunk(quant_type, f32_sample.data() + foff, quantized_buffer.data() + qoff, 0, rs, n_per_row, v); + qoff += row_sz * (size_t)rs; + foff += (size_t)rs * (size_t)n_per_row; } } - // quantized_buffer -> dequantized_buffer + // Dequantize into dequantized_buffer { const ggml_type_traits * traits = ggml_get_type_traits(quant_type); - const bool is_fp16 = quant_type == GGML_TYPE_F16; - const bool is_bf16 = quant_type == GGML_TYPE_BF16; - if (!is_fp16 && !is_bf16 && traits && traits->to_float) { - traits->to_float(quantized_buffer.data(), dequantized_buffer.data(), (int)(sample_row_count * (size_t)n_per_row)); + if (traits && traits->to_float && quant_type != GGML_TYPE_F16 && quant_type != GGML_TYPE_BF16) { + traits->to_float(quantized_buffer.data(), dequantized_buffer.data(), (int)(sample_rows * (size_t)n_per_row)); } else { - for (size_t r = 0; r < sample_row_count; ++r) { - uint8_t * src = quantized_buffer.data() + r * row_sz; + for (size_t r = 0; r < sample_rows; ++r) { + const uint8_t * src = quantized_buffer.data() + r * row_sz; float * dst = dequantized_buffer.data() + r * (size_t)n_per_row; - if (is_fp16) { + if (quant_type == GGML_TYPE_F16) { ggml_fp16_to_fp32_row((const ggml_fp16_t *)src, dst, (int)n_per_row); - } else if (is_bf16) { + } else if (quant_type == GGML_TYPE_BF16) { ggml_bf16_to_fp32_row((const ggml_bf16_t *)src, dst, (int)n_per_row); } else { if (!traits || !traits->to_float) { if (out_mse) { *out_mse = infinity; } if (out_proj) { *out_proj = 0.0; } - return infinity; } traits->to_float(src, dst, (int)n_per_row); @@ -858,94 +855,77 @@ static std::unordered_map target_bpw_type( } } - // Compute error - size_t offset = 0; - size_t row_idx = 0; + // Compute error per slice with trimmed aggregation + auto trimmed_sum = [&](std::vector & v) -> double { + const int64_t n = (int64_t)v.size(); + if (n == 0) { return 0.0; } + if (n < 50) { return std::accumulate(v.begin(), v.end(), 0.0); } + int64_t k = (int64_t) std::floor(0.02 * (double) n); // trim 2% on each side + k = std::clamp(k, 0, n / 32); // but no more than ~3% + std::nth_element(v.begin(), v.begin() + k, v.end()); + std::nth_element(v.begin() + k, v.begin() + (n - k), v.end()); + return std::accumulate(v.begin() + k, v.begin() + (n - k), 0.0); + }; + + size_t off = 0; + size_t ridx = 0; double total_mse = 0.0; double total_proj = 0.0; double total_bias = 0.0; - for (int64_t slice = 0; slice < ne2; ++slice) { - const int64_t rs = rows_sample[slice]; + for (int64_t s = 0; s < ne2; ++s) { + const int64_t rs = rows_sample[s]; if (rs == 0) { continue; } - const float * values = has_values ? values_sample + slice * n_per_row : nullptr; - const float * activations = has_activations ? activations_sample + slice * n_per_row : nullptr; - const double bias_denom = has_activations ? bias_denominator_per_slice[slice] : 0.0; + const float * v = has_values ? values_sample + s * n_per_row : nullptr; + const float * a = has_activations ? activations_sample + s * n_per_row : nullptr; + const double denom_bias = has_activations ? bias_denom[s] : 0.0; std::vector row_mse_norm; - std::vector row_proj_norm; row_mse_norm.reserve(rs); - if (activations) { row_proj_norm.reserve(rs); } + std::vector row_proj_norm; + if (a) { row_proj_norm.reserve(rs); } - for (int64_t r = 0; r < rs; ++r, ++row_idx) { - const float * x = f32_sample.data() + offset; - const float * y = dequantized_buffer.data() + offset; - double weighted_mse = 0.0; + for (int64_t r = 0; r < rs; ++r, ++ridx) { + const float * x = f32_sample.data() + off; + const float * y = dequantized_buffer.data() + off; + double w_mse = 0.0; double bias_num = 0.0; - if (values && activations) { - for (int64_t j = 0; j < n_per_row; ++j) { - const double w = std::max(0.0f, values[j]); - const double e = y[j] - x[j]; - const double a = activations[j]; - weighted_mse += w * e * e; - bias_num += w * e * a; - } - } else if (values) { - for (int64_t j = 0; j < n_per_row; ++j) { - const double w = std::max(0.0f, values[j]); - const double e = y[j] - x[j]; - weighted_mse += w * e * e; - } - } else { - for (int64_t j = 0; j < n_per_row; ++j) { - const double e = y[j] - x[j]; - weighted_mse += e * e; - } + for (int64_t j = 0; j < n_per_row; ++j) { + const double wj = v ? std::max(0.0f, v[j]) : 1.0; + const double e = y[j] - x[j]; + w_mse += wj * e * e; + if (a) { bias_num += wj * e * a[j]; } } - const double denom_x = row_sq_norm[row_idx]; - double m_norm = weighted_mse / (denom_x + epsilon); + const double denom_x = row_sq_norm[ridx]; + const double m_norm = w_mse / (denom_x + epsilon); row_mse_norm.push_back(std::isfinite(m_norm) ? m_norm : infinity); - if (activations) { + if (a) { double p_norm = 0.0; - if (bias_denom > 0.0) { - const double proj = bias_num * bias_num / (bias_denom + epsilon); + if (denom_bias > 0.0) { + const double proj = bias_num * bias_num / (denom_bias + epsilon); p_norm = std::isfinite(proj) ? proj : 0.0; } + row_proj_norm.push_back(p_norm); } - offset += (size_t)n_per_row; + off += (size_t)n_per_row; } - // Trimmed sum to avoid outlier rows dominating the results - auto trimmed_sum = [&](std::vector & v) -> double { - const int64_t n = (int64_t)v.size(); - if (n == 0) { return 0.0; } - if (n < 50) { return std::accumulate(v.begin(), v.end(), 0.0); } - - int64_t k = (int64_t)std::floor(0.02 * (double)n); // trim 2% each side - k = std::clamp(k, 0, n / 32); // cap at ~3.125% - std::nth_element(v.begin(), v.begin() + k, v.end()); - std::nth_element(v.begin() + k, v.begin() + (n - k), v.end()); - return std::accumulate(v.begin() + k, v.begin() + (n - k), 0.0); - }; - const double scale_rows = (double)nrows / std::max(1.0, (double)rs); const double slice_mse = trimmed_sum(row_mse_norm) * scale_rows; - const double slice_proj = activations ? trimmed_sum(row_proj_norm) * scale_rows : 0.0; + const double slice_proj = a ? trimmed_sum(row_proj_norm) * scale_rows : 0.0; total_mse += slice_mse; total_proj += slice_proj; - // per-slice lambda if provided, otherwise use scalar - const double bl = slice_bias_lambda ? (double)std::max(0.0f, slice_bias_lambda[slice]) : (double)tensor_bias_lambda; + const double bl = slice_bias_lambda ? (double)std::max(0.0f, slice_bias_lambda[s]) : (double)tensor_bias_lambda; total_bias += bl * slice_proj; if (!std::isfinite(total_mse) || !std::isfinite(total_proj) || !std::isfinite(total_bias)) { if (out_mse) { *out_mse = infinity; } if (out_proj) { *out_proj = 0.0; } - return infinity; } } @@ -954,7 +934,6 @@ static std::unordered_map target_bpw_type( if (out_proj) { *out_proj = total_proj; } const double total_err = slice_bias_lambda ? total_mse + total_bias : total_mse + tensor_bias_lambda * total_proj; - return std::isfinite(total_err) ? total_err : infinity; };