#include "ggml.h" #include "mmf.hpp" void ggml_blas_mul_mat_f( const ggml_backend_blas_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_TENSOR_BINARY_OP_LOCALS const ggml_type type = src0->type; GGML_ASSERT(ne0 == ne01); GGML_ASSERT(ne1 == ne11); GGML_ASSERT(ne2 == ne12); GGML_ASSERT(ne3 == ne13); // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == ggml_type_size(type)); GGML_ASSERT(nb10 == ggml_type_size(src1->type)); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); GGML_ASSERT(nb0 <= nb1); GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); // broadcast factors const int64_t r2 = ne12/ne02; const int64_t r3 = ne13/ne03; const int64_t ne_plane = ne01*ne00; const ggml_backend_blas_buffer * extra = (ggml_backend_blas_buffer *)src0->extra; for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i12 = 0; i12 < ne12; i12++) { const int64_t i03 = i13/r3; const int64_t i02 = i12/r2; const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03); const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13); float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); // switch to dequantized F32 data if (type != GGML_TYPE_F32) { x = (float *)extra->data + i02*ne_plane + i03*ne02*ne_plane; } cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne1, ne01, ne10, 1.0f, y, ne10, x, ne00, 0.0f, d, ne01); } } GGML_UNUSED(ctx); }