metal : fuse NORM + MUL + ADD, support non-multiples of 4 (#16220)

* metal : fuse NORM + MUL + ADD

* metal : support norms of non-multiple of 4

* cont : fix comment [no ci]
This commit is contained in:
Georgi Gerganov 2025-09-25 11:30:16 +03:00 committed by GitHub
parent 4ea00794b8
commit dfcd53f7ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 206 additions and 232 deletions

View File

@ -383,6 +383,7 @@ void ggml_graph_optimize(ggml_cgraph * gf) {
// fuse only ops that start with these operations
// can be expanded when needed
if (node.op() == GGML_OP_ADD ||
node.op() == GGML_OP_NORM ||
node.op() == GGML_OP_RMS_NORM) {
ops[0] = node.op();
@ -392,6 +393,7 @@ void ggml_graph_optimize(ggml_cgraph * gf) {
// can be expanded when needed
if (gf->nodes[f]->op != GGML_OP_ADD &&
gf->nodes[f]->op != GGML_OP_MUL &&
gf->nodes[f]->op != GGML_OP_NORM &&
gf->nodes[f]->op != GGML_OP_RMS_NORM) {
break;
}

View File

@ -1090,36 +1090,6 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin(
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
assert(op->op == GGML_OP_RMS_NORM);
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
char base[256];
char name[256];
switch (n_fuse) {
case 1: snprintf(base, 256, "kernel_rms_norm_f32"); break;
case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32"); break;
case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32"); break;
default: GGML_ABORT("fatal error");
}
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_L2_NORM);
@ -1167,16 +1137,37 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr
return res;
}
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_NORM);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM);
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
char base[256];
char name[256];
snprintf(base, 256, "kernel_norm_f32");
const char * suffix = "";
if (op->ne[0] % 4 == 0) {
suffix = "_4";
}
switch (op->op) {
case GGML_OP_NORM:
switch (n_fuse) {
case 1: snprintf(base, 256, "kernel_norm_f32%s", suffix); break;
case 2: snprintf(base, 256, "kernel_norm_mul_f32%s", suffix); break;
case 3: snprintf(base, 256, "kernel_norm_mul_add_f32%s", suffix); break;
default: GGML_ABORT("fatal error");
} break;
case GGML_OP_RMS_NORM:
switch (n_fuse) {
case 1: snprintf(base, 256, "kernel_rms_norm_f32%s", suffix); break;
case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32%s", suffix); break;
case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32%s", suffix); break;
default: GGML_ABORT("fatal error");
} break;
default: GGML_ABORT("fatal error");
}
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);

View File

@ -123,10 +123,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);

View File

@ -661,13 +661,13 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_ARGMAX:
return has_simdgroup_reduction;
case GGML_OP_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_RMS_NORM:
return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));
case GGML_OP_ROPE:
return true;
case GGML_OP_IM2COL:

View File

@ -428,16 +428,11 @@ typedef struct {
uint64_t nb1;
} ggml_metal_kargs_mul_mv_id;
// NORM
// RMS_NORM
typedef struct {
int32_t ne00;
int32_t ne00_4;
uint64_t nb01;
float eps;
} ggml_metal_kargs_norm;
typedef struct {
int32_t ne00;
int32_t ne00_4;
int32_t ne00_t;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
@ -448,7 +443,7 @@ typedef struct {
uint64_t nbf1[3];
uint64_t nbf2[3];
uint64_t nbf3[3];
} ggml_metal_kargs_rms_norm;
} ggml_metal_kargs_norm;
typedef struct {
int32_t ne00;

View File

@ -266,10 +266,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_set_rows(ctx, idx);
} break;
case GGML_OP_RMS_NORM:
{
n_fuse = ggml_metal_op_rms_norm(ctx, idx);
} break;
case GGML_OP_L2_NORM:
{
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
@ -279,6 +275,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
n_fuse = ggml_metal_op_group_norm(ctx, idx);
} break;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
{
n_fuse = ggml_metal_op_norm(ctx, idx);
} break;
@ -2346,146 +2343,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
return n_fuse;
}
int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
const int idx_end = ctx->idx_end;
const bool use_fusion = ctx->use_fusion;
const int debug_fusion = ctx->debug_fusion;
ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
float eps;
memcpy(&eps, op->op_params, sizeof(float));
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_kargs_rms_norm args = {
/*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.eps =*/ eps,
/*.nef1 =*/ { ne01 },
/*.nef2 =*/ { ne02 },
/*.nef3 =*/ { ne03 },
/*.nbf1 =*/ { nb01 },
/*.nbf2 =*/ { nb02 },
/*.nbf3 =*/ { nb03 },
};
ggml_op fops[8];
int n_fuse = 1;
ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 };
// d[0] = rms_norm(a)
// d[1] = mul(d[0], b)
// d[2] = add(d[1], c)
if (use_fusion) {
fops[0] = GGML_OP_RMS_NORM;
fops[1] = GGML_OP_MUL;
fops[2] = GGML_OP_ADD;
for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
break;
}
if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
break;
}
if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) {
break;
}
if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) {
break;
}
if (ops[n_fuse + 1]->type != GGML_TYPE_F32) {
break;
}
//ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1];
args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2];
args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3];
args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1];
args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2];
args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3];
}
++n_fuse;
if (debug_fusion > 1 && n_fuse > 1) {
if (n_fuse == 2) {
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
}
if (n_fuse == 3) {
GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
}
}
}
if (n_fuse > 1) {
bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
for (int i = 1; i < n_fuse; ++i) {
if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
ggml_metal_op_concurrency_reset(ctx);
break;
}
}
}
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rms_norm(lib, op, n_fuse);
int nth = 32; // SIMD width
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00/4);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2);
ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3);
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
return n_fuse;
}
int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
@ -2594,6 +2451,14 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
const int idx_end = ctx->idx_end;
const bool use_fusion = ctx->use_fusion;
const int debug_fusion = ctx->debug_fusion;
ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
@ -2602,37 +2467,121 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
float eps;
memcpy(&eps, op->op_params, sizeof(float));
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_kargs_norm args = {
/*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4,
/*.nb01 =*/ nb01,
/*.ne00_t =*/ ne00 % 4 == 0 ? ne00/4 : ne00,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.eps =*/ eps,
/*.nef1 =*/ { ne01 },
/*.nef2 =*/ { ne02 },
/*.nef3 =*/ { ne03 },
/*.nbf1 =*/ { nb01 },
/*.nbf2 =*/ { nb02 },
/*.nbf3 =*/ { nb03 },
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op);
ggml_op fops[8];
int n_fuse = 1;
ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 };
// d[0] = norm(a)
// d[1] = mul(d[0], b)
// d[2] = add(d[1], c)
if (use_fusion) {
fops[0] = op->op;
fops[1] = GGML_OP_MUL;
fops[2] = GGML_OP_ADD;
for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
break;
}
if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
break;
}
if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) {
break;
}
if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) {
break;
}
if (ops[n_fuse + 1]->type != GGML_TYPE_F32) {
break;
}
//ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1];
args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2];
args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3];
args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1];
args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2];
args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3];
}
++n_fuse;
if (debug_fusion > 1 && n_fuse > 1) {
if (n_fuse == 2) {
GGML_LOG_DEBUG("%s: fuse: %s + MUL\n", __func__, ggml_op_name(op->op));
}
if (n_fuse == 3) {
GGML_LOG_DEBUG("%s: fuse: %s + MUL + ADD\n", __func__, ggml_op_name(op->op));
}
}
}
if (n_fuse > 1) {
bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
for (int i = 1; i < n_fuse; ++i) {
if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
ggml_metal_op_concurrency_reset(ctx);
break;
}
}
}
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
int nth = 32; // SIMD width
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
while (nth < args.ne00_t && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00/4);
nth = std::min(nth, args.ne00_t);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
const int64_t nrows = ggml_nrows(op->src[0]);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2);
ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3);
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
return 1;
return n_fuse;
}
int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {

View File

@ -60,7 +60,6 @@ int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rms_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);

