60 lines
1.8 KiB
C++
60 lines
1.8 KiB
C++
#include "ggml.h"
|
|
#include "mmf.hpp"
|
|
|
|
void ggml_blas_mul_mat_f(
|
|
const ggml_backend_blas_context * ctx,
|
|
const ggml_tensor * src0,
|
|
const ggml_tensor * src1,
|
|
ggml_tensor * dst) {
|
|
|
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
|
|
const ggml_type type = src0->type;
|
|
|
|
GGML_ASSERT(ne0 == ne01);
|
|
GGML_ASSERT(ne1 == ne11);
|
|
GGML_ASSERT(ne2 == ne12);
|
|
GGML_ASSERT(ne3 == ne13);
|
|
|
|
// we don't support permuted src0 or src1
|
|
GGML_ASSERT(nb00 == ggml_type_size(type));
|
|
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
|
|
|
// dst cannot be transposed or permuted
|
|
GGML_ASSERT(nb0 == sizeof(float));
|
|
GGML_ASSERT(nb0 <= nb1);
|
|
GGML_ASSERT(nb1 <= nb2);
|
|
GGML_ASSERT(nb2 <= nb3);
|
|
|
|
// broadcast factors
|
|
const int64_t r2 = ne12/ne02;
|
|
const int64_t r3 = ne13/ne03;
|
|
const int64_t ne_plane = ne01*ne00;
|
|
|
|
const ggml_backend_blas_buffer * extra = (ggml_backend_blas_buffer *)src0->extra;
|
|
|
|
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
|
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
|
const int64_t i03 = i13/r3;
|
|
const int64_t i02 = i12/r2;
|
|
|
|
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
|
const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
|
|
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
|
|
|
|
// switch to dequantized F32 data
|
|
if (type != GGML_TYPE_F32) {
|
|
x = (float *)extra->data + i02*ne_plane + i03*ne02*ne_plane;
|
|
}
|
|
|
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
|
ne1, ne01, ne10,
|
|
1.0f, y, ne10,
|
|
x, ne00,
|
|
0.0f, d, ne01);
|
|
}
|
|
}
|
|
|
|
GGML_UNUSED(ctx);
|
|
}
|