metal : refactor bin kernels
This commit is contained in:
parent
7fcf1ef45d
commit
2c1a62aff0
|
|
@ -1392,34 +1392,73 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v
|
|||
GGML_UNUSED(op);
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(
|
||||
ggml_metal_library_t lib,
|
||||
ggml_op op,
|
||||
int32_t n_fuse,
|
||||
bool row) {
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const char * op_str = "undefined";
|
||||
switch (op) {
|
||||
case GGML_OP_ADD: op_str = "add"; break;
|
||||
case GGML_OP_SUB: op_str = "sub"; break;
|
||||
case GGML_OP_MUL: op_str = "mul"; break;
|
||||
case GGML_OP_DIV: op_str = "div"; break;
|
||||
int op_num = -1;
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_ADD: op_num = 0; break;
|
||||
case GGML_OP_SUB: op_num = 1; break;
|
||||
case GGML_OP_MUL: op_num = 2; break;
|
||||
case GGML_OP_DIV: op_num = 3; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
if (row) {
|
||||
snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse);
|
||||
} else {
|
||||
snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse);
|
||||
}
|
||||
const char * t0_str = ggml_type_name(op->src[0]->type);
|
||||
const char * t1_str = ggml_type_name(op->src[1]->type);
|
||||
const char * t_str = ggml_type_name(op->type);
|
||||
|
||||
snprintf(name, 256, "%s", base);
|
||||
const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
|
||||
|
||||
snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
|
||||
snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, n_fuse);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
|
||||
ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
}
|
||||
|
||||
res.c4 = is_c4;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
int op_num = -1;
|
||||
|
||||
switch (op) {
|
||||
case GGML_OP_ADD: op_num = 0; break;
|
||||
case GGML_OP_SUB: op_num = 1; break;
|
||||
case GGML_OP_MUL: op_num = 2; break;
|
||||
case GGML_OP_DIV: op_num = 3; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32");
|
||||
snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int32(cv, op_num, FC_BIN + 0);
|
||||
ggml_metal_cv_set_int32(cv, 1, FC_BIN + 1);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
}
|
||||
|
||||
return res;
|
||||
|
|
|
|||
|
|
@ -53,6 +53,8 @@ struct ggml_metal_pipeline_with_params {
|
|||
int nr1;
|
||||
|
||||
size_t smem;
|
||||
|
||||
bool c4;
|
||||
};
|
||||
|
||||
int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);
|
||||
|
|
@ -134,7 +136,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort
|
|||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse );
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one (ggml_metal_library_t lib, enum ggml_op op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
|
||||
|
|
|
|||
|
|
@ -350,6 +350,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta
|
|||
/*.nr1 =*/ 0,
|
||||
/*.nsg =*/ 0,
|
||||
/*.smem =*/ 0,
|
||||
/*.c4 =*/ false,
|
||||
};
|
||||
|
||||
res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
|
||||
|
|
@ -366,6 +367,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_
|
|||
/*.nr1 =*/ 0,
|
||||
/*.nsg =*/ 0,
|
||||
/*.smem =*/ 0,
|
||||
/*.c4 =*/ false,
|
||||
};
|
||||
|
||||
[lib->lock lock];
|
||||
|
|
@ -1054,7 +1056,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_ADD_ID:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_SCALE:
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@
|
|||
#define FC_SSM_CONV 900
|
||||
#define FC_SOLVE_TRI 1000
|
||||
#define FC_COUNT_EQUAL 1100
|
||||
#define FC_BIN 1200
|
||||
|
||||
// op-specific constants
|
||||
#define OP_FLASH_ATTN_EXT_NQPSG 8
|
||||
|
|
|
|||
|
|
@ -707,7 +707,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
|||
/*.o1 =*/ { 0 },
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
|
||||
auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
|
|
@ -2895,8 +2895,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
|
||||
|
||||
bool bcast_row = false;
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
|
@ -2990,18 +2988,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
struct ggml_metal_pipeline_with_params pipeline;
|
||||
|
||||
if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
|
||||
// src1 is a row
|
||||
GGML_ASSERT(ne11 == 1);
|
||||
|
||||
pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true);
|
||||
|
||||
bcast_row = true;
|
||||
} else {
|
||||
pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false);
|
||||
}
|
||||
pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);
|
||||
|
||||
if (n_fuse > 1) {
|
||||
bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
|
||||
|
|
@ -3015,26 +3002,33 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|||
}
|
||||
}
|
||||
|
||||
if (pipeline.c4) {
|
||||
args.ne00 = ne00/4;
|
||||
args.ne10 = ne10/4;
|
||||
args.ne0 = ne0/4;
|
||||
}
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
|
||||
|
||||
if (bcast_row) {
|
||||
const int64_t n = ggml_nelements(op)/4;
|
||||
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
} else {
|
||||
int nth = 32;
|
||||
int nth = 1;
|
||||
|
||||
while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
||||
while (4*nth < args.ne0 && nth < nth_max) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
int nb = 1;
|
||||
while (4*nb < ne01 && nth*nb < nth_max) {
|
||||
nb *= 2;
|
||||
}
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nb - 1)/nb, ne02, ne03, nth, nb, 1);
|
||||
|
||||
return n_fuse;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -895,11 +895,156 @@ enum ggml_sort_order {
|
|||
GGML_SORT_ORDER_DESC,
|
||||
};
|
||||
|
||||
// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
|
||||
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
||||
// cons: not very efficient
|
||||
template <int F>
|
||||
kernel void kernel_add_fuse_impl(
|
||||
//// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
|
||||
//// pros: works for non-contiguous tensors, supports broadcast across all dims
|
||||
//// cons: not very efficient
|
||||
//template <int F>
|
||||
//kernel void kernel_add_fuse_impl(
|
||||
// constant ggml_metal_kargs_bin & args,
|
||||
// device const char * src0,
|
||||
// device const char * src1,
|
||||
// device char * dst,
|
||||
// uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
// ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
// ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
// const int i03 = tgpig.z;
|
||||
// const int i02 = tgpig.y;
|
||||
// const int i01 = tgpig.x;
|
||||
//
|
||||
// const int i13 = i03%args.ne13;
|
||||
// const int i12 = i02%args.ne12;
|
||||
// const int i11 = i01%args.ne11;
|
||||
//
|
||||
// device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
|
||||
// device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
|
||||
//
|
||||
// device const float * src1_ptr[F];
|
||||
// for (short j = 0; j < F; ++j) {
|
||||
// src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
||||
// }
|
||||
//
|
||||
// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
// const int i10 = i0%args.ne10;
|
||||
//
|
||||
// float res = src0_ptr[i0];
|
||||
//
|
||||
//#pragma unroll
|
||||
// for (short j = 0; j < F; ++j) {
|
||||
// res += src1_ptr[j][i10];
|
||||
// }
|
||||
//
|
||||
// dst_ptr[i0] = res;
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
|
||||
//
|
||||
//template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
|
||||
//template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
|
||||
//template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
|
||||
//template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
|
||||
//template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
|
||||
//template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
|
||||
//template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
|
||||
//template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
|
||||
//
|
||||
//kernel void kernel_sub_fuse_1(
|
||||
// constant ggml_metal_kargs_bin & args,
|
||||
// device const char * src0,
|
||||
// device const char * src1,
|
||||
// device char * dst,
|
||||
// uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
// ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
// ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
// const int i03 = tgpig.z;
|
||||
// const int i02 = tgpig.y;
|
||||
// const int i01 = tgpig.x;
|
||||
//
|
||||
// const int i13 = i03%args.ne13;
|
||||
// const int i12 = i02%args.ne12;
|
||||
// const int i11 = i01%args.ne11;
|
||||
//
|
||||
// device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
||||
// device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
||||
// device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
||||
//
|
||||
// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
// const int i10 = i0%args.ne10;
|
||||
// *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//kernel void kernel_mul_fuse_1(
|
||||
// constant ggml_metal_kargs_bin & args,
|
||||
// device const char * src0,
|
||||
// device const char * src1,
|
||||
// device char * dst,
|
||||
// uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
// ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
// ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
// const int i03 = tgpig.z;
|
||||
// const int i02 = tgpig.y;
|
||||
// const int i01 = tgpig.x;
|
||||
//
|
||||
// const int i13 = i03%args.ne13;
|
||||
// const int i12 = i02%args.ne12;
|
||||
// const int i11 = i01%args.ne11;
|
||||
//
|
||||
// device const float * src0_ptr = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
|
||||
// device const float * src1_ptr = (device const float *)(src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]);
|
||||
// device float * dst_ptr = (device float *)(dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
|
||||
//
|
||||
// if (args.ne10 == 1) {
|
||||
// const float x = src1_ptr[0];
|
||||
// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
// dst_ptr[i0] = src0_ptr[i0] * x;
|
||||
// }
|
||||
// } else {
|
||||
// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
// const int i10 = i0 % args.ne10;
|
||||
// dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//kernel void kernel_div_fuse_1(
|
||||
// constant ggml_metal_kargs_bin & args,
|
||||
// device const char * src0,
|
||||
// device const char * src1,
|
||||
// device char * dst,
|
||||
// uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
// ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
// ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
// const int i03 = tgpig.z;
|
||||
// const int i02 = tgpig.y;
|
||||
// const int i01 = tgpig.x;
|
||||
//
|
||||
// const int i13 = i03%args.ne13;
|
||||
// const int i12 = i02%args.ne12;
|
||||
// const int i11 = i01%args.ne11;
|
||||
//
|
||||
// device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
||||
// device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
||||
// device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
||||
//
|
||||
// if (args.ne10 == 1) {
|
||||
// const float x = 1.0f / *((device float *)(src1_ptr));
|
||||
// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
// *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
|
||||
// }
|
||||
// } else {
|
||||
// for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
// const int i10 = i0%args.ne10;
|
||||
// *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
|
||||
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
|
||||
|
||||
template <typename T0, typename T1, typename T>
|
||||
kernel void kernel_bin_fuse_impl(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
|
|
@ -907,138 +1052,150 @@ kernel void kernel_add_fuse_impl(
|
|||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
|
||||
const short OP = FC_bin_op;
|
||||
const short F = FC_bin_f;
|
||||
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int i01 = tgpig.x;
|
||||
const int i01 = tgpig.x*ntg.y + tpitg.y;
|
||||
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i13 = i03%args.ne13;
|
||||
const int i12 = i02%args.ne12;
|
||||
const int i11 = i01%args.ne11;
|
||||
|
||||
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
|
||||
device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
|
||||
device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
|
||||
device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
|
||||
|
||||
device const float * src1_ptr[F];
|
||||
for (short j = 0; j < F; ++j) {
|
||||
src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
||||
}
|
||||
if (F == 1) {
|
||||
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
if (args.ne10 == 1) {
|
||||
T1 src1_cur = src1_ptr[0];
|
||||
|
||||
float res = src0_ptr[i0];
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
if (OP == 0) {
|
||||
dst_ptr[i0] = src0_ptr[i0] + src1_cur;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (short j = 0; j < F; ++j) {
|
||||
res += src1_ptr[j][i10];
|
||||
}
|
||||
if (OP == 1) {
|
||||
dst_ptr[i0] = src0_ptr[i0] - src1_cur;
|
||||
}
|
||||
|
||||
dst_ptr[i0] = res;
|
||||
}
|
||||
}
|
||||
if (OP == 2) {
|
||||
dst_ptr[i0] = src0_ptr[i0] * src1_cur;
|
||||
}
|
||||
|
||||
typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
|
||||
if (OP == 3) {
|
||||
dst_ptr[i0] = src0_ptr[i0] / src1_cur;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
|
||||
template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
|
||||
template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
|
||||
template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
|
||||
template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
|
||||
template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
|
||||
template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
|
||||
template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
|
||||
template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
|
||||
if (OP == 0) {
|
||||
dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
|
||||
}
|
||||
|
||||
kernel void kernel_sub_fuse_1(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int i01 = tgpig.x;
|
||||
if (OP == 1) {
|
||||
dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
|
||||
}
|
||||
|
||||
const int i13 = i03%args.ne13;
|
||||
const int i12 = i02%args.ne12;
|
||||
const int i11 = i01%args.ne11;
|
||||
if (OP == 2) {
|
||||
dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
|
||||
}
|
||||
|
||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul_fuse_1(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int i01 = tgpig.x;
|
||||
|
||||
const int i13 = i03%args.ne13;
|
||||
const int i12 = i02%args.ne12;
|
||||
const int i11 = i01%args.ne11;
|
||||
|
||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
||||
|
||||
if (args.ne10 == 1) {
|
||||
const float x = *((device float *)(src1_ptr));
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
|
||||
if (OP == 3) {
|
||||
dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
|
||||
device const T1 * src1_ptr[8];
|
||||
FOR_UNROLL (short j = 0; j < F; ++j) {
|
||||
src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
||||
}
|
||||
|
||||
if (args.ne10 == 1) {
|
||||
T1 src1_cur[8];
|
||||
FOR_UNROLL (short j = 0; j < F; ++j) {
|
||||
src1_cur[j] = src1_ptr[j][0];
|
||||
}
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
T res = src0_ptr[i0];
|
||||
|
||||
if (OP == 0) {
|
||||
FOR_UNROLL (short j = 0; j < F; ++j) {
|
||||
res += src1_cur[j];
|
||||
}
|
||||
}
|
||||
|
||||
if (OP == 1) {
|
||||
FOR_UNROLL (short j = 0; j < F; ++j) {
|
||||
res -= src1_cur[j];
|
||||
}
|
||||
}
|
||||
|
||||
if (OP == 2) {
|
||||
FOR_UNROLL (short j = 0; j < F; ++j) {
|
||||
res *= src1_cur[j];
|
||||
}
|
||||
}
|
||||
|
||||
if (OP == 3) {
|
||||
FOR_UNROLL (short j = 0; j < F; ++j) {
|
||||
res /= src1_cur[j];
|
||||
}
|
||||
}
|
||||
|
||||
dst_ptr[i0] = res;
|
||||
}
|
||||
} else {
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
|
||||
T res = src0_ptr[i0];
|
||||
|
||||
if (OP == 0) {
|
||||
FOR_UNROLL (short j = 0; j < F; ++j) {
|
||||
res += src1_ptr[j][i10];
|
||||
}
|
||||
}
|
||||
|
||||
if (OP == 1) {
|
||||
FOR_UNROLL (short j = 0; j < F; ++j) {
|
||||
res -= src1_ptr[j][i10];
|
||||
}
|
||||
}
|
||||
|
||||
if (OP == 2) {
|
||||
FOR_UNROLL (short j = 0; j < F; ++j) {
|
||||
res *= src1_ptr[j][i10];
|
||||
}
|
||||
}
|
||||
|
||||
if (OP == 3) {
|
||||
FOR_UNROLL (short j = 0; j < F; ++j) {
|
||||
res /= src1_ptr[j][i10];
|
||||
}
|
||||
}
|
||||
|
||||
dst_ptr[i0] = res;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_div_fuse_1(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int i01 = tgpig.x;
|
||||
typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
|
||||
|
||||
const int i13 = i03%args.ne13;
|
||||
const int i12 = i02%args.ne12;
|
||||
const int i11 = i01%args.ne11;
|
||||
|
||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
||||
|
||||
if (args.ne10 == 1) {
|
||||
const float x = 1.0f / *((device float *)(src1_ptr));
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
|
||||
}
|
||||
} else {
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
|
||||
}
|
||||
}
|
||||
}
|
||||
template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>;
|
||||
template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;
|
||||
|
||||
kernel void kernel_add_id(
|
||||
constant ggml_metal_kargs_add_id & args,
|
||||
|
|
@ -1057,7 +1214,7 @@ kernel void kernel_add_id(
|
|||
const size_t nb1 = args.ne0 * sizeof(float);
|
||||
const size_t nb2 = args.ne1 * nb1;
|
||||
|
||||
device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
|
||||
device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
|
||||
device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
|
||||
device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue