ggml-blas: force dequant routine to use max logical cores

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
This commit is contained in:
Aaron Teo 2025-12-14 21:57:09 +08:00
parent 447057973c
commit 6dff031caa
No known key found for this signature in database
1 changed files with 12 additions and 6 deletions

View File

@ -6,6 +6,7 @@
#include "ggml-backend.h"
#include <future>
#include <thread>
#include <vector>
#include <cstring>
#include <cstdint>
@ -27,6 +28,10 @@ struct ggml_backend_blas_buffer {
size_t size;
};
struct ggml_backend_blas_buffer_type_context {
int n_threads;
};
// BLAS backend - buffer
static void ggml_backend_blas_buffer_free_buffer(ggml_backend_buffer_t buffer) {
@ -95,7 +100,7 @@ static void ggml_backend_blas_buffer_set_tensor(
GGML_ASSERT(tensor);
memcpy((char *)tensor->data + offset, data, size);
// ggml_backend_blas_buffer_context * buf_ctx = (ggml_backend_blas_buffer_context *)buffer->buft->context;
ggml_backend_blas_buffer_type_context * buft_ctx = (ggml_backend_blas_buffer_type_context *)buffer->buft->context;
ggml_backend_blas_buffer * extra = (ggml_backend_blas_buffer *)tensor->extra;
const int64_t ne00 = tensor->ne[0];
@ -125,7 +130,7 @@ static void ggml_backend_blas_buffer_set_tensor(
const int min_cols_per_thread = 4096;
const int min_rows_per_thread = std::max((int)(min_cols_per_thread / ne00), 1);
const int n_threads = std::max(std::min(8, (int)(ne01 / min_rows_per_thread)), 1);
const int n_threads = std::max(std::min(buft_ctx->n_threads, (int)(ne01 / min_rows_per_thread)), 1);
#pragma omp parallel for num_threads(n_threads)
for (int64_t i01 = 0; i01 < ne01; i01++) {
@ -134,8 +139,6 @@ static void ggml_backend_blas_buffer_set_tensor(
}
}
}
GGML_UNUSED(buffer);
}
static void ggml_backend_blas_buffer_get_tensor(
@ -205,6 +208,10 @@ static bool ggml_backend_blas_buffer_type_is_host(ggml_backend_buffer_type_t buf
}
static ggml_backend_buffer_type_t ggml_backend_blas_buffer_type(void) {
static ggml_backend_blas_buffer_type_context buft_ctx = {
/* .n_threads = */ (int)std::thread::hardware_concurrency(),
};
static ggml_backend_buffer_type ggml_backend_blas_buffer_type = {
/* .iface = */ {
/* .get_name = */ ggml_backend_blas_buffer_type_get_name,
@ -215,7 +222,7 @@ static ggml_backend_buffer_type_t ggml_backend_blas_buffer_type(void) {
/* .is_host = */ ggml_backend_blas_buffer_type_is_host,
},
/* .device = */ NULL,
/* .context = */ NULL,
/* .context = */ &buft_ctx,
};
return &ggml_backend_blas_buffer_type;
@ -419,7 +426,6 @@ bool ggml_backend_is_blas(ggml_backend_t backend) {
}
void ggml_backend_blas_set_n_threads(ggml_backend_t backend, int n_threads) {
// TODO: IMPL
GGML_ASSERT(ggml_backend_is_blas(backend));
ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;