metal : fix ACC op

This commit is contained in:
Georgi Gerganov 2026-02-08 11:06:54 +02:00
parent 292f6908cd
commit 15a484dee6
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 18 additions and 12 deletions

View File

@ -1058,8 +1058,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_ADD_ID:
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_REPEAT:
case GGML_OP_SCALE:
case GGML_OP_FILL:

View File

@ -628,8 +628,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(op->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
@ -679,10 +679,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
}
ggml_metal_kargs_bin args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.ne00 =*/ ne10,
/*.ne01 =*/ ne11,
/*.ne02 =*/ ne12,
/*.ne03 =*/ ne13,
/*.nb00 =*/ nb00,
/*.nb01 =*/ pnb1,
/*.nb02 =*/ pnb2,
@ -695,10 +695,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.ne0 =*/ ne10,
/*.ne1 =*/ ne11,
/*.ne2 =*/ ne12,
/*.ne3 =*/ ne13,
/*.nb0 =*/ nb0,
/*.nb1 =*/ pnb1,
/*.nb2 =*/ pnb2,
@ -715,7 +715,13 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
int nth = 1;
while (2*nth < args.ne0 && nth < nth_max) {
nth *= 2;
}
ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);