From 2c1a62aff0884dab2ad5f79509d3aaae1f6668b8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 6 Feb 2026 11:33:21 +0200 Subject: [PATCH] metal : refactor bin kernels --- ggml/src/ggml-metal/ggml-metal-device.cpp | 75 ++++- ggml/src/ggml-metal/ggml-metal-device.h | 5 +- ggml/src/ggml-metal/ggml-metal-device.m | 4 +- ggml/src/ggml-metal/ggml-metal-impl.h | 1 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 44 ++- ggml/src/ggml-metal/ggml-metal.metal | 385 +++++++++++++++------- 6 files changed, 355 insertions(+), 159 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 6af0dd88d5..236e2d8be5 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1392,34 +1392,73 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v GGML_UNUSED(op); } -ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin( - ggml_metal_library_t lib, - ggml_op op, - int32_t n_fuse, - bool row) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) { char base[256]; char name[256]; - const char * op_str = "undefined"; - switch (op) { - case GGML_OP_ADD: op_str = "add"; break; - case GGML_OP_SUB: op_str = "sub"; break; - case GGML_OP_MUL: op_str = "mul"; break; - case GGML_OP_DIV: op_str = "div"; break; + int op_num = -1; + + switch (op->op) { + case GGML_OP_ADD: op_num = 0; break; + case GGML_OP_SUB: op_num = 1; break; + case GGML_OP_MUL: op_num = 2; break; + case GGML_OP_DIV: op_num = 3; break; default: GGML_ABORT("fatal error"); }; - if (row) { - snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse); - } else { - snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse); - } + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t1_str = ggml_type_name(op->src[1]->type); + const char * t_str = ggml_type_name(op->type); - snprintf(name, 256, "%s", base); + const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0); + + snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, n_fuse); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0); + ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + } + + res.c4 = is_c4; + + return res; +} + +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) { + char base[256]; + char name[256]; + + int op_num = -1; + + switch (op) { + case GGML_OP_ADD: op_num = 0; break; + case GGML_OP_SUB: op_num = 1; break; + case GGML_OP_MUL: op_num = 2; break; + case GGML_OP_DIV: op_num = 3; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32"); + snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int32(cv, op_num, FC_BIN + 0); + ggml_metal_cv_set_int32(cv, 1, FC_BIN + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } return res; diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 84dcec3083..6ad4fade01 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -53,6 +53,8 @@ struct ggml_metal_pipeline_with_params { int nr1; size_t smem; + + bool c4; }; int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline); @@ -134,7 +136,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); -struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse ); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one (ggml_metal_library_t lib, enum ggml_op op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c8e737d418..d683c75791 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -350,6 +350,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta /*.nr1 =*/ 0, /*.nsg =*/ 0, /*.smem =*/ 0, + /*.c4 =*/ false, }; res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); @@ -366,6 +367,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_ /*.nr1 =*/ 0, /*.nsg =*/ 0, /*.smem =*/ 0, + /*.c4 =*/ false, }; [lib->lock lock]; @@ -1054,7 +1056,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_ADD_ID: - return op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: case GGML_OP_REPEAT: case GGML_OP_SCALE: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 7f73cb97bb..77bb403c15 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -80,6 +80,7 @@ #define FC_SSM_CONV 900 #define FC_SOLVE_TRI 1000 #define FC_COUNT_EQUAL 1100 +#define FC_BIN 1200 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e0ed6c7805..f0f0e1a0c2 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -707,7 +707,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { /*.o1 =*/ { 0 }, }; - auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); + auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -2895,8 +2895,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); GGML_ASSERT(ggml_is_contiguous_rows(op->src[1])); - bool bcast_row = false; - ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); @@ -2990,18 +2988,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { struct ggml_metal_pipeline_with_params pipeline; - if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(op->src[0])); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true); - - bcast_row = true; - } else { - pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false); - } + pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse); if (n_fuse > 1) { bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1)); @@ -3015,26 +3002,33 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { } } + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne10 = ne10/4; + args.ne0 = ne0/4; + } + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, bid_src0, 1); ggml_metal_encoder_set_buffer (enc, bid_src1, 2); ggml_metal_encoder_set_buffer (enc, bid_dst, 3); - if (bcast_row) { - const int64_t n = ggml_nelements(op)/4; + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - } else { - int nth = 32; + int nth = 1; - while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { - nth *= 2; - } - - ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + while (4*nth < args.ne0 && nth < nth_max) { + nth *= 2; } + int nb = 1; + while (4*nb < ne01 && nth*nb < nth_max) { + nb *= 2; + } + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nb - 1)/nb, ne02, ne03, nth, nb, 1); + return n_fuse; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 612a42a1ea..dac292ae7b 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -895,11 +895,156 @@ enum ggml_sort_order { GGML_SORT_ORDER_DESC, }; -// general-purpose kernel for addition, subtraction, multiplication and division of two tensors -// pros: works for non-contiguous tensors, supports broadcast across all dims -// cons: not very efficient -template -kernel void kernel_add_fuse_impl( +//// general-purpose kernel for addition, subtraction, multiplication and division of two tensors +//// pros: works for non-contiguous tensors, supports broadcast across all dims +//// cons: not very efficient +//template +//kernel void kernel_add_fuse_impl( +// constant ggml_metal_kargs_bin & args, +// device const char * src0, +// device const char * src1, +// device char * dst, +// uint3 tgpig[[threadgroup_position_in_grid]], +// ushort3 tpitg[[thread_position_in_threadgroup]], +// ushort3 ntg[[threads_per_threadgroup]]) { +// const int i03 = tgpig.z; +// const int i02 = tgpig.y; +// const int i01 = tgpig.x; +// +// const int i13 = i03%args.ne13; +// const int i12 = i02%args.ne12; +// const int i11 = i01%args.ne11; +// +// device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); +// device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); +// +// device const float * src1_ptr[F]; +// for (short j = 0; j < F; ++j) { +// src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); +// } +// +// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { +// const int i10 = i0%args.ne10; +// +// float res = src0_ptr[i0]; +// +//#pragma unroll +// for (short j = 0; j < F; ++j) { +// res += src1_ptr[j][i10]; +// } +// +// dst_ptr[i0] = res; +// } +//} +// +//typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; +// +//template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; +//template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; +//template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; +//template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; +//template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>; +//template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>; +//template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; +//template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; +// +//kernel void kernel_sub_fuse_1( +// constant ggml_metal_kargs_bin & args, +// device const char * src0, +// device const char * src1, +// device char * dst, +// uint3 tgpig[[threadgroup_position_in_grid]], +// ushort3 tpitg[[thread_position_in_threadgroup]], +// ushort3 ntg[[threads_per_threadgroup]]) { +// const int i03 = tgpig.z; +// const int i02 = tgpig.y; +// const int i01 = tgpig.x; +// +// const int i13 = i03%args.ne13; +// const int i12 = i02%args.ne12; +// const int i11 = i01%args.ne11; +// +// device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; +// device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; +// device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; +// +// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { +// const int i10 = i0%args.ne10; +// *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); +// } +//} +// +//kernel void kernel_mul_fuse_1( +// constant ggml_metal_kargs_bin & args, +// device const char * src0, +// device const char * src1, +// device char * dst, +// uint3 tgpig[[threadgroup_position_in_grid]], +// ushort3 tpitg[[thread_position_in_threadgroup]], +// ushort3 ntg[[threads_per_threadgroup]]) { +// const int i03 = tgpig.z; +// const int i02 = tgpig.y; +// const int i01 = tgpig.x; +// +// const int i13 = i03%args.ne13; +// const int i12 = i02%args.ne12; +// const int i11 = i01%args.ne11; +// +// device const float * src0_ptr = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); +// device const float * src1_ptr = (device const float *)(src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]); +// device float * dst_ptr = (device float *)(dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); +// +// if (args.ne10 == 1) { +// const float x = src1_ptr[0]; +// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { +// dst_ptr[i0] = src0_ptr[i0] * x; +// } +// } else { +// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { +// const int i10 = i0 % args.ne10; +// dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10]; +// } +// } +//} +// +//kernel void kernel_div_fuse_1( +// constant ggml_metal_kargs_bin & args, +// device const char * src0, +// device const char * src1, +// device char * dst, +// uint3 tgpig[[threadgroup_position_in_grid]], +// ushort3 tpitg[[thread_position_in_threadgroup]], +// ushort3 ntg[[threads_per_threadgroup]]) { +// const int i03 = tgpig.z; +// const int i02 = tgpig.y; +// const int i01 = tgpig.x; +// +// const int i13 = i03%args.ne13; +// const int i12 = i02%args.ne12; +// const int i11 = i01%args.ne11; +// +// device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; +// device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; +// device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; +// +// if (args.ne10 == 1) { +// const float x = 1.0f / *((device float *)(src1_ptr)); +// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { +// *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; +// } +// } else { +// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { +// const int i10 = i0%args.ne10; +// *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); +// } +// } +//} + +constant short FC_bin_op [[function_constant(FC_BIN + 0)]]; +constant short FC_bin_f [[function_constant(FC_BIN + 1)]]; + +template +kernel void kernel_bin_fuse_impl( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -907,138 +1052,150 @@ kernel void kernel_add_fuse_impl( uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { + // OP: 0 - add, 1 - sub, 2 - mul, 3 - div + const short OP = FC_bin_op; + const short F = FC_bin_f; + const int i03 = tgpig.z; const int i02 = tgpig.y; - const int i01 = tgpig.x; + const int i01 = tgpig.x*ntg.y + tpitg.y; + + if (i01 >= args.ne01) { + return; + } const int i13 = i03%args.ne13; const int i12 = i02%args.ne12; const int i11 = i01%args.ne11; - device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); - device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); + device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); + device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); - device const float * src1_ptr[F]; - for (short j = 0; j < F; ++j) { - src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); - } + if (F == 1) { + device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; + if (args.ne10 == 1) { + T1 src1_cur = src1_ptr[0]; - float res = src0_ptr[i0]; + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + if (OP == 0) { + dst_ptr[i0] = src0_ptr[i0] + src1_cur; + } -#pragma unroll - for (short j = 0; j < F; ++j) { - res += src1_ptr[j][i10]; - } + if (OP == 1) { + dst_ptr[i0] = src0_ptr[i0] - src1_cur; + } - dst_ptr[i0] = res; - } -} + if (OP == 2) { + dst_ptr[i0] = src0_ptr[i0] * src1_cur; + } -typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; + if (OP == 3) { + dst_ptr[i0] = src0_ptr[i0] / src1_cur; + } + } + } else { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; -template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; -template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; -template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; -template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; -template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>; -template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>; -template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; -template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; + if (OP == 0) { + dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10]; + } -kernel void kernel_sub_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; + if (OP == 1) { + dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10]; + } - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + if (OP == 2) { + dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10]; + } - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); - } -} - -kernel void kernel_mul_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; - - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; - - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - - if (args.ne10 == 1) { - const float x = *((device float *)(src1_ptr)); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; + if (OP == 3) { + dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10]; + } + } } } else { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); + device const T1 * src1_ptr[8]; + FOR_UNROLL (short j = 0; j < F; ++j) { + src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); + } + + if (args.ne10 == 1) { + T1 src1_cur[8]; + FOR_UNROLL (short j = 0; j < F; ++j) { + src1_cur[j] = src1_ptr[j][0]; + } + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + T res = src0_ptr[i0]; + + if (OP == 0) { + FOR_UNROLL (short j = 0; j < F; ++j) { + res += src1_cur[j]; + } + } + + if (OP == 1) { + FOR_UNROLL (short j = 0; j < F; ++j) { + res -= src1_cur[j]; + } + } + + if (OP == 2) { + FOR_UNROLL (short j = 0; j < F; ++j) { + res *= src1_cur[j]; + } + } + + if (OP == 3) { + FOR_UNROLL (short j = 0; j < F; ++j) { + res /= src1_cur[j]; + } + } + + dst_ptr[i0] = res; + } + } else { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + + T res = src0_ptr[i0]; + + if (OP == 0) { + FOR_UNROLL (short j = 0; j < F; ++j) { + res += src1_ptr[j][i10]; + } + } + + if (OP == 1) { + FOR_UNROLL (short j = 0; j < F; ++j) { + res -= src1_ptr[j][i10]; + } + } + + if (OP == 2) { + FOR_UNROLL (short j = 0; j < F; ++j) { + res *= src1_ptr[j][i10]; + } + } + + if (OP == 3) { + FOR_UNROLL (short j = 0; j < F; ++j) { + res /= src1_ptr[j][i10]; + } + } + + dst_ptr[i0] = res; + } } } } -kernel void kernel_div_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; +typedef decltype(kernel_bin_fuse_impl) kernel_bin_fuse_t; - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; - - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - - if (args.ne10 == 1) { - const float x = 1.0f / *((device float *)(src1_ptr)); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; - } - } else { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); - } - } -} +template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl; +template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl; kernel void kernel_add_id( constant ggml_metal_kargs_add_id & args, @@ -1057,7 +1214,7 @@ kernel void kernel_add_id( const size_t nb1 = args.ne0 * sizeof(float); const size_t nb2 = args.ne1 * nb1; - device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2); + device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2); device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02); device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);