metal : fix ACC op (#19427)
This commit is contained in:
parent
c7db95f106
commit
6e473fb384
|
|
@ -1067,8 +1067,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
case GGML_OP_ADD_ID:
|
case GGML_OP_ADD_ID:
|
||||||
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
|
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
|
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||||
return true;
|
return true;
|
||||||
|
|
|
||||||
|
|
@ -620,8 +620,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(op->type == GGML_TYPE_F32);
|
GGML_ASSERT(op->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||||
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
|
||||||
|
|
||||||
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
|
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
|
||||||
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
|
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
|
||||||
|
|
@ -671,10 +671,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_metal_kargs_bin args = {
|
ggml_metal_kargs_bin args = {
|
||||||
/*.ne00 =*/ ne00,
|
/*.ne00 =*/ ne10,
|
||||||
/*.ne01 =*/ ne01,
|
/*.ne01 =*/ ne11,
|
||||||
/*.ne02 =*/ ne02,
|
/*.ne02 =*/ ne12,
|
||||||
/*.ne03 =*/ ne03,
|
/*.ne03 =*/ ne13,
|
||||||
/*.nb00 =*/ nb00,
|
/*.nb00 =*/ nb00,
|
||||||
/*.nb01 =*/ pnb1,
|
/*.nb01 =*/ pnb1,
|
||||||
/*.nb02 =*/ pnb2,
|
/*.nb02 =*/ pnb2,
|
||||||
|
|
@ -687,10 +687,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||||
/*.nb11 =*/ nb11,
|
/*.nb11 =*/ nb11,
|
||||||
/*.nb12 =*/ nb12,
|
/*.nb12 =*/ nb12,
|
||||||
/*.nb13 =*/ nb13,
|
/*.nb13 =*/ nb13,
|
||||||
/*.ne0 =*/ ne0,
|
/*.ne0 =*/ ne10,
|
||||||
/*.ne1 =*/ ne1,
|
/*.ne1 =*/ ne11,
|
||||||
/*.ne2 =*/ ne2,
|
/*.ne2 =*/ ne12,
|
||||||
/*.ne3 =*/ ne3,
|
/*.ne3 =*/ ne13,
|
||||||
/*.nb0 =*/ nb0,
|
/*.nb0 =*/ nb0,
|
||||||
/*.nb1 =*/ pnb1,
|
/*.nb1 =*/ pnb1,
|
||||||
/*.nb2 =*/ pnb2,
|
/*.nb2 =*/ pnb2,
|
||||||
|
|
@ -707,7 +707,13 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||||
|
|
||||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
|
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||||
|
|
||||||
|
int nth = 1;
|
||||||
|
|
||||||
|
while (2*nth < args.ne0 && nth < nth_max) {
|
||||||
|
nth *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
|
ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue