vulkan: sort graph to allow more parallel execution (#15850)

* vulkan: sort graph to allow more parallel execution

Add a backend proc to allow the backend to modify the graph. The
vulkan implementation looks at which nodes depend on each other
and greedily reorders them to group together nodes that don't
depend on each other. It only reorders the nodes, doesn't change
the contents of any of them.

With #15489, this reduces the number of synchronizations needed.

* call optimize_graph per-split
This commit is contained in:
Jeff Bolz 2025-09-08 13:10:07 -05:00 committed by GitHub
parent 0a16bf52e6
commit e68aa10d8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 154 additions and 0 deletions

View File

@ -114,6 +114,9 @@ extern "C" {
void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event);
// wait for an event on on a different stream
void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event);
// (optional) sort/optimize the nodes in the graph
void (*optimize_graph) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
};
struct ggml_backend {

View File

@ -463,6 +463,13 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
backend->iface.event_wait(backend, event);
}
static void ggml_backend_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
GGML_ASSERT(backend);
if (backend->iface.optimize_graph != NULL) {
backend->iface.optimize_graph(backend, cgraph);
}
}
// Backend device
const char * ggml_backend_dev_name(ggml_backend_dev_t device) {
@ -1298,6 +1305,10 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra
struct ggml_backend_sched_split * split = &sched->splits[i];
split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
// Optimize this split of the graph. This needs to happen before we make graph_copy,
// so they are in sync.
ggml_backend_optimize_graph(sched->backends[split->backend_id], &split->graph);
// add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
for (int j = 0; j < split->n_inputs; j++) {
assert(graph_copy->size > (graph_copy->n_nodes + 1));

View File

@ -270,6 +270,7 @@ static struct ggml_backend_i blas_backend_i = {
/* .graph_compute = */ ggml_backend_blas_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
};
static ggml_guid_t ggml_backend_blas_guid(void) {

View File

@ -2690,6 +2690,7 @@ static const ggml_backend_i ggml_backend_cann_interface = {
/* .graph_compute = */ ggml_backend_cann_graph_compute,
/* .event_record = */ ggml_backend_cann_event_record,
/* .event_wait = */ ggml_backend_cann_event_wait,
/* .optimize_graph = */ NULL,
};
/**

View File

@ -190,6 +190,7 @@ static const struct ggml_backend_i ggml_backend_cpu_i = {
/* .graph_compute = */ ggml_backend_cpu_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
};
static ggml_guid_t ggml_backend_cpu_guid(void) {

View File

@ -3135,6 +3135,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
/* .event_record = */ ggml_backend_cuda_event_record,
/* .event_wait = */ ggml_backend_cuda_event_wait,
/* .optimize_graph = */ NULL,
};
static ggml_guid_t ggml_backend_cuda_guid() {

View File

@ -6275,6 +6275,7 @@ static struct ggml_backend_i ggml_backend_metal_i = {
/* .graph_compute = */ ggml_backend_metal_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
};
static ggml_guid_t ggml_backend_metal_guid(void) {

View File

@ -2838,6 +2838,7 @@ static ggml_backend_i ggml_backend_opencl_i = {
/* .graph_compute = */ ggml_backend_opencl_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
};
ggml_backend_t ggml_backend_opencl_init(void) {

View File

@ -795,6 +795,7 @@ static ggml_backend_i ggml_backend_rpc_interface = {
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
};
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {

View File

@ -4063,6 +4063,7 @@ static ggml_backend_i ggml_backend_sycl_interface = {
/* .graph_compute = */ ggml_backend_sycl_graph_compute,
/* .event_record = */ ggml_backend_sycl_event_record,
/* .event_wait = */ ggml_backend_sycl_event_wait,
/* .optimize_graph = */ NULL,
};
static ggml_guid_t ggml_backend_sycl_guid() {

View File

@ -583,6 +583,7 @@ struct vk_device_struct {
bool disable_fusion;
bool disable_host_visible_vidmem;
bool allow_sysmem_fallback;
bool disable_optimize_graph;
#ifdef GGML_VULKAN_MEMORY_DEBUG
std::unique_ptr<vk_memory_logger> memory_logger;
@ -3592,6 +3593,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK");
device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr;
const char* GGML_VK_DISABLE_OPTIMIZE_GRAPH = getenv("GGML_VK_DISABLE_OPTIMIZE_GRAPH");
device->disable_optimize_graph = GGML_VK_DISABLE_OPTIMIZE_GRAPH != nullptr;
bool fp16_storage = false;
bool fp16_compute = false;
bool maintenance4_support = false;
@ -11853,6 +11857,131 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
UNUSED(backend);
}
// Sort the graph for improved parallelism.
static void ggml_vk_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * graph)
{
VK_LOG_DEBUG("ggml_vk_optimize_graph(" << graph->n_nodes << " nodes)");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
if (ctx->device->disable_optimize_graph) {
return;
}
auto const &is_empty = [](ggml_tensor * node) -> bool {
return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
};
auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool {
for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
if (dst->src[s] == src) {
return true;
}
}
// implicit dependency if they view the same tensor
const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst;
const ggml_tensor *src2 = src->view_src ? src->view_src : src;
if (dst2 == src2) {
return true;
}
return false;
};
// This function tries to reorder the graph to allow nodes to run in parallel.
// This helps with small batches, but for large batches its a slowdown, probably
// due to cache contention. So only reorder if the majority of nodes have few rows.
int num_small_nodes = 0;
int num_counted_nodes = 0;
for (int i = 0; i < graph->n_nodes; ++i) {
if (!is_empty(graph->nodes[i]) &&
graph->nodes[i]->op != GGML_OP_SET_ROWS) {
if (ggml_nrows(graph->nodes[i]) <= 8) {
num_small_nodes++;
}
num_counted_nodes++;
}
}
if (num_small_nodes < num_counted_nodes / 2) {
return;
}
std::vector<ggml_tensor *> new_order;
std::vector<bool> used(graph->n_nodes, false);
int first_unused = 0;
while (first_unused < graph->n_nodes) {
std::vector<int> current_set;
// First, grab the next unused node.
current_set.push_back(first_unused);
// Loop through the next N nodes. Grab any that don't depend on other nodes that
// haven't already been run. Nodes that have already been run have used[i] set
// to true. Allow nodes that depend on the previous node if it's a fusion pattern
// that we support (e.g. RMS_NORM + MUL).
// This first pass only grabs "real" (non-view nodes). Second pass grabs view nodes.
// The goal is to not interleave real and view nodes in a way that breaks fusion.
const int NUM_TO_CHECK = 20;
for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
if (used[j]) {
continue;
}
if (is_empty(graph->nodes[j])) {
continue;
}
bool ok = true;
for (int c = first_unused; c < j; ++c) {
if (!used[c] &&
is_src_of(graph->nodes[j], graph->nodes[c]) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL)) {
ok = false;
break;
}
}
if (ok) {
current_set.push_back(j);
}
}
// Second pass grabs view nodes.
// Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add).
if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) {
for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
if (used[j]) {
continue;
}
if (!is_empty(graph->nodes[j])) {
continue;
}
bool ok = true;
for (int c = first_unused; c < j; ++c) {
bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end();
// skip views whose srcs haven't been processed.
if (!used[c] &&
is_src_of(graph->nodes[j], graph->nodes[c]) &&
!c_in_current_set) {
ok = false;
break;
}
}
if (ok) {
current_set.push_back(j);
}
}
}
// Push the current set into new_order
for (auto c : current_set) {
new_order.push_back(graph->nodes[c]);
used[c] = true;
}
while (first_unused < graph->n_nodes && used[first_unused]) {
first_unused++;
}
}
// Replace the graph with the new order.
for (int i = 0; i < graph->n_nodes; ++i) {
graph->nodes[i] = new_order[i];
}
}
// TODO: enable async and synchronize
static ggml_backend_i ggml_backend_vk_interface = {
/* .get_name = */ ggml_backend_vk_name,
@ -11868,6 +11997,7 @@ static ggml_backend_i ggml_backend_vk_interface = {
/* .graph_compute = */ ggml_backend_vk_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ ggml_vk_optimize_graph,
};
static ggml_guid_t ggml_backend_vk_guid() {

View File

@ -665,6 +665,7 @@ static ggml_backend_i ggml_backend_webgpu_i = {
/* .graph_compute = */ ggml_backend_webgpu_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
};
/* End GGML Backend Interface */

View File

@ -586,6 +586,7 @@ static ggml_backend_i ggml_backend_zdnn_i = {
/* .graph_compute = */ ggml_backend_zdnn_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
};
static ggml_guid_t ggml_backend_zdnn_guid(void) {