From e4cff0956b0dd7196cd8d4717daf02329cdeb870 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 12 Mar 2026 09:42:40 +0200 Subject: [PATCH] metal : avoid divisions in bin kernel (#20426) * metal : avoid modulus in bin kernel when not broadcasting * metal : fix capture_started flag --- ggml/src/ggml-metal/ggml-metal-context.m | 4 +++- ggml/src/ggml-metal/ggml-metal-device.cpp | 4 +++- ggml/src/ggml-metal/ggml-metal-ops.cpp | 4 +--- ggml/src/ggml-metal/ggml-metal.metal | 11 +++++++---- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m index 855fd1adae..32d97cd5d0 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.m +++ b/ggml/src/ggml-metal/ggml-metal-context.m @@ -554,7 +554,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * // enter here only when capturing in order to wait for all computation to finish // otherwise, we leave the graph to compute asynchronously - if (!use_capture && ctx->capture_started) { + if (use_capture && ctx->capture_started) { // wait for completion and check status of each command buffer // needed to detect if the device ran out-of-memory for example (#1881) { @@ -606,6 +606,8 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * [ctx->capture_scope endScope]; [[MTLCaptureManager sharedCaptureManager] stopCapture]; + + ctx->capture_started = false; } } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 15ae2e517d..72ad876d5e 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1470,10 +1470,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0); + const bool is_cb = op->src[0]->ne[0] != op->src[1]->ne[0]; const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536; snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : ""); - snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb); + snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d_cb=%d", base, op_num, n_fuse, is_rb, is_cb); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -1482,6 +1483,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0); ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1); ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2); + ggml_metal_cv_set_bool (cv, is_cb, FC_BIN + 3); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 306dbcf366..c0bcad392b 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3180,9 +3180,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, bid_dst, 3); if (pipeline.cnt) { - const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, args.ne0, ggml_nrows(op), 1, 1, 1, 1); } else { const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 0b77d5349b..24a3092af2 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1111,6 +1111,7 @@ template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_un constant short FC_bin_op [[function_constant(FC_BIN + 0)]]; constant short FC_bin_f [[function_constant(FC_BIN + 1)]]; constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]]; +constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]]; template kernel void kernel_bin_fuse_impl( @@ -1124,11 +1125,12 @@ kernel void kernel_bin_fuse_impl( #define FC_OP FC_bin_op #define FC_F FC_bin_f #define FC_RB FC_bin_rb +#define FC_CB FC_bin_cb if (FC_RB) { // row broadcast - const uint i0 = tgpig.x; - const uint i1 = i0%args.ne10; + const uint i0 = tgpig.y*args.ne00 + tgpig.x; + const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x; device const T0 * src0_row = (device const T0 *) (src0); device T * dst_row = (device T *) (dst); @@ -1200,7 +1202,7 @@ kernel void kernel_bin_fuse_impl( device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; + const int i10 = FC_CB ? i0%args.ne10 : i0; if (FC_OP == 0) { dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10]; @@ -1225,7 +1227,7 @@ kernel void kernel_bin_fuse_impl( } for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; + const int i10 = FC_CB ? i0%args.ne10 : i0; T res = src0_ptr[i0]; @@ -1261,6 +1263,7 @@ kernel void kernel_bin_fuse_impl( #undef FC_OP #undef FC_F #undef FC_RB +#undef FC_CB } typedef decltype(kernel_bin_fuse_impl) kernel_bin_fuse_t;