ggml-blas: refactor backend
Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
This commit is contained in:
parent
2ee4d5fe2f
commit
adbfbf9086
|
|
@ -11,9 +11,10 @@ find_package(BLAS)
|
|||
if (BLAS_FOUND)
|
||||
message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}")
|
||||
|
||||
ggml_add_backend_library(ggml-blas
|
||||
ggml-blas.cpp
|
||||
)
|
||||
file(GLOB GGML_SOURCES_BLAS "*.c" "*.cpp")
|
||||
file(GLOB GGML_HEADERS_BLAS "*.h" "*.hpp")
|
||||
|
||||
ggml_add_backend_library(ggml-blas ${GGML_HEADERS_BLAS} ${GGML_SOURCES_BLAS})
|
||||
|
||||
if (${GGML_BLAS_VENDOR} MATCHES "Apple")
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,67 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <future>
|
||||
|
||||
#if defined(GGML_BLAS_USE_ACCELERATE)
|
||||
# include <Accelerate/Accelerate.h>
|
||||
#elif defined(GGML_BLAS_USE_MKL)
|
||||
# include <mkl.h>
|
||||
#elif defined(GGML_BLAS_USE_BLIS)
|
||||
# include <blis.h>
|
||||
#elif defined(GGML_BLAS_USE_NVPL)
|
||||
# include <nvpl_blas.h>
|
||||
#else
|
||||
# include <cblas.h>
|
||||
#endif
|
||||
|
||||
#define GGML_BLAS_NAME "BLAS"
|
||||
#define GGML_BLAS_VERSION GGML_BACKEND_API_VERSION
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct ggml_backend_blas_buffer {
|
||||
void * data; // dequantized data
|
||||
size_t size; // ggml_nelements * sizeof(float)
|
||||
};
|
||||
|
||||
struct ggml_backend_blas_buffer_context {
|
||||
void * data;
|
||||
size_t size;
|
||||
std::vector<ggml_backend_blas_buffer *> buffers;
|
||||
|
||||
~ggml_backend_blas_buffer_context() {
|
||||
ggml_aligned_free(data, size);
|
||||
for (auto * extra : buffers) {
|
||||
ggml_aligned_free(extra->data, extra->size);
|
||||
delete extra;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_backend_blas_buffer_type_context {
|
||||
int n_threads;
|
||||
|
||||
#ifndef GGML_USE_OPENMP
|
||||
std::vector<std::future<void>> tasks;
|
||||
#endif
|
||||
};
|
||||
|
||||
struct ggml_backend_blas_context {
|
||||
int n_threads;
|
||||
};
|
||||
|
||||
struct ggml_backend_blas_device_context {
|
||||
char _dummy; // Prevent empty struct warning
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
@ -1,54 +1,29 @@
|
|||
#include "ggml-impl.h"
|
||||
#include "ggml-blas.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-blas.h"
|
||||
|
||||
#include "ggml-blas/common.hpp"
|
||||
#include "ggml-blas/mmf.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <future>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <cstdint>
|
||||
|
||||
#if defined(GGML_BLAS_USE_ACCELERATE)
|
||||
# include <Accelerate/Accelerate.h>
|
||||
#elif defined(GGML_BLAS_USE_MKL)
|
||||
# include <mkl.h>
|
||||
#elif defined(GGML_BLAS_USE_BLIS)
|
||||
# include <blis.h>
|
||||
#elif defined(GGML_BLAS_USE_NVPL)
|
||||
# include <nvpl_blas.h>
|
||||
#else
|
||||
# include <cblas.h>
|
||||
#endif
|
||||
// BLAS backend - graph compute
|
||||
|
||||
struct ggml_backend_blas_buffer {
|
||||
void * data; // dequantized data
|
||||
size_t size; // ggml_nelements * sizeof(float)
|
||||
};
|
||||
static void ggml_blas_compute_forward_mul_mat(
|
||||
const ggml_backend_blas_context * ctx,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
struct ggml_backend_blas_buffer_context {
|
||||
void * data;
|
||||
size_t size;
|
||||
std::vector<ggml_backend_blas_buffer *> buffers;
|
||||
const ggml_tensor * src0 = dst->src[0]; // weights
|
||||
const ggml_tensor * src1 = dst->src[1]; // inputs
|
||||
|
||||
~ggml_backend_blas_buffer_context() {
|
||||
ggml_aligned_free(data, size);
|
||||
for (auto * extra : buffers) {
|
||||
ggml_aligned_free(extra->data, extra->size);
|
||||
delete extra;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_backend_blas_buffer_type_context {
|
||||
int n_threads;
|
||||
|
||||
#ifndef GGML_USE_OPENMP
|
||||
std::vector<std::future<void>> tasks;
|
||||
#endif
|
||||
};
|
||||
ggml_blas_mul_mat_f(ctx, src0, src1, dst);
|
||||
}
|
||||
|
||||
// BLAS backend - buffer
|
||||
|
||||
|
|
@ -298,69 +273,6 @@ static ggml_backend_buffer_type_t ggml_backend_blas_buffer_type(void) {
|
|||
return &ggml_backend_blas_buffer_type;
|
||||
}
|
||||
|
||||
struct ggml_backend_blas_context {
|
||||
int n_threads;
|
||||
};
|
||||
|
||||
static void ggml_backend_blas_mul_mat(
|
||||
ggml_backend_blas_context * ctx,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const ggml_type type = src0->type;
|
||||
|
||||
GGML_ASSERT(ne0 == ne01);
|
||||
GGML_ASSERT(ne1 == ne11);
|
||||
GGML_ASSERT(ne2 == ne12);
|
||||
GGML_ASSERT(ne3 == ne13);
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == ggml_type_size(type));
|
||||
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t r2 = ne12/ne02;
|
||||
const int64_t r3 = ne13/ne03;
|
||||
const int64_t ne_plane = ne01*ne00;
|
||||
|
||||
const ggml_backend_blas_buffer * extra = (ggml_backend_blas_buffer *)src0->extra;
|
||||
|
||||
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||
const int64_t i03 = i13/r3;
|
||||
const int64_t i02 = i12/r2;
|
||||
|
||||
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
||||
const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
|
||||
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
|
||||
|
||||
// switch to dequantized F32 data
|
||||
if (type != GGML_TYPE_F32) {
|
||||
x = (float *)extra->data + i02*ne_plane + i03*ne02*ne_plane;
|
||||
}
|
||||
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||
ne1, ne01, ne10,
|
||||
1.0f, y, ne10,
|
||||
x, ne00,
|
||||
0.0f, d, ne01);
|
||||
}
|
||||
}
|
||||
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
|
||||
static const char * ggml_backend_blas_get_name(ggml_backend_t backend) {
|
||||
return "BLAS";
|
||||
|
||||
|
|
@ -390,9 +302,8 @@ static ggml_status ggml_backend_blas_graph_compute(
|
|||
switch (node->op) {
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
ggml_backend_blas_mul_mat(ctx, node);
|
||||
ggml_blas_compute_forward_mul_mat(ctx, node);
|
||||
} break;
|
||||
|
||||
default:
|
||||
GGML_ABORT("%s: unsupported op %s\n", __func__, ggml_op_desc(node));
|
||||
}
|
||||
|
|
@ -471,8 +382,6 @@ void ggml_backend_blas_set_n_threads(ggml_backend_t backend, int n_threads) {
|
|||
#endif
|
||||
}
|
||||
|
||||
struct ggml_backend_blas_device_context {};
|
||||
|
||||
static const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) {
|
||||
return "BLAS";
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
#include "ggml.h"
|
||||
#include "mmf.hpp"
|
||||
|
||||
void ggml_blas_mul_mat_f(
|
||||
const ggml_backend_blas_context * ctx,
|
||||
const ggml_tensor * src0,
|
||||
const ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const ggml_type type = src0->type;
|
||||
|
||||
GGML_ASSERT(ne0 == ne01);
|
||||
GGML_ASSERT(ne1 == ne11);
|
||||
GGML_ASSERT(ne2 == ne12);
|
||||
GGML_ASSERT(ne3 == ne13);
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == ggml_type_size(type));
|
||||
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t r2 = ne12/ne02;
|
||||
const int64_t r3 = ne13/ne03;
|
||||
const int64_t ne_plane = ne01*ne00;
|
||||
|
||||
const ggml_backend_blas_buffer * extra = (ggml_backend_blas_buffer *)src0->extra;
|
||||
|
||||
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||
const int64_t i03 = i13/r3;
|
||||
const int64_t i02 = i12/r2;
|
||||
|
||||
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
||||
const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
|
||||
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
|
||||
|
||||
// switch to dequantized F32 data
|
||||
if (type != GGML_TYPE_F32) {
|
||||
x = (float *)extra->data + i02*ne_plane + i03*ne02*ne_plane;
|
||||
}
|
||||
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||
ne1, ne01, ne10,
|
||||
1.0f, y, ne10,
|
||||
x, ne00,
|
||||
0.0f, d, ne01);
|
||||
}
|
||||
}
|
||||
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_blas_mul_mat_f(
|
||||
const ggml_backend_blas_context * ctx,
|
||||
const ggml_tensor * src0,
|
||||
const ggml_tensor * src1,
|
||||
ggml_tensor * dst);
|
||||
Loading…
Reference in New Issue