This commit is contained in:
Georgi Gerganov 2026-02-06 14:09:42 +02:00
parent 2c1a62aff0
commit 21f6a37348
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
5 changed files with 115 additions and 383 deletions

View File

@ -1412,8 +1412,10 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l
const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;
snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, n_fuse);
snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
@ -1421,13 +1423,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l
ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
res.c4 = is_c4;
res.c4 = is_c4;
res.cnt = is_rb;
return res;
}

View File

@ -55,6 +55,7 @@ struct ggml_metal_pipeline_with_params {
size_t smem;
bool c4;
bool cnt;
};
int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);

View File

@ -346,11 +346,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta
struct ggml_metal_pipeline_with_params res = {
/*.pipeline =*/ nil,
/*.nsg =*/ 0,
/*.nr0 =*/ 0,
/*.nr1 =*/ 0,
/*.nsg =*/ 0,
/*.smem =*/ 0,
/*.c4 =*/ false,
/*.cnt =*/ false,
};
res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
@ -363,11 +364,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta
struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
struct ggml_metal_pipeline_with_params res = {
/*.pipeline =*/ nil,
/*.nsg =*/ 0,
/*.nr0 =*/ 0,
/*.nr1 =*/ 0,
/*.nsg =*/ 0,
/*.smem =*/ 0,
/*.c4 =*/ false,
/*.cnt =*/ false,
};
[lib->lock lock];

View File

@ -3014,21 +3014,22 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
if (pipeline.cnt) {
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
int nth = 1;
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
} else {
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
while (4*nth < args.ne0 && nth < nth_max) {
nth *= 2;
int nth = 1;
while (2*nth < args.ne0 && nth < nth_max) {
nth *= 2;
}
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
}
int nb = 1;
while (4*nb < ne01 && nth*nb < nth_max) {
nb *= 2;
}
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nb - 1)/nb, ne02, ne03, nth, nb, 1);
return n_fuse;
}

View File

