Fix cuda graph update logic.
This commit is contained in:
parent
4bbe5b1e59
commit
3afbd9f327
|
|
@ -940,6 +940,7 @@ struct ggml_cuda_graph {
|
|||
size_t num_nodes = 0;
|
||||
std::vector<cudaGraphNode_t> nodes;
|
||||
std::vector<cudaKernelNodeParams> params;
|
||||
int number_consecutive_updates = 0;
|
||||
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
||||
#endif
|
||||
};
|
||||
|
|
@ -954,7 +955,9 @@ struct ggml_backend_cuda_context {
|
|||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
bool cuda_graph_initialized = false;
|
||||
bool disable_due_to_gpu_arch = false;
|
||||
bool disable_graph_due_to_env = false;
|
||||
bool disable_graph_due_to_gpu_arch = false;
|
||||
bool disable_graph_due_to_too_many_updates = false;
|
||||
#endif
|
||||
|
||||
explicit ggml_backend_cuda_context(int device) :
|
||||
|
|
|
|||
|
|
@ -2642,8 +2642,8 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
|||
}
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
static bool check_node_graph_compatibility(const ggml_cgraph * cgraph,
|
||||
bool use_cuda_graph) {
|
||||
static bool check_node_graph_compatibility(const ggml_cgraph * cgraph) {
|
||||
bool use_cuda_graph = true;
|
||||
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
|
||||
|
|
@ -2753,8 +2753,14 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) {
|
||||
static void update_cuda_graph_properties(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) {
|
||||
cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i]);
|
||||
}
|
||||
}
|
||||
|
||||
static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) {
|
||||
bool cuda_graph_update_required = false;
|
||||
|
||||
if (cuda_graph->instance == nullptr) {
|
||||
|
|
@ -2768,7 +2774,6 @@ static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const gg
|
|||
}
|
||||
|
||||
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
||||
// and store properties to allow this comparison for the next token
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
bool has_matching_properties = true;
|
||||
if (!cuda_graph_update_required) {
|
||||
|
|
@ -2777,7 +2782,6 @@ static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const gg
|
|||
if (!has_matching_properties) {
|
||||
cuda_graph_update_required = true;
|
||||
}
|
||||
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i]);
|
||||
}
|
||||
|
||||
return cuda_graph_update_required;
|
||||
|
|
@ -3057,22 +3061,14 @@ static void capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cuda_g
|
|||
}
|
||||
}
|
||||
|
||||
static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph* cgraph) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
||||
|
||||
ggml_cuda_set_device(cuda_ctx->device);
|
||||
|
||||
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||
|
||||
ggml_cuda_graph * cuda_graph = new ggml_cuda_graph();
|
||||
|
||||
cuda_graph->cgraph = cgraph;
|
||||
|
||||
static bool should_use_cuda_graph(ggml_backend_cuda_context * cuda_ctx, const struct ggml_cgraph * cgraph) {
|
||||
bool use_cuda_graph = true;
|
||||
|
||||
if (!cuda_ctx->cuda_graph_initialized) {
|
||||
cuda_ctx->disable_graph_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||
|
||||
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
|
||||
cuda_ctx->disable_due_to_gpu_arch = true;
|
||||
cuda_ctx->disable_graph_due_to_gpu_arch = true;
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
|
||||
#endif
|
||||
|
|
@ -3083,17 +3079,30 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen
|
|||
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
|
||||
// or previous graph capture failure.
|
||||
// Also disable for multi-gpu for now. TO DO investigate
|
||||
if (disable_cuda_graphs_due_to_env
|
||||
|| cuda_ctx->disable_due_to_gpu_arch) {
|
||||
if (cuda_ctx->disable_graph_due_to_env || cuda_ctx->disable_graph_due_to_gpu_arch ||
|
||||
cuda_ctx->disable_graph_due_to_too_many_updates) {
|
||||
use_cuda_graph = false;
|
||||
}
|
||||
|
||||
if (use_cuda_graph) {
|
||||
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
|
||||
use_cuda_graph = check_node_graph_compatibility(cgraph);
|
||||
}
|
||||
|
||||
if (use_cuda_graph) {
|
||||
return use_cuda_graph;
|
||||
}
|
||||
|
||||
static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
||||
|
||||
ggml_cuda_set_device(cuda_ctx->device);
|
||||
|
||||
ggml_cuda_graph * cuda_graph = new ggml_cuda_graph();
|
||||
|
||||
cuda_graph->cgraph = cgraph;
|
||||
|
||||
if (should_use_cuda_graph(cuda_ctx, cgraph)) {
|
||||
capture_cuda_graph(cuda_ctx, cuda_graph, cgraph);
|
||||
update_cuda_graph_properties(cuda_graph, cgraph);
|
||||
}
|
||||
|
||||
return cuda_graph;
|
||||
|
|
@ -3105,7 +3114,7 @@ static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backe
|
|||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_backend_graph_plan_t plan, const ggml_cgraph* cgraph) {
|
||||
static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_backend_graph_plan_t plan, const ggml_cgraph * cgraph) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
||||
|
||||
ggml_cuda_set_device(cuda_ctx->device);
|
||||
|
|
@ -3114,15 +3123,28 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac
|
|||
|
||||
cuda_graph->cgraph = cgraph;
|
||||
|
||||
if (!cuda_graph->graph) {
|
||||
return;
|
||||
bool use_cuda_graph = should_use_cuda_graph(cuda_ctx, cgraph);
|
||||
bool cuda_graph_update_required = false;
|
||||
|
||||
// check if we are doing a graph update
|
||||
if (cuda_graph->instance == nullptr && use_cuda_graph // no graph -> graph
|
||||
|| cuda_graph->instance != nullptr && !use_cuda_graph // graph -> no graph
|
||||
|| use_cuda_graph && is_cuda_graph_update_required(cuda_graph, cgraph)) { // graph property mismatch
|
||||
cuda_graph->number_consecutive_updates++;
|
||||
if (cuda_graph->number_consecutive_updates >= 4) {
|
||||
cuda_ctx->disable_graph_due_to_too_many_updates = true;
|
||||
use_cuda_graph = false;
|
||||
} else {
|
||||
cuda_graph_update_required = true;
|
||||
}
|
||||
} else {
|
||||
cuda_graph->number_consecutive_updates = 0;
|
||||
}
|
||||
|
||||
bool use_cuda_graph = true;
|
||||
|
||||
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
|
||||
|
||||
if (!use_cuda_graph) {
|
||||
if (use_cuda_graph && cuda_graph_update_required) {
|
||||
capture_cuda_graph(cuda_ctx, cuda_graph, cgraph);
|
||||
update_cuda_graph_properties(cuda_graph, cgraph);
|
||||
} else if (!use_cuda_graph) {
|
||||
if (cuda_graph->instance != nullptr) {
|
||||
CUDA_CHECK(cudaGraphExecDestroy(cuda_graph->instance));
|
||||
}
|
||||
|
|
@ -3132,10 +3154,6 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac
|
|||
cuda_graph->instance = nullptr;
|
||||
cuda_graph->graph = nullptr;
|
||||
}
|
||||
|
||||
if (is_cuda_graph_update_required(cuda_graph, cgraph)) {
|
||||
capture_cuda_graph(cuda_ctx, cuda_graph, cgraph);
|
||||
}
|
||||
}
|
||||
|
||||
static enum ggml_status ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
|
||||
|
|
|
|||
|
|
@ -34,11 +34,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_g
|
|||
// CUDA_GRAPHS_DISABLED
|
||||
((ncols > 65536) &&
|
||||
((cuda_graph && cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
||||
ctx.disable_due_to_gpu_arch)) ||
|
||||
ctx.disable_graph_due_to_env || ctx.disable_graph_due_to_gpu_arch || ctx.disable_graph_due_to_too_many_updates)) ||
|
||||
// CUDA_GRAPHS ENABLED
|
||||
((ncols > 32768) &&
|
||||
!((cuda_graph && cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
||||
ctx.disable_due_to_gpu_arch))) {
|
||||
ctx.disable_graph_due_to_env || ctx.disable_graph_due_to_gpu_arch || ctx.disable_graph_due_to_too_many_updates))) {
|
||||
#else
|
||||
(ncols > 65536)) {
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
|
|
|||
Loading…
Reference in New Issue