ggml-blas: move global blas n threads to set_n_threads
Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
This commit is contained in:
parent
6dff031caa
commit
e481be6da6
|
|
@ -91,11 +91,11 @@ static void ggml_backend_blas_buffer_memset_tensor(
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_blas_buffer_set_tensor(
|
static void ggml_backend_blas_buffer_set_tensor(
|
||||||
ggml_backend_buffer_t buffer,
|
ggml_backend_buffer_t buffer,
|
||||||
ggml_tensor * tensor,
|
ggml_tensor * tensor,
|
||||||
const void * data,
|
const void * data,
|
||||||
size_t offset,
|
size_t offset,
|
||||||
size_t size) {
|
size_t size) {
|
||||||
|
|
||||||
GGML_ASSERT(tensor);
|
GGML_ASSERT(tensor);
|
||||||
memcpy((char *)tensor->data + offset, data, size);
|
memcpy((char *)tensor->data + offset, data, size);
|
||||||
|
|
@ -143,10 +143,10 @@ static void ggml_backend_blas_buffer_set_tensor(
|
||||||
|
|
||||||
static void ggml_backend_blas_buffer_get_tensor(
|
static void ggml_backend_blas_buffer_get_tensor(
|
||||||
ggml_backend_buffer_t buffer,
|
ggml_backend_buffer_t buffer,
|
||||||
const ggml_tensor * tensor,
|
const ggml_tensor * tensor,
|
||||||
void * data,
|
void * data,
|
||||||
size_t offset,
|
size_t offset,
|
||||||
size_t size) {
|
size_t size) {
|
||||||
|
|
||||||
GGML_ASSERT(tensor);
|
GGML_ASSERT(tensor);
|
||||||
memcpy(data, (const char *)tensor->data + offset, size);
|
memcpy(data, (const char *)tensor->data + offset, size);
|
||||||
|
|
@ -292,18 +292,6 @@ static void ggml_backend_blas_mul_mat(
|
||||||
|
|
||||||
const ggml_backend_blas_buffer * extra = (ggml_backend_blas_buffer *)src0->extra;
|
const ggml_backend_blas_buffer * extra = (ggml_backend_blas_buffer *)src0->extra;
|
||||||
|
|
||||||
#if defined(OPENBLAS_VERSION)
|
|
||||||
openblas_set_num_threads(ctx->n_threads);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(GGML_BLAS_USE_BLIS)
|
|
||||||
bli_thread_set_num_threads(ctx->n_threads);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(GGML_BLAS_USE_NVPL)
|
|
||||||
nvpl_blas_set_num_threads(ctx->n_threads);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
||||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||||
const int64_t i03 = i13/r3;
|
const int64_t i03 = i13/r3;
|
||||||
|
|
@ -430,6 +418,18 @@ void ggml_backend_blas_set_n_threads(ggml_backend_t backend, int n_threads) {
|
||||||
|
|
||||||
ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
|
ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
|
||||||
ctx->n_threads = n_threads;
|
ctx->n_threads = n_threads;
|
||||||
|
|
||||||
|
#if defined(OPENBLAS_VERSION)
|
||||||
|
openblas_set_num_threads(ctx->n_threads);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(GGML_BLAS_USE_BLIS)
|
||||||
|
bli_thread_set_num_threads(ctx->n_threads);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(GGML_BLAS_USE_NVPL)
|
||||||
|
nvpl_blas_set_num_threads(ctx->n_threads);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: maybe implement description?
|
// TODO: maybe implement description?
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue