From 652d303b3270fc8a4682fd4aa4181678c14e9f4a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Sep 2025 16:28:49 +0300 Subject: [PATCH] metal : fuse add + rms ggml-ci --- ggml/src/ggml-metal/ggml-metal-device.cpp | 10 +-- ggml/src/ggml-metal/ggml-metal-device.h | 2 +- ggml/src/ggml-metal/ggml-metal-ops.cpp | 46 +++++++++++--- ggml/src/ggml-metal/ggml-metal-ops.h | 2 +- ggml/src/ggml-metal/ggml-metal.metal | 77 +++++++++++++++++++++++ 5 files changed, 124 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index fe015afc54..11958f002d 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1090,7 +1090,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin( return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) { +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse, bool fuse_prev) { assert(op->op == GGML_OP_RMS_NORM); GGML_ASSERT(op->src[0]->ne[0] % 4 == 0); @@ -1099,10 +1099,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_librar char base[256]; char name[256]; + const char * suffix = fuse_prev ? "_prev" : ""; + switch (n_fuse) { - case 1: snprintf(base, 256, "kernel_rms_norm_f32"); break; - case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32"); break; - case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32"); break; + case 1: snprintf(base, 256, "kernel_rms_norm%s_f32", suffix); break; + case 2: snprintf(base, 256, "kernel_rms_norm%s_mul_f32", suffix); break; + case 3: snprintf(base, 256, "kernel_rms_norm%s_mul_add_f32", suffix); break; default: GGML_ABORT("fatal error"); } diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 044d6953f6..2ea01ff251 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -123,7 +123,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse, bool fuse_prev); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 04665b3d6d..d042101f72 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -193,6 +193,19 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { case GGML_OP_MUL: case GGML_OP_DIV: { + if (node->op == GGML_OP_ADD) { + const ggml_op fops[2] = { GGML_OP_ADD, GGML_OP_RMS_NORM }; + + if (idx + 1 < ctx->idx_end && + nodes[1]->op == GGML_OP_RMS_NORM && + ggml_can_fuse(gf, idx, fops, 2) && + ggml_are_same_layout(node, node->src[0]) && + ggml_are_same_layout(node, node->src[1])) { + n_fuse = ggml_metal_op_rms_norm(ctx, idx + 1, true) + 1; + break; + } + } + n_fuse = ggml_metal_op_bin(ctx, idx); } break; case GGML_OP_ADD_ID: @@ -268,7 +281,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { } break; case GGML_OP_RMS_NORM: { - n_fuse = ggml_metal_op_rms_norm(ctx, idx); + n_fuse = ggml_metal_op_rms_norm(ctx, idx, false); } break; case GGML_OP_L2_NORM: { @@ -2346,7 +2359,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { return n_fuse; } -int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) { +int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx, bool fuse_prev) { ggml_cgraph * gf = ctx->gf; ggml_tensor * op = ggml_graph_node(gf, idx); @@ -2372,6 +2385,17 @@ int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) { ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + ggml_metal_buffer_id bid_src0_0; + ggml_metal_buffer_id bid_src0_1; + + if (fuse_prev) { + ggml_tensor * prev = ggml_graph_node(gf, idx - 1); + GGML_ASSERT(prev->op == GGML_OP_ADD); + + bid_src0_0 = ggml_metal_get_buffer_id(prev->src[0]); + bid_src0_1 = ggml_metal_get_buffer_id(prev->src[1]); + } + ggml_metal_kargs_rms_norm args = { /*.ne00 =*/ ne00, /*.ne00_4 =*/ ne00/4, @@ -2459,7 +2483,7 @@ int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) { } } - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rms_norm(lib, op, n_fuse); + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rms_norm(lib, op, n_fuse, fuse_prev); int nth = 32; // SIMD width @@ -2474,10 +2498,18 @@ int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, bid_src0, 1); - ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2); - ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3); - ggml_metal_encoder_set_buffer (enc, bid_dst, 4); + if (fuse_prev) { + ggml_metal_encoder_set_buffer (enc, bid_src0_0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src0_1, 2); + ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 3); + ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 4); + ggml_metal_encoder_set_buffer (enc, bid_dst, 5); + } else { + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2); + ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3); + ggml_metal_encoder_set_buffer (enc, bid_dst, 4); + } ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index b620de164d..d073e31d62 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -60,7 +60,7 @@ int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx); int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx); int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx); int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_rms_norm (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_rms_norm (ggml_metal_op_t ctx, int idx, bool fuse_prev); int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c7d97ba70b..bf28f1ee4b 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2631,6 +2631,83 @@ template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fus template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>; template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>; +// TODO: same as kernel_rms_norm_fuse_impl but fuses the previous ADD operation +// can be simplified, but for now there does not seem to be benefit from this fusion so leaving it like this +// F == 1 : rms_norm (no fuse) +// F == 2 : rms_norm + mul +// F == 3 : rms_norm + mul + add +template +kernel void kernel_rms_norm_prev_fuse_impl( + constant ggml_metal_kargs_rms_norm & args, + device const char * src0_0, // src0 = src0_0 + src1_0 + device const char * src0_1, + device const char * src1_0, + device const char * src1_1, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; + } + + const int i01 = tgpig.x; + const int i02 = tgpig.y; + const int i03 = tgpig.z; + + device const float4 * x0 = (device const float4 *) (src0_0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]); + device const float4 * x1 = (device const float4 *) (src0_1 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]); + + device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]); + device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]); + + float sumf = 0.0f; + + // parallel sum + for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) { + const float4 x = x0[i00] + x1[i00]; + sumf += dot(x, x); + } + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float mean = sumf/args.ne00; + const float scale = 1.0f/sqrt(mean + args.eps); + + device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); + for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) { + const float4 x = x0[i00] + x1[i00]; + if (F == 1) { + y[i00] = (x*scale); + } + if (F == 2) { + y[i00] = (x*scale)*f0[i00]; + } + if (F == 3) { + y[i00] = (x*scale)*f0[i00] + f1[i00]; + } + } +} + +typedef decltype(kernel_rms_norm_prev_fuse_impl<1>) kernel_rms_norm_prev_fuse_t; + +template [[host_name("kernel_rms_norm_prev_f32")]] kernel kernel_rms_norm_prev_fuse_t kernel_rms_norm_prev_fuse_impl<1>; +template [[host_name("kernel_rms_norm_prev_mul_f32")]] kernel kernel_rms_norm_prev_fuse_t kernel_rms_norm_prev_fuse_impl<2>; +template [[host_name("kernel_rms_norm_prev_mul_add_f32")]] kernel kernel_rms_norm_prev_fuse_t kernel_rms_norm_prev_fuse_impl<3>; + kernel void kernel_l2_norm_f32( constant ggml_metal_kargs_l2_norm & args, device const char * src0,