General code refactor

This commit is contained in:
Ed Addario 2025-08-21 19:18:54 +01:00
parent 9e11f82e8f
commit 5b6f1e9fde
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 196 additions and 219 deletions

View File

@ -596,10 +596,7 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float *
return new_size; return new_size;
} }
// Returns per-tensor overrides of quantization types to meet target BPW with the lowest ppl // Returns per-tensor type overrides to meet target BPW at lowest ppl
// sample_rows_per_expert: Larger values will result in more accurate error estimates, but will take longer to compute
// bias_lambda: Affects the weight of the bias term in the weigthed MSE error function. 0.0 means no bias (standard MSE),
// 1.0 means equal weight for bias and error, 2.0 means twice as much weight for bias
static std::unordered_map<std::string, ggml_type> target_bpw_type( static std::unordered_map<std::string, ggml_type> target_bpw_type(
llama_model_loader & ml, llama_model_loader & ml,
std::vector<no_init<uint8_t>> & buffer, std::vector<no_init<uint8_t>> & buffer,
@ -609,9 +606,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const std::unordered_map<std::string, std::vector<float>> * values_data, const std::unordered_map<std::string, std::vector<float>> * values_data,
const std::unordered_map<std::string, std::vector<float>> * activations_data, const std::unordered_map<std::string, std::vector<float>> * activations_data,
const llama_model_quantize_params * params, const llama_model_quantize_params * params,
int nthread, int nthread
int sample_rows_per_expert = 512,
float bias_lambda = 1.0
) { ) {
struct candidate_types { struct candidate_types {
ggml_type type; ggml_type type;
@ -621,15 +616,15 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
}; };
struct tensor_info { struct tensor_info {
const llama_model_loader::llama_tensor_weight * w; const llama_model_loader::llama_tensor_weight * w = nullptr;
std::vector<candidate_types> candidate; std::vector<candidate_types> candidate = {};
int choice = -1; int choice = -1;
float min_bpw = 0.0; float min_bpw = 0.0;
float max_bpw = 0.0; float max_bpw = 0.0;
size_t n_elements = 0; size_t n_elements = 0;
}; };
const ggml_type k_candidates[] = { constexpr ggml_type k_quants[] = {
GGML_TYPE_Q2_K, GGML_TYPE_Q2_K,
GGML_TYPE_Q3_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q4_0, GGML_TYPE_Q4_0,
@ -648,7 +643,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
#endif #endif
}; };
const ggml_type iq_candidates[] = { constexpr ggml_type iq_quants[] = {
GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_S,
GGML_TYPE_IQ1_M, GGML_TYPE_IQ1_M,
GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XXS,
@ -665,9 +660,49 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
GGML_TYPE_Q6_K GGML_TYPE_Q6_K
}; };
auto name_tn = LLM_TN(model.arch); auto get_values = [&](const std::string & tensor_name) -> const float * {
float target_bpw = params->target_bpw; if (!values_data) { return nullptr; }
const auto it = values_data->find(remap_imatrix(tensor_name, mapped));
if (it == values_data->end()) { return nullptr; }
return it->second.data();
};
auto get_activations = [&](const std::string & tensor_name) -> const float * {
if (!activations_data) { return nullptr; }
const auto it = activations_data->find(remap_imatrix(tensor_name, mapped));
if (it == activations_data->end()) { return nullptr; }
return it->second.data();
};
auto tensor_bytes = [](const ggml_tensor * t, const ggml_type typ) -> size_t {
const int64_t n_per_row = t->ne[0];
const int64_t nrows = t->ne[1];
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
const size_t row_sz = ggml_row_size(typ, n_per_row);
return (size_t)ne2 * (size_t)nrows * row_sz;
};
auto tensor_bpw = [&](const ggml_tensor * t, const ggml_type typ) -> double {
const int64_t nelem = ggml_nelements(t);
const size_t bytes = tensor_bytes(t, typ);
return (double)bytes * 8.0 / (double)nelem;
};
auto is_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> bool {
const int64_t n_per_row = t->ne[0];
const int64_t blck = ggml_blck_size(typ);
if (blck <= 1) { return true; }
return n_per_row % blck == 0;
};
auto make_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> ggml_type {
if (is_compatible(t, typ)) { return typ; }
ggml_type fb = fallback_type(typ);
if (is_compatible(t, fb)) { return fb; }
return GGML_TYPE_F16;
};
auto name_tn = LLM_TN(model.arch);
auto can_quantize = [&](const ggml_tensor * t) -> bool { auto can_quantize = [&](const ggml_tensor * t) -> bool {
const std::string name = ggml_get_name(t); const std::string name = ggml_get_name(t);
bool q = name.rfind("weight") == name.size() - 6; bool q = name.rfind("weight") == name.size() - 6;
@ -705,231 +740,182 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
return q; return q;
}; };
auto get_values = [&](const std::string & tensor_name) -> const float * {
if (!values_data) { return nullptr; }
const auto it = values_data->find(remap_imatrix(tensor_name, mapped));
if (it == values_data->end()) { return nullptr; }
return it->second.data();
};
auto get_activations = [&](const std::string & tensor_name) -> const float * {
if (!activations_data) { return nullptr; }
const auto it = activations_data->find(remap_imatrix(tensor_name, mapped));
if (it == activations_data->end()) { return nullptr; }
return it->second.data();
};
auto total_bytes = [](const ggml_tensor * t, const ggml_type typ) -> size_t {
const int64_t n_per_row = t->ne[0];
const int64_t nrows = t->ne[1];
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
const size_t row_sz = ggml_row_size(typ, n_per_row);
return (size_t)ne2 * (size_t)nrows * row_sz;
};
auto tensor_bpw = [&](const ggml_tensor * t, const ggml_type typ) -> double {
const int64_t nelem = ggml_nelements(t);
const size_t bytes = total_bytes(t, typ);
return bytes * 8.0 / nelem;
};
auto is_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> bool {
const int64_t n_per_row = t->ne[0];
const int64_t blck = ggml_blck_size(typ);
if (blck <= 1) { return true; }
return n_per_row % blck == 0;
};
auto make_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> ggml_type {
if (is_compatible(t, typ)) { return typ; }
ggml_type fb = fallback_type(typ);
if (is_compatible(t, fb)) { return fb; }
return GGML_TYPE_F16;
};
// Estimate error for a given type using a sampled subset of rows // Estimate error for a given type using a sampled subset of rows
auto estimate_error = [&](const ggml_tensor * t, auto estimate_error = [&](const ggml_tensor * t,
const ggml_type typ, const ggml_type quant_type,
const std::vector<float> & f32_sample, const std::vector<float> & f32_sample,
const std::vector<int64_t> & sample_rows_per_slice, const std::vector<int64_t> & sample_rows_per_slice,
const float * values_sample, const float * values_sample,
const float * activations_sample, const float * activations_sample,
std::vector<uint8_t> & qbuf, std::vector<uint8_t> & quantized_buffer,
std::vector<float> & deq) -> double std::vector<float> & dequantized_buffer) -> double
{ {
const int64_t n_per_row = t->ne[0]; const int64_t n_per_row = t->ne[0];
const int64_t nrows = t->ne[1]; const int64_t nrows = t->ne[1];
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1; const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
const size_t nels = f32_sample.size(); const size_t sample_element_count = f32_sample.size();
const size_t total_sampled_rows = nels / (size_t)n_per_row; const size_t sample_row_count = sample_element_count / (size_t)n_per_row;
if (total_sampled_rows == 0) { return 0.0; } if (sample_row_count == 0) { return 0.0; }
const size_t row_sz = ggml_row_size(typ, n_per_row); const size_t row_size = ggml_row_size(quant_type, n_per_row);
const size_t need_q = row_sz * total_sampled_rows; const size_t buffer_size = row_size * sample_row_count;
if (qbuf.size() < need_q) { qbuf.resize(need_q); } if (quantized_buffer.size() < buffer_size) { quantized_buffer.resize(buffer_size); }
if (deq.size() < nels) { deq.resize(nels); } if (dequantized_buffer.size() < sample_element_count) { dequantized_buffer.resize(sample_element_count); }
// Precompute denominators: std::vector row_sq_norm(sample_row_count, 0.0);
// - x2_per_row: sum_j w[j]*x[j]^2 if w present else sum_j x[j]^2 std::vector bias_denominator_per_slice(ne2, 0.0);
// - bden_per_slice: sum_j w[j]*a[j]^2 if w & a present; sum_j a[j]^2 if only a present; 0 otherwise
std::vector x2_per_row(total_sampled_rows, 0.0);
std::vector bden_per_slice(ne2, 0.0);
const bool has_w = (values_sample != nullptr); // Precompute bias denominator per slice
const bool has_a = (activations_sample != nullptr); const bool has_values = (values_sample != nullptr);
const bool has_activations = (activations_sample != nullptr);
// Precompute bden per slice (depends only on w,a) if (has_activations) {
if (has_a) {
for (int64_t s = 0; s < ne2; ++s) { for (int64_t s = 0; s < ne2; ++s) {
const float * wv = has_w ? values_sample + s * n_per_row : nullptr; const float * values = has_values ? values_sample + s * n_per_row : nullptr;
const float * act = activations_sample + s * n_per_row; const float * activations = activations_sample + s * n_per_row;
double bden = 0.0; double bias_denominator = 0.0;
if (has_w) { if (has_values) {
for (int64_t j = 0; j < n_per_row; ++j) { for (int64_t j = 0; j < n_per_row; ++j) {
const double a = act[j]; const double a = activations[j];
bden += (double) wv[j] * a * a; bias_denominator += values[j] * a * a;
} }
} else { } else {
for (int64_t j = 0; j < n_per_row; ++j) { for (int64_t j = 0; j < n_per_row; ++j) {
const double a = act[j]; const double a = activations[j];
bden += a * a; bias_denominator += a * a;
} }
} }
bden_per_slice[s] = bden; bias_denominator_per_slice[s] = bias_denominator;
} }
} }
// Precompute x2 per sampled row // Compute squared norms of sampled rows
{ {
size_t off = 0; size_t offset = 0;
size_t row_idx = 0; size_t row_idx = 0;
for (int64_t s = 0; s < ne2; ++s) { for (int64_t s = 0; s < ne2; ++s) {
const int64_t rs = sample_rows_per_slice[s]; const int64_t rs = sample_rows_per_slice[s];
if (rs == 0) { continue; } if (rs == 0) { continue; }
const float * wv = has_w ? values_sample + s * n_per_row : nullptr; const float * values = has_values ? values_sample + s * n_per_row : nullptr;
for (int64_t r = 0; r < rs; ++r, ++row_idx) { for (int64_t r = 0; r < rs; ++r, ++row_idx) {
const float * x = f32_sample.data() + off; const float * row = f32_sample.data() + offset;
double x2 = 0.0; double rsn = 0.0;
if (has_w) { if (has_values) {
for (int64_t j = 0; j < n_per_row; ++j) { for (int64_t j = 0; j < n_per_row; ++j) {
const double w = wv[j]; const double v = values[j];
const double xx = x[j]; const double x = row[j];
x2 += w * xx * xx; rsn += v * x * x;
} }
} else { } else {
for (int64_t j = 0; j < n_per_row; ++j) { for (int64_t j = 0; j < n_per_row; ++j) {
const double xx = x[j]; const double x = row[j];
x2 += xx * xx; rsn += x * x;
} }
} }
x2_per_row[row_idx] = x2; row_sq_norm[row_idx] = rsn;
off += (size_t)n_per_row; offset += (size_t)n_per_row;
} }
} }
} }
// Quantize sampled rows slice-by-slice into qbuf // Quantize sampled rows slice-by-slice into quantized_buffer
size_t qoff = 0; size_t quantised_offset = 0;
size_t foff = 0; size_t floats_offset = 0;
for (int64_t slice = 0; slice < ne2; ++slice) { for (int64_t slice = 0; slice < ne2; ++slice) {
const int64_t rs = sample_rows_per_slice[slice]; const int64_t rs = sample_rows_per_slice[slice];
if (rs == 0) { continue; } if (rs == 0) { continue; }
const float * value = values_sample ? values_sample + slice * n_per_row : nullptr; const float * value = values_sample ? values_sample + slice * n_per_row : nullptr;
(void)ggml_quantize_chunk(typ, f32_sample.data() + foff, qbuf.data() + qoff, 0, rs, n_per_row, value); (void)ggml_quantize_chunk(quant_type, f32_sample.data() + floats_offset, quantized_buffer.data() + quantised_offset, 0, rs, n_per_row, value);
qoff += row_sz * (size_t)rs; quantised_offset += row_size * (size_t)rs;
foff += (size_t)rs * (size_t)n_per_row; floats_offset += (size_t)rs * (size_t)n_per_row;
} }
// Dequantize into deq (row-wise if needed to avoid int overflow) // Dequantize into dequantized_buffer
{ {
const ggml_type_traits * traits = ggml_get_type_traits(typ); const ggml_type_traits * traits = ggml_get_type_traits(quant_type);
if (typ == GGML_TYPE_F16) { if (quant_type == GGML_TYPE_F16) {
ggml_fp16_to_fp32_row((const ggml_fp16_t *)qbuf.data(), deq.data(), (int)nels); ggml_fp16_to_fp32_row((const ggml_fp16_t *)quantized_buffer.data(), dequantized_buffer.data(), (int)sample_element_count);
} else if (typ == GGML_TYPE_BF16) { } else if (quant_type == GGML_TYPE_BF16) {
ggml_bf16_to_fp32_row((const ggml_bf16_t *)qbuf.data(), deq.data(), (int)nels); ggml_bf16_to_fp32_row((const ggml_bf16_t *)quantized_buffer.data(), dequantized_buffer.data(), (int)sample_element_count);
} else { } else {
if (!traits || !traits->to_float) { if (!traits || !traits->to_float) {
LLAMA_LOG_WARN("%s: unsupported quantization type %s\n", __func__, ggml_type_name(typ)); LLAMA_LOG_WARN("%s: unsupported quantization type %s\n", __func__, ggml_type_name(quant_type));
return 1e35; return 1e35;
} }
size_t done = 0; size_t done = 0;
while (done < nels) { while (done < sample_element_count) {
const size_t chunk = std::min((size_t)n_per_row, nels - done); const size_t chunk = std::min((size_t)n_per_row, sample_element_count - done);
traits->to_float(qbuf.data() + done / n_per_row * row_sz, deq.data() + done, (int)chunk); traits->to_float(quantized_buffer.data() + done / n_per_row * row_size, dequantized_buffer.data() + done, (int)chunk);
done += chunk; done += chunk;
} }
} }
} }
// Compute error // Compute error
const double eps = 1e-12; size_t offset = 0;
size_t off = 0;
size_t row_idx = 0; size_t row_idx = 0;
double total_err = 0.0; double total_err = 0.0;
for (int64_t slice = 0; slice < ne2; ++slice) { for (int64_t slice = 0; slice < ne2; ++slice) {
const int64_t rs = sample_rows_per_slice[slice]; const int64_t rs = sample_rows_per_slice[slice];
if (rs == 0) { continue; } if (rs == 0) { continue; }
const float * wv = has_w ? values_sample + slice * n_per_row : nullptr; const float * values = has_values ? values_sample + slice * n_per_row : nullptr;
const float * act = has_a ? activations_sample + slice * n_per_row : nullptr; const float * activations = has_activations ? activations_sample + slice * n_per_row : nullptr;
const double bden = has_a ? bden_per_slice[slice] : 0.0; const double bias_denominator = has_activations ? bias_denominator_per_slice[slice] : 0.0;
double slice_err = 0.0; double slice_err = 0.0;
for (int64_t r = 0; r < rs; ++r, ++row_idx) { for (int64_t r = 0; r < rs; ++r, ++row_idx) {
const float * x = f32_sample.data() + off; const float * x = f32_sample.data() + offset;
const float * y = deq.data() + off; const float * y = dequantized_buffer.data() + offset;
double weighted_mse = 0.0;
double mse_w = 0.0; double bias_numerator = 0.0;
double bnum = 0.0; if (values && activations) {
if (wv && act) {
for (int64_t j = 0; j < n_per_row; ++j) { for (int64_t j = 0; j < n_per_row; ++j) {
const double w = wv[j]; const double v = values[j];
const double e = y[j] - x[j]; const double e = y[j] - x[j];
const double a = act[j]; const double a = activations[j];
mse_w += w * e * e; weighted_mse += v * e * e;
bnum += w * e * a; bias_numerator += v * e * a;
} }
} else if (wv) { } else if (values) {
for (int64_t j = 0; j < n_per_row; ++j) { for (int64_t j = 0; j < n_per_row; ++j) {
const double w = wv[j]; const double v = values[j];
const double e = y[j] - x[j]; const double e = y[j] - x[j];
mse_w += w * e * e; weighted_mse += v * e * e;
} }
} else if (act) { } else if (activations) {
for (int64_t j = 0; j < n_per_row; ++j) { for (int64_t j = 0; j < n_per_row; ++j) {
const double e = y[j] - x[j]; const double e = y[j] - x[j];
const double a = act[j]; const double a = activations[j];
mse_w += e * e; weighted_mse += e * e;
bnum += e * a; bias_numerator += e * a;
} }
} else { } else {
for (int64_t j = 0; j < n_per_row; ++j) { for (int64_t j = 0; j < n_per_row; ++j) {
const double e = y[j] - x[j]; const double e = y[j] - x[j];
mse_w += e * e; weighted_mse += e * e;
} }
} }
// corrected normalization: divide the full numerator by x2 double err_numerator = weighted_mse;
double numer = mse_w; constexpr double epsilon = 1e-12;
if (act && bias_lambda != 0.0) { constexpr float bias_lambda = 1.0;
const double proj = bnum * bnum / (bden + eps); //bias_lambda defines the weight of the bias term in the weigthed MSE error function
numer += bias_lambda * proj; // 0.0 means no bias (standard MSE) 1.0 means equal weight for bias and error,
// 2.0 means twice as much weight for bias, etc
if (activations && bias_lambda != 0.0) {
const double proj = bias_numerator * bias_numerator / (bias_denominator + epsilon);
err_numerator += bias_lambda * proj;
} }
const double denom = x2_per_row[row_idx] + eps; const double err_denominator = row_sq_norm[row_idx] + epsilon;
const double row_err = numer / denom; const double row_err = err_numerator / err_denominator;
slice_err += row_err; slice_err += row_err;
off += (size_t)n_per_row; offset += (size_t)n_per_row;
} }
// scale to full rows (nrows) // scale to full rows (nrows)
@ -942,14 +928,11 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
std::vector<tensor_info> all; std::vector<tensor_info> all;
all.reserve(tensors.size()); all.reserve(tensors.size());
for (const auto * tw : tensors) { for (const auto * tw : tensors) {
std::vector<std::thread> workers; std::vector<std::thread> workers;
workers.reserve(std::max(1, nthread)); workers.reserve(std::max(1, nthread));
ggml_tensor * t = tw->tensor; ggml_tensor * t = tw->tensor;
const std::string name = ggml_get_name(t); const std::string name = ggml_get_name(t);
if (!can_quantize(t)) { continue; } if (!can_quantize(t)) { continue; }
LLAMA_LOG_INFO("\t%s: - processing tensor %45s \t(%12d elements)\n", __func__, name.c_str(), (int)ggml_nelements(t)); LLAMA_LOG_INFO("\t%s: - processing tensor %45s \t(%12d elements)\n", __func__, name.c_str(), (int)ggml_nelements(t));
@ -959,37 +942,26 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
} }
ml.load_data_for(t); ml.load_data_for(t);
const int64_t nelem = ggml_nelements(t);
std::vector<no_init<float>> f32_conv_buf;
const float * values_all = get_values(name);
const float * activations_all = get_activations(name);
// Dequantize only sampled rows into f32_sample // Dequantize only sampled rows into f32_sample
const int64_t n_per_row = t->ne[0]; const int64_t n_per_row = t->ne[0];
const int64_t nrows_total = t->ne[1]; const int64_t nrows_total = t->ne[1];
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1; const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
// Larger sample_rows_per_expert values may result in more accurate error estimates, but will take longer to compute
int sample_rows_per_expert = 512;
std::vector<float> f32_sample;
f32_sample.reserve((size_t)ne2 * (size_t)std::min<int64_t>(nrows_total, sample_rows_per_expert) * (size_t)n_per_row);
// deterministic sampling seed based on tensor name + fixed constant
std::mt19937 rng(std::hash<std::string>{}(name) ^0xeabada55cafed00d);
std::vector<int64_t> sample_rows_per_slice(ne2, 0);
const int64_t sample_rows_max = std::max<int64_t>(1, std::min<int64_t>(nrows_total, sample_rows_per_expert));
const int64_t stride = std::max<int64_t>(1, nrows_total / sample_rows_max);
std::vector<float> row_buffer(n_per_row);
const ggml_type src_type = t->type; const ggml_type src_type = t->type;
const ggml_type_traits *src_traits = ggml_get_type_traits(src_type); const ggml_type_traits *src_traits = ggml_get_type_traits(src_type);
const bool src_is_quant = ggml_is_quantized(src_type); const bool src_is_quant = ggml_is_quantized(src_type);
const size_t src_row_sz = ggml_row_size(src_type, n_per_row); const size_t src_row_sz = ggml_row_size(src_type, n_per_row);
std::vector<float> f32_sample;
f32_sample.reserve((size_t)ne2 * (size_t)std::min<int64_t>(nrows_total, sample_rows_per_expert) * (size_t)n_per_row);
std::vector<float> values_sample;
std::vector<float> activations_sample;
std::vector<int64_t> sample_rows_per_slice(ne2, 0);
// deterministic sampling seed based on tensor name + fixed constant
std::mt19937 rng(std::hash<std::string>{}(name) ^0xeabada55cafed00d);
const int64_t sample_rows_max = std::max<int64_t>(1, std::min<int64_t>(nrows_total, sample_rows_per_expert));
const int64_t stride = std::max<int64_t>(1, nrows_total / sample_rows_max);
// Temporary buffer for one dequantized row
std::vector<float> rowbuf((size_t)n_per_row);
for (int64_t slice = 0; slice < ne2; ++slice) { for (int64_t slice = 0; slice < ne2; ++slice) {
int64_t current_sampled_rows = 0; int64_t current_sampled_rows = 0;
int64_t offset = 0; int64_t offset = 0;
@ -1004,19 +976,19 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
f32_sample.insert(f32_sample.end(), src_row, src_row + n_per_row); f32_sample.insert(f32_sample.end(), src_row, src_row + n_per_row);
} else if (src_type == GGML_TYPE_F16) { } else if (src_type == GGML_TYPE_F16) {
const ggml_fp16_t * src_row = (const ggml_fp16_t *)((const uint8_t *)t->data + slice * (src_row_sz * nrows_total) + r * src_row_sz); const ggml_fp16_t * src_row = (const ggml_fp16_t *)((const uint8_t *)t->data + slice * (src_row_sz * nrows_total) + r * src_row_sz);
ggml_fp16_to_fp32_row(src_row, rowbuf.data(), (int)n_per_row); ggml_fp16_to_fp32_row(src_row, row_buffer.data(), (int)n_per_row);
f32_sample.insert(f32_sample.end(), rowbuf.begin(), rowbuf.end()); f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end());
} else if (src_type == GGML_TYPE_BF16) { } else if (src_type == GGML_TYPE_BF16) {
const ggml_bf16_t * src_row = (const ggml_bf16_t *)((const uint8_t *)t->data + slice * (src_row_sz * nrows_total) + r * src_row_sz); const ggml_bf16_t * src_row = (const ggml_bf16_t *)((const uint8_t *)t->data + slice * (src_row_sz * nrows_total) + r * src_row_sz);
ggml_bf16_to_fp32_row(src_row, rowbuf.data(), (int)n_per_row); ggml_bf16_to_fp32_row(src_row, row_buffer.data(), (int)n_per_row);
f32_sample.insert(f32_sample.end(), rowbuf.begin(), rowbuf.end()); f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end());
} else if (src_is_quant) { } else if (src_is_quant) {
const uint8_t * qrow = (const uint8_t *)t->data + slice * (src_row_sz * nrows_total) + r * src_row_sz; const uint8_t * qrow = (const uint8_t *)t->data + slice * (src_row_sz * nrows_total) + r * src_row_sz;
if (!src_traits || !src_traits->to_float) { if (!src_traits || !src_traits->to_float) {
throw std::runtime_error(format("cannot dequantize type %s for sampling", ggml_type_name(src_type))); throw std::runtime_error(format("cannot dequantize type %s for sampling", ggml_type_name(src_type)));
} }
src_traits->to_float(qrow, rowbuf.data(), (int)n_per_row); src_traits->to_float(qrow, row_buffer.data(), (int)n_per_row);
f32_sample.insert(f32_sample.end(), rowbuf.begin(), rowbuf.end()); f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end());
} else { } else {
throw std::runtime_error(format("unsupported src type %s for sampling", ggml_type_name(src_type))); throw std::runtime_error(format("unsupported src type %s for sampling", ggml_type_name(src_type)));
} }
@ -1045,6 +1017,10 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
} }
}; };
const float * values_all = get_values(name);
const float * activations_all = get_activations(name);
std::vector<float> values_sample;
std::vector<float> activations_sample;
if (values_all) { if (values_all) {
// get size from the map (not just the raw pointer) // get size from the map (not just the raw pointer)
auto itv = values_data->find(remap_imatrix(name, mapped)); auto itv = values_data->find(remap_imatrix(name, mapped));
@ -1057,6 +1033,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
copy_or_broadcast(activations_all, sz, activations_sample); copy_or_broadcast(activations_all, sz, activations_sample);
} }
const int64_t nelem = ggml_nelements(t);
tensor_info info; tensor_info info;
info.w = tw; info.w = tw;
info.n_elements = nelem; info.n_elements = nelem;
@ -1067,12 +1044,12 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
// Build list of candidate types first (compatible ones) // Build list of candidate types first (compatible ones)
std::vector<ggml_type> quant_candidates; std::vector<ggml_type> quant_candidates;
if (is_iq(params->ftype)) { if (is_iq(params->ftype)) {
quant_candidates.assign(std::begin(iq_candidates), std::end(iq_candidates)); quant_candidates.assign(std::begin(iq_quants), std::end(iq_quants));
} else { } else {
quant_candidates.assign(std::begin(k_candidates), std::end(k_candidates)); quant_candidates.assign(std::begin(k_quants), std::end(k_quants));
} }
// Compute maximum row size among compatible candidates (to size qbuf once) // Compute maximum row size among compatible candidates (to size quantized_buffer once)
size_t max_row_sz = 0; size_t max_row_sz = 0;
const bool has_valid_imatrix = !values_sample.empty() && values_sample.size() == (size_t)ne2 * (size_t)n_per_row; const bool has_valid_imatrix = !values_sample.empty() && values_sample.size() == (size_t)ne2 * (size_t)n_per_row;
std::vector<ggml_type> compatible_candidates; std::vector<ggml_type> compatible_candidates;
@ -1092,21 +1069,20 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
compatible_candidates.erase(std::unique(compatible_candidates.begin(), compatible_candidates.end()), compatible_candidates.end()); compatible_candidates.erase(std::unique(compatible_candidates.begin(), compatible_candidates.end()), compatible_candidates.end());
// Now evaluate candidates // Now evaluate candidates
std::vector<candidate_types> cand_out(compatible_candidates.size()); std::vector<candidate_types> eval_candidates(compatible_candidates.size());
const float *vals_ptr = values_sample.empty() ? nullptr : values_sample.data(); const float *values = values_sample.empty() ? nullptr : values_sample.data();
const float *acts_ptr = activations_sample.empty() ? nullptr : activations_sample.data(); const float *activations = activations_sample.empty() ? nullptr : activations_sample.data();
std::vector<uint8_t> qbuf(max_row_sz * total_sampled_rows); std::vector<uint8_t> quantized_buffer(max_row_sz * total_sampled_rows);
std::vector<float> deq(f32_sample.size()); std::vector<float> dequantised_buffer(f32_sample.size());
int n_eval_threads = std::max(1, std::min<int>(nthread, (int)compatible_candidates.size())); int n_eval_threads = std::max(1, std::min<int>(nthread, (int)compatible_candidates.size()));
std::atomic<size_t> cidx{0}; std::atomic<size_t> cidx{0};
std::vector<std::thread> eval_workers; std::vector<std::thread> eval_workers;
eval_workers.reserve(n_eval_threads); eval_workers.reserve(n_eval_threads);
for (int ti = 0; ti < n_eval_threads; ++ti) { for (int ti = 0; ti < n_eval_threads; ++ti) {
eval_workers.emplace_back([&] { eval_workers.emplace_back([&] {
// thread-local scratch // thread-local scratch
std::vector<uint8_t> tl_qbuf(qbuf.size()); std::vector<uint8_t> tl_quantized_buffer(quantized_buffer.size());
std::vector<float> tl_deq(deq.size()); std::vector<float> tl_dequantised_buffer(dequantised_buffer.size());
for (;;) { for (;;) {
const size_t i = cidx.fetch_add(1, std::memory_order_relaxed); const size_t i = cidx.fetch_add(1, std::memory_order_relaxed);
@ -1114,15 +1090,16 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const ggml_type tt = compatible_candidates[i]; const ggml_type tt = compatible_candidates[i];
const auto bpw = (float)tensor_bpw(t, tt); const auto bpw = (float)tensor_bpw(t, tt);
const size_t bytes = total_bytes(t, tt); const size_t bytes = tensor_bytes(t, tt);
const auto err = (float)estimate_error(t, tt, f32_sample, sample_rows_per_slice, vals_ptr, acts_ptr, tl_qbuf, tl_deq); const auto err = (float)estimate_error(t, tt, f32_sample, sample_rows_per_slice, values, activations, tl_quantized_buffer, tl_dequantised_buffer);
cand_out[i] = candidate_types{ tt, bpw, bytes, err }; eval_candidates[i] = candidate_types{ tt, bpw, bytes, err };
} }
}); });
} }
for (auto &th : eval_workers) { th.join(); } for (auto &th : eval_workers) { th.join(); }
for (auto &c : cand_out) { for (auto &c : eval_candidates) {
if (c.bytes > 0) { info.candidate.push_back(c); } if (c.bytes > 0) { info.candidate.push_back(c); }
} }
@ -1132,7 +1109,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
info.candidate.push_back(candidate_types{ t->type, bpw, ggml_nbytes(t), 0.0 }); info.candidate.push_back(candidate_types{ t->type, bpw, ggml_nbytes(t), 0.0 });
} }
// Remove dominated candidates: if A has >= bytes and >= error than B (and > in at least one), drop A. // Keep only the Paretooptimal candidates: if A has >= bytes and >= error than B, drop A.
{ {
std::vector<candidate_types> pruned; std::vector<candidate_types> pruned;
pruned.reserve(info.candidate.size()); pruned.reserve(info.candidate.size());
@ -1155,36 +1132,37 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
info.candidate.swap(pruned); info.candidate.swap(pruned);
} }
std::sort(info.candidate.begin(), info.candidate.end(), [](const candidate_types & a, const candidate_types & b) {
if (a.bpw != b.bpw) { return a.bpw < b.bpw; }
if (a.error != b.error) { return a.error < b.error; }
return a.bytes < b.bytes;
});
// Collapse candidates with identical storage size (bytes) // Collapse candidates with identical storage size (bytes)
{ {
std::vector<candidate_types> uniq; std::vector<candidate_types> unique;
uniq.reserve(info.candidate.size()); unique.reserve(info.candidate.size());
// Sort by bpw asc, error asc, bytes asc
std::sort(info.candidate.begin(), info.candidate.end(), [](const candidate_types & a, const candidate_types & b) {
if (a.bpw != b.bpw) { return a.bpw < b.bpw; }
if (a.error != b.error) { return a.error < b.error; }
return a.bytes < b.bytes;
});
for (size_t i = 0; i < info.candidate.size();) { for (size_t i = 0; i < info.candidate.size();) {
size_t j = i + 1; size_t j = i + 1;
candidate_types best = info.candidate[i]; candidate_types best = info.candidate[i];
// group same-byte entries, keep the one with the lowest error // group same-byte entries, keep the one with the lowest error
while (j < info.candidate.size() && info.candidate[j].bytes == info.candidate[i].bytes) { while (j < info.candidate.size() && info.candidate[j].bytes == info.candidate[i].bytes) {
if (info.candidate[j].error < best.error) { best = info.candidate[j]; } if (info.candidate[j].error < best.error) {
best = info.candidate[j];
}
++j; ++j;
} }
uniq.push_back(best); unique.push_back(best);
i = j; i = j;
} }
info.candidate.swap(uniq); info.candidate.swap(unique);
} }
// Initialize choice at the smallest bpw candidate // Initialize choice at the smallest bpw candidate
info.choice = 0; info.choice = 0;
info.min_bpw = info.candidate.front().bpw; info.min_bpw = info.candidate.front().bpw;
info.max_bpw = info.candidate.back().bpw; info.max_bpw = info.candidate.back().bpw;
all.push_back(std::move(info)); all.push_back(std::move(info));
} }
@ -1196,6 +1174,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
for (const auto & ti : all) { for (const auto & ti : all) {
b += ti.candidate[ti.choice].bytes; b += ti.candidate[ti.choice].bytes;
} }
return b; return b;
}; };
@ -1204,6 +1183,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
for (const auto & ti : all) { for (const auto & ti : all) {
w += ti.n_elements; w += ti.n_elements;
} }
return w; return w;
}; };
@ -1215,12 +1195,14 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
// Precompute current bpw // Precompute current bpw
double bpw_now = current_bpw(); double bpw_now = current_bpw();
float target_bpw = params->target_bpw;
// If minimal bpw is already above the target, we're constrained by geometry; return closest (min bpw) // If minimal bpw is already above the target, we're constrained by geometry; return closest (min bpw)
if (bpw_now >= target_bpw) { if (bpw_now >= target_bpw) {
std::unordered_map<std::string, ggml_type> overrides; std::unordered_map<std::string, ggml_type> overrides;
for (const auto & ti : all) { for (const auto & ti : all) {
overrides[ggml_get_name(ti.w->tensor)] = ti.candidate[ti.choice].type; overrides[ggml_get_name(ti.w->tensor)] = ti.candidate[ti.choice].type;
} }
return overrides; return overrides;
} }
@ -1268,6 +1250,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
best = upgrade{ i, j, err, delta_bytes, ratio }; best = upgrade{ i, j, err, delta_bytes, ratio };
} }
} }
return best; return best;
}; };
@ -1286,16 +1269,12 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
} }
} }
// We might still be below target but taking any single upgrade overshoots. // We might still be below target so we try to find the best upgrade one last time
// Try to find the best upgrade that overshoots the target_bpw by the least and has the best error-to-size ratio.
{ {
double under_gap = target_bpw - bpw_now;
upgrade best_over{ -1, -1, 0.0, 0, -1.0 }; upgrade best_over{ -1, -1, 0.0, 0, -1.0 };
double best_over_gap = 1e300; double best_over_gap = 1e300;
double under_gap = target_bpw - bpw_now;
size_t now_bytes = current_total_bytes(); size_t now_bytes = current_total_bytes();
for (int i = 0; i < (int) all.size(); ++i) { for (int i = 0; i < (int) all.size(); ++i) {
const auto & ti = all[i]; const auto & ti = all[i];
if (ti.choice >= (int)ti.candidate.size() - 1) { continue; } if (ti.choice >= (int)ti.candidate.size() - 1) { continue; }
@ -1305,19 +1284,16 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const auto & cur = ti.candidate[ti.choice]; const auto & cur = ti.candidate[ti.choice];
const auto & nxt = ti.candidate[j]; const auto & nxt = ti.candidate[j];
size_t delta_bytes = nxt.bytes - cur.bytes; size_t delta_bytes = nxt.bytes - cur.bytes;
if (delta_bytes == 0) { continue; } if (delta_bytes == 0) { continue; }
size_t over_bytes = now_bytes + delta_bytes; size_t over_bytes = now_bytes + delta_bytes;
double bpw_over = (double)over_bytes * 8.0 / (double)tw; double bpw_over = (double)over_bytes * 8.0 / (double)tw;
double over_gap = std::abs(bpw_over - (double)target_bpw);
double err = cur.error - nxt.error; double err = cur.error - nxt.error;
if (err < 0.0) { err = 0.0; } if (err < 0.0) { err = 0.0; }
double ratio = err / (double)(delta_bytes * 8ull); double ratio = err / (double)(delta_bytes * 8ull);
double over_gap = std::abs(bpw_over - (double)target_bpw);
if (over_gap < best_over_gap - 1e-12 || (std::abs(over_gap - best_over_gap) <= 1e-12 && ratio > best_over.ratio)) { if (over_gap < best_over_gap - 1e-12 || (std::abs(over_gap - best_over_gap) <= 1e-12 && ratio > best_over.ratio)) {
best_over_gap = over_gap; best_over_gap = over_gap;
best_over = upgrade{ i, j, err, delta_bytes, ratio }; best_over = upgrade{ i, j, err, delta_bytes, ratio };
@ -1339,6 +1315,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
__func__, ggml_get_name(ti.w->tensor), ggml_type_name(ti.candidate[ti.choice].type), ti.candidate[ti.choice].bpw, ti.candidate[ti.choice].error); __func__, ggml_get_name(ti.w->tensor), ggml_type_name(ti.candidate[ti.choice].type), ti.candidate[ti.choice].bpw, ti.candidate[ti.choice].error);
overrides[ggml_get_name(ti.w->tensor)] = ti.candidate[ti.choice].type; overrides[ggml_get_name(ti.w->tensor)] = ti.candidate[ti.choice].type;
} }
return overrides; return overrides;
} }