diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 4eb90cbbb3..84c3954337 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -940,6 +940,7 @@ struct ggml_cuda_graph { size_t num_nodes = 0; std::vector nodes; std::vector params; + int number_consecutive_updates = 0; std::vector ggml_graph_properties; #endif }; @@ -954,7 +955,9 @@ struct ggml_backend_cuda_context { #ifdef USE_CUDA_GRAPH bool cuda_graph_initialized = false; - bool disable_due_to_gpu_arch = false; + bool disable_graph_due_to_env = false; + bool disable_graph_due_to_gpu_arch = false; + bool disable_graph_due_to_too_many_updates = false; #endif explicit ggml_backend_cuda_context(int device) : diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 9f6a0500aa..d93c86818c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2642,8 +2642,8 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { } #ifdef USE_CUDA_GRAPH -static bool check_node_graph_compatibility(const ggml_cgraph * cgraph, - bool use_cuda_graph) { +static bool check_node_graph_compatibility(const ggml_cgraph * cgraph) { + bool use_cuda_graph = true; // Loop over nodes in GGML graph to obtain info needed for CUDA graph @@ -2753,8 +2753,14 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra return true; } -static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) { +static void update_cuda_graph_properties(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) { + cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; i++) { + set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i]); + } +} +static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) { bool cuda_graph_update_required = false; if (cuda_graph->instance == nullptr) { @@ -2768,7 +2774,6 @@ static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const gg } // Loop over nodes in GGML graph to determine if CUDA graph update is required - // and store properties to allow this comparison for the next token for (int i = 0; i < cgraph->n_nodes; i++) { bool has_matching_properties = true; if (!cuda_graph_update_required) { @@ -2777,7 +2782,6 @@ static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const gg if (!has_matching_properties) { cuda_graph_update_required = true; } - set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i]); } return cuda_graph_update_required; @@ -3057,22 +3061,14 @@ static void capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cuda_g } } -static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph* cgraph) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; - - ggml_cuda_set_device(cuda_ctx->device); - - static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); - - ggml_cuda_graph * cuda_graph = new ggml_cuda_graph(); - - cuda_graph->cgraph = cgraph; - +static bool should_use_cuda_graph(ggml_backend_cuda_context * cuda_ctx, const struct ggml_cgraph * cgraph) { bool use_cuda_graph = true; if (!cuda_ctx->cuda_graph_initialized) { + cuda_ctx->disable_graph_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); + if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { - cuda_ctx->disable_due_to_gpu_arch = true; + cuda_ctx->disable_graph_due_to_gpu_arch = true; #ifndef NDEBUG GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); #endif @@ -3083,17 +3079,30 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, // or previous graph capture failure. // Also disable for multi-gpu for now. TO DO investigate - if (disable_cuda_graphs_due_to_env - || cuda_ctx->disable_due_to_gpu_arch) { + if (cuda_ctx->disable_graph_due_to_env || cuda_ctx->disable_graph_due_to_gpu_arch || + cuda_ctx->disable_graph_due_to_too_many_updates) { use_cuda_graph = false; } if (use_cuda_graph) { - use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph); + use_cuda_graph = check_node_graph_compatibility(cgraph); } - if (use_cuda_graph) { + return use_cuda_graph; +} + +static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + + ggml_cuda_set_device(cuda_ctx->device); + + ggml_cuda_graph * cuda_graph = new ggml_cuda_graph(); + + cuda_graph->cgraph = cgraph; + + if (should_use_cuda_graph(cuda_ctx, cgraph)) { capture_cuda_graph(cuda_ctx, cuda_graph, cgraph); + update_cuda_graph_properties(cuda_graph, cgraph); } return cuda_graph; @@ -3105,7 +3114,7 @@ static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backe GGML_UNUSED(backend); } -static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_backend_graph_plan_t plan, const ggml_cgraph* cgraph) { +static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_backend_graph_plan_t plan, const ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; ggml_cuda_set_device(cuda_ctx->device); @@ -3114,15 +3123,28 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac cuda_graph->cgraph = cgraph; - if (!cuda_graph->graph) { - return; + bool use_cuda_graph = should_use_cuda_graph(cuda_ctx, cgraph); + bool cuda_graph_update_required = false; + + // check if we are doing a graph update + if (cuda_graph->instance == nullptr && use_cuda_graph // no graph -> graph + || cuda_graph->instance != nullptr && !use_cuda_graph // graph -> no graph + || use_cuda_graph && is_cuda_graph_update_required(cuda_graph, cgraph)) { // graph property mismatch + cuda_graph->number_consecutive_updates++; + if (cuda_graph->number_consecutive_updates >= 4) { + cuda_ctx->disable_graph_due_to_too_many_updates = true; + use_cuda_graph = false; + } else { + cuda_graph_update_required = true; + } + } else { + cuda_graph->number_consecutive_updates = 0; } - bool use_cuda_graph = true; - - use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph); - - if (!use_cuda_graph) { + if (use_cuda_graph && cuda_graph_update_required) { + capture_cuda_graph(cuda_ctx, cuda_graph, cgraph); + update_cuda_graph_properties(cuda_graph, cgraph); + } else if (!use_cuda_graph) { if (cuda_graph->instance != nullptr) { CUDA_CHECK(cudaGraphExecDestroy(cuda_graph->instance)); } @@ -3132,10 +3154,6 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac cuda_graph->instance = nullptr; cuda_graph->graph = nullptr; } - - if (is_cuda_graph_update_required(cuda_graph, cgraph)) { - capture_cuda_graph(cuda_ctx, cuda_graph, cgraph); - } } static enum ggml_status ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu index 67bbb0b62d..6e91f236b9 100644 --- a/ggml/src/ggml-cuda/mean.cu +++ b/ggml/src/ggml-cuda/mean.cu @@ -34,11 +34,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_g // CUDA_GRAPHS_DISABLED ((ncols > 65536) && ((cuda_graph && cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.disable_due_to_gpu_arch)) || + ctx.disable_graph_due_to_env || ctx.disable_graph_due_to_gpu_arch || ctx.disable_graph_due_to_too_many_updates)) || // CUDA_GRAPHS ENABLED ((ncols > 32768) && !((cuda_graph && cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.disable_due_to_gpu_arch))) { + ctx.disable_graph_due_to_env || ctx.disable_graph_due_to_gpu_arch || ctx.disable_graph_due_to_too_many_updates))) { #else (ncols > 65536)) { #endif // USE_CUDA_GRAPH