ggml-blas: code clean up

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
This commit is contained in:
Aaron Teo 2025-12-11 21:27:13 +08:00
parent 19c8ec9964
commit 1926e07e1a
No known key found for this signature in database
1 changed files with 31 additions and 32 deletions

View File

@ -27,13 +27,16 @@ struct ggml_backend_blas_context {
#endif
};
static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
static void ggml_backend_blas_mul_mat(
ggml_backend_blas_context * ctx,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
const enum ggml_type type = src0->type;
const ggml_type type = src0->type;
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
@ -84,8 +87,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
}
#else
for (int i = 1; i < n_threads; i++) {
const int64_t start = i*ne01/n_threads;
const int64_t end = (i + 1)*ne01/n_threads;
const int64_t start = (i + 0) * ne01/n_threads;
const int64_t end = (i + 1) * ne01/n_threads;
if (start < end) {
ctx->tasks.push_back(std::async(std::launch::async, [=]() {
for (int64_t i01 = start; i01 < end; i01++) {
@ -149,14 +152,17 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
}
}
static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * ids = dst->src[2];
static void ggml_backend_blas_mul_mat_id(
ggml_backend_blas_context * ctx,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * ids = dst->src[2];
GGML_TENSOR_BINARY_OP_LOCALS
const enum ggml_type type = src0->type;
const ggml_type type = src0->type;
GGML_ASSERT(nb00 == ggml_type_size(type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
@ -173,15 +179,10 @@ static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_t
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
// broadcast factors
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
GGML_UNUSED(r2);
GGML_UNUSED(r3);
const int64_t ne_plane = ne01*ne00;
const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);
const size_t desired_wsize = type == GGML_TYPE_F32
? 0
: ne03*ne02*ne_plane*sizeof(float);
if (ctx->work_size < desired_wsize) {
ctx->work_data.reset(new char[desired_wsize]);
@ -210,7 +211,7 @@ static void ggml_backend_blas_mul_mat_id(ggml_backend_blas_context * ctx, ggml_t
}
#else
for (int i = 1; i < n_threads; i++) {
const int64_t start = i*ne01/n_threads;
const int64_t start = (i + 0)*ne01/n_threads;
const int64_t end = (i + 1)*ne01/n_threads;
if (start < end) {
ctx->tasks.push_back(std::async(std::launch::async, [=]() {
@ -555,15 +556,13 @@ static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const s
case GGML_OP_MUL_MAT_ID:
{
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * src1 = op->src[1];
const struct ggml_tensor * src2 = op->src[2];
const ggml_tensor * src0 = op->src[0];
const ggml_tensor * src1 = op->src[1];
// GGML_LOG_INFO("%s: op=GGML_OP_MUL_MAT_ID src0_type=%s src1_type=%s src2_type=%s ne0=%lld ne1=%lld ne2=%lld ne3=%lld\n",
// __func__, ggml_type_name(src0->type), ggml_type_name(src1->type), ggml_type_name(src2->type),
// op->ne[0], op->ne[1], op->ne[2], op->ne[3]);
return src2->type == GGML_TYPE_I32;
return ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) &&
src1->type == GGML_TYPE_F32 &&
(src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
}
case GGML_OP_OUT_PROD: