From c5ce4bc227592afb2ec87aa4efce2d0ac0482c51 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 8 Apr 2026 09:05:51 +0800 Subject: [PATCH] CUDA: make cuda graphs props check faster (#21472) * CUDA: compute fast hash instead of expensive props check * use seen node * use memcp --- ggml/src/ggml-cuda/common.cuh | 21 +----- ggml/src/ggml-cuda/ggml-cuda.cu | 113 ++------------------------------ 2 files changed, 6 insertions(+), 128 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 1c9233b4fc..a2960e5ae3 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1157,19 +1157,6 @@ struct ggml_tensor_extra_gpu { #define USE_CUDA_GRAPH #endif -struct ggml_cuda_graph_node_properties { - void * node_data; - ggml_op node_op; - enum ggml_type node_type; - int32_t flags; - int64_t ne[GGML_MAX_DIMS]; - size_t nb[GGML_MAX_DIMS]; - void * src_data[GGML_MAX_SRC]; - int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; -}; - -static_assert(std::is_trivial::value, "ggml_cuda_graph_node_properties must be trivial"); - struct ggml_cuda_graph { #ifdef USE_CUDA_GRAPH ~ggml_cuda_graph() { @@ -1186,13 +1173,7 @@ struct ggml_cuda_graph { std::vector nodes; bool disable_due_to_gpu_arch = false; bool warmup_complete = false; - std::vector props; - - // these are extra tensors (inputs) that participate in the ggml graph but are not nodes - // they properties also have to match in order to be able to safely reuse a CUDA graph - // ref: https://github.com/ggml-org/llama.cpp/pull/18583 - // ref: https://github.com/ggml-org/llama.cpp/pull/19165 - std::vector extra; + std::vector nodes_copy; bool is_enabled() const { static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 25b904b7dc..b21196bb4f 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -82,7 +82,6 @@ #include #include #include -#include static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); @@ -2969,74 +2968,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { return use_cuda_graph; } -static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { - memset(props, 0, sizeof(ggml_cuda_graph_node_properties)); - props->node_data = node->data; - props->node_op = node->op; - props->node_type = node->type; - props->flags = node->flags; - for (int i = 0; i < GGML_MAX_DIMS; i++) { - props->ne[i] = node->ne[i]; - props->nb[i] = node->nb[i]; - } - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (!node->src[i]) { - continue; - } - - props->src_data[i] = node->src[i]->data; - } - memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS); -} - -static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) { - if (node->data != props->node_data && node->op != GGML_OP_VIEW) { - return false; - } - - if (node->op != props->node_op) { - return false; - } - - if (node->type != props->node_type) { - return false; - } - - for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (node->ne[i] != props->ne[i]) { - return false; - } - if (node->nb[i] != props->nb[i]) { - return false; - } - } - - if (node->op != GGML_OP_VIEW) { - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (!node->src[i]) { - if (props->src_data[i] != nullptr) { - return false; - } - continue; - } - - if (node->src[i]->data != props->src_data[i]) { - return false; - } - } - } - - if (memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { - return false; - } - - if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) { - return false; - } - - return true; -} - static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) { return cgraph->nodes[0]; } @@ -3048,52 +2979,18 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); // Check if the graph size has changed - if (graph->props.size() != (size_t)cgraph->n_nodes) { + if ((int)graph->nodes_copy.size() != cgraph->n_nodes) { res = true; - graph->props.resize(cgraph->n_nodes); + graph->nodes_copy.resize(cgraph->n_nodes); } - // Loop over nodes in GGML graph to determine if CUDA graph update is required - // and store properties to allow this comparison for the next token - std::unordered_set seen_node; - std::vector srcs_extra; for (int i = 0; i < cgraph->n_nodes; i++) { - bool props_match = true; - - seen_node.insert(cgraph->nodes[i]); - if (!res) { - props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]); - } - if (!props_match) { - res = true; - } - ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]); - - for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { - ggml_tensor * src = cgraph->nodes[i]->src[src_idx]; - if (src && seen_node.find(src) == seen_node.end()) { - srcs_extra.push_back(src); + if (memcmp(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) { + res = true; } } - } - - if (graph->extra.size() != (size_t) srcs_extra.size()) { - res = true; - graph->extra.resize(srcs_extra.size()); - } - - for (size_t i = 0; i < srcs_extra.size(); ++i) { - bool props_match = true; - - if (!res) { - props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]); - } - - if (!props_match) { - res = true; - } - ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]); + memcpy(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor)); } return res;