From f6823746134ae4f89aecff968150a2986693cd0a Mon Sep 17 00:00:00 2001 From: Aaron Teo Date: Thu, 11 Dec 2025 20:51:02 +0800 Subject: [PATCH] ggml-blas: initial mmid impl Signed-off-by: Aaron Teo --- ggml/src/ggml-blas/ggml-blas.cpp | 154 +++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 5b888cdd8c..1685e91c3c 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -149,6 +149,143 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg } } +static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // weights + const ggml_tensor * src1 = dst->src[1]; // inputs + const ggml_tensor * src2 = dst->src[2]; // ids + + GGML_TENSOR_TERNARY_OP_LOCALS + + const ggml_type type = src0->type; + + GGML_ASSERT(ne10 == ne00); + GGML_ASSERT(ne21 == ne12); + GGML_ASSERT(ne22 == 1 || ne22 == ne13); + GGML_ASSERT(src2->type == GGML_TYPE_I32); + + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1 && nb1 <= nb2 && nb2 <= nb3); + + const int64_t n_used = (int64_t)ne20; + GGML_ASSERT(n_used <= ne02); + + const int64_t ne_plane = ne01 * ne00; + const size_t desired_wsize = (type == GGML_TYPE_F32) ? 0 : ne03 * ne02 * ne_plane * sizeof(float); + if (ctx->work_size < desired_wsize) { + ctx->work_data.reset(new char[desired_wsize]); + ctx->work_size = desired_wsize; + } + void * wdata = ctx->work_data.get(); + + if (type != GGML_TYPE_F32) { + const auto * type_traits = ggml_get_type_traits(type); + ggml_to_float_t to_float = type_traits->to_float; + + for (int64_t i03 = 0; i03 < ne03; ++i03) { + for (int64_t i02 = 0; i02 < ne02; ++i02) { + const void * x = (char *)src0->data + i02*nb02 + i03*nb03; + float * wplane = (float *)wdata + i02*ne_plane + i03*ne02*ne_plane; + + const int min_cols_per_thread = 4096; + const int min_rows_per_thread = std::max((int)(min_cols_per_thread / ne00), 1); + const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01 / min_rows_per_thread)), 1); + +#ifdef GGML_USE_OPENMP + #pragma omp parallel for num_threads(n_threads) + for (int64_t i01 = 0; i01 < ne01; ++i01) { + to_float((const char *)x + i01*nb01, wplane + i01*ne00, ne00); + } +#else + for (int i = 1; i < n_threads; i++) { + const int64_t start = i * ne01/n_threads; + const int64_t end = (i + 1) * ne01/n_threads; + if (start < end) { + ctx->tasks.push_back(std::async(std::launch::async, [=]() { + for (int64_t i01 = start; i01 < end; ++i01) { + to_float((const char *)x + i01*nb01, wplane + i01*ne00, ne00); + } + })); + } + } + { + const int64_t start = 0; + const int64_t end = ne01/n_threads; + for (int64_t i01 = start; i01 < end; i01++) { + to_float((const char *)x + i01*nb01, wplane + i01*ne00, ne00); + } + } +#endif + } + } + +#ifndef GGML_USE_OPENMP + for (auto & task: ctx->tasks) { + task.get(); + } + ctx->tasks.clear(); +#endif + } + +#ifdef OPENBLAS_VERSION + openblas_set_num_threads(ctx->n_threads); +#endif + +#ifdef GGML_BLAS_USE_BLIS + bli_thread_set_num_threads(ctx->n_threads); +#endif + +#ifdef GGML_BLAS_USE_NVPL + nvpl_blas_set_num_threads(ctx->n_threads); +#endif + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t j = 0; j < ne12; ++j) { + const int64_t ids_batch_index = (ne22 > 1 ? i13 : 0); + const int32_t * ids_row = (const int32_t *)((char *)src2->data + ids_batch_index*nb22 + j*nb21); + float * out_ptr = (float *)((char *)dst->data + i13*nb3 + j*nb2); + + for (int iE = 0; iE < n_used; ++iE) { + const int expert_id = ids_row[iE]; + GGML_ASSERT(expert_id < ne02); + + const float * wmat; + if (type == GGML_TYPE_F32) { + wmat = (const float *)((char *)src0->data + expert_id*nb02); + } else { + wmat = (const float *)((char *)wdata + expert_id * ne_plane * sizeof(float)); + } + + if (ne03 > 1) { + int64_t w_batch_index = (ne03 == ne13 ? i13 : 0); + wmat = (const float *)((char *)wdata + (w_batch_index * ne02 + expert_id) * ne_plane * sizeof(float)); + } + + const float * inp = (const float *)((char *)src1->data + + ((ne11 == 1 ? 0 : iE) * nb11) + + j * nb12 + i13 * nb13); + + if (iE == 0) { + cblas_sgemv(CblasRowMajor, CblasNoTrans, (int)ne01, (int)ne00, + 1.0f, wmat, (int)ne00, + inp, 1, + 0.0f, + out_ptr, 1); + } else { + cblas_sgemv(CblasRowMajor, CblasNoTrans, (int)ne01, (int)ne00, + 1.0f, wmat, (int)ne00, + inp, 1, + 1.0f, + out_ptr, 1); + } + } + } + } +} + static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; @@ -235,6 +372,10 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, ggml_backend_blas_mul_mat(ctx, node); break; + case GGML_OP_MUL_MAT_ID: + ggml_backend_blas_mul_mat_id(ctx, node); + break; + case GGML_OP_OUT_PROD: ggml_backend_blas_out_prod(ctx, node); break; @@ -418,6 +559,19 @@ static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const s (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL); } + case GGML_OP_MUL_MAT_ID: + { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * src2 = op->src[2]; + + // GGML_LOG_INFO("%s: op=GGML_OP_MUL_MAT_ID src0_type=%s src1_type=%s src2_type=%s ne0=%lld ne1=%lld ne2=%lld ne3=%lld\n", + // __func__, ggml_type_name(src0->type), ggml_type_name(src1->type), ggml_type_name(src2->type), + // op->ne[0], op->ne[1], op->ne[2], op->ne[3]); + + return src2->type == GGML_TYPE_I32; + } + case GGML_OP_OUT_PROD: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 &&