diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 43280644e4..a3256d59dd 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1124,6 +1124,7 @@ struct ggml_tensor_extra_gpu { struct ggml_cuda_graph_node_properties { void * node_data; ggml_op node_op; + enum ggml_type node_type; int32_t flags; int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index cfcffde8a2..e9e9592eba 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2920,6 +2920,7 @@ static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties memset(props, 0, sizeof(ggml_cuda_graph_node_properties)); props->node_data = node->data; props->node_op = node->op; + props->node_type = node->type; props->flags = node->flags; for (int i = 0; i < GGML_MAX_DIMS; i++) { props->ne[i] = node->ne[i]; @@ -2944,6 +2945,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_ return false; } + if (node->type != props->node_type) { + return false; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { if (node->ne[i] != props->ne[i]) { return false;