diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 62887d0190..ca896c2541 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -91,11 +91,11 @@ static void ggml_backend_blas_buffer_memset_tensor( } static void ggml_backend_blas_buffer_set_tensor( - ggml_backend_buffer_t buffer, - ggml_tensor * tensor, - const void * data, - size_t offset, - size_t size) { + ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { GGML_ASSERT(tensor); memcpy((char *)tensor->data + offset, data, size); @@ -143,10 +143,10 @@ static void ggml_backend_blas_buffer_set_tensor( static void ggml_backend_blas_buffer_get_tensor( ggml_backend_buffer_t buffer, - const ggml_tensor * tensor, - void * data, - size_t offset, - size_t size) { + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { GGML_ASSERT(tensor); memcpy(data, (const char *)tensor->data + offset, size); @@ -292,18 +292,6 @@ static void ggml_backend_blas_mul_mat( const ggml_backend_blas_buffer * extra = (ggml_backend_blas_buffer *)src0->extra; -#if defined(OPENBLAS_VERSION) - openblas_set_num_threads(ctx->n_threads); -#endif - -#if defined(GGML_BLAS_USE_BLIS) - bli_thread_set_num_threads(ctx->n_threads); -#endif - -#if defined(GGML_BLAS_USE_NVPL) - nvpl_blas_set_num_threads(ctx->n_threads); -#endif - for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i12 = 0; i12 < ne12; i12++) { const int64_t i03 = i13/r3; @@ -430,6 +418,18 @@ void ggml_backend_blas_set_n_threads(ggml_backend_t backend, int n_threads) { ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context; ctx->n_threads = n_threads; + +#if defined(OPENBLAS_VERSION) + openblas_set_num_threads(ctx->n_threads); +#endif + +#if defined(GGML_BLAS_USE_BLIS) + bli_thread_set_num_threads(ctx->n_threads); +#endif + +#if defined(GGML_BLAS_USE_NVPL) + nvpl_blas_set_num_threads(ctx->n_threads); +#endif } // TODO: maybe implement description?