Reduce compute time by parallelising tensor processing - courtesy of https://github.com/ddh0
This commit is contained in:
parent
951de2e2c2
commit
12e0524f3a
|
|
@ -15,6 +15,7 @@
|
|||
#include <regex>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <optional>
|
||||
|
||||
// Quantization types. Changes to this struct must be replicated in quantize.cpp
|
||||
struct tensor_quantization {
|
||||
|
|
@ -623,7 +624,6 @@ static void signal_handler(int) {
|
|||
// Returns tensor type overrides to meet a global bpw target
|
||||
static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||
llama_model_loader & ml,
|
||||
std::vector<no_init<uint8_t>> & buffer,
|
||||
const llama_model & model,
|
||||
const std::vector<const llama_model_loader::llama_tensor_weight *> & tensors,
|
||||
const std::map<int, std::string> & mapped,
|
||||
|
|
@ -659,6 +659,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
GGML_TYPE_IQ3_XXS,
|
||||
GGML_TYPE_Q3_K,
|
||||
GGML_TYPE_IQ4_XS,
|
||||
GGML_TYPE_IQ4_NL,
|
||||
GGML_TYPE_Q4_K,
|
||||
GGML_TYPE_Q5_K,
|
||||
GGML_TYPE_Q6_K,
|
||||
|
|
@ -1127,16 +1128,22 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
|
||||
install_signal_handlers();
|
||||
auto bpw_data = load_bpw_state();
|
||||
std::vector<tensor_info> all;
|
||||
all.reserve(tensors.size());
|
||||
for (const auto * tw : tensors) {
|
||||
|
||||
// Significantly reduce compute time by parallelising tensor processing - courtesy of https://github.com/ddh0
|
||||
auto process_tensor = [&](const llama_model_loader::llama_tensor_weight * tw,
|
||||
std::vector<no_init<uint8_t>> & thread_local_buffer,
|
||||
std::mutex & loader_mutex,
|
||||
std::mutex & log_mutex) -> std::optional<tensor_info>
|
||||
{
|
||||
ggml_tensor * tensor = tw->tensor;
|
||||
const std::string name = ggml_get_name(tensor);
|
||||
if (!can_quantize(tensor)) { continue; }
|
||||
check_signal_handler(all);
|
||||
if (bpw_stop.load(std::memory_order_relaxed)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// If we already have fully evaluatedd this tensor then reuse it
|
||||
if (auto it_saved = bpw_data.find(name); it_saved != bpw_data.end()) {
|
||||
// check for pre-computed results from a checkpoint file.
|
||||
auto it_saved = bpw_data.find(name);
|
||||
if (it_saved != bpw_data.end()) {
|
||||
tensor_info info;
|
||||
info.w = tw;
|
||||
info.candidate = it_saved->second.candidate;
|
||||
|
|
@ -1144,17 +1151,21 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
info.min_bpw = it_saved->second.min_bpw;
|
||||
info.max_bpw = it_saved->second.max_bpw;
|
||||
info.n_elements = it_saved->second.n_elements ? it_saved->second.n_elements : (size_t)ggml_nelements(tensor);
|
||||
all.push_back(std::move(info));
|
||||
continue;
|
||||
return info;
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(log_mutex);
|
||||
LLAMA_LOG_INFO("\ttarget_bpw_type: - processing tensor %45s \t(%12" PRId64 " elements)\n", name.c_str(), ggml_nelements(tensor));
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("\t%s: - processing tensor %45s \t(%12" PRId64 " elements)\n", __func__, name.c_str(), ggml_nelements(tensor));
|
||||
if (!ml.use_mmap) {
|
||||
if (buffer.size() < ggml_nbytes(tensor)) { buffer.resize(ggml_nbytes(tensor)); }
|
||||
tensor->data = buffer.data();
|
||||
if (thread_local_buffer.size() < ggml_nbytes(tensor)) { thread_local_buffer.resize(ggml_nbytes(tensor)); }
|
||||
tensor->data = thread_local_buffer.data();
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(loader_mutex);
|
||||
ml.load_data_for(tensor);
|
||||
}
|
||||
|
||||
// Dequantize sampled rows into f32_sample
|
||||
const int64_t n_per_row = tensor->ne[0];
|
||||
|
|
@ -1170,7 +1181,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
const int64_t max_rows = 4096;
|
||||
int64_t total_rows = std::llround(slice_budget / std::max<int64_t>(1, n));
|
||||
total_rows = std::max<int64_t>(min_rows, std::min<int64_t>(total_rows, std::min<int64_t>(rows, max_rows)));
|
||||
if (rows <= min_rows * 2) { total_rows = rows; } // use all rows for small tensors
|
||||
if (rows <= min_rows * 2) { total_rows = rows; }
|
||||
return total_rows;
|
||||
};
|
||||
|
||||
|
|
@ -1191,17 +1202,16 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
return;
|
||||
}
|
||||
if (t == GGML_TYPE_F16) {
|
||||
ggml_fp16_to_fp32_row((const ggml_fp16_t *) src, dst, (int)n_per_row);
|
||||
ggml_fp16_to_fp32_row((const ggml_fp16_t *)src, dst, (int)n_per_row);
|
||||
return;
|
||||
}
|
||||
if (t == GGML_TYPE_BF16) {
|
||||
ggml_bf16_to_fp32_row((const ggml_bf16_t *) src, dst, (int)n_per_row);
|
||||
ggml_bf16_to_fp32_row((const ggml_bf16_t *)src, dst, (int)n_per_row);
|
||||
return;
|
||||
}
|
||||
|
||||
if (src_is_quant) {
|
||||
GGML_ASSERT(src_traits && src_traits->to_float);
|
||||
src_traits->to_float(src, dst, (int) n_per_row);
|
||||
src_traits->to_float(src, dst, (int)n_per_row);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -1266,6 +1276,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
return;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(log_mutex);
|
||||
LLAMA_LOG_WARN("%s: side data size mismatch for %s: got %zu, expected %zu or %zu; ignoring\n", func, name.c_str(), src_sz, (size_t)n_per_row, want);
|
||||
};
|
||||
|
||||
|
|
@ -1276,12 +1287,9 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
if (values_all) { copy_or_broadcast(values_all, values_sz, values_sample); }
|
||||
if (activations_all) { copy_or_broadcast(activations_all, activations_sz, activations_sample); }
|
||||
|
||||
const int64_t nelem = ggml_nelements(tensor);
|
||||
tensor_info info;
|
||||
info.w = tw;
|
||||
info.n_elements = nelem;
|
||||
|
||||
// Prepare scratch buffers sized for the largest candidate row size
|
||||
info.n_elements = ggml_nelements(tensor);
|
||||
size_t total_sampled_rows = f32_sample.size() / n_per_row;
|
||||
|
||||
// Build list of candidate types first (compatible ones)
|
||||
|
|
@ -1295,7 +1303,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
for (size_t i = 0; i < base_sz; ++i) {
|
||||
ggml_type ts_type = base_arr[i];
|
||||
if (is_iq(ts_type) && !has_valid_imatrix) {
|
||||
LLAMA_LOG_WARN("%s: skipping %s for %s, no or mismatched imatrix\n", __func__, ggml_type_name(ts_type), name.c_str());
|
||||
std::lock_guard<std::mutex> lock(log_mutex);
|
||||
LLAMA_LOG_WARN("\t%s: skipping %s for %s, no or mismatched imatrix\n", func, ggml_type_name(ts_type), name.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -1325,19 +1334,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
std::vector<uint8_t> quantized_buffer(max_row_sz * total_sampled_rows);
|
||||
std::vector<float> dequantized_buffer(f32_sample.size());
|
||||
const float * slice_lambda = lambdas.empty() ? nullptr : lambdas.data();
|
||||
int n_eval_threads = std::max(1, std::min<int>(nthread, (int)compatible_candidates.size()));
|
||||
std::atomic<size_t> cidx{0};
|
||||
std::vector<std::thread> eval_workers;
|
||||
eval_workers.reserve(n_eval_threads);
|
||||
for (int ti = 0; ti < n_eval_threads; ++ti) {
|
||||
eval_workers.emplace_back([&] {
|
||||
// thread-local scratch
|
||||
std::vector<uint8_t> tl_quantized_buffer(quantized_buffer.size());
|
||||
std::vector<float> tl_dequantized_buffer(dequantized_buffer.size());
|
||||
for (;;) {
|
||||
if (bpw_stop.load(std::memory_order_relaxed)) { break; } // stop if a signal arrived
|
||||
const size_t i = cidx.fetch_add(1, std::memory_order_acq_rel);
|
||||
if (i >= compatible_candidates.size()) { break; }
|
||||
for (size_t i = 0; i < compatible_candidates.size(); ++i) {
|
||||
if (bpw_stop.load(std::memory_order_relaxed)) { break; }
|
||||
|
||||
const ggml_type tensor_types = compatible_candidates[i];
|
||||
const auto bpw = (float)tensor_bpw(tensor, tensor_types);
|
||||
|
|
@ -1345,25 +1343,17 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
double mse = 0.0;
|
||||
double proj = 0.0;
|
||||
const auto err = estimate_error(tensor, tensor_types, f32_sample, rows_sample, values, activations,
|
||||
tl_quantized_buffer, tl_dequantized_buffer, tensor_lambda, slice_lambda, &mse, &proj);
|
||||
quantized_buffer, dequantized_buffer, tensor_lambda, slice_lambda, &mse, &proj);
|
||||
eval_candidates[i] = candidate_types{ tensor_types, bpw, bytes, err, mse, proj };
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (auto &th : eval_workers) { th.join(); }
|
||||
|
||||
// If interruption happened mid-evaluation, exit without adding a half-baked tensor entry
|
||||
if (bpw_stop.load(std::memory_order_relaxed) && cidx.load(std::memory_order_relaxed) < compatible_candidates.size()) {
|
||||
check_signal_handler(all);
|
||||
}
|
||||
if (bpw_stop.load(std::memory_order_relaxed)) { return std::nullopt; }
|
||||
|
||||
// Check if biasing is needed
|
||||
bool bias_needed = false;
|
||||
if (!lambdas.empty()) {
|
||||
int min_mse = -1;
|
||||
int min_bias = -1;
|
||||
{
|
||||
double best_mse = std::numeric_limits<double>::infinity();
|
||||
double best_err = std::numeric_limits<double>::infinity();
|
||||
for (int i = 0; i < (int)eval_candidates.size(); ++i) {
|
||||
|
|
@ -1378,7 +1368,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
min_bias = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (min_mse != min_bias) {
|
||||
bias_needed = true;
|
||||
|
|
@ -1388,8 +1377,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
if (c.bytes == 0) { continue; }
|
||||
const double mse = std::max(c.mse, epsilon);
|
||||
const double bias_term = std::max(0.0, c.error - c.mse);
|
||||
const double rel = bias_term / mse;
|
||||
max_rel_bias = std::max(rel, max_rel_bias);
|
||||
max_rel_bias = std::max(bias_term / mse, max_rel_bias);
|
||||
}
|
||||
|
||||
bias_needed = max_rel_bias >= 0.5; // >= 50% of MSE?
|
||||
|
|
@ -1404,7 +1392,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
|
||||
if (info.candidate.empty()) {
|
||||
// As a last resort, keep original type
|
||||
float bpw = ggml_nbytes(tensor) * 8.0f / nelem;
|
||||
float bpw = ggml_nbytes(tensor) * 8.0f / info.n_elements;
|
||||
info.candidate.push_back(candidate_types{ tensor->type, bpw, ggml_nbytes(tensor), 0.0 });
|
||||
}
|
||||
|
||||
|
|
@ -1416,26 +1404,18 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
if (a.bytes != b.bytes) { return a.bytes < b.bytes; }
|
||||
return a.error < b.error;
|
||||
});
|
||||
const auto last = std::unique(candidates.begin(), candidates.end(), [](const candidate_types & a, const candidate_types & b) {
|
||||
candidates.erase(std::unique(candidates.begin(), candidates.end(), [](const candidate_types & a, const candidate_types & b) {
|
||||
return a.bytes == b.bytes;
|
||||
});
|
||||
candidates.erase(last, candidates.end());
|
||||
|
||||
// Pareto by bytes -> error
|
||||
}), candidates.end());
|
||||
std::vector<candidate_types> pareto;
|
||||
pareto.reserve(candidates.size());
|
||||
double best_err = infinity;
|
||||
size_t last_b = std::numeric_limits<size_t>::max();
|
||||
for (const auto & c : candidates) {
|
||||
if (c.bytes != last_b) {
|
||||
last_b = c.bytes;
|
||||
if (c.error < best_err) {
|
||||
best_err = c.error;
|
||||
pareto.push_back(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
candidates.swap(pareto);
|
||||
if (candidates.size() < 3) { return; } // need at least 3 points to do convex hull
|
||||
|
||||
|
|
@ -1470,9 +1450,42 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
info.choice = 0;
|
||||
info.min_bpw = info.candidate.front().bpw;
|
||||
info.max_bpw = info.candidate.back().bpw;
|
||||
all.push_back(std::move(info));
|
||||
check_signal_handler(all); // save after each tensor
|
||||
|
||||
return info;
|
||||
};
|
||||
|
||||
std::vector<tensor_info> all; // this vector will be populated by the parallel workers
|
||||
{
|
||||
std::atomic<size_t> tensor_idx{0}; // shared work queue index for all threads
|
||||
const size_t num_tensors_to_process = tensors.size();
|
||||
std::mutex loader_mutex;
|
||||
std::mutex log_mutex;
|
||||
std::mutex results_mutex;
|
||||
std::vector<std::thread> workers;
|
||||
int num_threads_to_spawn = std::max(1, std::min<int>(nthread, (int)num_tensors_to_process));
|
||||
|
||||
for (int i = 0; i < num_threads_to_spawn; ++i) {
|
||||
workers.emplace_back([&]() {
|
||||
std::vector<no_init<uint8_t>> thread_local_buffer;
|
||||
while (true) {
|
||||
const size_t current_idx = tensor_idx.fetch_add(1);
|
||||
if (current_idx >= num_tensors_to_process) { break; }
|
||||
const auto * tw = tensors[current_idx];
|
||||
if (!can_quantize(tw->tensor)) { continue; }
|
||||
// Execute the main processing logic for this tensor
|
||||
std::optional<tensor_info> result_info = process_tensor(tw, thread_local_buffer, loader_mutex, log_mutex);
|
||||
if (result_info) {
|
||||
std::lock_guard<std::mutex> lock(results_mutex);
|
||||
all.push_back(std::move(*result_info));
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (auto & w : workers) { w.join(); }
|
||||
}
|
||||
|
||||
check_signal_handler(all);
|
||||
|
||||
if (all.empty()) { return {}; }
|
||||
|
||||
|
|
@ -1965,7 +1978,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
LLAMA_LOG_WARN("%s: imatrix without activations provided, target bpw quantization will be less accurate\n", __func__);
|
||||
}
|
||||
LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve %.4f bpw\n", __func__, params->target_bpw);
|
||||
bpw_overrides = target_bpw_type(ml, read_data, model, tensors, mapped, values_data, activations_data, params, nthread);
|
||||
bpw_overrides = target_bpw_type(ml, model, tensors, mapped, values_data, activations_data, params, nthread);
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: no imatrix provided, target bpw will not apply\n", __func__);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue