General refactoring

This commit is contained in:
Ed Addario 2025-09-20 21:31:31 +01:00
parent ad70fca5b2
commit 14fae69a7b
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 39 additions and 36 deletions

View File

@ -729,19 +729,19 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
auto estimate_error = [&](const ggml_tensor * t, auto estimate_error = [&](const ggml_tensor * t,
const ggml_type quant_type, const ggml_type quant_type,
const std::vector<float> & f32_sample, const std::vector<float> & f32_sample,
const std::vector<int64_t> & sample_rows_per_slice, const std::vector<int64_t> & rows_sample,
const float * values_sample, const float * values_sample,
const float * activations_sample, const float * activations_sample,
std::vector<uint8_t> & quantized_buffer, std::vector<uint8_t> & quantized_buffer,
std::vector<float> & dequantized_buffer, std::vector<float> & dequantized_buffer,
float bias_lambda, float tensor_bias_lambda,
const float * slice_bias_lambda,
double * out_mse = nullptr, double * out_mse = nullptr,
double * out_proj = nullptr) -> double double * out_proj = nullptr) -> double
{ {
const int64_t n_per_row = t->ne[0]; const int64_t n_per_row = t->ne[0];
const int64_t nrows = t->ne[1]; const int64_t nrows = t->ne[1];
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1; const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
const size_t sample_element_count = f32_sample.size(); const size_t sample_element_count = f32_sample.size();
const size_t sample_row_count = n_per_row > 0 ? sample_element_count / (size_t)n_per_row : 0; const size_t sample_row_count = n_per_row > 0 ? sample_element_count / (size_t)n_per_row : 0;
if (sample_row_count == 0) { if (sample_row_count == 0) {
@ -753,8 +753,9 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
size_t expected_rows = 0; size_t expected_rows = 0;
for (int64_t s = 0; s < ne2; ++s) { for (int64_t s = 0; s < ne2; ++s) {
expected_rows += (size_t)sample_rows_per_slice[s]; expected_rows += (size_t)rows_sample[s];
} }
if (expected_rows != sample_row_count) { if (expected_rows != sample_row_count) {
if (out_mse) { *out_mse = infinity; } if (out_mse) { *out_mse = infinity; }
if (out_proj) { *out_proj = 0.0; } if (out_proj) { *out_proj = 0.0; }
@ -783,17 +784,18 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const double a = activations[j]; const double a = activations[j];
denom += w * a * a; denom += w * a * a;
} }
bias_denominator_per_slice[s] = denom; bias_denominator_per_slice[s] = denom;
} }
} }
// Per-row squared norms with weighting // Weighted per-row squared norms
std::vector<double> row_sq_norm(sample_row_count, 0.0); std::vector<double> row_sq_norm(sample_row_count, 0.0);
{ {
size_t offset = 0; size_t offset = 0;
size_t row_idx = 0; size_t row_idx = 0;
for (int64_t s = 0; s < ne2; ++s) { for (int64_t s = 0; s < ne2; ++s) {
const int64_t rs = sample_rows_per_slice[s]; const int64_t rs = rows_sample[s];
if (rs == 0) { continue; } if (rs == 0) { continue; }
const float * values = has_values ? values_sample + s * n_per_row : nullptr; const float * values = has_values ? values_sample + s * n_per_row : nullptr;
@ -823,7 +825,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
size_t q_offset = 0; size_t q_offset = 0;
size_t f_offset = 0; size_t f_offset = 0;
for (int64_t slice = 0; slice < ne2; ++slice) { for (int64_t slice = 0; slice < ne2; ++slice) {
const int64_t rs = sample_rows_per_slice[slice]; const int64_t rs = rows_sample[slice];
if (rs == 0) { continue; } if (rs == 0) { continue; }
const float * value = has_values ? values_sample + slice * n_per_row : nullptr; const float * value = has_values ? values_sample + slice * n_per_row : nullptr;
@ -843,21 +845,19 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
} else { } else {
for (size_t r = 0; r < sample_row_count; ++r) { for (size_t r = 0; r < sample_row_count; ++r) {
uint8_t * src = quantized_buffer.data() + r * row_sz; uint8_t * src = quantized_buffer.data() + r * row_sz;
float * dst = dequantized_buffer.data() + r * (size_t) n_per_row; float * dst = dequantized_buffer.data() + r * (size_t)n_per_row;
if (is_fp16) { if (is_fp16) {
ggml_fp16_to_fp32_row((const ggml_fp16_t *) src, dst, (int) n_per_row); ggml_fp16_to_fp32_row((const ggml_fp16_t *)src, dst, (int)n_per_row);
} } else if (is_bf16) {
else if (is_bf16) { ggml_bf16_to_fp32_row((const ggml_bf16_t *)src, dst, (int)n_per_row);
ggml_bf16_to_fp32_row((const ggml_bf16_t *) src, dst, (int) n_per_row); } else {
}
else {
if (!traits || !traits->to_float) { if (!traits || !traits->to_float) {
if (out_mse) { *out_mse = infinity; } if (out_mse) { *out_mse = infinity; }
if (out_proj) { *out_proj = 0.0; } if (out_proj) { *out_proj = 0.0; }
return infinity; return infinity;
} }
traits->to_float(src, dst, (int) n_per_row); traits->to_float(src, dst, (int)n_per_row);
} }
} }
} }
@ -1098,20 +1098,20 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
offset = dist(rng); offset = dist(rng);
} }
for (int64_t r = offset; r < nrows_total && current_sampled_rows < sample_rows_max; r += stride) { for (int64_t r = offset; r < nrows_total && current_sampled_rows < rows_sample_max; r += stride) {
if (src_type == GGML_TYPE_F32) { if (src_type == GGML_TYPE_F32) {
const float * src_row = (const float *)t->data + slice * (n_per_row * nrows_total) + r * n_per_row; const float * src_row = (const float *)tensor->data + slice * (n_per_row * nrows_total) + r * n_per_row;
f32_sample.insert(f32_sample.end(), src_row, src_row + n_per_row); f32_sample.insert(f32_sample.end(), src_row, src_row + n_per_row);
} else if (src_type == GGML_TYPE_F16) { } else if (src_type == GGML_TYPE_F16) {
const auto * src_row = (const ggml_fp16_t *)((const uint8_t *)t->data + slice * (src_row_sz * nrows_total) + r * src_row_sz); const auto * src_row = (const ggml_fp16_t *)((const uint8_t *)tensor->data + slice * (src_row_sz * nrows_total) + r * src_row_sz);
ggml_fp16_to_fp32_row(src_row, row_buffer.data(), (int)n_per_row); ggml_fp16_to_fp32_row(src_row, row_buffer.data(), (int)n_per_row);
f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end()); f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end());
} else if (src_type == GGML_TYPE_BF16) { } else if (src_type == GGML_TYPE_BF16) {
const auto * src_row = (const ggml_bf16_t *)((const uint8_t *)t->data + slice * (src_row_sz * nrows_total) + r * src_row_sz); const auto * src_row = (const ggml_bf16_t *)((const uint8_t *)tensor->data + slice * (src_row_sz * nrows_total) + r * src_row_sz);
ggml_bf16_to_fp32_row(src_row, row_buffer.data(), (int)n_per_row); ggml_bf16_to_fp32_row(src_row, row_buffer.data(), (int)n_per_row);
f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end()); f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end());
} else if (src_is_quant) { } else if (src_is_quant) {
const uint8_t * qrow = (const uint8_t *)t->data + slice * (src_row_sz * nrows_total) + r * src_row_sz; const uint8_t * qrow = (const uint8_t *)tensor->data + slice * (src_row_sz * nrows_total) + r * src_row_sz;
if (!src_traits || !src_traits->to_float) { if (!src_traits || !src_traits->to_float) {
throw std::runtime_error(format("cannot dequantize type %s for sampling", ggml_type_name(src_type))); throw std::runtime_error(format("cannot dequantize type %s for sampling", ggml_type_name(src_type)));
} }
@ -1120,9 +1120,11 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
} else { } else {
throw std::runtime_error(format("unsupported src type %s for sampling", ggml_type_name(src_type))); throw std::runtime_error(format("unsupported src type %s for sampling", ggml_type_name(src_type)));
} }
++current_sampled_rows; ++current_sampled_rows;
} }
sample_rows_per_slice[slice] = current_sampled_rows;
rows_sample[slice] = current_sampled_rows;
} }
auto side_data = [&](const std::unordered_map<std::string, std::vector<float>> * m, const std::string & tensor_name) -> std::pair<const float*, size_t> { auto side_data = [&](const std::unordered_map<std::string, std::vector<float>> * m, const std::string & tensor_name) -> std::pair<const float*, size_t> {
@ -1160,7 +1162,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
if (values_all) { copy_or_broadcast(values_all, values_sz, values_sample); } if (values_all) { copy_or_broadcast(values_all, values_sz, values_sample); }
if (activations_all) { copy_or_broadcast(activations_all, activations_sz, activations_sample); } if (activations_all) { copy_or_broadcast(activations_all, activations_sz, activations_sample); }
const int64_t nelem = ggml_nelements(t); const int64_t nelem = ggml_nelements(tensor);
tensor_info info; tensor_info info;
info.w = tw; info.w = tw;
info.n_elements = nelem; info.n_elements = nelem;
@ -1185,8 +1187,9 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
__func__, ggml_type_name(ts_type), name.c_str()); __func__, ggml_type_name(ts_type), name.c_str());
continue; continue;
} }
ggml_type tt = make_compatible(t, ts_type);
if (!is_compatible(t, tt)) { continue; } ggml_type tt = make_compatible(tensor, ts_type);
if (!is_compatible(tensor, tt)) { continue; }
compatible_candidates.push_back(tt); compatible_candidates.push_back(tt);
max_row_sz = std::max(max_row_sz, ggml_row_size(tt, n_per_row)); max_row_sz = std::max(max_row_sz, ggml_row_size(tt, n_per_row));
} }
@ -1222,16 +1225,16 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
// thread-local scratch // thread-local scratch
std::vector<uint8_t> tl_quantized_buffer(quantized_buffer.size()); std::vector<uint8_t> tl_quantized_buffer(quantized_buffer.size());
std::vector<float> tl_dequantised_buffer(dequantised_buffer.size()); std::vector<float> tl_dequantised_buffer(dequantised_buffer.size());
for (;;) { for (;;) {
const size_t i = cidx.fetch_add(1, std::memory_order_relaxed); const size_t i = cidx.fetch_add(1, std::memory_order_relaxed);
if (i >= compatible_candidates.size()) { break; } if (i >= compatible_candidates.size()) { break; }
const ggml_type tt = compatible_candidates[i]; const ggml_type tensor_types = compatible_candidates[i];
const auto bpw = (float)tensor_bpw(t, tt); const auto bpw = (float)tensor_bpw(tensor, tensor_types);
const size_t bytes = tensor_bytes(t, tt); const size_t bytes = tensor_bytes(tensor, tensor_types);
const auto err = estimate_error(t, tt, f32_sample, sample_rows_per_slice, values, activations, tl_quantized_buffer, tl_dequantised_buffer, bias_lambda); const auto err = estimate_error(tensor, tensor_types, f32_sample, rows_sample, values, activations,
eval_candidates[i] = candidate_types{ tt, bpw, bytes, err }; tl_quantized_buffer, tl_dequantised_buffer, tensor_lambda, slice_lambda);
eval_candidates[i] = candidate_types{ tensor_types, bpw, bytes, err };
} }
}); });
} }
@ -1244,8 +1247,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
if (info.candidate.empty()) { if (info.candidate.empty()) {
// As a last resort, keep original type // As a last resort, keep original type
float bpw = ggml_nbytes(t) * 8.0f / nelem; float bpw = ggml_nbytes(tensor) * 8.0f / nelem;
info.candidate.push_back(candidate_types{ t->type, bpw, ggml_nbytes(t), 0.0 }); info.candidate.push_back(candidate_types{ tensor->type, bpw, ggml_nbytes(tensor), 0.0 });
} }
// Keep only the paretooptimal candidates: if A has >= bytes and >= error than B, drop A. // Keep only the paretooptimal candidates: if A has >= bytes and >= error than B, drop A.
@ -1274,6 +1277,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
// same bytes: we already sorted by error; skip // same bytes: we already sorted by error; skip
} }
} }
info.candidate.swap(pruned); info.candidate.swap(pruned);
} }
@ -1299,6 +1303,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
} }
convex.push_back(p); convex.push_back(p);
} }
info.candidate.swap(convex); info.candidate.swap(convex);
} }
} }
@ -1312,7 +1317,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
if (all.empty()) { return {}; } if (all.empty()) { return {}; }
// Lagrangian relaxation to minimise error subject to a bpw target constraint
auto total_bytes = [&]() -> size_t { auto total_bytes = [&]() -> size_t {
size_t tb = 0; size_t tb = 0;
for (const auto & ti : all) { for (const auto & ti : all) {
@ -1359,6 +1363,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
return emit_overrides(); return emit_overrides();
} }
// Lagrangian relaxation to minimise error subject to a bpw target constraint
auto lagrange_penalty = [&](const double mu, std::vector<int> & choice, size_t & bytes, double & err) { auto lagrange_penalty = [&](const double mu, std::vector<int> & choice, size_t & bytes, double & err) {
choice.resize(all.size()); choice.resize(all.size());
bytes = 0; bytes = 0;
@ -1406,6 +1411,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
if (bytes_hi <= budget_bytes) { if (bytes_hi <= budget_bytes) {
break; break;
} }
mu_hi *= 2.0; mu_hi *= 2.0;
if (++expand > 60) { if (++expand > 60) {
break; break;
@ -1422,11 +1428,9 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
lagrange_penalty(mu, choice_mid, bytes_mid, err_mid); lagrange_penalty(mu, choice_mid, bytes_mid, err_mid);
const double gap = std::abs((double)bytes_mid - (double)budget_bytes); const double gap = std::abs((double)bytes_mid - (double)budget_bytes);
if (bytes_mid > budget_bytes) { if (bytes_mid > budget_bytes) {
// Too big, need stronger penalty // Too big, need stronger penalty
mu_lo = mu; mu_lo = mu;
if (gap < best_over_gap - epsilon || (std::abs(gap - best_over_gap) <= epsilon && err_mid < best_over_err)) { if (gap < best_over_gap - epsilon || (std::abs(gap - best_over_gap) <= epsilon && err_mid < best_over_err)) {
best_over_gap = gap; best_over_gap = gap;
best_over_err = err_mid; best_over_err = err_mid;
@ -1435,7 +1439,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
} else { } else {
// Under budget, good candidate // Under budget, good candidate
mu_hi = mu; mu_hi = mu;
if (gap < best_under_gap - epsilon || (std::abs(gap - best_under_gap) <= epsilon && err_mid < best_under_err)) { if (gap < best_under_gap - epsilon || (std::abs(gap - best_under_gap) <= epsilon && err_mid < best_under_err)) {
best_under_gap = gap; best_under_gap = gap;
best_under_err = err_mid; best_under_err = err_mid;