Fix cuda graph update logic.

This commit is contained in:
Xiangyan Sun 2025-10-15 20:45:06 -07:00
parent 4bbe5b1e59
commit 3afbd9f327
3 changed files with 57 additions and 36 deletions

View File

@ -940,6 +940,7 @@ struct ggml_cuda_graph {
size_t num_nodes = 0;
std::vector<cudaGraphNode_t> nodes;
std::vector<cudaKernelNodeParams> params;
int number_consecutive_updates = 0;
std::vector<ggml_graph_node_properties> ggml_graph_properties;
#endif
};
@ -954,7 +955,9 @@ struct ggml_backend_cuda_context {
#ifdef USE_CUDA_GRAPH
bool cuda_graph_initialized = false;
bool disable_due_to_gpu_arch = false;
bool disable_graph_due_to_env = false;
bool disable_graph_due_to_gpu_arch = false;
bool disable_graph_due_to_too_many_updates = false;
#endif
explicit ggml_backend_cuda_context(int device) :

View File

@ -2642,8 +2642,8 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
}
#ifdef USE_CUDA_GRAPH
static bool check_node_graph_compatibility(const ggml_cgraph * cgraph,
bool use_cuda_graph) {
static bool check_node_graph_compatibility(const ggml_cgraph * cgraph) {
bool use_cuda_graph = true;
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
@ -2753,8 +2753,14 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
return true;
}
static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) {
static void update_cuda_graph_properties(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) {
cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
for (int i = 0; i < cgraph->n_nodes; i++) {
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i]);
}
}
static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) {
bool cuda_graph_update_required = false;
if (cuda_graph->instance == nullptr) {
@ -2768,7 +2774,6 @@ static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const gg
}
// Loop over nodes in GGML graph to determine if CUDA graph update is required
// and store properties to allow this comparison for the next token
for (int i = 0; i < cgraph->n_nodes; i++) {
bool has_matching_properties = true;
if (!cuda_graph_update_required) {
@ -2777,7 +2782,6 @@ static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const gg
if (!has_matching_properties) {
cuda_graph_update_required = true;
}
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i]);
}
return cuda_graph_update_required;
@ -3057,22 +3061,14 @@ static void capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cuda_g
}
}
static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph* cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
ggml_cuda_set_device(cuda_ctx->device);
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
ggml_cuda_graph * cuda_graph = new ggml_cuda_graph();
cuda_graph->cgraph = cgraph;
static bool should_use_cuda_graph(ggml_backend_cuda_context * cuda_ctx, const struct ggml_cgraph * cgraph) {
bool use_cuda_graph = true;
if (!cuda_ctx->cuda_graph_initialized) {
cuda_ctx->disable_graph_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
cuda_ctx->disable_due_to_gpu_arch = true;
cuda_ctx->disable_graph_due_to_gpu_arch = true;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
#endif
@ -3083,17 +3079,30 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
// or previous graph capture failure.
// Also disable for multi-gpu for now. TO DO investigate
if (disable_cuda_graphs_due_to_env
|| cuda_ctx->disable_due_to_gpu_arch) {
if (cuda_ctx->disable_graph_due_to_env || cuda_ctx->disable_graph_due_to_gpu_arch ||
cuda_ctx->disable_graph_due_to_too_many_updates) {
use_cuda_graph = false;
}
if (use_cuda_graph) {
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
use_cuda_graph = check_node_graph_compatibility(cgraph);
}
if (use_cuda_graph) {
return use_cuda_graph;
}
static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
ggml_cuda_set_device(cuda_ctx->device);
ggml_cuda_graph * cuda_graph = new ggml_cuda_graph();
cuda_graph->cgraph = cgraph;
if (should_use_cuda_graph(cuda_ctx, cgraph)) {
capture_cuda_graph(cuda_ctx, cuda_graph, cgraph);
update_cuda_graph_properties(cuda_graph, cgraph);
}
return cuda_graph;
@ -3105,7 +3114,7 @@ static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backe
GGML_UNUSED(backend);
}
static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_backend_graph_plan_t plan, const ggml_cgraph* cgraph) {
static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_backend_graph_plan_t plan, const ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
ggml_cuda_set_device(cuda_ctx->device);
@ -3114,15 +3123,28 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac
cuda_graph->cgraph = cgraph;
if (!cuda_graph->graph) {
return;
bool use_cuda_graph = should_use_cuda_graph(cuda_ctx, cgraph);
bool cuda_graph_update_required = false;
// check if we are doing a graph update
if (cuda_graph->instance == nullptr && use_cuda_graph // no graph -> graph
|| cuda_graph->instance != nullptr && !use_cuda_graph // graph -> no graph
|| use_cuda_graph && is_cuda_graph_update_required(cuda_graph, cgraph)) { // graph property mismatch
cuda_graph->number_consecutive_updates++;
if (cuda_graph->number_consecutive_updates >= 4) {
cuda_ctx->disable_graph_due_to_too_many_updates = true;
use_cuda_graph = false;
} else {
cuda_graph_update_required = true;
}
} else {
cuda_graph->number_consecutive_updates = 0;
}
bool use_cuda_graph = true;
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
if (!use_cuda_graph) {
if (use_cuda_graph && cuda_graph_update_required) {
capture_cuda_graph(cuda_ctx, cuda_graph, cgraph);
update_cuda_graph_properties(cuda_graph, cgraph);
} else if (!use_cuda_graph) {
if (cuda_graph->instance != nullptr) {
CUDA_CHECK(cudaGraphExecDestroy(cuda_graph->instance));
}
@ -3132,10 +3154,6 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac
cuda_graph->instance = nullptr;
cuda_graph->graph = nullptr;
}
if (is_cuda_graph_update_required(cuda_graph, cgraph)) {
capture_cuda_graph(cuda_ctx, cuda_graph, cgraph);
}
}
static enum ggml_status ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {

View File

@ -34,11 +34,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_g
// CUDA_GRAPHS_DISABLED
((ncols > 65536) &&
((cuda_graph && cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.disable_due_to_gpu_arch)) ||
ctx.disable_graph_due_to_env || ctx.disable_graph_due_to_gpu_arch || ctx.disable_graph_due_to_too_many_updates)) ||
// CUDA_GRAPHS ENABLED
((ncols > 32768) &&
!((cuda_graph && cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.disable_due_to_gpu_arch))) {
ctx.disable_graph_due_to_env || ctx.disable_graph_due_to_gpu_arch || ctx.disable_graph_due_to_too_many_updates))) {
#else
(ncols > 65536)) {
#endif // USE_CUDA_GRAPH