diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 69950f2bbd..ceac29cc09 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -326,19 +326,6 @@ static ggml_status ggml_backend_blas_graph_compute( switch (node->op) { case GGML_OP_MUL_MAT: { - const ggml_tensor * src1 = node->src[1]; - - const int64_t ne10 = src1->ne[0]; - const int64_t ne0 = node->ne[0]; - const int64_t ne1 = node->ne[1]; - - // TODO: find the optimal value - const int64_t min_batch = 32; - - if (ne0 <= min_batch && ne1 <= min_batch && ne10 <= min_batch) { - return GGML_STATUS_FAILED; - } - ggml_blas_compute_forward_mul_mat(ctx, node); } break; case GGML_OP_OUT_PROD: @@ -508,12 +495,20 @@ static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const g switch (dst->op) { case GGML_OP_MUL_MAT: { + const int64_t ne10 = src1->ne[0]; + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + // TODO: find the optimal value + const int64_t min_batch = 32; + return ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->type == GGML_TYPE_F32 // NOTE: llama-bench creates views that somehow does not go through init_tensor // this prevents the uninitialized views from being used in BLAS && src0->view_src == nullptr && src1->view_src == nullptr + && (ne0 >= min_batch || ne1 >= min_batch || ne10 >= min_batch) && (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL); }