View File

@ -66,6 +66,10 @@ static inline float e8m0_to_fp32(uint8_t x) {
return as_type<float>(bits);
}
static inline float dot(float x, float y) {
return x*y;
}
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@ -2493,30 +2497,43 @@ kernel void kernel_argmax_f32(
dst_i32[tgpig] = arg_val;
}
kernel void kernel_norm_f32(
// F == 1 : norm (no fuse)
// F == 2 : norm + mul
// F == 3 : norm + mul + add
template <typename T, short F>
kernel void kernel_norm_fuse_impl(
constant ggml_metal_kargs_norm & args,
device const char * src0,
device const char * src1_0,
device const char * src1_1,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort ntg[[threads_per_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
const int i01 = tgpig.x;
const int i02 = tgpig.y;
const int i03 = tgpig.z;
float4 sumf4(0.0f);
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
T sumft(0.0f);
float sumf = 0.0f;
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
sumf4 += x[i00];
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
sumft += x[i00];
}
sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3];
sumf = dot(sumft, T(1.0f));
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
@ -2532,10 +2549,10 @@ kernel void kernel_norm_f32(
const float mean = sumf/args.ne00;
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
sumf = 0.0f;
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
y[i00] = x[i00] - mean;
sumf += dot(y[i00], y[i00]);
}
@ -2555,17 +2572,35 @@ kernel void kernel_norm_f32(
const float variance = sumf/args.ne00;
const float scale = 1.0f/sqrt(variance + args.eps);
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
y[i00] = y[i00] * scale;
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
if (F == 1) {
y[i00] = (y[i00]*scale);
}
if (F == 2) {
y[i00] = (y[i00]*scale)*f0[i00];
}
if (F == 3) {
y[i00] = (y[i00]*scale)*f0[i00] + f1[i00];
}
}
}
typedef decltype(kernel_norm_fuse_impl<float4, 1>) kernel_norm_fuse_t;
template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 1>;
template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 2>;
template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 3>;
template [[host_name("kernel_norm_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 1>;
template [[host_name("kernel_norm_mul_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 3>;
// F == 1 : rms_norm (no fuse)
// F == 2 : rms_norm + mul
// F == 3 : rms_norm + mul + add
template <short F>
template <typename T, short F>
kernel void kernel_rms_norm_fuse_impl(
constant ggml_metal_kargs_rms_norm & args,
constant ggml_metal_kargs_norm & args,
device const char * src0,
device const char * src1_0,
device const char * src1_1,
@ -2584,15 +2619,15 @@ kernel void kernel_rms_norm_fuse_impl(
const int i02 = tgpig.y;
const int i03 = tgpig.z;
device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
@ -2611,8 +2646,8 @@ kernel void kernel_rms_norm_fuse_impl(
const float mean = sumf/args.ne00;
const float scale = 1.0f/sqrt(mean + args.eps);
device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
if (F == 1) {
y[i00] = (x[i00]*scale);
}
@ -2625,11 +2660,15 @@ kernel void kernel_rms_norm_fuse_impl(
}
}
typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t;
typedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t;
template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>;
template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>;
template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>;
template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 1>;
template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 3>;
template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>;
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
kernel void kernel_l2_norm_f32(
constant ggml_metal_kargs_l2_norm & args,

View File

@ -6117,7 +6117,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));