ggml-blas: refactor backend

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
This commit is contained in:
Aaron Teo 2025-12-20 15:11:45 +08:00
parent 2ee4d5fe2f
commit adbfbf9086
No known key found for this signature in database
6 changed files with 156 additions and 111 deletions

View File

@ -11,9 +11,10 @@ find_package(BLAS)
if (BLAS_FOUND)
message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}")
ggml_add_backend_library(ggml-blas
ggml-blas.cpp
)
file(GLOB GGML_SOURCES_BLAS "*.c" "*.cpp")
file(GLOB GGML_HEADERS_BLAS "*.h" "*.hpp")
ggml_add_backend_library(ggml-blas ${GGML_HEADERS_BLAS} ${GGML_SOURCES_BLAS})
if (${GGML_BLAS_VENDOR} MATCHES "Apple")
add_compile_definitions(ACCELERATE_NEW_LAPACK)

View File

@ -0,0 +1,67 @@
#pragma once
#include "ggml.h"
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
#include <vector>
#include <memory>
#include <future>
#if defined(GGML_BLAS_USE_ACCELERATE)
# include <Accelerate/Accelerate.h>
#elif defined(GGML_BLAS_USE_MKL)
# include <mkl.h>
#elif defined(GGML_BLAS_USE_BLIS)
# include <blis.h>
#elif defined(GGML_BLAS_USE_NVPL)
# include <nvpl_blas.h>
#else
# include <cblas.h>
#endif
#define GGML_BLAS_NAME "BLAS"
#define GGML_BLAS_VERSION GGML_BACKEND_API_VERSION
#ifdef __cplusplus
extern "C" {
#endif
struct ggml_backend_blas_buffer {
void * data; // dequantized data
size_t size; // ggml_nelements * sizeof(float)
};
struct ggml_backend_blas_buffer_context {
void * data;
size_t size;
std::vector<ggml_backend_blas_buffer *> buffers;
~ggml_backend_blas_buffer_context() {
ggml_aligned_free(data, size);
for (auto * extra : buffers) {
ggml_aligned_free(extra->data, extra->size);
delete extra;
}
}
};
struct ggml_backend_blas_buffer_type_context {
int n_threads;
#ifndef GGML_USE_OPENMP
std::vector<std::future<void>> tasks;
#endif
};
struct ggml_backend_blas_context {
int n_threads;
};
struct ggml_backend_blas_device_context {
char _dummy; // Prevent empty struct warning
};
#ifdef __cplusplus
}
#endif

View File

@ -1,54 +1,29 @@
#include "ggml-impl.h"
#include "ggml-blas.h"
#include "ggml-backend-impl.h"
#include "ggml.h"
#include "ggml-backend.h"
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
#include "ggml-blas.h"
#include "ggml-blas/common.hpp"
#include "ggml-blas/mmf.hpp"
#include <cstdint>
#include <cstring>
#include <future>
#include <thread>
#include <vector>
#include <cstring>
#include <cstdint>
#if defined(GGML_BLAS_USE_ACCELERATE)
# include <Accelerate/Accelerate.h>
#elif defined(GGML_BLAS_USE_MKL)
# include <mkl.h>
#elif defined(GGML_BLAS_USE_BLIS)
# include <blis.h>
#elif defined(GGML_BLAS_USE_NVPL)
# include <nvpl_blas.h>
#else
# include <cblas.h>
#endif
// BLAS backend - graph compute
struct ggml_backend_blas_buffer {
void * data; // dequantized data
size_t size; // ggml_nelements * sizeof(float)
};
static void ggml_blas_compute_forward_mul_mat(
const ggml_backend_blas_context * ctx,
ggml_tensor * dst) {
struct ggml_backend_blas_buffer_context {
void * data;
size_t size;
std::vector<ggml_backend_blas_buffer *> buffers;
const ggml_tensor * src0 = dst->src[0]; // weights
const ggml_tensor * src1 = dst->src[1]; // inputs
~ggml_backend_blas_buffer_context() {
ggml_aligned_free(data, size);
for (auto * extra : buffers) {
ggml_aligned_free(extra->data, extra->size);
delete extra;
}
}
};
struct ggml_backend_blas_buffer_type_context {
int n_threads;
#ifndef GGML_USE_OPENMP
std::vector<std::future<void>> tasks;
#endif
};
ggml_blas_mul_mat_f(ctx, src0, src1, dst);
}
// BLAS backend - buffer
@ -298,69 +273,6 @@ static ggml_backend_buffer_type_t ggml_backend_blas_buffer_type(void) {
return &ggml_backend_blas_buffer_type;
}
struct ggml_backend_blas_context {
int n_threads;
};
static void ggml_backend_blas_mul_mat(
ggml_backend_blas_context * ctx,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
const ggml_type type = src0->type;
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
// broadcast factors
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
const int64_t ne_plane = ne01*ne00;
const ggml_backend_blas_buffer * extra = (ggml_backend_blas_buffer *)src0->extra;
for (int64_t i13 = 0; i13 < ne13; i13++) {
for (int64_t i12 = 0; i12 < ne12; i12++) {
const int64_t i03 = i13/r3;
const int64_t i02 = i12/r2;
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
// switch to dequantized F32 data
if (type != GGML_TYPE_F32) {
x = (float *)extra->data + i02*ne_plane + i03*ne02*ne_plane;
}
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
ne1, ne01, ne10,
1.0f, y, ne10,
x, ne00,
0.0f, d, ne01);
}
}
GGML_UNUSED(ctx);
}
static const char * ggml_backend_blas_get_name(ggml_backend_t backend) {
return "BLAS";
@ -390,9 +302,8 @@ static ggml_status ggml_backend_blas_graph_compute(
switch (node->op) {
case GGML_OP_MUL_MAT:
{
ggml_backend_blas_mul_mat(ctx, node);
ggml_blas_compute_forward_mul_mat(ctx, node);
} break;
default:
GGML_ABORT("%s: unsupported op %s\n", __func__, ggml_op_desc(node));
}
@ -471,8 +382,6 @@ void ggml_backend_blas_set_n_threads(ggml_backend_t backend, int n_threads) {
#endif
}
struct ggml_backend_blas_device_context {};
static const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) {
return "BLAS";

View File

@ -0,0 +1,59 @@
#include "ggml.h"
#include "mmf.hpp"
void ggml_blas_mul_mat_f(
const ggml_backend_blas_context * ctx,
const ggml_tensor * src0,
const ggml_tensor * src1,
ggml_tensor * dst) {
GGML_TENSOR_BINARY_OP_LOCALS
const ggml_type type = src0->type;
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
// broadcast factors
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
const int64_t ne_plane = ne01*ne00;
const ggml_backend_blas_buffer * extra = (ggml_backend_blas_buffer *)src0->extra;
for (int64_t i13 = 0; i13 < ne13; i13++) {
for (int64_t i12 = 0; i12 < ne12; i12++) {
const int64_t i03 = i13/r3;
const int64_t i02 = i12/r2;
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
// switch to dequantized F32 data
if (type != GGML_TYPE_F32) {
x = (float *)extra->data + i02*ne_plane + i03*ne02*ne_plane;
}
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
ne1, ne01, ne10,
1.0f, y, ne10,
x, ne00,
0.0f, d, ne01);
}
}
GGML_UNUSED(ctx);
}

View File

@ -0,0 +1,9 @@
#pragma once
#include "common.hpp"
void ggml_blas_mul_mat_f(
const ggml_backend_blas_context * ctx,
const ggml_tensor * src0,
const ggml_tensor * src1,
ggml_tensor * dst);