CUDA: fuse muls (#21665)
This commit is contained in:
parent
d132f22fc9
commit
e34f042154
|
|
@ -472,6 +472,36 @@ void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst,
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
|
||||
GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
|
||||
|
||||
switch (n_fuse) {
|
||||
case 2:
|
||||
ggml_cuda_op_fused_binbcast_impl<op_mul, 2>(ctx, dst);
|
||||
break;
|
||||
case 3:
|
||||
ggml_cuda_op_fused_binbcast_impl<op_mul, 3>(ctx, dst);
|
||||
break;
|
||||
case 4:
|
||||
ggml_cuda_op_fused_binbcast_impl<op_mul, 4>(ctx, dst);
|
||||
break;
|
||||
case 5:
|
||||
ggml_cuda_op_fused_binbcast_impl<op_mul, 5>(ctx, dst);
|
||||
break;
|
||||
case 6:
|
||||
ggml_cuda_op_fused_binbcast_impl<op_mul, 6>(ctx, dst);
|
||||
break;
|
||||
case 7:
|
||||
ggml_cuda_op_fused_binbcast_impl<op_mul, 7>(ctx, dst);
|
||||
break;
|
||||
case 8:
|
||||
ggml_cuda_op_fused_binbcast_impl<op_mul, 8>(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "Unsupported n_fuse value");
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
|
|
|
|||
|
|
@ -9,3 +9,4 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|||
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);
|
||||
void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);
|
||||
|
|
|
|||
|
|
@ -3758,10 +3758,10 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
|||
continue;
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_ADD) {
|
||||
if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) {
|
||||
int n_fuse = 0;
|
||||
ggml_op ops[8];
|
||||
std::fill(ops, ops + 8, GGML_OP_ADD);
|
||||
std::fill(ops, ops + 8, node->op);
|
||||
|
||||
for (; n_fuse <= 6; ++n_fuse){
|
||||
if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
|
||||
|
|
@ -3778,13 +3778,17 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
|||
n_fuse++;
|
||||
|
||||
if (n_fuse > 1) {
|
||||
ggml_tensor fused_add_node;
|
||||
memcpy(&fused_add_node, node, sizeof(ggml_tensor));
|
||||
ggml_tensor fused_node;
|
||||
memcpy(&fused_node, node, sizeof(ggml_tensor));
|
||||
for (int j = 0; j < n_fuse - 1; ++j) {
|
||||
fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
|
||||
fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
|
||||
}
|
||||
fused_node.data = cgraph->nodes[i + n_fuse - 1]->data;
|
||||
if (node->op == GGML_OP_ADD) {
|
||||
ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse);
|
||||
} else {
|
||||
ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse);
|
||||
}
|
||||
fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;
|
||||
ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);
|
||||
i += n_fuse - 1;
|
||||
|
||||
continue;
|
||||
|
|
|
|||
Loading…
Reference in New Issue