diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 7339fe0c07..adb4d5f0cb 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -472,6 +472,36 @@ void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, } } +void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) { + GGML_ASSERT(2 <= n_fuse && n_fuse <= 8); + + switch (n_fuse) { + case 2: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 3: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 4: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 5: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 6: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 7: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 8: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + default: + GGML_ASSERT(false && "Unsupported n_fuse value"); + } +} + void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; diff --git a/ggml/src/ggml-cuda/binbcast.cuh b/ggml/src/ggml-cuda/binbcast.cuh index 62bc950111..12624785b4 100644 --- a/ggml/src/ggml-cuda/binbcast.cuh +++ b/ggml/src/ggml-cuda/binbcast.cuh @@ -9,3 +9,4 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse); +void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 841af0726b..8613d20b9f 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3758,10 +3758,10 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } - if (node->op == GGML_OP_ADD) { + if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) { int n_fuse = 0; ggml_op ops[8]; - std::fill(ops, ops + 8, GGML_OP_ADD); + std::fill(ops, ops + 8, node->op); for (; n_fuse <= 6; ++n_fuse){ if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) { @@ -3778,13 +3778,17 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud n_fuse++; if (n_fuse > 1) { - ggml_tensor fused_add_node; - memcpy(&fused_add_node, node, sizeof(ggml_tensor)); + ggml_tensor fused_node; + memcpy(&fused_node, node, sizeof(ggml_tensor)); for (int j = 0; j < n_fuse - 1; ++j) { - fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + } + fused_node.data = cgraph->nodes[i + n_fuse - 1]->data; + if (node->op == GGML_OP_ADD) { + ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse); + } else { + ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse); } - fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data; - ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse); i += n_fuse - 1; continue;