diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 1685e91c3c..3565facf73 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -150,72 +150,82 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg } static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; // weights - const ggml_tensor * src1 = dst->src[1]; // inputs - const ggml_tensor * src2 = dst->src[2]; // ids + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * ids = dst->src[2]; - GGML_TENSOR_TERNARY_OP_LOCALS + GGML_TENSOR_BINARY_OP_LOCALS - const ggml_type type = src0->type; - - GGML_ASSERT(ne10 == ne00); - GGML_ASSERT(ne21 == ne12); - GGML_ASSERT(ne22 == 1 || ne22 == ne13); - GGML_ASSERT(src2->type == GGML_TYPE_I32); + const enum ggml_type type = src0->type; GGML_ASSERT(nb00 == ggml_type_size(type)); GGML_ASSERT(nb10 == ggml_type_size(src1->type)); - GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1 && nb1 <= nb2 && nb2 <= nb3); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); - const int64_t n_used = (int64_t)ne20; - GGML_ASSERT(n_used <= ne02); + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne13 == 1); + GGML_ASSERT(ne3 == 1); + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + + GGML_UNUSED(r2); + GGML_UNUSED(r3); + + const int64_t ne_plane = ne01*ne00; + const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float); - const int64_t ne_plane = ne01 * ne00; - const size_t desired_wsize = (type == GGML_TYPE_F32) ? 0 : ne03 * ne02 * ne_plane * sizeof(float); if (ctx->work_size < desired_wsize) { ctx->work_data.reset(new char[desired_wsize]); ctx->work_size = desired_wsize; } void * wdata = ctx->work_data.get(); + // convert src0 to float if (type != GGML_TYPE_F32) { const auto * type_traits = ggml_get_type_traits(type); - ggml_to_float_t to_float = type_traits->to_float; + ggml_to_float_t const to_float = type_traits->to_float; - for (int64_t i03 = 0; i03 < ne03; ++i03) { - for (int64_t i02 = 0; i02 < ne02; ++i02) { - const void * x = (char *)src0->data + i02*nb02 + i03*nb03; - float * wplane = (float *)wdata + i02*ne_plane + i03*ne02*ne_plane; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + const void * x = (char *) src0->data + i02*nb02 + i03*nb03; + float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane; const int min_cols_per_thread = 4096; - const int min_rows_per_thread = std::max((int)(min_cols_per_thread / ne00), 1); - const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01 / min_rows_per_thread)), 1); + const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1); + const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1); #ifdef GGML_USE_OPENMP #pragma omp parallel for num_threads(n_threads) - for (int64_t i01 = 0; i01 < ne01; ++i01) { - to_float((const char *)x + i01*nb01, wplane + i01*ne00, ne00); + for (int64_t i01 = 0; i01 < ne01; i01++) { + to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00); } #else for (int i = 1; i < n_threads; i++) { - const int64_t start = i * ne01/n_threads; - const int64_t end = (i + 1) * ne01/n_threads; + const int64_t start = i*ne01/n_threads; + const int64_t end = (i + 1)*ne01/n_threads; if (start < end) { ctx->tasks.push_back(std::async(std::launch::async, [=]() { - for (int64_t i01 = start; i01 < end; ++i01) { - to_float((const char *)x + i01*nb01, wplane + i01*ne00, ne00); + for (int64_t i01 = start; i01 < end; i01++) { + to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00); } })); } } { + // reuse the current thread for the first task const int64_t start = 0; - const int64_t end = ne01/n_threads; + const int64_t end = ne01/n_threads; for (int64_t i01 = start; i01 < end; i01++) { - to_float((const char *)x + i01*nb01, wplane + i01*ne00, ne00); + to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00); } } #endif @@ -223,65 +233,49 @@ static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_t } #ifndef GGML_USE_OPENMP - for (auto & task: ctx->tasks) { + // wait for all tasks to finish + for (auto & task : ctx->tasks) { task.get(); } ctx->tasks.clear(); #endif } -#ifdef OPENBLAS_VERSION +#if defined(OPENBLAS_VERSION) openblas_set_num_threads(ctx->n_threads); #endif -#ifdef GGML_BLAS_USE_BLIS +#if defined(GGML_BLAS_USE_BLIS) bli_thread_set_num_threads(ctx->n_threads); #endif -#ifdef GGML_BLAS_USE_NVPL +#if defined(GGML_BLAS_USE_NVPL) nvpl_blas_set_num_threads(ctx->n_threads); #endif - for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t j = 0; j < ne12; ++j) { - const int64_t ids_batch_index = (ne22 > 1 ? i13 : 0); - const int32_t * ids_row = (const int32_t *)((char *)src2->data + ids_batch_index*nb22 + j*nb21); - float * out_ptr = (float *)((char *)dst->data + i13*nb3 + j*nb2); + const int n_ids = ids->ne[0]; + const int n_tokens = ids->ne[1]; - for (int iE = 0; iE < n_used; ++iE) { - const int expert_id = ids_row[iE]; - GGML_ASSERT(expert_id < ne02); + for (int t = 0; t < n_tokens; ++t) { + for (int e = 0; e < n_ids; ++e) { + const int32_t expert = *(const int32_t *) ((const char *) ids->data + e*ids->nb[0] + t*ids->nb[1]); + GGML_ASSERT(expert >= 0 && expert < ne02); - const float * wmat; - if (type == GGML_TYPE_F32) { - wmat = (const float *)((char *)src0->data + expert_id*nb02); - } else { - wmat = (const float *)((char *)wdata + expert_id * ne_plane * sizeof(float)); - } + const int e_src1 = e % ne11; - if (ne03 > 1) { - int64_t w_batch_index = (ne03 == ne13 ? i13 : 0); - wmat = (const float *)((char *)wdata + (w_batch_index * ne02 + expert_id) * ne_plane * sizeof(float)); - } + const float * a = (float *) ((char *) src0->data + expert*nb02); + const float * b = (float *) ((char *) src1->data + e_src1*nb11 + t*nb12); + float * d = (float *) ((char *) dst->data + e*nb1 + t*nb2); - const float * inp = (const float *)((char *)src1->data - + ((ne11 == 1 ? 0 : iE) * nb11) - + j * nb12 + i13 * nb13); - - if (iE == 0) { - cblas_sgemv(CblasRowMajor, CblasNoTrans, (int)ne01, (int)ne00, - 1.0f, wmat, (int)ne00, - inp, 1, - 0.0f, - out_ptr, 1); - } else { - cblas_sgemv(CblasRowMajor, CblasNoTrans, (int)ne01, (int)ne00, - 1.0f, wmat, (int)ne00, - inp, 1, - 1.0f, - out_ptr, 1); - } + if (type != GGML_TYPE_F32) { + a = (float *) wdata + expert*ne_plane; } + + cblas_sgemv(CblasRowMajor, CblasNoTrans, + ne01, ne00, + 1.0f, a, ne00, + b, 1, + 0.0f, d, 1); } } }