diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index b4ca9c5dd6..3db7f12629 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1067,8 +1067,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_ADD_ID: - return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: + return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_REPEAT: case GGML_OP_CONV_TRANSPOSE_1D: return true; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index c04e9fc7ff..3d5db0b79f 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -620,8 +620,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); GGML_ASSERT(op->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(op->src[0])); - GGML_ASSERT(ggml_is_contiguous(op->src[1])); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[1])); const size_t pnb1 = ((const int32_t *) op->op_params)[0]; const size_t pnb2 = ((const int32_t *) op->op_params)[1]; @@ -671,10 +671,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { } ggml_metal_kargs_bin args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, + /*.ne00 =*/ ne10, + /*.ne01 =*/ ne11, + /*.ne02 =*/ ne12, + /*.ne03 =*/ ne13, /*.nb00 =*/ nb00, /*.nb01 =*/ pnb1, /*.nb02 =*/ pnb2, @@ -687,10 +687,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { /*.nb11 =*/ nb11, /*.nb12 =*/ nb12, /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, + /*.ne0 =*/ ne10, + /*.ne1 =*/ ne11, + /*.ne2 =*/ ne12, + /*.ne3 =*/ ne13, /*.nb0 =*/ nb0, /*.nb1 =*/ pnb1, /*.nb2 =*/ pnb2, @@ -707,7 +707,13 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + int nth = 1; + + while (2*nth < args.ne0 && nth < nth_max) { + nth *= 2; + } ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);