This commit is contained in:
Xiangyan Sun 2025-12-16 23:06:53 -05:00 committed by GitHub
commit 732004e64e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 658 additions and 492 deletions

View File

@ -93,8 +93,11 @@ extern "C" {
GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
GGML_API bool ggml_backend_supports_graph_plan(ggml_backend_t backend);
GGML_API bool ggml_backend_supports_graph_plan_update(ggml_backend_t backend);
GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API void ggml_backend_graph_plan_update(ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph);
GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);

View File

@ -332,6 +332,18 @@ void ggml_backend_synchronize(ggml_backend_t backend) {
backend->iface.synchronize(backend);
}
bool ggml_backend_supports_graph_plan(ggml_backend_t backend) {
GGML_ASSERT(backend);
return (bool) backend->iface.graph_plan_create;
}
bool ggml_backend_supports_graph_plan_update(ggml_backend_t backend) {
GGML_ASSERT(backend);
return (bool) backend->iface.graph_plan_update;
}
ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.graph_plan_create != NULL);
@ -346,6 +358,13 @@ void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_pla
backend->iface.graph_plan_free(backend, plan);
}
void ggml_backend_graph_plan_update(ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph* cgraph) {
GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.graph_plan_update != NULL);
backend->iface.graph_plan_update(backend, plan, cgraph);
}
enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.graph_plan_compute != NULL);
@ -680,6 +699,11 @@ struct ggml_backend_sched_split {
struct ggml_cgraph graph;
};
struct ggml_backend_sched_plan {
int backend_id;
ggml_backend_graph_plan_t plan;
};
struct ggml_backend_sched {
bool is_reset; // true if the scheduler has been reset since the last graph split
bool is_alloc;
@ -709,6 +733,12 @@ struct ggml_backend_sched {
int n_splits;
int splits_capacity;
// graph plans
struct ggml_backend_sched_plan * plans;
int n_plans;
int plans_capacity;
bool plan_needs_update;
// pipeline parallelism support
int n_copies;
int cur_copy;
@ -919,6 +949,16 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru
}
}
static void ggml_backend_sched_free_plans(ggml_backend_sched_t sched) {
for (int i = 0; i < sched->n_plans; i++) {
ggml_backend_t backend = sched->backends[sched->plans[i].backend_id];
if (ggml_backend_supports_graph_plan(backend)) {
ggml_backend_graph_plan_free(backend, sched->plans[i].plan);
}
}
sched->n_plans = 0;
}
// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
// reset splits
@ -1386,6 +1426,7 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra
assert(graph_copy->size > graph_copy->n_leafs);
graph_copy->leafs[graph_copy->n_leafs++] = leaf;
}
sched->plan_needs_update = true;
}
static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
@ -1440,6 +1481,62 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
return true;
}
static void ggml_backend_sched_update_plans(ggml_backend_sched_t sched) {
// create graph plans
if (sched->plan_needs_update) {
bool create_new_plans;
if (sched->n_plans == sched->n_splits) {
create_new_plans = false;
for (int i = 0; i < sched->n_splits; i++) {
if (sched->splits[i].backend_id != sched->plans[i].backend_id) {
create_new_plans = true;
break;
}
}
} else {
create_new_plans = true;
}
if (create_new_plans) {
// free previous and recreate new plans
ggml_backend_sched_free_plans(sched);
if (sched->plans_capacity < sched->n_splits) {
while (sched->plans_capacity < sched->n_splits) {
sched->plans_capacity *= 2;
}
sched->plans = (ggml_backend_sched_plan *) realloc(
sched->plans, sched->plans_capacity * sizeof(struct ggml_backend_sched_plan));
GGML_ASSERT(sched->plans);
}
sched->n_plans = sched->n_splits;
for (int i = 0; i < sched->n_splits; i++) {
ggml_backend_t backend = sched->backends[sched->splits[i].backend_id];
sched->plans[i].backend_id = sched->splits[i].backend_id;
if (ggml_backend_supports_graph_plan(backend)) {
sched->plans[i].plan = ggml_backend_graph_plan_create(backend, &sched->splits[i].graph);
} else {
sched->plans[i].plan = nullptr;
}
}
} else {
// update existing plans
for (int i = 0; i < sched->n_splits; i++) {
ggml_backend_t backend = sched->backends[sched->splits[i].backend_id];
if (ggml_backend_supports_graph_plan(backend)) {
if (ggml_backend_supports_graph_plan_update(backend)) {
ggml_backend_graph_plan_update(backend, sched->plans[i].plan, &sched->splits[i].graph);
} else {
ggml_backend_graph_plan_free(backend, sched->plans[i].plan);
sched->plans[i].plan = ggml_backend_graph_plan_create(backend, &sched->splits[i].graph);
}
} else {
sched->plans[i].plan = nullptr;
}
}
}
sched->plan_needs_update = false;
}
}
static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
GGML_ASSERT(sched);
struct ggml_backend_sched_split * splits = sched->splits;
@ -1448,6 +1545,8 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
std::vector<int32_t> ids;
std::vector<ggml_bitset_t> used_ids;
ggml_backend_sched_update_plans(sched);
for (int split_id = 0; split_id < sched->n_splits; split_id++) {
struct ggml_backend_sched_split * split = &splits[split_id];
int split_backend_id = split->backend_id;
@ -1577,7 +1676,12 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
}
if (!sched->callback_eval) {
enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
enum ggml_status ec;
if (ggml_backend_supports_graph_plan(split_backend) && sched->plans[split_id].plan) {
ec = ggml_backend_graph_plan_compute(split_backend, sched->plans[split_id].plan);
} else {
ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
}
if (ec != GGML_STATUS_SUCCESS) {
return ec;
}
@ -1675,6 +1779,10 @@ ggml_backend_sched_t ggml_backend_sched_new(
sched->splits = (ggml_backend_sched_split *) calloc(initial_splits_capacity, sizeof(sched->splits[0]));
sched->splits_capacity = initial_splits_capacity;
const int initial_plans_capacity = 16;
sched->plans = (ggml_backend_sched_plan *) calloc(initial_plans_capacity, sizeof(sched->plans[0]));
sched->plans_capacity = initial_plans_capacity;
for (int b = 0; b < n_backends; b++) {
sched->backends[b] = backends[b];
sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]);
@ -1708,6 +1816,8 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
ggml_free(sched->ctx);
ggml_hash_set_free(&sched->hash_set);
free(sched->splits);
ggml_backend_sched_free_plans(sched);
free(sched->plans);
free(sched->hv_tensor_backend_ids);
free(sched->hv_tensor_copies);
free(sched->node_backend_ids);

View File

@ -1021,13 +1021,12 @@ struct ggml_cuda_graph {
}
cudaGraph_t graph = nullptr;
cudaGraphExec_t instance = nullptr;
const ggml_cgraph * cgraph;
size_t num_nodes = 0;
std::vector<cudaGraphNode_t> nodes;
std::vector<cudaKernelNodeParams> params;
bool disable_due_to_gpu_arch = false;
bool disable_due_to_too_many_updates = false;
bool disable_due_to_failed_graph_capture = false;
int number_consecutive_updates = 0;
int number_consecutive_computes = 0;
std::vector<ggml_graph_node_properties> ggml_graph_properties;
#endif
};
@ -1191,7 +1190,12 @@ struct ggml_backend_cuda_context {
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
std::unique_ptr<ggml_cuda_graph> cuda_graph;
#ifdef USE_CUDA_GRAPH
bool cuda_graph_initialized = false;
bool disable_graph_due_to_env = false;
bool disable_graph_due_to_gpu_arch = false;
bool disable_graph_due_to_too_many_updates = false;
#endif
int curr_stream_no = 0;

File diff suppressed because it is too large Load Diff

View File

@ -10,7 +10,7 @@ template <typename T> __global__ void divide_by_count(T * result, size_t count)
*result /= static_cast<T>(count);
}
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
@ -33,14 +33,12 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
#ifdef USE_CUDA_GRAPH
// CUDA_GRAPHS_DISABLED
((ncols > 65536) &&
((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
ctx.cuda_graph->disable_due_to_failed_graph_capture)) ||
((cuda_graph && cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.disable_graph_due_to_env || ctx.disable_graph_due_to_gpu_arch || ctx.disable_graph_due_to_too_many_updates)) ||
// CUDA_GRAPHS ENABLED
((ncols > 32768) &&
!((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
ctx.cuda_graph->disable_due_to_failed_graph_capture))) {
!((cuda_graph && cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.disable_graph_due_to_env || ctx.disable_graph_due_to_gpu_arch || ctx.disable_graph_due_to_too_many_updates))) {
#else
(ncols > 65536)) {
#endif // USE_CUDA_GRAPH

View File

@ -1,3 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_cuda_graph * cuda_graph, ggml_tensor * dst);