@ -895,153 +895,10 @@ enum ggml_sort_order {
GGML_SORT_ORDER_DESC,
};
//// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
//// pros: works for non-contiguous tensors, supports broadcast across all dims
//// cons: not very efficient
//template <int F>
//kernel void kernel_add_fuse_impl(
// constant ggml_metal_kargs_bin & args,
// device const char * src0,
// device const char * src1,
// device char * dst,
// uint3 tgpig[[threadgroup_position_in_grid]],
// ushort3 tpitg[[thread_position_in_threadgroup]],
// ushort3 ntg[[threads_per_threadgroup]]) {
// const int i03 = tgpig.z;
// const int i02 = tgpig.y;
// const int i01 = tgpig.x;
//
// const int i13 = i03%args.ne13;
// const int i12 = i02%args.ne12;
// const int i11 = i01%args.ne11;
//
// device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
// device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
//
// device const float * src1_ptr[F];
// for (short j = 0; j < F; ++j) {
// src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
// }
//
// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
// const int i10 = i0%args.ne10;
//
// float res = src0_ptr[i0];
//
//#pragma unroll
// for (short j = 0; j < F; ++j) {
// res += src1_ptr[j][i10];
// }
//
// dst_ptr[i0] = res;
// }
//}
//
//typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
//
//template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
//template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
//template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
//template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
//template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
//template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
//template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
//template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
//
//kernel void kernel_sub_fuse_1(
// constant ggml_metal_kargs_bin & args,
// device const char * src0,
// device const char * src1,
// device char * dst,
// uint3 tgpig[[threadgroup_position_in_grid]],
// ushort3 tpitg[[thread_position_in_threadgroup]],
// ushort3 ntg[[threads_per_threadgroup]]) {
// const int i03 = tgpig.z;
// const int i02 = tgpig.y;
// const int i01 = tgpig.x;
//
// const int i13 = i03%args.ne13;
// const int i12 = i02%args.ne12;
// const int i11 = i01%args.ne11;
//
// device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
// device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
// device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
//
// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
// const int i10 = i0%args.ne10;
// *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
// }
//}
//
//kernel void kernel_mul_fuse_1(
// constant ggml_metal_kargs_bin & args,
// device const char * src0,
// device const char * src1,
// device char * dst,
// uint3 tgpig[[threadgroup_position_in_grid]],
// ushort3 tpitg[[thread_position_in_threadgroup]],
// ushort3 ntg[[threads_per_threadgroup]]) {
// const int i03 = tgpig.z;
// const int i02 = tgpig.y;
// const int i01 = tgpig.x;
//
// const int i13 = i03%args.ne13;
// const int i12 = i02%args.ne12;
// const int i11 = i01%args.ne11;
//
// device const float * src0_ptr = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
// device const float * src1_ptr = (device const float *)(src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]);
// device float * dst_ptr = (device float *)(dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
//
// if (args.ne10 == 1) {
// const float x = src1_ptr[0];
// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
// dst_ptr[i0] = src0_ptr[i0] * x;
// }
// } else {
// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
// const int i10 = i0 % args.ne10;
// dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
// }
// }
//}
//
//kernel void kernel_div_fuse_1(
// constant ggml_metal_kargs_bin & args,
// device const char * src0,
// device const char * src1,
// device char * dst,
// uint3 tgpig[[threadgroup_position_in_grid]],
// ushort3 tpitg[[thread_position_in_threadgroup]],
// ushort3 ntg[[threads_per_threadgroup]]) {
// const int i03 = tgpig.z;
// const int i02 = tgpig.y;
// const int i01 = tgpig.x;
//
// const int i13 = i03%args.ne13;
// const int i12 = i02%args.ne12;
// const int i11 = i01%args.ne11;
//
// device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
// device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
// device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
//
// if (args.ne10 == 1) {
// const float x = 1.0f / *((device float *)(src1_ptr));
// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
// *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
// }
// } else {
// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
// const int i10 = i0%args.ne10;
// *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
// }
// }
//}
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
template <typename T0, typename T1, typename T>
kernel void kernel_bin_fuse_impl(
@ -1052,136 +909,134 @@ kernel void kernel_bin_fuse_impl(
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
const short OP = FC_bin_op;
const short F = FC_bin_f;
#define FC_OP FC_bin_op
#define FC_F FC_bin_f
#define FC_RB FC_bin_rb
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x*ntg.y + tpitg.y;
if (FC_RB) {
// row broadcast
const uint i0 = tgpig.x;
const uint i1 = i0%args.ne10;
if (i01 >= args.ne01) {
return;
}
device const T0 * src0_row = (device const T0 *) (src0);
device T * dst_row = (device T *) (dst);
const int i13 = i03%args.ne13;
const int i12 = i02%args.ne12;
const int i11 = i01%args.ne11;
if (FC_F == 1) {
device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
if (FC_OP == 0) {
dst_row[i0] = src0_row[i0] + src1_row[i1];
}
if (F == 1) {
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
if (FC_OP == 1) {
dst_row[i0] = src0_row[i0] - src1_row[i1];
}
if (args.ne10 == 1) {
T1 src1_cur = src1_ptr[0];
if (FC_OP == 2) {
dst_row[i0] = src0_row[i0] * src1_row[i1];
}
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
if (OP == 0) {
dst_ptr[i0] = src0_ptr[i0] + src1_cur;
}
if (OP == 1) {
dst_ptr[i0] = src0_ptr[i0] - src1_cur;
}
if (OP == 2) {
dst_ptr[i0] = src0_ptr[i0] * src1_cur;
}
if (OP == 3) {
dst_ptr[i0] = src0_ptr[i0] / src1_cur;
}
if (FC_OP == 3) {
dst_row[i0] = src0_row[i0] / src1_row[i1];
}
} else {
T0 res = src0_row[i0];
if (FC_OP == 0) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res += ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 1) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 2) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 3) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
dst_row[i0] = res;
}
} else {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;
if (i01 >= args.ne01) {
return;
}
const int i13 = i03%args.ne13;
const int i12 = i02%args.ne12;
const int i11 = i01%args.ne11;
device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
if (FC_F == 1) {
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
if (OP == 0) {
if (FC_OP == 0) {
dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
}
if (OP == 1) {
if (FC_OP == 1) {
dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
}
if (OP == 2) {
if (FC_OP == 2) {
dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
}
if (OP == 3) {
if (FC_OP == 3) {
dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
}
}
}
} else {
device const T1 * src1_ptr[8];
FOR_UNROLL (short j = 0; j < F; ++j) {
src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
}
if (args.ne10 == 1) {
T1 src1_cur[8];
FOR_UNROLL (short j = 0; j < F; ++j) {
src1_cur[j] = src1_ptr[j][0];
}
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
T res = src0_ptr[i0];
if (OP == 0) {
FOR_UNROLL (short j = 0; j < F; ++j) {
res += src1_cur[j];
}
}
if (OP == 1) {
FOR_UNROLL (short j = 0; j < F; ++j) {
res -= src1_cur[j];
}
}
if (OP == 2) {
FOR_UNROLL (short j = 0; j < F; ++j) {
res *= src1_cur[j];
}
}
if (OP == 3) {
FOR_UNROLL (short j = 0; j < F; ++j) {
res /= src1_cur[j];
}
}
dst_ptr[i0] = res;
}
} else {
device const T1 * src1_ptr[8];
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
}
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = i0%args.ne10;
T res = src0_ptr[i0];
if (OP == 0) {
FOR_UNROLL (short j = 0; j < F; ++j) {
if (FC_OP == 0) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res += src1_ptr[j][i10];
}
}
if (OP == 1) {
FOR_UNROLL (short j = 0; j < F; ++j) {
if (FC_OP == 1) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res -= src1_ptr[j][i10];
}
}
if (OP == 2) {
FOR_UNROLL (short j = 0; j < F; ++j) {
if (FC_OP == 2) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res *= src1_ptr[j][i10];
}
}
if (OP == 3) {
FOR_UNROLL (short j = 0; j < F; ++j) {
if (FC_OP == 3) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res /= src1_ptr[j][i10];
}
}
@ -1190,6 +1045,10 @@ kernel void kernel_bin_fuse_impl(
}
}
}
#undef FC_OP
#undef FC_F
#undef FC_RB
}
typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
@ -1255,141 +1114,6 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
// assumption: src1 is a row
// broadcast src1 into src0
template <short F>
kernel void kernel_add_row_c4_fuse_impl(
constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tpig[[thread_position_in_grid]]) {
const uint nb = args.ne00/4;
const uint i = tpig % nb;
device const float4 * src0_row = (device const float4 *) (src0);
device float4 * dst_row = (device float4 *) (dst);
float4 res = src0_row[tpig];
#pragma unroll(F)
for (short j = 0; j < F; ++j) {
res += ((device const float4 *) (src1 + args.o1[j]))[i];
}
dst_row[tpig] = res;
}
typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
template <short F>
kernel void kernel_sub_row_c4_fuse_impl(
constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tpig[[thread_position_in_grid]]) {
const uint nb = args.ne00/4;
const uint i = tpig % nb;
device const float4 * src0_row = (device const float4 *) (src0);
device float4 * dst_row = (device float4 *) (dst);
device const float4 * src1_row[F];
for (short j = 0; j < F; ++j) {
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
}
float4 res = src0_row[tpig];
#pragma unroll(F)
for (short j = 0; j < F; ++j) {
res -= src1_row[j][i];
}
dst_row[tpig] = res;
}
typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
template <short F>
kernel void kernel_mul_row_c4_fuse_impl(
constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tpig[[thread_position_in_grid]]) {
const uint nb = args.ne00/4;
const uint i = tpig % nb;
device const float4 * src0_row = (device const float4 *) (src0);
device float4 * dst_row = (device float4 *) (dst);
device const float4 * src1_row[F];
for (short j = 0; j < F; ++j) {
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
}
float4 res = src0_row[tpig];
#pragma unroll(F)
for (short j = 0; j < F; ++j) {
res *= src1_row[j][i];
}
dst_row[tpig] = res;
}
typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
template <short F>
kernel void kernel_div_row_c4_fuse_impl(
constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tpig[[thread_position_in_grid]]) {
const uint nb = args.ne00/4;
const uint i = tpig % nb;
device const float4 * src0_row = (device const float4 *) (src0);
device float4 * dst_row = (device float4 *) (dst);
device const float4 * src1_row[F];
for (short j = 0; j < F; ++j) {
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
}
float4 res = src0_row[tpig];
#pragma unroll(F)
for (short j = 0; j < F; ++j) {
res /= src1_row[j][i];
}
dst_row[tpig] = res;
}
typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
kernel void kernel_scale_f32(
constant ggml_metal_kargs_scale & args,
device const float * src0,