diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index ebd7ce8bb2..62887d0190 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -6,6 +6,7 @@ #include "ggml-backend.h" #include +#include #include #include #include @@ -27,6 +28,10 @@ struct ggml_backend_blas_buffer { size_t size; }; +struct ggml_backend_blas_buffer_type_context { + int n_threads; +}; + // BLAS backend - buffer static void ggml_backend_blas_buffer_free_buffer(ggml_backend_buffer_t buffer) { @@ -95,7 +100,7 @@ static void ggml_backend_blas_buffer_set_tensor( GGML_ASSERT(tensor); memcpy((char *)tensor->data + offset, data, size); - // ggml_backend_blas_buffer_context * buf_ctx = (ggml_backend_blas_buffer_context *)buffer->buft->context; + ggml_backend_blas_buffer_type_context * buft_ctx = (ggml_backend_blas_buffer_type_context *)buffer->buft->context; ggml_backend_blas_buffer * extra = (ggml_backend_blas_buffer *)tensor->extra; const int64_t ne00 = tensor->ne[0]; @@ -125,7 +130,7 @@ static void ggml_backend_blas_buffer_set_tensor( const int min_cols_per_thread = 4096; const int min_rows_per_thread = std::max((int)(min_cols_per_thread / ne00), 1); - const int n_threads = std::max(std::min(8, (int)(ne01 / min_rows_per_thread)), 1); + const int n_threads = std::max(std::min(buft_ctx->n_threads, (int)(ne01 / min_rows_per_thread)), 1); #pragma omp parallel for num_threads(n_threads) for (int64_t i01 = 0; i01 < ne01; i01++) { @@ -134,8 +139,6 @@ static void ggml_backend_blas_buffer_set_tensor( } } } - - GGML_UNUSED(buffer); } static void ggml_backend_blas_buffer_get_tensor( @@ -205,6 +208,10 @@ static bool ggml_backend_blas_buffer_type_is_host(ggml_backend_buffer_type_t buf } static ggml_backend_buffer_type_t ggml_backend_blas_buffer_type(void) { + static ggml_backend_blas_buffer_type_context buft_ctx = { + /* .n_threads = */ (int)std::thread::hardware_concurrency(), + }; + static ggml_backend_buffer_type ggml_backend_blas_buffer_type = { /* .iface = */ { /* .get_name = */ ggml_backend_blas_buffer_type_get_name, @@ -215,7 +222,7 @@ static ggml_backend_buffer_type_t ggml_backend_blas_buffer_type(void) { /* .is_host = */ ggml_backend_blas_buffer_type_is_host, }, /* .device = */ NULL, - /* .context = */ NULL, + /* .context = */ &buft_ctx, }; return &ggml_backend_blas_buffer_type; @@ -419,7 +426,6 @@ bool ggml_backend_is_blas(ggml_backend_t backend) { } void ggml_backend_blas_set_n_threads(ggml_backend_t backend, int n_threads) { - // TODO: IMPL GGML_ASSERT(ggml_backend_is_blas(backend)); ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;