Minor factoring for efficiency and correctness

This commit is contained in:
Ed Addario 2025-08-30 10:14:46 +01:00
parent 556f6b04fe
commit eab8708244
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 60 additions and 66 deletions

View File

@ -596,7 +596,7 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float *
return new_size; return new_size;
} }
// Returns per-tensor type overrides to meet target BPW at lowest error // Returns tensor type overrides to meet a global bpw target
static std::unordered_map<std::string, ggml_type> target_bpw_type( static std::unordered_map<std::string, ggml_type> target_bpw_type(
llama_model_loader & ml, llama_model_loader & ml,
std::vector<no_init<uint8_t>> & buffer, std::vector<no_init<uint8_t>> & buffer,
@ -650,6 +650,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
}; };
constexpr double epsilon = 1e-12; constexpr double epsilon = 1e-12;
constexpr double infinity = std::numeric_limits<double>::infinity();
auto tensor_bytes = [](const ggml_tensor * t, const ggml_type typ) -> size_t { auto tensor_bytes = [](const ggml_tensor * t, const ggml_type typ) -> size_t {
const int64_t n_per_row = t->ne[0]; const int64_t n_per_row = t->ne[0];
@ -680,7 +681,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
auto name_tn = LLM_TN(model.arch); auto name_tn = LLM_TN(model.arch);
auto can_quantize = [&](const ggml_tensor * t) -> bool { auto can_quantize = [&](const ggml_tensor * t) -> bool {
// This list should be kept in sync with llama_tensor_quantize_impl() // This list should be kept in sync with llama_tensor_quantize_impl() to avoid drift
const std::string name = ggml_get_name(t); const std::string name = ggml_get_name(t);
bool q = name.rfind("weight") == name.size() - 6; bool q = name.rfind("weight") == name.size() - 6;
q &= ggml_n_dims(t) >= 2; q &= ggml_n_dims(t) >= 2;
@ -730,9 +731,15 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1; const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
const size_t sample_element_count = f32_sample.size(); const size_t sample_element_count = f32_sample.size();
const size_t sample_row_count = sample_element_count / (size_t)n_per_row; const size_t sample_row_count = n_per_row > 0 ? sample_element_count / (size_t)n_per_row : 0;
if (sample_row_count == 0) { return 0.0; } if (sample_row_count == 0) { return 0.0; }
size_t expected_rows = 0;
for (int64_t s = 0; s < ne2; ++s) {
expected_rows += (size_t)sample_rows_per_slice[s];
}
if (expected_rows != sample_row_count) { return infinity; }
const size_t row_sz = ggml_row_size(quant_type, n_per_row); const size_t row_sz = ggml_row_size(quant_type, n_per_row);
const size_t buffer_sz = row_sz * sample_row_count; const size_t buffer_sz = row_sz * sample_row_count;
@ -750,15 +757,15 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const float * activations = activations_sample + s * n_per_row; const float * activations = activations_sample + s * n_per_row;
double denom = 0.0; double denom = 0.0;
for (int64_t j = 0; j < n_per_row; ++j) { for (int64_t j = 0; j < n_per_row; ++j) {
const double w = values ? std::max(0.0f, values[j]) : 1.0;
const double a = activations[j]; const double a = activations[j];
const double w = values ? values[j] : 1.0;
denom += w * a * a; denom += w * a * a;
} }
bias_denominator_per_slice[s] = denom; bias_denominator_per_slice[s] = denom;
} }
} }
// Compute per-row squared norms with weighting (if values are provided) // Per-row squared norms with weighting
std::vector<double> row_sq_norm(sample_row_count, 0.0); std::vector<double> row_sq_norm(sample_row_count, 0.0);
{ {
size_t offset = 0; size_t offset = 0;
@ -768,15 +775,14 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
if (rs == 0) { continue; } if (rs == 0) { continue; }
const float * values = has_values ? values_sample + s * n_per_row : nullptr; const float * values = has_values ? values_sample + s * n_per_row : nullptr;
for (int64_t r = 0; r < rs; ++r, ++row_idx) { for (int64_t r = 0; r < rs; ++r, ++row_idx) {
const float * x = f32_sample.data() + offset; const float * x = f32_sample.data() + offset;
double rsn = 0.0; double rsn = 0.0;
if (values) { if (values) {
for (int64_t j = 0; j < n_per_row; ++j) { for (int64_t j = 0; j < n_per_row; ++j) {
const double v = values[j]; const double w = std::max(0.0f, values[j]);
const double xx = x[j]; const double xx = x[j];
rsn += v * xx * xx; rsn += w * xx * xx;
} }
} else { } else {
for (int64_t j = 0; j < n_per_row; ++j) { for (int64_t j = 0; j < n_per_row; ++j) {
@ -790,7 +796,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
} }
} }
// Quantize sampled rows slice-by-slice into quantized_buffer // Quantize sampled rows per slice -> quantized_buffer
{ {
size_t q_offset = 0; size_t q_offset = 0;
size_t f_offset = 0; size_t f_offset = 0;
@ -800,35 +806,32 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const float * value = has_values ? values_sample + slice * n_per_row : nullptr; const float * value = has_values ? values_sample + slice * n_per_row : nullptr;
(void)ggml_quantize_chunk(quant_type, f32_sample.data() + f_offset, quantized_buffer.data() + q_offset, 0, rs, n_per_row, value); (void)ggml_quantize_chunk(quant_type, f32_sample.data() + f_offset, quantized_buffer.data() + q_offset, 0, rs, n_per_row, value);
q_offset += row_sz * (size_t)rs; q_offset += row_sz * (size_t)rs;
f_offset += (size_t)rs * (size_t)n_per_row; f_offset += (size_t)rs * (size_t)n_per_row;
} }
} }
// Dequantize into dequantized_buffer // quantized_buffer -> dequantized_buffer
{ {
const ggml_type_traits * traits = ggml_get_type_traits(quant_type); const ggml_type_traits * traits = ggml_get_type_traits(quant_type);
auto row_to_float = [&](size_t r) {
const bool is_fp16 = quant_type == GGML_TYPE_F16;
const bool is_bf16 = quant_type == GGML_TYPE_BF16;
if (!is_fp16 && !is_bf16 && traits && traits->to_float) {
traits->to_float(quantized_buffer.data(), dequantized_buffer.data(), (int)(sample_row_count * (size_t)n_per_row));
} else {
for (size_t r = 0; r < sample_row_count; ++r) {
uint8_t * src = quantized_buffer.data() + r * row_sz; uint8_t * src = quantized_buffer.data() + r * row_sz;
float * dst = dequantized_buffer.data() + r * (size_t) n_per_row; float * dst = dequantized_buffer.data() + r * (size_t) n_per_row;
if (quant_type == GGML_TYPE_F16) { if (is_fp16) {
ggml_fp16_to_fp32_row((const ggml_fp16_t *) src, dst, (int)n_per_row); ggml_fp16_to_fp32_row((const ggml_fp16_t *) src, dst, (int)n_per_row);
} else if (quant_type == GGML_TYPE_BF16) { } else if (is_bf16) {
ggml_bf16_to_fp32_row((const ggml_bf16_t *) src, dst, (int)n_per_row); ggml_bf16_to_fp32_row((const ggml_bf16_t *) src, dst, (int)n_per_row);
} else { } else {
if (!traits || !traits->to_float) { if (!traits || !traits->to_float) { return infinity; }
LLAMA_LOG_WARN("%s: unsupported quantization type %s\n", __func__, ggml_type_name(quant_type));
return false;
}
traits->to_float(src, dst, (int)n_per_row); traits->to_float(src, dst, (int)n_per_row);
} }
}
return true;
};
for (size_t r = 0; r < sample_row_count; ++r) {
if (!row_to_float(r)) { return 1e35; }
} }
} }
@ -836,6 +839,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
size_t offset = 0; size_t offset = 0;
size_t row_idx = 0; size_t row_idx = 0;
double total_err = 0.0; double total_err = 0.0;
for (int64_t slice = 0; slice < ne2; ++slice) { for (int64_t slice = 0; slice < ne2; ++slice) {
const int64_t rs = sample_rows_per_slice[slice]; const int64_t rs = sample_rows_per_slice[slice];
if (rs == 0) { continue; } if (rs == 0) { continue; }
@ -843,9 +847,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const float * values = has_values ? values_sample + slice * n_per_row : nullptr; const float * values = has_values ? values_sample + slice * n_per_row : nullptr;
const float * activations = has_activations ? activations_sample + slice * n_per_row : nullptr; const float * activations = has_activations ? activations_sample + slice * n_per_row : nullptr;
const double bias_denom = has_activations ? bias_denominator_per_slice[slice] : 0.0; const double bias_denom = has_activations ? bias_denominator_per_slice[slice] : 0.0;
double slice_err = 0.0; double slice_err = 0.0;
for (int64_t r = 0; r < rs; ++r, ++row_idx) { for (int64_t r = 0; r < rs; ++r, ++row_idx) {
const float * x = f32_sample.data() + offset; const float * x = f32_sample.data() + offset;
const float * y = dequantized_buffer.data() + offset; const float * y = dequantized_buffer.data() + offset;
@ -853,17 +855,17 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
double bias_num = 0.0; double bias_num = 0.0;
if (values && activations) { if (values && activations) {
for (int64_t j = 0; j < n_per_row; ++j) { for (int64_t j = 0; j < n_per_row; ++j) {
const double v = values[j]; const double w = std::max(0.0f, values[j]);
const double e = y[j] - x[j]; const double e = y[j] - x[j];
const double a = activations[j]; const double a = activations[j];
weighted_mse += v * e * e; weighted_mse += w * e * e;
bias_num += v * e * a; bias_num += w * e * a;
} }
} else if (values) { } else if (values) {
for (int64_t j = 0; j < n_per_row; ++j) { for (int64_t j = 0; j < n_per_row; ++j) {
const double v = values[j]; const double w = std::max(0.0f, values[j]);
const double e = y[j] - x[j]; const double e = y[j] - x[j];
weighted_mse += v * e * e; weighted_mse += w * e * e;
} }
} else if (activations) { } else if (activations) {
for (int64_t j = 0; j < n_per_row; ++j) { for (int64_t j = 0; j < n_per_row; ++j) {
@ -881,26 +883,28 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
double err_num = weighted_mse; double err_num = weighted_mse;
if (activations && bias_lambda != 0.0f) { if (activations && bias_lambda != 0.0f) {
if (bias_denom > 0.0) {
const double proj = bias_num * bias_num / (bias_denom + epsilon); const double proj = bias_num * bias_num / (bias_denom + epsilon);
err_num += (double)bias_lambda * proj; err_num += bias_lambda * proj;
}
} }
const double err_den = row_sq_norm[row_idx] + epsilon; const double denom = row_sq_norm[row_idx] + epsilon;
slice_err += err_num / err_den; slice_err += err_num / denom;
offset += (size_t)n_per_row; offset += (size_t)n_per_row;
} }
const double scale_rows = (double)nrows / std::max(1.0, (double)rs); const double scale_rows = (double)nrows / std::max(1.0, (double)rs);
total_err += slice_err * scale_rows; total_err += slice_err * scale_rows;
if (!std::isfinite(total_err)) { return infinity; }
} }
return std::isfinite(total_err) ? total_err : 1e35; return std::isfinite(total_err) ? total_err : infinity;
}; };
// Scaling factor to increase lambda when activations are concentrated
auto directional_scale = [&](const float * values, const float * activations, int64_t n_per_row) { auto directional_scale = [&](const float * values, const float * activations, int64_t n_per_row) {
if (!activations) { return 1.0f; } if (!activations) { return 1.0f; }
// Compute dominance = ||sqrt(v).*a||_2 / (RMS(a)*sqrt(sum(v)))
// If no values, use v=1
double sum_v = 0.0; double sum_v = 0.0;
double sum_aw2 = 0.0; double sum_aw2 = 0.0;
double sum_a2 = 0.0; double sum_a2 = 0.0;
@ -915,13 +919,10 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const double denom = std::sqrt(std::max(epsilon, sum_v)) * std::max(epsilon, rms_a); const double denom = std::sqrt(std::max(epsilon, sum_v)) * std::max(epsilon, rms_a);
const double scale = denom > 0.0 ? std::sqrt(sum_aw2) / denom : 1.0; const double scale = denom > 0.0 ? std::sqrt(sum_aw2) / denom : 1.0;
// Clamp to a reasonable range
return (float)std::clamp(scale, 0.5, 2.0); return (float)std::clamp(scale, 0.5, 2.0);
}; };
// Returns an adaptive lambda for this tensor using a small probe set // Higher precision but much longer to compute
// bias_lambda adjusts the trade-off between systematic bias (introduced by blockwise scaling) and MSE
// larger value favours quantisation types that produce smaller bias even if the MSE is slightly larger
auto precise_lambda = [&](const ggml_tensor * t, auto precise_lambda = [&](const ggml_tensor * t,
const std::vector<float> & f32_sample, const std::vector<float> & f32_sample,
const std::vector<int64_t> & sample_rows_per_slice, const std::vector<int64_t> & sample_rows_per_slice,
@ -929,10 +930,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const float * activations, const float * activations,
const std::vector<ggml_type> & compatible_candidates) -> float const std::vector<ggml_type> & compatible_candidates) -> float
{ {
// No activations => no projection term
if (!activations) { return 0.0f; } if (!activations) { return 0.0f; }
// pick a tiny probe set: try to spread around mid-range types
std::vector<ggml_type> probes; std::vector<ggml_type> probes;
probes.reserve(3); probes.reserve(3);
auto push_if = [&](const ggml_type tiny) { auto push_if = [&](const ggml_type tiny) {
@ -941,7 +940,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
} }
}; };
// Prefer family-consistent probes; fall back to whatever exists
push_if(GGML_TYPE_Q4_K); push_if(GGML_TYPE_Q4_K);
push_if(GGML_TYPE_Q3_K); push_if(GGML_TYPE_Q3_K);
push_if(GGML_TYPE_Q5_K); push_if(GGML_TYPE_Q5_K);
@ -953,19 +951,18 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
} }
if (probes.empty()) { return 0.0f; } if (probes.empty()) { return 0.0f; }
// Scratch buffers (reused) // Scratch buffers
const int64_t n_per_row = t->ne[0]; const int64_t n_per_row = t->ne[0];
const size_t total_sampled_rows = f32_sample.size() / n_per_row; const size_t total_sampled_rows = f32_sample.size() / n_per_row;
size_t max_row_sz = 0; size_t max_row_sz = 0;
for (auto pt : probes) { for (auto pt : probes) {
max_row_sz = std::max(max_row_sz, ggml_row_size(pt, n_per_row)); max_row_sz = std::max(max_row_sz, ggml_row_size(pt, n_per_row));
} }
std::vector<uint8_t> quantized_buffer(max_row_sz * total_sampled_rows); std::vector<uint8_t> quantized_buffer(max_row_sz * total_sampled_rows);
std::vector<float> dequantized_buffer(f32_sample.size()); std::vector<float> dequantized_buffer(f32_sample.size());
std::vector<double> ratios; std::vector<double> ratios;
ratios.reserve(probes.size()); ratios.reserve(probes.size());
for (const auto pt : probes) { for (const auto pt : probes) {
// err at lambda=0 => pure weighted MSE part // err at lambda=0 => pure weighted MSE part
double err0 = estimate_error(t, pt, f32_sample, sample_rows_per_slice, values, activations, quantized_buffer, dequantized_buffer, 0.0f); double err0 = estimate_error(t, pt, f32_sample, sample_rows_per_slice, values, activations, quantized_buffer, dequantized_buffer, 0.0f);
@ -984,17 +981,17 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
std::nth_element(ratios.begin(), ratios.begin() + ratios.size() / 2, ratios.end()); std::nth_element(ratios.begin(), ratios.begin() + ratios.size() / 2, ratios.end());
double lambda = ratios[ratios.size() / 2]; double lambda = ratios[ratios.size() / 2];
// activations directional scale
const float scale = directional_scale(values, activations, n_per_row); const float scale = directional_scale(values, activations, n_per_row);
lambda *= scale; lambda *= scale;
// clamp to safe range
lambda = std::clamp(lambda, 0.0, 8.0); lambda = std::clamp(lambda, 0.0, 8.0);
return (float)lambda; return (float)lambda;
}; };
// Faster to compute but lower precision. Best option for the vast majority of models
auto fast_lambda = [&](const float * values, const float * activations, const int64_t n_per_row) { auto fast_lambda = [&](const float * values, const float * activations, const int64_t n_per_row) {
if (!activations) { return 0.0f; } if (!activations) { return 0.0f; }
double s = 0.0; double s = 0.0;
double s2 = 0.0; double s2 = 0.0;
for (int64_t j = 0; j < n_per_row; ++j) { for (int64_t j = 0; j < n_per_row; ++j) {
@ -1004,17 +1001,13 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
s += aw2; s += aw2;
s2 += aw2 * aw2; s2 += aw2 * aw2;
} }
if (s2 <= 0.0) { return 0.0f; } if (s2 <= 0.0) { return 0.0f; }
const auto d = (double)n_per_row; const auto d = (double)n_per_row;
//const double p = s * s / (d * s2 + epsilon);
//const double lambda = 8.0 * std::clamp(1.0 - p, 0.0, 1.0);
// Map p in (0,1] to lambda in [0,8] decreasing
double base = 1.0 - s * s / (d * s2 + epsilon); double base = 1.0 - s * s / (d * s2 + epsilon);
base = std::clamp(base, 0.0, 1.0); base = std::clamp(base, 0.0, 1.0);
// activations directional scale
const double scale = directional_scale(values, activations, n_per_row); const double scale = directional_scale(values, activations, n_per_row);
// clamp to safe range
const double lambda = std::clamp(base * scale, 0.0, 1.0) * 8.0; const double lambda = std::clamp(base * scale, 0.0, 1.0) * 8.0;
return (float)lambda; return (float)lambda;
@ -1036,13 +1029,13 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
} }
ml.load_data_for(t); ml.load_data_for(t);
// Dequantize only sampled rows into f32_sample // Dequantize sampled rows into f32_sample
const int64_t n_per_row = t->ne[0]; const int64_t n_per_row = t->ne[0];
const int64_t nrows_total = t->ne[1]; const int64_t nrows_total = t->ne[1];
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1; const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
// Larger sample_rows_per_expert values may result in more accurate error estimates, but will take longer to compute // Larger sample_rows_per_expert values may result in more accurate error estimates, but it will take much longer to compute
constexpr int sample_rows_per_expert = 384; constexpr int sample_rows_per_expert = 256;
std::vector<float> f32_sample; std::vector<float> f32_sample;
f32_sample.reserve((size_t)ne2 * (size_t)std::min<int64_t>(nrows_total, sample_rows_per_expert) * (size_t)n_per_row); f32_sample.reserve((size_t)ne2 * (size_t)std::min<int64_t>(nrows_total, sample_rows_per_expert) * (size_t)n_per_row);
@ -1096,6 +1089,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const std::string key = remap_imatrix(tensor_name, mapped); const std::string key = remap_imatrix(tensor_name, mapped);
const auto it = m->find(key); const auto it = m->find(key);
if (it == m->end()) { return {nullptr, 0}; } if (it == m->end()) { return {nullptr, 0}; }
return { it->second.data(), it->second.size() }; return { it->second.data(), it->second.size() };
}; };
@ -1104,7 +1098,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const size_t want = (size_t)ne2 * (size_t)n_per_row; const size_t want = (size_t)ne2 * (size_t)n_per_row;
dst.clear(); dst.clear();
if (!src || src_sz == 0) { return; } if (!src || src_sz == 0) { return; }
if (src_sz == want) { if (src_sz == want) {
dst.resize(want); dst.resize(want);
std::memcpy(dst.data(), src, want * sizeof(float)); std::memcpy(dst.data(), src, want * sizeof(float));
@ -1160,7 +1153,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
std::sort(compatible_candidates.begin(), compatible_candidates.end()); std::sort(compatible_candidates.begin(), compatible_candidates.end());
compatible_candidates.erase(std::unique(compatible_candidates.begin(), compatible_candidates.end()), compatible_candidates.end()); compatible_candidates.erase(std::unique(compatible_candidates.begin(), compatible_candidates.end()), compatible_candidates.end());
// Compute adaptive bias_lambda for this tensor // Adjusts the trade-off between systematic bias (introduced by blockwise scaling) and MSE.
// Larger values favours quantisation types that produce smaller bias even if the MSE is slightly bigger
float bias_lambda = 0.0f; float bias_lambda = 0.0f;
{ {
const float * values = values_sample.empty() ? nullptr : values_sample.data(); const float * values = values_sample.empty() ? nullptr : values_sample.data();