metal : fuse add + rms

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-09-18 16:28:49 +03:00
parent 703f9e32c4
commit 652d303b32
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
5 changed files with 124 additions and 13 deletions

View File

@ -1090,7 +1090,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin(
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse, bool fuse_prev) {
assert(op->op == GGML_OP_RMS_NORM);
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
@ -1099,10 +1099,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_librar
char base[256];
char name[256];
const char * suffix = fuse_prev ? "_prev" : "";
switch (n_fuse) {
case 1: snprintf(base, 256, "kernel_rms_norm_f32"); break;
case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32"); break;
case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32"); break;
case 1: snprintf(base, 256, "kernel_rms_norm%s_f32", suffix); break;
case 2: snprintf(base, 256, "kernel_rms_norm%s_mul_f32", suffix); break;
case 3: snprintf(base, 256, "kernel_rms_norm%s_mul_add_f32", suffix); break;
default: GGML_ABORT("fatal error");
}

View File

@ -123,7 +123,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse, bool fuse_prev);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);

View File

@ -193,6 +193,19 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
case GGML_OP_MUL:
case GGML_OP_DIV:
{
if (node->op == GGML_OP_ADD) {
const ggml_op fops[2] = { GGML_OP_ADD, GGML_OP_RMS_NORM };
if (idx + 1 < ctx->idx_end &&
nodes[1]->op == GGML_OP_RMS_NORM &&
ggml_can_fuse(gf, idx, fops, 2) &&
ggml_are_same_layout(node, node->src[0]) &&
ggml_are_same_layout(node, node->src[1])) {
n_fuse = ggml_metal_op_rms_norm(ctx, idx + 1, true) + 1;
break;
}
}
n_fuse = ggml_metal_op_bin(ctx, idx);
} break;
case GGML_OP_ADD_ID:
@ -268,7 +281,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
} break;
case GGML_OP_RMS_NORM:
{
n_fuse = ggml_metal_op_rms_norm(ctx, idx);
n_fuse = ggml_metal_op_rms_norm(ctx, idx, false);
} break;
case GGML_OP_L2_NORM:
{
@ -2346,7 +2359,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
return n_fuse;
}
int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) {
int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx, bool fuse_prev) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
@ -2372,6 +2385,17 @@ int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) {
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_buffer_id bid_src0_0;
ggml_metal_buffer_id bid_src0_1;
if (fuse_prev) {
ggml_tensor * prev = ggml_graph_node(gf, idx - 1);
GGML_ASSERT(prev->op == GGML_OP_ADD);
bid_src0_0 = ggml_metal_get_buffer_id(prev->src[0]);
bid_src0_1 = ggml_metal_get_buffer_id(prev->src[1]);
}
ggml_metal_kargs_rms_norm args = {
/*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4,
@ -2459,7 +2483,7 @@ int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) {
}
}
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rms_norm(lib, op, n_fuse);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rms_norm(lib, op, n_fuse, fuse_prev);
int nth = 32; // SIMD width
@ -2474,10 +2498,18 @@ int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2);
ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3);
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
if (fuse_prev) {
ggml_metal_encoder_set_buffer (enc, bid_src0_0, 1);
ggml_metal_encoder_set_buffer (enc, bid_src0_1, 2);
ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 3);
ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 4);
ggml_metal_encoder_set_buffer (enc, bid_dst, 5);
} else {
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2);
ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3);
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
}
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);

View File

@ -60,7 +60,7 @@ int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rms_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rms_norm (ggml_metal_op_t ctx, int idx, bool fuse_prev);
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);

View File

@ -2631,6 +2631,83 @@ template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fus
template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>;
template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>;
// TODO: same as kernel_rms_norm_fuse_impl but fuses the previous ADD operation
// can be simplified, but for now there does not seem to be benefit from this fusion so leaving it like this
// F == 1 : rms_norm (no fuse)
// F == 2 : rms_norm + mul
// F == 3 : rms_norm + mul + add
template <short F>
kernel void kernel_rms_norm_prev_fuse_impl(
constant ggml_metal_kargs_rms_norm & args,
device const char * src0_0, // src0 = src0_0 + src1_0
device const char * src0_1,
device const char * src1_0,
device const char * src1_1,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
const int i01 = tgpig.x;
const int i02 = tgpig.y;
const int i03 = tgpig.z;
device const float4 * x0 = (device const float4 *) (src0_0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
device const float4 * x1 = (device const float4 *) (src0_1 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
const float4 x = x0[i00] + x1[i00];
sumf += dot(x, x);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float mean = sumf/args.ne00;
const float scale = 1.0f/sqrt(mean + args.eps);
device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
const float4 x = x0[i00] + x1[i00];
if (F == 1) {
y[i00] = (x*scale);
}
if (F == 2) {
y[i00] = (x*scale)*f0[i00];
}
if (F == 3) {
y[i00] = (x*scale)*f0[i00] + f1[i00];
}
}
}
typedef decltype(kernel_rms_norm_prev_fuse_impl<1>) kernel_rms_norm_prev_fuse_t;
template [[host_name("kernel_rms_norm_prev_f32")]] kernel kernel_rms_norm_prev_fuse_t kernel_rms_norm_prev_fuse_impl<1>;
template [[host_name("kernel_rms_norm_prev_mul_f32")]] kernel kernel_rms_norm_prev_fuse_t kernel_rms_norm_prev_fuse_impl<2>;
template [[host_name("kernel_rms_norm_prev_mul_add_f32")]] kernel kernel_rms_norm_prev_fuse_t kernel_rms_norm_prev_fuse_impl<3>;
kernel void kernel_l2_norm_f32(
constant ggml_metal_kargs_l2_norm & args,
device const char * src0,