From 2c1a62aff0884dab2ad5f79509d3aaae1f6668b8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 6 Feb 2026 11:33:21 +0200 Subject: [PATCH 1/2] metal : refactor bin kernels --- ggml/src/ggml-metal/ggml-metal-device.cpp | 75 ++++- ggml/src/ggml-metal/ggml-metal-device.h | 5 +- ggml/src/ggml-metal/ggml-metal-device.m | 4 +- ggml/src/ggml-metal/ggml-metal-impl.h | 1 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 44 ++- ggml/src/ggml-metal/ggml-metal.metal | 385 +++++++++++++++------- 6 files changed, 355 insertions(+), 159 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 6af0dd88d5..236e2d8be5 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1392,34 +1392,73 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v GGML_UNUSED(op); } -ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin( - ggml_metal_library_t lib, - ggml_op op, - int32_t n_fuse, - bool row) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) { char base[256]; char name[256]; - const char * op_str = "undefined"; - switch (op) { - case GGML_OP_ADD: op_str = "add"; break; - case GGML_OP_SUB: op_str = "sub"; break; - case GGML_OP_MUL: op_str = "mul"; break; - case GGML_OP_DIV: op_str = "div"; break; + int op_num = -1; + + switch (op->op) { + case GGML_OP_ADD: op_num = 0; break; + case GGML_OP_SUB: op_num = 1; break; + case GGML_OP_MUL: op_num = 2; break; + case GGML_OP_DIV: op_num = 3; break; default: GGML_ABORT("fatal error"); }; - if (row) { - snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse); - } else { - snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse); - } + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t1_str = ggml_type_name(op->src[1]->type); + const char * t_str = ggml_type_name(op->type); - snprintf(name, 256, "%s", base); + const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0); + + snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, n_fuse); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0); + ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + } + + res.c4 = is_c4; + + return res; +} + +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) { + char base[256]; + char name[256]; + + int op_num = -1; + + switch (op) { + case GGML_OP_ADD: op_num = 0; break; + case GGML_OP_SUB: op_num = 1; break; + case GGML_OP_MUL: op_num = 2; break; + case GGML_OP_DIV: op_num = 3; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32"); + snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int32(cv, op_num, FC_BIN + 0); + ggml_metal_cv_set_int32(cv, 1, FC_BIN + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } return res; diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 84dcec3083..6ad4fade01 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -53,6 +53,8 @@ struct ggml_metal_pipeline_with_params { int nr1; size_t smem; + + bool c4; }; int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline); @@ -134,7 +136,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); -struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse ); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one (ggml_metal_library_t lib, enum ggml_op op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c8e737d418..d683c75791 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -350,6 +350,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta /*.nr1 =*/ 0, /*.nsg =*/ 0, /*.smem =*/ 0, + /*.c4 =*/ false, }; res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); @@ -366,6 +367,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_ /*.nr1 =*/ 0, /*.nsg =*/ 0, /*.smem =*/ 0, + /*.c4 =*/ false, }; [lib->lock lock]; @@ -1054,7 +1056,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_ADD_ID: - return op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: case GGML_OP_REPEAT: case GGML_OP_SCALE: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 7f73cb97bb..77bb403c15 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -80,6 +80,7 @@ #define FC_SSM_CONV 900 #define FC_SOLVE_TRI 1000 #define FC_COUNT_EQUAL 1100 +#define FC_BIN 1200 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e0ed6c7805..f0f0e1a0c2 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -707,7 +707,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { /*.o1 =*/ { 0 }, }; - auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); + auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -2895,8 +2895,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); GGML_ASSERT(ggml_is_contiguous_rows(op->src[1])); - bool bcast_row = false; - ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); @@ -2990,18 +2988,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { struct ggml_metal_pipeline_with_params pipeline; - if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(op->src[0])); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true); - - bcast_row = true; - } else { - pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false); - } + pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse); if (n_fuse > 1) { bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1)); @@ -3015,26 +3002,33 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { } } + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne10 = ne10/4; + args.ne0 = ne0/4; + } + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, bid_src0, 1); ggml_metal_encoder_set_buffer (enc, bid_src1, 2); ggml_metal_encoder_set_buffer (enc, bid_dst, 3); - if (bcast_row) { - const int64_t n = ggml_nelements(op)/4; + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - } else { - int nth = 32; + int nth = 1; - while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { - nth *= 2; - } - - ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + while (4*nth < args.ne0 && nth < nth_max) { + nth *= 2; } + int nb = 1; + while (4*nb < ne01 && nth*nb < nth_max) { + nb *= 2; + } + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nb - 1)/nb, ne02, ne03, nth, nb, 1); + return n_fuse; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 612a42a1ea..dac292ae7b 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -895,11 +895,156 @@ enum ggml_sort_order { GGML_SORT_ORDER_DESC, }; -// general-purpose kernel for addition, subtraction, multiplication and division of two tensors -// pros: works for non-contiguous tensors, supports broadcast across all dims -// cons: not very efficient -template -kernel void kernel_add_fuse_impl( +//// general-purpose kernel for addition, subtraction, multiplication and division of two tensors +//// pros: works for non-contiguous tensors, supports broadcast across all dims +//// cons: not very efficient +//template +//kernel void kernel_add_fuse_impl( +// constant ggml_metal_kargs_bin & args, +// device const char * src0, +// device const char * src1, +// device char * dst, +// uint3 tgpig[[threadgroup_position_in_grid]], +// ushort3 tpitg[[thread_position_in_threadgroup]], +// ushort3 ntg[[threads_per_threadgroup]]) { +// const int i03 = tgpig.z; +// const int i02 = tgpig.y; +// const int i01 = tgpig.x; +// +// const int i13 = i03%args.ne13; +// const int i12 = i02%args.ne12; +// const int i11 = i01%args.ne11; +// +// device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); +// device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); +// +// device const float * src1_ptr[F]; +// for (short j = 0; j < F; ++j) { +// src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); +// } +// +// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { +// const int i10 = i0%args.ne10; +// +// float res = src0_ptr[i0]; +// +//#pragma unroll +// for (short j = 0; j < F; ++j) { +// res += src1_ptr[j][i10]; +// } +// +// dst_ptr[i0] = res; +// } +//} +// +//typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; +// +//template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; +//template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; +//template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; +//template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; +//template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>; +//template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>; +//template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; +//template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; +// +//kernel void kernel_sub_fuse_1( +// constant ggml_metal_kargs_bin & args, +// device const char * src0, +// device const char * src1, +// device char * dst, +// uint3 tgpig[[threadgroup_position_in_grid]], +// ushort3 tpitg[[thread_position_in_threadgroup]], +// ushort3 ntg[[threads_per_threadgroup]]) { +// const int i03 = tgpig.z; +// const int i02 = tgpig.y; +// const int i01 = tgpig.x; +// +// const int i13 = i03%args.ne13; +// const int i12 = i02%args.ne12; +// const int i11 = i01%args.ne11; +// +// device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; +// device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; +// device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; +// +// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { +// const int i10 = i0%args.ne10; +// *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); +// } +//} +// +//kernel void kernel_mul_fuse_1( +// constant ggml_metal_kargs_bin & args, +// device const char * src0, +// device const char * src1, +// device char * dst, +// uint3 tgpig[[threadgroup_position_in_grid]], +// ushort3 tpitg[[thread_position_in_threadgroup]], +// ushort3 ntg[[threads_per_threadgroup]]) { +// const int i03 = tgpig.z; +// const int i02 = tgpig.y; +// const int i01 = tgpig.x; +// +// const int i13 = i03%args.ne13; +// const int i12 = i02%args.ne12; +// const int i11 = i01%args.ne11; +// +// device const float * src0_ptr = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); +// device const float * src1_ptr = (device const float *)(src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]); +// device float * dst_ptr = (device float *)(dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); +// +// if (args.ne10 == 1) { +// const float x = src1_ptr[0]; +// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { +// dst_ptr[i0] = src0_ptr[i0] * x; +// } +// } else { +// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { +// const int i10 = i0 % args.ne10; +// dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10]; +// } +// } +//} +// +//kernel void kernel_div_fuse_1( +// constant ggml_metal_kargs_bin & args, +// device const char * src0, +// device const char * src1, +// device char * dst, +// uint3 tgpig[[threadgroup_position_in_grid]], +// ushort3 tpitg[[thread_position_in_threadgroup]], +// ushort3 ntg[[threads_per_threadgroup]]) { +// const int i03 = tgpig.z; +// const int i02 = tgpig.y; +// const int i01 = tgpig.x; +// +// const int i13 = i03%args.ne13; +// const int i12 = i02%args.ne12; +// const int i11 = i01%args.ne11; +// +// device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; +// device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; +// device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; +// +// if (args.ne10 == 1) { +// const float x = 1.0f / *((device float *)(src1_ptr)); +// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { +// *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; +// } +// } else { +// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { +// const int i10 = i0%args.ne10; +// *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); +// } +// } +//} + +constant short FC_bin_op [[function_constant(FC_BIN + 0)]]; +constant short FC_bin_f [[function_constant(FC_BIN + 1)]]; + +template +kernel void kernel_bin_fuse_impl( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -907,138 +1052,150 @@ kernel void kernel_add_fuse_impl( uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { + // OP: 0 - add, 1 - sub, 2 - mul, 3 - div + const short OP = FC_bin_op; + const short F = FC_bin_f; + const int i03 = tgpig.z; const int i02 = tgpig.y; - const int i01 = tgpig.x; + const int i01 = tgpig.x*ntg.y + tpitg.y; + + if (i01 >= args.ne01) { + return; + } const int i13 = i03%args.ne13; const int i12 = i02%args.ne12; const int i11 = i01%args.ne11; - device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); - device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); + device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); + device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); - device const float * src1_ptr[F]; - for (short j = 0; j < F; ++j) { - src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); - } + if (F == 1) { + device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; + if (args.ne10 == 1) { + T1 src1_cur = src1_ptr[0]; - float res = src0_ptr[i0]; + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + if (OP == 0) { + dst_ptr[i0] = src0_ptr[i0] + src1_cur; + } -#pragma unroll - for (short j = 0; j < F; ++j) { - res += src1_ptr[j][i10]; - } + if (OP == 1) { + dst_ptr[i0] = src0_ptr[i0] - src1_cur; + } - dst_ptr[i0] = res; - } -} + if (OP == 2) { + dst_ptr[i0] = src0_ptr[i0] * src1_cur; + } -typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; + if (OP == 3) { + dst_ptr[i0] = src0_ptr[i0] / src1_cur; + } + } + } else { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; -template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; -template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; -template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; -template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; -template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>; -template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>; -template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; -template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; + if (OP == 0) { + dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10]; + } -kernel void kernel_sub_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; + if (OP == 1) { + dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10]; + } - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + if (OP == 2) { + dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10]; + } - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); - } -} - -kernel void kernel_mul_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; - - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; - - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - - if (args.ne10 == 1) { - const float x = *((device float *)(src1_ptr)); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; + if (OP == 3) { + dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10]; + } + } } } else { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); + device const T1 * src1_ptr[8]; + FOR_UNROLL (short j = 0; j < F; ++j) { + src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); + } + + if (args.ne10 == 1) { + T1 src1_cur[8]; + FOR_UNROLL (short j = 0; j < F; ++j) { + src1_cur[j] = src1_ptr[j][0]; + } + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + T res = src0_ptr[i0]; + + if (OP == 0) { + FOR_UNROLL (short j = 0; j < F; ++j) { + res += src1_cur[j]; + } + } + + if (OP == 1) { + FOR_UNROLL (short j = 0; j < F; ++j) { + res -= src1_cur[j]; + } + } + + if (OP == 2) { + FOR_UNROLL (short j = 0; j < F; ++j) { + res *= src1_cur[j]; + } + } + + if (OP == 3) { + FOR_UNROLL (short j = 0; j < F; ++j) { + res /= src1_cur[j]; + } + } + + dst_ptr[i0] = res; + } + } else { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + + T res = src0_ptr[i0]; + + if (OP == 0) { + FOR_UNROLL (short j = 0; j < F; ++j) { + res += src1_ptr[j][i10]; + } + } + + if (OP == 1) { + FOR_UNROLL (short j = 0; j < F; ++j) { + res -= src1_ptr[j][i10]; + } + } + + if (OP == 2) { + FOR_UNROLL (short j = 0; j < F; ++j) { + res *= src1_ptr[j][i10]; + } + } + + if (OP == 3) { + FOR_UNROLL (short j = 0; j < F; ++j) { + res /= src1_ptr[j][i10]; + } + } + + dst_ptr[i0] = res; + } } } } -kernel void kernel_div_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; +typedef decltype(kernel_bin_fuse_impl) kernel_bin_fuse_t; - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; - - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - - if (args.ne10 == 1) { - const float x = 1.0f / *((device float *)(src1_ptr)); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; - } - } else { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); - } - } -} +template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl; +template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl; kernel void kernel_add_id( constant ggml_metal_kargs_add_id & args, @@ -1057,7 +1214,7 @@ kernel void kernel_add_id( const size_t nb1 = args.ne0 * sizeof(float); const size_t nb2 = args.ne1 * nb1; - device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2); + device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2); device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02); device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11); From 21f6a3734874656bfb60daf923012a9dfe870089 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 6 Feb 2026 14:09:42 +0200 Subject: [PATCH 2/2] cont --- ggml/src/ggml-metal/ggml-metal-device.cpp | 8 +- ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 6 +- ggml/src/ggml-metal/ggml-metal-ops.cpp | 23 +- ggml/src/ggml-metal/ggml-metal.metal | 460 +++++----------------- 5 files changed, 115 insertions(+), 383 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 236e2d8be5..f943fd07bf 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1412,8 +1412,10 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0); + const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536; + snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : ""); - snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, n_fuse); + snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -1421,13 +1423,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0); ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1); + ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); ggml_metal_cv_free(cv); } - res.c4 = is_c4; + res.c4 = is_c4; + res.cnt = is_rb; return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 6ad4fade01..93d7f6a216 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -55,6 +55,7 @@ struct ggml_metal_pipeline_with_params { size_t smem; bool c4; + bool cnt; }; int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index d683c75791..891d70c85a 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -346,11 +346,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta struct ggml_metal_pipeline_with_params res = { /*.pipeline =*/ nil, + /*.nsg =*/ 0, /*.nr0 =*/ 0, /*.nr1 =*/ 0, - /*.nsg =*/ 0, /*.smem =*/ 0, /*.c4 =*/ false, + /*.cnt =*/ false, }; res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); @@ -363,11 +364,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { struct ggml_metal_pipeline_with_params res = { /*.pipeline =*/ nil, + /*.nsg =*/ 0, /*.nr0 =*/ 0, /*.nr1 =*/ 0, - /*.nsg =*/ 0, /*.smem =*/ 0, /*.c4 =*/ false, + /*.cnt =*/ false, }; [lib->lock lock]; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index f0f0e1a0c2..dbf25433c2 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3014,21 +3014,22 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, bid_src1, 2); ggml_metal_encoder_set_buffer (enc, bid_dst, 3); - const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + if (pipeline.cnt) { + const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op); - int nth = 1; + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + } else { + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - while (4*nth < args.ne0 && nth < nth_max) { - nth *= 2; + int nth = 1; + + while (2*nth < args.ne0 && nth < nth_max) { + nth *= 2; + } + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); } - int nb = 1; - while (4*nb < ne01 && nth*nb < nth_max) { - nb *= 2; - } - - ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nb - 1)/nb, ne02, ne03, nth, nb, 1); - return n_fuse; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index dac292ae7b..35cc3bbdfd 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -895,153 +895,10 @@ enum ggml_sort_order { GGML_SORT_ORDER_DESC, }; -//// general-purpose kernel for addition, subtraction, multiplication and division of two tensors -//// pros: works for non-contiguous tensors, supports broadcast across all dims -//// cons: not very efficient -//template -//kernel void kernel_add_fuse_impl( -// constant ggml_metal_kargs_bin & args, -// device const char * src0, -// device const char * src1, -// device char * dst, -// uint3 tgpig[[threadgroup_position_in_grid]], -// ushort3 tpitg[[thread_position_in_threadgroup]], -// ushort3 ntg[[threads_per_threadgroup]]) { -// const int i03 = tgpig.z; -// const int i02 = tgpig.y; -// const int i01 = tgpig.x; -// -// const int i13 = i03%args.ne13; -// const int i12 = i02%args.ne12; -// const int i11 = i01%args.ne11; -// -// device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); -// device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); -// -// device const float * src1_ptr[F]; -// for (short j = 0; j < F; ++j) { -// src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); -// } -// -// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { -// const int i10 = i0%args.ne10; -// -// float res = src0_ptr[i0]; -// -//#pragma unroll -// for (short j = 0; j < F; ++j) { -// res += src1_ptr[j][i10]; -// } -// -// dst_ptr[i0] = res; -// } -//} -// -//typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; -// -//template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; -//template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; -//template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; -//template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; -//template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>; -//template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>; -//template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; -//template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; -// -//kernel void kernel_sub_fuse_1( -// constant ggml_metal_kargs_bin & args, -// device const char * src0, -// device const char * src1, -// device char * dst, -// uint3 tgpig[[threadgroup_position_in_grid]], -// ushort3 tpitg[[thread_position_in_threadgroup]], -// ushort3 ntg[[threads_per_threadgroup]]) { -// const int i03 = tgpig.z; -// const int i02 = tgpig.y; -// const int i01 = tgpig.x; -// -// const int i13 = i03%args.ne13; -// const int i12 = i02%args.ne12; -// const int i11 = i01%args.ne11; -// -// device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; -// device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; -// device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; -// -// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { -// const int i10 = i0%args.ne10; -// *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); -// } -//} -// -//kernel void kernel_mul_fuse_1( -// constant ggml_metal_kargs_bin & args, -// device const char * src0, -// device const char * src1, -// device char * dst, -// uint3 tgpig[[threadgroup_position_in_grid]], -// ushort3 tpitg[[thread_position_in_threadgroup]], -// ushort3 ntg[[threads_per_threadgroup]]) { -// const int i03 = tgpig.z; -// const int i02 = tgpig.y; -// const int i01 = tgpig.x; -// -// const int i13 = i03%args.ne13; -// const int i12 = i02%args.ne12; -// const int i11 = i01%args.ne11; -// -// device const float * src0_ptr = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); -// device const float * src1_ptr = (device const float *)(src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]); -// device float * dst_ptr = (device float *)(dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); -// -// if (args.ne10 == 1) { -// const float x = src1_ptr[0]; -// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { -// dst_ptr[i0] = src0_ptr[i0] * x; -// } -// } else { -// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { -// const int i10 = i0 % args.ne10; -// dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10]; -// } -// } -//} -// -//kernel void kernel_div_fuse_1( -// constant ggml_metal_kargs_bin & args, -// device const char * src0, -// device const char * src1, -// device char * dst, -// uint3 tgpig[[threadgroup_position_in_grid]], -// ushort3 tpitg[[thread_position_in_threadgroup]], -// ushort3 ntg[[threads_per_threadgroup]]) { -// const int i03 = tgpig.z; -// const int i02 = tgpig.y; -// const int i01 = tgpig.x; -// -// const int i13 = i03%args.ne13; -// const int i12 = i02%args.ne12; -// const int i11 = i01%args.ne11; -// -// device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; -// device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; -// device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; -// -// if (args.ne10 == 1) { -// const float x = 1.0f / *((device float *)(src1_ptr)); -// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { -// *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; -// } -// } else { -// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { -// const int i10 = i0%args.ne10; -// *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); -// } -// } -//} - +// OP: 0 - add, 1 - sub, 2 - mul, 3 - div constant short FC_bin_op [[function_constant(FC_BIN + 0)]]; constant short FC_bin_f [[function_constant(FC_BIN + 1)]]; +constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]]; template kernel void kernel_bin_fuse_impl( @@ -1052,136 +909,134 @@ kernel void kernel_bin_fuse_impl( uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - // OP: 0 - add, 1 - sub, 2 - mul, 3 - div - const short OP = FC_bin_op; - const short F = FC_bin_f; +#define FC_OP FC_bin_op +#define FC_F FC_bin_f +#define FC_RB FC_bin_rb - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x*ntg.y + tpitg.y; + if (FC_RB) { + // row broadcast + const uint i0 = tgpig.x; + const uint i1 = i0%args.ne10; - if (i01 >= args.ne01) { - return; - } + device const T0 * src0_row = (device const T0 *) (src0); + device T * dst_row = (device T *) (dst); - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + if (FC_F == 1) { + device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]); - device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); - device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); + if (FC_OP == 0) { + dst_row[i0] = src0_row[i0] + src1_row[i1]; + } - if (F == 1) { - device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); + if (FC_OP == 1) { + dst_row[i0] = src0_row[i0] - src1_row[i1]; + } - if (args.ne10 == 1) { - T1 src1_cur = src1_ptr[0]; + if (FC_OP == 2) { + dst_row[i0] = src0_row[i0] * src1_row[i1]; + } - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - if (OP == 0) { - dst_ptr[i0] = src0_ptr[i0] + src1_cur; - } - - if (OP == 1) { - dst_ptr[i0] = src0_ptr[i0] - src1_cur; - } - - if (OP == 2) { - dst_ptr[i0] = src0_ptr[i0] * src1_cur; - } - - if (OP == 3) { - dst_ptr[i0] = src0_ptr[i0] / src1_cur; - } + if (FC_OP == 3) { + dst_row[i0] = src0_row[i0] / src1_row[i1]; } } else { + T0 res = src0_row[i0]; + + if (FC_OP == 0) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res += ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + if (FC_OP == 1) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res -= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + if (FC_OP == 2) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res *= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + if (FC_OP == 3) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res /= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + dst_row[i0] = res; + } + } else { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + + if (i01 >= args.ne01) { + return; + } + + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; + + device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); + device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); + + if (FC_F == 1) { + device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { const int i10 = i0%args.ne10; - if (OP == 0) { + if (FC_OP == 0) { dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10]; } - if (OP == 1) { + if (FC_OP == 1) { dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10]; } - if (OP == 2) { + if (FC_OP == 2) { dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10]; } - if (OP == 3) { + if (FC_OP == 3) { dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10]; } } - } - } else { - device const T1 * src1_ptr[8]; - FOR_UNROLL (short j = 0; j < F; ++j) { - src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); - } - - if (args.ne10 == 1) { - T1 src1_cur[8]; - FOR_UNROLL (short j = 0; j < F; ++j) { - src1_cur[j] = src1_ptr[j][0]; - } - - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - T res = src0_ptr[i0]; - - if (OP == 0) { - FOR_UNROLL (short j = 0; j < F; ++j) { - res += src1_cur[j]; - } - } - - if (OP == 1) { - FOR_UNROLL (short j = 0; j < F; ++j) { - res -= src1_cur[j]; - } - } - - if (OP == 2) { - FOR_UNROLL (short j = 0; j < F; ++j) { - res *= src1_cur[j]; - } - } - - if (OP == 3) { - FOR_UNROLL (short j = 0; j < F; ++j) { - res /= src1_cur[j]; - } - } - - dst_ptr[i0] = res; - } } else { + device const T1 * src1_ptr[8]; + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); + } + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { const int i10 = i0%args.ne10; T res = src0_ptr[i0]; - if (OP == 0) { - FOR_UNROLL (short j = 0; j < F; ++j) { + if (FC_OP == 0) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { res += src1_ptr[j][i10]; } } - if (OP == 1) { - FOR_UNROLL (short j = 0; j < F; ++j) { + if (FC_OP == 1) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { res -= src1_ptr[j][i10]; } } - if (OP == 2) { - FOR_UNROLL (short j = 0; j < F; ++j) { + if (FC_OP == 2) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { res *= src1_ptr[j][i10]; } } - if (OP == 3) { - FOR_UNROLL (short j = 0; j < F; ++j) { + if (FC_OP == 3) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { res /= src1_ptr[j][i10]; } } @@ -1190,6 +1045,10 @@ kernel void kernel_bin_fuse_impl( } } } + +#undef FC_OP +#undef FC_F +#undef FC_RB } typedef decltype(kernel_bin_fuse_impl) kernel_bin_fuse_t; @@ -1255,141 +1114,6 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat; template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat; -// assumption: src1 is a row -// broadcast src1 into src0 -template -kernel void kernel_add_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res += ((device const float4 *) (src1 + args.o1[j]))[i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t; - -template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; -template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>; -template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>; -template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>; -template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>; -template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>; -template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>; -template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>; - -template -kernel void kernel_sub_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res -= src1_row[j][i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t; - -template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; - -template -kernel void kernel_mul_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res *= src1_row[j][i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t; - -template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; - -template -kernel void kernel_div_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res /= src1_row[j][i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t; - -template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; - kernel void kernel_scale_f32( constant ggml_metal_kargs_scale & args, device const float * src0,