metal : avoid divisions in bin kernel (#20426)
* metal : avoid modulus in bin kernel when not broadcasting * metal : fix capture_started flag
This commit is contained in:
parent
4cc6eb158c
commit
e4cff0956b
|
|
@ -554,7 +554,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
|
|||
|
||||
// enter here only when capturing in order to wait for all computation to finish
|
||||
// otherwise, we leave the graph to compute asynchronously
|
||||
if (!use_capture && ctx->capture_started) {
|
||||
if (use_capture && ctx->capture_started) {
|
||||
// wait for completion and check status of each command buffer
|
||||
// needed to detect if the device ran out-of-memory for example (#1881)
|
||||
{
|
||||
|
|
@ -606,6 +606,8 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
|
|||
|
||||
[ctx->capture_scope endScope];
|
||||
[[MTLCaptureManager sharedCaptureManager] stopCapture];
|
||||
|
||||
ctx->capture_started = false;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1470,10 +1470,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l
|
|||
|
||||
const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
|
||||
|
||||
const bool is_cb = op->src[0]->ne[0] != op->src[1]->ne[0];
|
||||
const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;
|
||||
|
||||
snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
|
||||
snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb);
|
||||
snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d_cb=%d", base, op_num, n_fuse, is_rb, is_cb);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
|
|
@ -1482,6 +1483,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l
|
|||
ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
|
||||
ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
|
||||
ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2);
|
||||
ggml_metal_cv_set_bool (cv, is_cb, FC_BIN + 3);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
|
|
|
|||
|
|
@ -3180,9 +3180,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|||
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
|
||||
|
||||
if (pipeline.cnt) {
|
||||
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, args.ne0, ggml_nrows(op), 1, 1, 1, 1);
|
||||
} else {
|
||||
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
|
|
|
|||
|
|
@ -1111,6 +1111,7 @@ template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_un
|
|||
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
|
||||
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
|
||||
constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
|
||||
constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]];
|
||||
|
||||
template <typename T0, typename T1, typename T>
|
||||
kernel void kernel_bin_fuse_impl(
|
||||
|
|
@ -1124,11 +1125,12 @@ kernel void kernel_bin_fuse_impl(
|
|||
#define FC_OP FC_bin_op
|
||||
#define FC_F FC_bin_f
|
||||
#define FC_RB FC_bin_rb
|
||||
#define FC_CB FC_bin_cb
|
||||
|
||||
if (FC_RB) {
|
||||
// row broadcast
|
||||
const uint i0 = tgpig.x;
|
||||
const uint i1 = i0%args.ne10;
|
||||
const uint i0 = tgpig.y*args.ne00 + tgpig.x;
|
||||
const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;
|
||||
|
||||
device const T0 * src0_row = (device const T0 *) (src0);
|
||||
device T * dst_row = (device T *) (dst);
|
||||
|
|
@ -1200,7 +1202,7 @@ kernel void kernel_bin_fuse_impl(
|
|||
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
const int i10 = FC_CB ? i0%args.ne10 : i0;
|
||||
|
||||
if (FC_OP == 0) {
|
||||
dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
|
||||
|
|
@ -1225,7 +1227,7 @@ kernel void kernel_bin_fuse_impl(
|
|||
}
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
const int i10 = FC_CB ? i0%args.ne10 : i0;
|
||||
|
||||
T res = src0_ptr[i0];
|
||||
|
||||
|
|
@ -1261,6 +1263,7 @@ kernel void kernel_bin_fuse_impl(
|
|||
#undef FC_OP
|
||||
#undef FC_F
|
||||
#undef FC_RB
|
||||
#undef FC_CB
|
||||
}
|
||||
|
||||
typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
|
||||
|
|
|
|||
Loading…
Reference in New Issue