diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 3565facf73..fecf5fc702 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -27,13 +27,16 @@ struct ggml_backend_blas_context { #endif }; -static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; +static void ggml_backend_blas_mul_mat( + ggml_backend_blas_context * ctx, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; GGML_TENSOR_BINARY_OP_LOCALS - const enum ggml_type type = src0->type; + const ggml_type type = src0->type; GGML_ASSERT(ne0 == ne01); GGML_ASSERT(ne1 == ne11); @@ -70,8 +73,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { - const void * x = (char *) src0->data + i02*nb02 + i03*nb03; - float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane; + const void * x = (char *) src0->data + i02*nb02 + i03*nb03; + float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane; const int min_cols_per_thread = 4096; const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1); @@ -84,8 +87,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg } #else for (int i = 1; i < n_threads; i++) { - const int64_t start = i*ne01/n_threads; - const int64_t end = (i + 1)*ne01/n_threads; + const int64_t start = (i + 0) * ne01/n_threads; + const int64_t end = (i + 1) * ne01/n_threads; if (start < end) { ctx->tasks.push_back(std::async(std::launch::async, [=]() { for (int64_t i01 = start; i01 < end; i01++) { @@ -149,14 +152,17 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg } } -static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - const struct ggml_tensor * ids = dst->src[2]; +static void ggml_backend_blas_mul_mat_id( + ggml_backend_blas_context * ctx, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * ids = dst->src[2]; GGML_TENSOR_BINARY_OP_LOCALS - const enum ggml_type type = src0->type; + const ggml_type type = src0->type; GGML_ASSERT(nb00 == ggml_type_size(type)); GGML_ASSERT(nb10 == ggml_type_size(src1->type)); @@ -173,15 +179,10 @@ static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_t GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(ids->type == GGML_TYPE_I32); - // broadcast factors - const int64_t r2 = ne12/ne02; - const int64_t r3 = ne13/ne03; - - GGML_UNUSED(r2); - GGML_UNUSED(r3); - const int64_t ne_plane = ne01*ne00; - const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float); + const size_t desired_wsize = type == GGML_TYPE_F32 + ? 0 + : ne03*ne02*ne_plane*sizeof(float); if (ctx->work_size < desired_wsize) { ctx->work_data.reset(new char[desired_wsize]); @@ -196,8 +197,8 @@ static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_t for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { - const void * x = (char *) src0->data + i02*nb02 + i03*nb03; - float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane; + const void * x = (char *) src0->data + i02*nb02 + i03*nb03; + float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane; const int min_cols_per_thread = 4096; const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1); @@ -210,7 +211,7 @@ static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_t } #else for (int i = 1; i < n_threads; i++) { - const int64_t start = i*ne01/n_threads; + const int64_t start = (i + 0)*ne01/n_threads; const int64_t end = (i + 1)*ne01/n_threads; if (start < end) { ctx->tasks.push_back(std::async(std::launch::async, [=]() { @@ -555,15 +556,13 @@ static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const s case GGML_OP_MUL_MAT_ID: { - const struct ggml_tensor * src0 = op->src[0]; - const struct ggml_tensor * src1 = op->src[1]; - const struct ggml_tensor * src2 = op->src[2]; + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; - // GGML_LOG_INFO("%s: op=GGML_OP_MUL_MAT_ID src0_type=%s src1_type=%s src2_type=%s ne0=%lld ne1=%lld ne2=%lld ne3=%lld\n", - // __func__, ggml_type_name(src0->type), ggml_type_name(src1->type), ggml_type_name(src2->type), - // op->ne[0], op->ne[1], op->ne[2], op->ne[3]); - - return src2->type == GGML_TYPE_I32; + return ggml_is_contiguous(src0) && + ggml_is_contiguous(src1) && + src1->type == GGML_TYPE_F32 && + (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL); } case GGML_OP_OUT_PROD: