parent
703f9e32c4
commit
652d303b32
|
|
@ -1090,7 +1090,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin(
|
|||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse, bool fuse_prev) {
|
||||
assert(op->op == GGML_OP_RMS_NORM);
|
||||
|
||||
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
|
||||
|
|
@ -1099,10 +1099,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_librar
|
|||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const char * suffix = fuse_prev ? "_prev" : "";
|
||||
|
||||
switch (n_fuse) {
|
||||
case 1: snprintf(base, 256, "kernel_rms_norm_f32"); break;
|
||||
case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32"); break;
|
||||
case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32"); break;
|
||||
case 1: snprintf(base, 256, "kernel_rms_norm%s_f32", suffix); break;
|
||||
case 2: snprintf(base, 256, "kernel_rms_norm%s_mul_f32", suffix); break;
|
||||
case 3: snprintf(base, 256, "kernel_rms_norm%s_mul_add_f32", suffix); break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_me
|
|||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse, bool fuse_prev);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
|
|
|||
|
|
@ -193,6 +193,19 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
{
|
||||
if (node->op == GGML_OP_ADD) {
|
||||
const ggml_op fops[2] = { GGML_OP_ADD, GGML_OP_RMS_NORM };
|
||||
|
||||
if (idx + 1 < ctx->idx_end &&
|
||||
nodes[1]->op == GGML_OP_RMS_NORM &&
|
||||
ggml_can_fuse(gf, idx, fops, 2) &&
|
||||
ggml_are_same_layout(node, node->src[0]) &&
|
||||
ggml_are_same_layout(node, node->src[1])) {
|
||||
n_fuse = ggml_metal_op_rms_norm(ctx, idx + 1, true) + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
n_fuse = ggml_metal_op_bin(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_ADD_ID:
|
||||
|
|
@ -268,7 +281,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
} break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
{
|
||||
n_fuse = ggml_metal_op_rms_norm(ctx, idx);
|
||||
n_fuse = ggml_metal_op_rms_norm(ctx, idx, false);
|
||||
} break;
|
||||
case GGML_OP_L2_NORM:
|
||||
{
|
||||
|
|
@ -2346,7 +2359,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|||
return n_fuse;
|
||||
}
|
||||
|
||||
int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) {
|
||||
int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx, bool fuse_prev) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
|
||||
|
|
@ -2372,6 +2385,17 @@ int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) {
|
|||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_buffer_id bid_src0_0;
|
||||
ggml_metal_buffer_id bid_src0_1;
|
||||
|
||||
if (fuse_prev) {
|
||||
ggml_tensor * prev = ggml_graph_node(gf, idx - 1);
|
||||
GGML_ASSERT(prev->op == GGML_OP_ADD);
|
||||
|
||||
bid_src0_0 = ggml_metal_get_buffer_id(prev->src[0]);
|
||||
bid_src0_1 = ggml_metal_get_buffer_id(prev->src[1]);
|
||||
}
|
||||
|
||||
ggml_metal_kargs_rms_norm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne00_4 =*/ ne00/4,
|
||||
|
|
@ -2459,7 +2483,7 @@ int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rms_norm(lib, op, n_fuse);
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rms_norm(lib, op, n_fuse, fuse_prev);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
|
|
@ -2474,10 +2498,18 @@ int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
|
||||
if (fuse_prev) {
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0_0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0_1, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 3);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 4);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 5);
|
||||
} else {
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
|
||||
}
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx);
|
|||
int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rms_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rms_norm (ggml_metal_op_t ctx, int idx, bool fuse_prev);
|
||||
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
|
||||
|
|
|
|||
|
|
@ -2631,6 +2631,83 @@ template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fus
|
|||
template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>;
|
||||
template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>;
|
||||
|
||||
// TODO: same as kernel_rms_norm_fuse_impl but fuses the previous ADD operation
|
||||
// can be simplified, but for now there does not seem to be benefit from this fusion so leaving it like this
|
||||
// F == 1 : rms_norm (no fuse)
|
||||
// F == 2 : rms_norm + mul
|
||||
// F == 3 : rms_norm + mul + add
|
||||
template <short F>
|
||||
kernel void kernel_rms_norm_prev_fuse_impl(
|
||||
constant ggml_metal_kargs_rms_norm & args,
|
||||
device const char * src0_0, // src0 = src0_0 + src1_0
|
||||
device const char * src0_1,
|
||||
device const char * src1_0,
|
||||
device const char * src1_1,
|
||||
device char * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
const int i01 = tgpig.x;
|
||||
const int i02 = tgpig.y;
|
||||
const int i03 = tgpig.z;
|
||||
|
||||
device const float4 * x0 = (device const float4 *) (src0_0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
|
||||
device const float4 * x1 = (device const float4 *) (src0_1 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
|
||||
|
||||
device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
|
||||
device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
// parallel sum
|
||||
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
|
||||
const float4 x = x0[i00] + x1[i00];
|
||||
sumf += dot(x, x);
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
const float mean = sumf/args.ne00;
|
||||
const float scale = 1.0f/sqrt(mean + args.eps);
|
||||
|
||||
device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
||||
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
|
||||
const float4 x = x0[i00] + x1[i00];
|
||||
if (F == 1) {
|
||||
y[i00] = (x*scale);
|
||||
}
|
||||
if (F == 2) {
|
||||
y[i00] = (x*scale)*f0[i00];
|
||||
}
|
||||
if (F == 3) {
|
||||
y[i00] = (x*scale)*f0[i00] + f1[i00];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_rms_norm_prev_fuse_impl<1>) kernel_rms_norm_prev_fuse_t;
|
||||
|
||||
template [[host_name("kernel_rms_norm_prev_f32")]] kernel kernel_rms_norm_prev_fuse_t kernel_rms_norm_prev_fuse_impl<1>;
|
||||
template [[host_name("kernel_rms_norm_prev_mul_f32")]] kernel kernel_rms_norm_prev_fuse_t kernel_rms_norm_prev_fuse_impl<2>;
|
||||
template [[host_name("kernel_rms_norm_prev_mul_add_f32")]] kernel kernel_rms_norm_prev_fuse_t kernel_rms_norm_prev_fuse_impl<3>;
|
||||
|
||||
kernel void kernel_l2_norm_f32(
|
||||
constant ggml_metal_kargs_l2_norm & args,
|
||||
device const char * src0,
|
||||
|
|
|
|||
Loading…
Reference in New Issue