ggml-blas: force dequant routine to use max logical cores
Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
This commit is contained in:
parent
447057973c
commit
6dff031caa
|
|
@ -6,6 +6,7 @@
|
|||
#include "ggml-backend.h"
|
||||
|
||||
#include <future>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <cstdint>
|
||||
|
|
@ -27,6 +28,10 @@ struct ggml_backend_blas_buffer {
|
|||
size_t size;
|
||||
};
|
||||
|
||||
struct ggml_backend_blas_buffer_type_context {
|
||||
int n_threads;
|
||||
};
|
||||
|
||||
// BLAS backend - buffer
|
||||
|
||||
static void ggml_backend_blas_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
|
|
@ -95,7 +100,7 @@ static void ggml_backend_blas_buffer_set_tensor(
|
|||
GGML_ASSERT(tensor);
|
||||
memcpy((char *)tensor->data + offset, data, size);
|
||||
|
||||
// ggml_backend_blas_buffer_context * buf_ctx = (ggml_backend_blas_buffer_context *)buffer->buft->context;
|
||||
ggml_backend_blas_buffer_type_context * buft_ctx = (ggml_backend_blas_buffer_type_context *)buffer->buft->context;
|
||||
ggml_backend_blas_buffer * extra = (ggml_backend_blas_buffer *)tensor->extra;
|
||||
|
||||
const int64_t ne00 = tensor->ne[0];
|
||||
|
|
@ -125,7 +130,7 @@ static void ggml_backend_blas_buffer_set_tensor(
|
|||
|
||||
const int min_cols_per_thread = 4096;
|
||||
const int min_rows_per_thread = std::max((int)(min_cols_per_thread / ne00), 1);
|
||||
const int n_threads = std::max(std::min(8, (int)(ne01 / min_rows_per_thread)), 1);
|
||||
const int n_threads = std::max(std::min(buft_ctx->n_threads, (int)(ne01 / min_rows_per_thread)), 1);
|
||||
|
||||
#pragma omp parallel for num_threads(n_threads)
|
||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||
|
|
@ -134,8 +139,6 @@ static void ggml_backend_blas_buffer_set_tensor(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static void ggml_backend_blas_buffer_get_tensor(
|
||||
|
|
@ -205,6 +208,10 @@ static bool ggml_backend_blas_buffer_type_is_host(ggml_backend_buffer_type_t buf
|
|||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_blas_buffer_type(void) {
|
||||
static ggml_backend_blas_buffer_type_context buft_ctx = {
|
||||
/* .n_threads = */ (int)std::thread::hardware_concurrency(),
|
||||
};
|
||||
|
||||
static ggml_backend_buffer_type ggml_backend_blas_buffer_type = {
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_blas_buffer_type_get_name,
|
||||
|
|
@ -215,7 +222,7 @@ static ggml_backend_buffer_type_t ggml_backend_blas_buffer_type(void) {
|
|||
/* .is_host = */ ggml_backend_blas_buffer_type_is_host,
|
||||
},
|
||||
/* .device = */ NULL,
|
||||
/* .context = */ NULL,
|
||||
/* .context = */ &buft_ctx,
|
||||
};
|
||||
|
||||
return &ggml_backend_blas_buffer_type;
|
||||
|
|
@ -419,7 +426,6 @@ bool ggml_backend_is_blas(ggml_backend_t backend) {
|
|||
}
|
||||
|
||||
void ggml_backend_blas_set_n_threads(ggml_backend_t backend, int n_threads) {
|
||||
// TODO: IMPL
|
||||
GGML_ASSERT(ggml_backend_is_blas(backend));
|
||||
|
||||
ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
|
||||
|
|
|
|||
Loading…
Reference in New Issue