Minor factoring for efficiency and correctness
This commit is contained in:
parent
556f6b04fe
commit
eab8708244
|
|
@ -596,7 +596,7 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float *
|
|||
return new_size;
|
||||
}
|
||||
|
||||
// Returns per-tensor type overrides to meet target BPW at lowest error
|
||||
// Returns tensor type overrides to meet a global bpw target
|
||||
static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||
llama_model_loader & ml,
|
||||
std::vector<no_init<uint8_t>> & buffer,
|
||||
|
|
@ -650,6 +650,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
};
|
||||
|
||||
constexpr double epsilon = 1e-12;
|
||||
constexpr double infinity = std::numeric_limits<double>::infinity();
|
||||
|
||||
auto tensor_bytes = [](const ggml_tensor * t, const ggml_type typ) -> size_t {
|
||||
const int64_t n_per_row = t->ne[0];
|
||||
|
|
@ -680,7 +681,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
|
||||
auto name_tn = LLM_TN(model.arch);
|
||||
auto can_quantize = [&](const ggml_tensor * t) -> bool {
|
||||
// This list should be kept in sync with llama_tensor_quantize_impl()
|
||||
// This list should be kept in sync with llama_tensor_quantize_impl() to avoid drift
|
||||
const std::string name = ggml_get_name(t);
|
||||
bool q = name.rfind("weight") == name.size() - 6;
|
||||
q &= ggml_n_dims(t) >= 2;
|
||||
|
|
@ -730,9 +731,15 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
|
||||
|
||||
const size_t sample_element_count = f32_sample.size();
|
||||
const size_t sample_row_count = sample_element_count / (size_t)n_per_row;
|
||||
const size_t sample_row_count = n_per_row > 0 ? sample_element_count / (size_t)n_per_row : 0;
|
||||
if (sample_row_count == 0) { return 0.0; }
|
||||
|
||||
size_t expected_rows = 0;
|
||||
for (int64_t s = 0; s < ne2; ++s) {
|
||||
expected_rows += (size_t)sample_rows_per_slice[s];
|
||||
}
|
||||
if (expected_rows != sample_row_count) { return infinity; }
|
||||
|
||||
const size_t row_sz = ggml_row_size(quant_type, n_per_row);
|
||||
const size_t buffer_sz = row_sz * sample_row_count;
|
||||
|
||||
|
|
@ -750,15 +757,15 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
const float * activations = activations_sample + s * n_per_row;
|
||||
double denom = 0.0;
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double w = values ? std::max(0.0f, values[j]) : 1.0;
|
||||
const double a = activations[j];
|
||||
const double w = values ? values[j] : 1.0;
|
||||
denom += w * a * a;
|
||||
}
|
||||
bias_denominator_per_slice[s] = denom;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute per-row squared norms with weighting (if values are provided)
|
||||
// Per-row squared norms with weighting
|
||||
std::vector<double> row_sq_norm(sample_row_count, 0.0);
|
||||
{
|
||||
size_t offset = 0;
|
||||
|
|
@ -768,15 +775,14 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
if (rs == 0) { continue; }
|
||||
|
||||
const float * values = has_values ? values_sample + s * n_per_row : nullptr;
|
||||
|
||||
for (int64_t r = 0; r < rs; ++r, ++row_idx) {
|
||||
const float * x = f32_sample.data() + offset;
|
||||
double rsn = 0.0;
|
||||
if (values) {
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double v = values[j];
|
||||
const double w = std::max(0.0f, values[j]);
|
||||
const double xx = x[j];
|
||||
rsn += v * xx * xx;
|
||||
rsn += w * xx * xx;
|
||||
}
|
||||
} else {
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
|
|
@ -790,7 +796,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
}
|
||||
}
|
||||
|
||||
// Quantize sampled rows slice-by-slice into quantized_buffer
|
||||
// Quantize sampled rows per slice -> quantized_buffer
|
||||
{
|
||||
size_t q_offset = 0;
|
||||
size_t f_offset = 0;
|
||||
|
|
@ -800,35 +806,32 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
|
||||
const float * value = has_values ? values_sample + slice * n_per_row : nullptr;
|
||||
(void)ggml_quantize_chunk(quant_type, f32_sample.data() + f_offset, quantized_buffer.data() + q_offset, 0, rs, n_per_row, value);
|
||||
|
||||
q_offset += row_sz * (size_t)rs;
|
||||
f_offset += (size_t)rs * (size_t)n_per_row;
|
||||
}
|
||||
}
|
||||
|
||||
// Dequantize into dequantized_buffer
|
||||
// quantized_buffer -> dequantized_buffer
|
||||
{
|
||||
const ggml_type_traits * traits = ggml_get_type_traits(quant_type);
|
||||
auto row_to_float = [&](size_t r) {
|
||||
uint8_t * src = quantized_buffer.data() + r * row_sz;
|
||||
float * dst = dequantized_buffer.data() + r * (size_t)n_per_row;
|
||||
if (quant_type == GGML_TYPE_F16) {
|
||||
ggml_fp16_to_fp32_row((const ggml_fp16_t *)src, dst, (int)n_per_row);
|
||||
} else if (quant_type == GGML_TYPE_BF16) {
|
||||
ggml_bf16_to_fp32_row((const ggml_bf16_t *)src, dst, (int)n_per_row);
|
||||
} else {
|
||||
if (!traits || !traits->to_float) {
|
||||
LLAMA_LOG_WARN("%s: unsupported quantization type %s\n", __func__, ggml_type_name(quant_type));
|
||||
return false;
|
||||
|
||||
const bool is_fp16 = quant_type == GGML_TYPE_F16;
|
||||
const bool is_bf16 = quant_type == GGML_TYPE_BF16;
|
||||
if (!is_fp16 && !is_bf16 && traits && traits->to_float) {
|
||||
traits->to_float(quantized_buffer.data(), dequantized_buffer.data(), (int)(sample_row_count * (size_t)n_per_row));
|
||||
} else {
|
||||
for (size_t r = 0; r < sample_row_count; ++r) {
|
||||
uint8_t * src = quantized_buffer.data() + r * row_sz;
|
||||
float * dst = dequantized_buffer.data() + r * (size_t) n_per_row;
|
||||
if (is_fp16) {
|
||||
ggml_fp16_to_fp32_row((const ggml_fp16_t *) src, dst, (int)n_per_row);
|
||||
} else if (is_bf16) {
|
||||
ggml_bf16_to_fp32_row((const ggml_bf16_t *) src, dst, (int)n_per_row);
|
||||
} else {
|
||||
if (!traits || !traits->to_float) { return infinity; }
|
||||
traits->to_float(src, dst, (int)n_per_row);
|
||||
}
|
||||
traits->to_float(src, dst, (int)n_per_row);
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
for (size_t r = 0; r < sample_row_count; ++r) {
|
||||
if (!row_to_float(r)) { return 1e35; }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -836,6 +839,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
size_t offset = 0;
|
||||
size_t row_idx = 0;
|
||||
double total_err = 0.0;
|
||||
|
||||
for (int64_t slice = 0; slice < ne2; ++slice) {
|
||||
const int64_t rs = sample_rows_per_slice[slice];
|
||||
if (rs == 0) { continue; }
|
||||
|
|
@ -843,9 +847,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
const float * values = has_values ? values_sample + slice * n_per_row : nullptr;
|
||||
const float * activations = has_activations ? activations_sample + slice * n_per_row : nullptr;
|
||||
const double bias_denom = has_activations ? bias_denominator_per_slice[slice] : 0.0;
|
||||
|
||||
double slice_err = 0.0;
|
||||
|
||||
for (int64_t r = 0; r < rs; ++r, ++row_idx) {
|
||||
const float * x = f32_sample.data() + offset;
|
||||
const float * y = dequantized_buffer.data() + offset;
|
||||
|
|
@ -853,17 +855,17 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
double bias_num = 0.0;
|
||||
if (values && activations) {
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double v = values[j];
|
||||
const double w = std::max(0.0f, values[j]);
|
||||
const double e = y[j] - x[j];
|
||||
const double a = activations[j];
|
||||
weighted_mse += v * e * e;
|
||||
bias_num += v * e * a;
|
||||
weighted_mse += w * e * e;
|
||||
bias_num += w * e * a;
|
||||
}
|
||||
} else if (values) {
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double v = values[j];
|
||||
const double w = std::max(0.0f, values[j]);
|
||||
const double e = y[j] - x[j];
|
||||
weighted_mse += v * e * e;
|
||||
weighted_mse += w * e * e;
|
||||
}
|
||||
} else if (activations) {
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
|
|
@ -881,26 +883,28 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
|
||||
double err_num = weighted_mse;
|
||||
if (activations && bias_lambda != 0.0f) {
|
||||
const double proj = bias_num * bias_num / (bias_denom + epsilon);
|
||||
err_num += (double)bias_lambda * proj;
|
||||
if (bias_denom > 0.0) {
|
||||
const double proj = bias_num * bias_num / (bias_denom + epsilon);
|
||||
err_num += bias_lambda * proj;
|
||||
}
|
||||
}
|
||||
|
||||
const double err_den = row_sq_norm[row_idx] + epsilon;
|
||||
slice_err += err_num / err_den;
|
||||
const double denom = row_sq_norm[row_idx] + epsilon;
|
||||
slice_err += err_num / denom;
|
||||
offset += (size_t)n_per_row;
|
||||
}
|
||||
|
||||
const double scale_rows = (double)nrows / std::max(1.0, (double)rs);
|
||||
total_err += slice_err * scale_rows;
|
||||
if (!std::isfinite(total_err)) { return infinity; }
|
||||
}
|
||||
|
||||
return std::isfinite(total_err) ? total_err : 1e35;
|
||||
return std::isfinite(total_err) ? total_err : infinity;
|
||||
};
|
||||
|
||||
// Scaling factor to increase lambda when activations are concentrated
|
||||
auto directional_scale = [&](const float * values, const float * activations, int64_t n_per_row) {
|
||||
if (!activations) { return 1.0f; }
|
||||
// Compute dominance = ||sqrt(v).*a||_2 / (RMS(a)*sqrt(sum(v)))
|
||||
// If no values, use v=1
|
||||
double sum_v = 0.0;
|
||||
double sum_aw2 = 0.0;
|
||||
double sum_a2 = 0.0;
|
||||
|
|
@ -915,13 +919,10 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
const double denom = std::sqrt(std::max(epsilon, sum_v)) * std::max(epsilon, rms_a);
|
||||
const double scale = denom > 0.0 ? std::sqrt(sum_aw2) / denom : 1.0;
|
||||
|
||||
// Clamp to a reasonable range
|
||||
return (float)std::clamp(scale, 0.5, 2.0);
|
||||
};
|
||||
|
||||
// Returns an adaptive lambda for this tensor using a small probe set
|
||||
// bias_lambda adjusts the trade-off between systematic bias (introduced by block‑wise scaling) and MSE
|
||||
// larger value favours quantisation types that produce smaller bias even if the MSE is slightly larger
|
||||
// Higher precision but much longer to compute
|
||||
auto precise_lambda = [&](const ggml_tensor * t,
|
||||
const std::vector<float> & f32_sample,
|
||||
const std::vector<int64_t> & sample_rows_per_slice,
|
||||
|
|
@ -929,10 +930,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
const float * activations,
|
||||
const std::vector<ggml_type> & compatible_candidates) -> float
|
||||
{
|
||||
// No activations => no projection term
|
||||
if (!activations) { return 0.0f; }
|
||||
|
||||
// pick a tiny probe set: try to spread around mid-range types
|
||||
std::vector<ggml_type> probes;
|
||||
probes.reserve(3);
|
||||
auto push_if = [&](const ggml_type tiny) {
|
||||
|
|
@ -941,7 +940,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
}
|
||||
};
|
||||
|
||||
// Prefer family-consistent probes; fall back to whatever exists
|
||||
push_if(GGML_TYPE_Q4_K);
|
||||
push_if(GGML_TYPE_Q3_K);
|
||||
push_if(GGML_TYPE_Q5_K);
|
||||
|
|
@ -953,19 +951,18 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
}
|
||||
if (probes.empty()) { return 0.0f; }
|
||||
|
||||
// Scratch buffers (reused)
|
||||
// Scratch buffers
|
||||
const int64_t n_per_row = t->ne[0];
|
||||
const size_t total_sampled_rows = f32_sample.size() / n_per_row;
|
||||
size_t max_row_sz = 0;
|
||||
for (auto pt : probes) {
|
||||
max_row_sz = std::max(max_row_sz, ggml_row_size(pt, n_per_row));
|
||||
}
|
||||
|
||||
std::vector<uint8_t> quantized_buffer(max_row_sz * total_sampled_rows);
|
||||
std::vector<float> dequantized_buffer(f32_sample.size());
|
||||
|
||||
std::vector<double> ratios;
|
||||
ratios.reserve(probes.size());
|
||||
|
||||
for (const auto pt : probes) {
|
||||
// err at lambda=0 => pure weighted MSE part
|
||||
double err0 = estimate_error(t, pt, f32_sample, sample_rows_per_slice, values, activations, quantized_buffer, dequantized_buffer, 0.0f);
|
||||
|
|
@ -984,17 +981,17 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
std::nth_element(ratios.begin(), ratios.begin() + ratios.size() / 2, ratios.end());
|
||||
double lambda = ratios[ratios.size() / 2];
|
||||
|
||||
// activations directional scale
|
||||
const float scale = directional_scale(values, activations, n_per_row);
|
||||
lambda *= scale;
|
||||
|
||||
// clamp to safe range
|
||||
lambda = std::clamp(lambda, 0.0, 8.0);
|
||||
|
||||
return (float)lambda;
|
||||
};
|
||||
|
||||
// Faster to compute but lower precision. Best option for the vast majority of models
|
||||
auto fast_lambda = [&](const float * values, const float * activations, const int64_t n_per_row) {
|
||||
if (!activations) { return 0.0f; }
|
||||
|
||||
double s = 0.0;
|
||||
double s2 = 0.0;
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
|
|
@ -1004,17 +1001,13 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
s += aw2;
|
||||
s2 += aw2 * aw2;
|
||||
}
|
||||
|
||||
if (s2 <= 0.0) { return 0.0f; }
|
||||
const auto d = (double)n_per_row;
|
||||
//const double p = s * s / (d * s2 + epsilon);
|
||||
//const double lambda = 8.0 * std::clamp(1.0 - p, 0.0, 1.0);
|
||||
// Map p in (0,1] to lambda in [0,8] decreasing
|
||||
double base = 1.0 - s * s / (d * s2 + epsilon);
|
||||
base = std::clamp(base, 0.0, 1.0);
|
||||
|
||||
// activations directional scale
|
||||
const double scale = directional_scale(values, activations, n_per_row);
|
||||
// clamp to safe range
|
||||
const double lambda = std::clamp(base * scale, 0.0, 1.0) * 8.0;
|
||||
|
||||
return (float)lambda;
|
||||
|
|
@ -1036,13 +1029,13 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
}
|
||||
ml.load_data_for(t);
|
||||
|
||||
// Dequantize only sampled rows into f32_sample
|
||||
// Dequantize sampled rows into f32_sample
|
||||
const int64_t n_per_row = t->ne[0];
|
||||
const int64_t nrows_total = t->ne[1];
|
||||
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
|
||||
|
||||
// Larger sample_rows_per_expert values may result in more accurate error estimates, but will take longer to compute
|
||||
constexpr int sample_rows_per_expert = 384;
|
||||
// Larger sample_rows_per_expert values may result in more accurate error estimates, but it will take much longer to compute
|
||||
constexpr int sample_rows_per_expert = 256;
|
||||
std::vector<float> f32_sample;
|
||||
f32_sample.reserve((size_t)ne2 * (size_t)std::min<int64_t>(nrows_total, sample_rows_per_expert) * (size_t)n_per_row);
|
||||
|
||||
|
|
@ -1096,6 +1089,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
const std::string key = remap_imatrix(tensor_name, mapped);
|
||||
const auto it = m->find(key);
|
||||
if (it == m->end()) { return {nullptr, 0}; }
|
||||
|
||||
return { it->second.data(), it->second.size() };
|
||||
};
|
||||
|
||||
|
|
@ -1104,7 +1098,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
const size_t want = (size_t)ne2 * (size_t)n_per_row;
|
||||
dst.clear();
|
||||
if (!src || src_sz == 0) { return; }
|
||||
|
||||
if (src_sz == want) {
|
||||
dst.resize(want);
|
||||
std::memcpy(dst.data(), src, want * sizeof(float));
|
||||
|
|
@ -1160,7 +1153,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
std::sort(compatible_candidates.begin(), compatible_candidates.end());
|
||||
compatible_candidates.erase(std::unique(compatible_candidates.begin(), compatible_candidates.end()), compatible_candidates.end());
|
||||
|
||||
// Compute adaptive bias_lambda for this tensor
|
||||
// Adjusts the trade-off between systematic bias (introduced by block‑wise scaling) and MSE.
|
||||
// Larger values favours quantisation types that produce smaller bias even if the MSE is slightly bigger
|
||||
float bias_lambda = 0.0f;
|
||||
{
|
||||
const float * values = values_sample.empty() ? nullptr : values_sample.data();
|
||||
|
|
|
|||
Loading…
Reference in New Issue