Memory optimisations (AI assisted)

This commit is contained in:
Ed Addario 2026-01-22 11:39:26 +00:00
parent 2ede173218
commit ff3b9b4cae
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 42 additions and 28 deletions

View File

@ -966,7 +966,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
}
const size_t row_sz = ggml_row_size(quant_type, n_per_row);
if (quantized_buffer.size() < row_sz * sample_rows) { quantized_buffer.resize(row_sz * sample_rows); }
constexpr size_t SAFETY_PADDING = 256;
if (quantized_buffer.size() < row_sz * sample_rows + SAFETY_PADDING) { quantized_buffer.resize(row_sz * sample_rows + SAFETY_PADDING); }
if (dequantized_buffer.size() < sample_elems) { dequantized_buffer.resize(sample_elems); }
const bool has_vals = values_sample != nullptr;
@ -1230,9 +1231,9 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
double s2 = 0.0;
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = v ? std::max(0.0f, v[j]) : 1.0;
const double aw = std::sqrt(w) * a[j]; // z = w * a^2
s1 += aw * aw;
s2 += aw * aw * aw * aw;
const double aw2 = w * a[j] * a[j];
s1 += aw2;
s2 += aw2 * aw2;
}
if (s1 > 0.0) {
@ -1259,6 +1260,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const std::string name = ggml_get_name(tensor);
if (bpw_stop.load(std::memory_order_relaxed)) { return std::nullopt; }
const std::string remapped_name = remap_imatrix(name, mapped);
// Check cache
if (auto tn = bpw_data.find(name); tn != bpw_data.end()) {
type_choice tc;
@ -1315,7 +1318,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const ggml_type_traits * traits = ggml_get_type_traits(src_type);
for (int64_t slice = 0; slice < ne2; ++slice) {
std::mt19937 rng(std::hash<std::string>{}(name) ^ HASH_MAGIC ^ slice);
std::mt19937 rng(djb2_hash((const uint8_t*)name.data(), name.size()) ^ HASH_MAGIC ^ slice);
const int64_t limit = std::max<int64_t>(1, std::min<int64_t>(nrows_total, rows_to_sample));
const int64_t stride = std::max<int64_t>(1, nrows_total / limit);
int64_t offset = stride > 1 ? std::uniform_int_distribution<int64_t>(0, stride - 1)(rng) : 0;
@ -1343,7 +1346,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
// Prepare side data
auto get_side_data = [&](const auto * m) {
if (!m) { return std::pair<const float *, size_t>{nullptr, 0}; }
auto it = m->find(remap_imatrix(name, mapped));
auto it = m->find(remapped_name);
return it != m->end() ? std::pair{it->second.data(), it->second.size()} : std::pair<const float*, size_t>{nullptr, 0};
};
@ -1353,29 +1356,36 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
// Cache WCE stats once per tensor to avoid repeated map lookups/regex inside compute_quant_error
float h_norm = 1.0f;
if (valid_wce && statistics_data) {
const std::string key = remap_imatrix(name, mapped);
if (auto it = statistics_data->find(key); it != statistics_data->end() && !it->second.empty()) {
if (auto it = statistics_data->find(remapped_name); it != statistics_data->end() && !it->second.empty()) {
h_norm = it->second.size() > 3 ? it->second[1] : 1.0f;
}
}
std::vector<float> val_vec;
std::vector<float> act_vec;
auto prepare_broadcast = [&](const float* src, size_t sz, std::vector<float>& dst) {
if (!src) { return; }
std::vector<float> val_storage;
std::vector<float> act_storage;
const float * val_vec_ptr = nullptr;
const float * act_vec_ptr = nullptr;
auto prepare_broadcast = [&](const float* src, size_t sz, std::vector<float>& storage, const float*& out_ptr) {
if (!src) {
out_ptr = nullptr;
return;
}
size_t req = (size_t)ne2 * n_per_row;
if (sz == req) { dst.assign(src, src + req); }
if (sz == req) { out_ptr = src; }
else if (sz == (size_t)n_per_row) {
dst.resize(req);
for (int s = 0; s < ne2; ++s) { std::memcpy(dst.data() + s * n_per_row, src, n_per_row * sizeof(float)); }
storage.resize(req);
for (int s = 0; s < ne2; ++s) { std::memcpy(storage.data() + s * n_per_row, src, n_per_row * sizeof(float)); }
out_ptr = storage.data();
} else {
std::lock_guard<std::mutex> lock(log_mutex);
out_ptr = nullptr;
LLAMA_LOG_WARN("%s: side data mismatch for %s\n", func, name.c_str());
}
};
prepare_broadcast(val_ptr, val_sz, val_vec);
prepare_broadcast(act_ptr, act_sz, act_vec);
prepare_broadcast(val_ptr, val_sz, val_storage, val_vec_ptr);
prepare_broadcast(act_ptr, act_sz, act_storage, act_vec_ptr);
// Precompute WCE reference stats
wce_cache ref_wce;
@ -1383,13 +1393,13 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
size_t total_rows_sampled = 0;
for (int64_t r : rows_sample) { total_rows_sampled += r; }
if (valid_wce && !val_vec.empty() && !act_vec.empty()) {
if (valid_wce && val_vec_ptr && act_vec_ptr) {
ref_wce.row_sq_norm.reserve(total_rows_sampled);
size_t off = 0;
for (int64_t s = 0; s < ne2; ++s) {
const int64_t rs = rows_sample[s];
if (rs == 0) { continue; }
const float * v = val_vec.data() + s * n_per_row;
const float * v = val_vec_ptr + s * n_per_row;
for (int64_t r = 0; r < rs; ++r) {
const float * wx = f32_sample.data() + off;
double norm_x = 0.0;
@ -1405,13 +1415,13 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
// Precompute MSE reference stats
ref_mse.row_sq_norm.reserve(total_rows_sampled);
ref_mse.bias_denominator.assign(ne2, 0.0);
const bool has_acts = !act_vec.empty();
const bool has_vals = !val_vec.empty();
const bool has_acts = act_vec_ptr != nullptr;
const bool has_vals = val_vec_ptr != nullptr;
if (has_acts) {
for (int64_t s = 0; s < ne2; ++s) {
const float * v = has_vals ? val_vec.data() + s * n_per_row : nullptr;
const float * a = act_vec.data() + s * n_per_row;
const float * v = has_vals ? val_vec_ptr + s * n_per_row : nullptr;
const float * a = act_vec_ptr + s * n_per_row;
double denom = 0.0;
if (v) {
for (int64_t j = 0; j < n_per_row; ++j) { denom += std::max(0.0f, v[j]) * a[j] * a[j]; }
@ -1426,7 +1436,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
size_t off = 0;
for (int64_t s = 0; s < ne2; ++s) {
const int64_t rs = rows_sample[s];
const float * v = has_vals ? val_vec.data() + s * n_per_row : nullptr;
const float * v = has_vals ? val_vec_ptr + s * n_per_row : nullptr;
for (int64_t r = 0; r < rs; ++r) {
const float * x = f32_sample.data() + off;
double sum = 0.0;
@ -1447,7 +1457,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
std::vector<ggml_type> valid_types;
valid_types.reserve(std::size(quant_types));
size_t max_row_sz = 0;
const bool valid_matrix = !val_vec.empty();
const bool valid_matrix = val_vec_ptr != nullptr;
for (auto t : quant_types) {
if (is_iq(t) && !valid_matrix) { continue; }
@ -1461,7 +1471,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
valid_types.erase(std::unique(valid_types.begin(), valid_types.end()), valid_types.end());
float tensor_lambda = 0.0f;
std::vector<float> slice_lambdas = estimate_lambda(val_vec.empty()?nullptr:val_vec.data(), act_vec.empty()?nullptr:act_vec.data(), n_per_row, ne2);
std::vector<float> slice_lambdas = estimate_lambda(val_vec_ptr, act_vec_ptr, n_per_row, ne2);
if (!slice_lambdas.empty()) {
double sum = 0;
for(float l : slice_lambdas) { sum += l; }
@ -1473,6 +1483,10 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
evaluations.reserve(valid_types.size());
std::vector<uint8_t> q_buf;
std::vector<float> dq_buf;
if (total_rows_sampled > 0 && max_row_sz > 0) {
q_buf.reserve(total_rows_sampled * max_row_sz + 256); // safety padding
dq_buf.reserve(total_rows_sampled * n_per_row);
}
for (ggml_type vt : valid_types) {
if (bpw_stop.load(std::memory_order_relaxed)) { return std::nullopt; }
@ -1484,8 +1498,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
vt,
f32_sample,
rows_sample,
val_vec.empty() ? nullptr : val_vec.data(),
act_vec.empty() ? nullptr : act_vec.data(),
val_vec_ptr,
act_vec_ptr,
q_buf,
dq_buf,
tensor_lambda,