Fix cuda graph update logic.
This commit is contained in:
parent
4bbe5b1e59
commit
3afbd9f327
|
|
@ -940,6 +940,7 @@ struct ggml_cuda_graph {
|
||||||
size_t num_nodes = 0;
|
size_t num_nodes = 0;
|
||||||
std::vector<cudaGraphNode_t> nodes;
|
std::vector<cudaGraphNode_t> nodes;
|
||||||
std::vector<cudaKernelNodeParams> params;
|
std::vector<cudaKernelNodeParams> params;
|
||||||
|
int number_consecutive_updates = 0;
|
||||||
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
@ -954,7 +955,9 @@ struct ggml_backend_cuda_context {
|
||||||
|
|
||||||
#ifdef USE_CUDA_GRAPH
|
#ifdef USE_CUDA_GRAPH
|
||||||
bool cuda_graph_initialized = false;
|
bool cuda_graph_initialized = false;
|
||||||
bool disable_due_to_gpu_arch = false;
|
bool disable_graph_due_to_env = false;
|
||||||
|
bool disable_graph_due_to_gpu_arch = false;
|
||||||
|
bool disable_graph_due_to_too_many_updates = false;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
explicit ggml_backend_cuda_context(int device) :
|
explicit ggml_backend_cuda_context(int device) :
|
||||||
|
|
|
||||||
|
|
@ -2642,8 +2642,8 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_CUDA_GRAPH
|
#ifdef USE_CUDA_GRAPH
|
||||||
static bool check_node_graph_compatibility(const ggml_cgraph * cgraph,
|
static bool check_node_graph_compatibility(const ggml_cgraph * cgraph) {
|
||||||
bool use_cuda_graph) {
|
bool use_cuda_graph = true;
|
||||||
|
|
||||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||||
|
|
||||||
|
|
@ -2753,8 +2753,14 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) {
|
static void update_cuda_graph_properties(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) {
|
||||||
|
cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
|
||||||
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
|
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const ggml_cgraph * cgraph) {
|
||||||
bool cuda_graph_update_required = false;
|
bool cuda_graph_update_required = false;
|
||||||
|
|
||||||
if (cuda_graph->instance == nullptr) {
|
if (cuda_graph->instance == nullptr) {
|
||||||
|
|
@ -2768,7 +2774,6 @@ static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const gg
|
||||||
}
|
}
|
||||||
|
|
||||||
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
||||||
// and store properties to allow this comparison for the next token
|
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
bool has_matching_properties = true;
|
bool has_matching_properties = true;
|
||||||
if (!cuda_graph_update_required) {
|
if (!cuda_graph_update_required) {
|
||||||
|
|
@ -2777,7 +2782,6 @@ static bool is_cuda_graph_update_required(ggml_cuda_graph * cuda_graph, const gg
|
||||||
if (!has_matching_properties) {
|
if (!has_matching_properties) {
|
||||||
cuda_graph_update_required = true;
|
cuda_graph_update_required = true;
|
||||||
}
|
}
|
||||||
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph->ggml_graph_properties[i]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return cuda_graph_update_required;
|
return cuda_graph_update_required;
|
||||||
|
|
@ -3057,22 +3061,14 @@ static void capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cuda_g
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph* cgraph) {
|
static bool should_use_cuda_graph(ggml_backend_cuda_context * cuda_ctx, const struct ggml_cgraph * cgraph) {
|
||||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
|
||||||
|
|
||||||
ggml_cuda_set_device(cuda_ctx->device);
|
|
||||||
|
|
||||||
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
|
||||||
|
|
||||||
ggml_cuda_graph * cuda_graph = new ggml_cuda_graph();
|
|
||||||
|
|
||||||
cuda_graph->cgraph = cgraph;
|
|
||||||
|
|
||||||
bool use_cuda_graph = true;
|
bool use_cuda_graph = true;
|
||||||
|
|
||||||
if (!cuda_ctx->cuda_graph_initialized) {
|
if (!cuda_ctx->cuda_graph_initialized) {
|
||||||
|
cuda_ctx->disable_graph_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||||
|
|
||||||
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
|
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
|
||||||
cuda_ctx->disable_due_to_gpu_arch = true;
|
cuda_ctx->disable_graph_due_to_gpu_arch = true;
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
|
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -3083,17 +3079,30 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen
|
||||||
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
|
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
|
||||||
// or previous graph capture failure.
|
// or previous graph capture failure.
|
||||||
// Also disable for multi-gpu for now. TO DO investigate
|
// Also disable for multi-gpu for now. TO DO investigate
|
||||||
if (disable_cuda_graphs_due_to_env
|
if (cuda_ctx->disable_graph_due_to_env || cuda_ctx->disable_graph_due_to_gpu_arch ||
|
||||||
|| cuda_ctx->disable_due_to_gpu_arch) {
|
cuda_ctx->disable_graph_due_to_too_many_updates) {
|
||||||
use_cuda_graph = false;
|
use_cuda_graph = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (use_cuda_graph) {
|
if (use_cuda_graph) {
|
||||||
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
|
use_cuda_graph = check_node_graph_compatibility(cgraph);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (use_cuda_graph) {
|
return use_cuda_graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) {
|
||||||
|
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
||||||
|
|
||||||
|
ggml_cuda_set_device(cuda_ctx->device);
|
||||||
|
|
||||||
|
ggml_cuda_graph * cuda_graph = new ggml_cuda_graph();
|
||||||
|
|
||||||
|
cuda_graph->cgraph = cgraph;
|
||||||
|
|
||||||
|
if (should_use_cuda_graph(cuda_ctx, cgraph)) {
|
||||||
capture_cuda_graph(cuda_ctx, cuda_graph, cgraph);
|
capture_cuda_graph(cuda_ctx, cuda_graph, cgraph);
|
||||||
|
update_cuda_graph_properties(cuda_graph, cgraph);
|
||||||
}
|
}
|
||||||
|
|
||||||
return cuda_graph;
|
return cuda_graph;
|
||||||
|
|
@ -3114,15 +3123,28 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac
|
||||||
|
|
||||||
cuda_graph->cgraph = cgraph;
|
cuda_graph->cgraph = cgraph;
|
||||||
|
|
||||||
if (!cuda_graph->graph) {
|
bool use_cuda_graph = should_use_cuda_graph(cuda_ctx, cgraph);
|
||||||
return;
|
bool cuda_graph_update_required = false;
|
||||||
|
|
||||||
|
// check if we are doing a graph update
|
||||||
|
if (cuda_graph->instance == nullptr && use_cuda_graph // no graph -> graph
|
||||||
|
|| cuda_graph->instance != nullptr && !use_cuda_graph // graph -> no graph
|
||||||
|
|| use_cuda_graph && is_cuda_graph_update_required(cuda_graph, cgraph)) { // graph property mismatch
|
||||||
|
cuda_graph->number_consecutive_updates++;
|
||||||
|
if (cuda_graph->number_consecutive_updates >= 4) {
|
||||||
|
cuda_ctx->disable_graph_due_to_too_many_updates = true;
|
||||||
|
use_cuda_graph = false;
|
||||||
|
} else {
|
||||||
|
cuda_graph_update_required = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cuda_graph->number_consecutive_updates = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool use_cuda_graph = true;
|
if (use_cuda_graph && cuda_graph_update_required) {
|
||||||
|
capture_cuda_graph(cuda_ctx, cuda_graph, cgraph);
|
||||||
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
|
update_cuda_graph_properties(cuda_graph, cgraph);
|
||||||
|
} else if (!use_cuda_graph) {
|
||||||
if (!use_cuda_graph) {
|
|
||||||
if (cuda_graph->instance != nullptr) {
|
if (cuda_graph->instance != nullptr) {
|
||||||
CUDA_CHECK(cudaGraphExecDestroy(cuda_graph->instance));
|
CUDA_CHECK(cudaGraphExecDestroy(cuda_graph->instance));
|
||||||
}
|
}
|
||||||
|
|
@ -3132,10 +3154,6 @@ static void ggml_backend_cuda_graph_plan_update(ggml_backend_t backend, ggml_bac
|
||||||
cuda_graph->instance = nullptr;
|
cuda_graph->instance = nullptr;
|
||||||
cuda_graph->graph = nullptr;
|
cuda_graph->graph = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_cuda_graph_update_required(cuda_graph, cgraph)) {
|
|
||||||
capture_cuda_graph(cuda_ctx, cuda_graph, cgraph);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static enum ggml_status ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
|
static enum ggml_status ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
|
||||||
|
|
|
||||||
|
|
@ -34,11 +34,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_g
|
||||||
// CUDA_GRAPHS_DISABLED
|
// CUDA_GRAPHS_DISABLED
|
||||||
((ncols > 65536) &&
|
((ncols > 65536) &&
|
||||||
((cuda_graph && cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
((cuda_graph && cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
||||||
ctx.disable_due_to_gpu_arch)) ||
|
ctx.disable_graph_due_to_env || ctx.disable_graph_due_to_gpu_arch || ctx.disable_graph_due_to_too_many_updates)) ||
|
||||||
// CUDA_GRAPHS ENABLED
|
// CUDA_GRAPHS ENABLED
|
||||||
((ncols > 32768) &&
|
((ncols > 32768) &&
|
||||||
!((cuda_graph && cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
!((cuda_graph && cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
||||||
ctx.disable_due_to_gpu_arch))) {
|
ctx.disable_graph_due_to_env || ctx.disable_graph_due_to_gpu_arch || ctx.disable_graph_due_to_too_many_updates))) {
|
||||||
#else
|
#else
|
||||||
(ncols > 65536)) {
|
(ncols > 65536)) {
|
||||||
#endif // USE_CUDA_GRAPH
|
#endif // USE_CUDA_GRAPH
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue