From c7af376c298b7d09c280233548668ba6fcc17deb Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sun, 30 Nov 2025 08:17:55 +0800 Subject: [PATCH] CUDA: add stream-based concurrency (#16991) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * CUDA: add stream-based concurrency * HIP: fix hipStreamWaitEvent define and nodiscard warnings * ggml-cuda: fix fusion inside stream * ggml-cuda: fix bug w.r.t first stream launch * ggml-cuda: format * ggml-cuda: improve assert message * ggml-cuda: use lambda instead of duplicating code * ggml-cuda: add some more comments * ggml-cuda: add more detailed comments about concurrency * ggml-cuda: rename + remove unused var * ggml-cuda: fix condition for stream launch * ggml-cuda: address review comments, add destructor * common.cuh: add is_valid for concurrent events * common.cuh: make comment better * update comment Co-authored-by: Johannes Gäßler * update comment Co-authored-by: Johannes Gäßler * common.cuh: fix lower_bound condition + remove join_node data from write_ranges * ggml-cuda: fix overlap condition + shadowing parameter --------- Co-authored-by: Carl Philipp Klemm Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/common.cuh | 170 ++++++++++++++++- ggml/src/ggml-cuda/ggml-cuda.cu | 311 ++++++++++++++++++++++++++++++- ggml/src/ggml-cuda/vendors/hip.h | 2 +- 3 files changed, 469 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 0b10e5f6ae..611341deb0 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -21,10 +21,12 @@ #include "ggml-common.h" #include +#include #include #include #include #include +#include #include #if defined(GGML_USE_HIP) @@ -980,6 +982,154 @@ struct ggml_cuda_graph { #endif }; +struct ggml_cuda_concurrent_event { + std::vector join_events; + cudaEvent_t fork_event = nullptr; + + int n_streams = 0; + std::unordered_map stream_mapping; + + const ggml_tensor * join_node; + + ggml_cuda_concurrent_event() = default; + + ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete; + ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete; + + explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) { + join_events.resize(n_streams); + + for (size_t i = 0; i < join_events.size(); ++i) { + CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming)); + } + + CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming)); + } + + ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept + : join_events(std::move(other.join_events)) + , fork_event(other.fork_event) + , n_streams(other.n_streams) + , stream_mapping(std::move(other.stream_mapping)) + , join_node(other.join_node) { + other.fork_event = nullptr; + } + + // 1. check if any branches write to overlapping memory ranges (except the join node) + // 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event + // we assume all nodes have the same buffer + bool is_valid() const { + std::vector>> write_ranges; + write_ranges.resize(n_streams); + + // get join_node's memory range to exclude from overlap checking. + // multiple nodes can use join_node's buffer; we synchronize on the join node. + const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node; + const int64_t join_start = (int64_t) join_t->data; + const int64_t join_end = join_start + ggml_nbytes(join_t); + + for (const auto & [tensor, stream] : stream_mapping) { + const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor; + const int64_t t_start = (int64_t) t->data; + const int64_t t_end = t_start + ggml_nbytes(t); + + // skip tensors that overlap with join_node's buffer. + if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) { + continue; + } + + // concurrent streams begin from 1 + write_ranges[stream - 1].emplace_back(t_start, t_end); + } + + for (int i = 0; i < n_streams; ++i) { + // sorts first by start then by end of write range + std::sort(write_ranges[i].begin(), write_ranges[i].end()); + } + + bool writes_overlap = false; + bool dependent_srcs = false; + for (const auto & [tensor, stream] : stream_mapping) { + const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor; + const int64_t t_start = (int64_t) t->data; + const int64_t t_end = t_start + ggml_nbytes(t); + + // skip tensors that overlap with join_node's buffer + if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) { + continue; + } + + // check if this buffer's write data overlaps with another stream's + std::pair data_range = std::make_pair(t_start, t_end); + for (int i = 0; i < n_streams; ++i) { + if (i == stream - 1) { + continue; + } + auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range); + + if (it != write_ranges[i].end()) { + const std::pair & other = *it; + + // std::lower_bound returns the first element where other >= data_range (lexicographically). + // This guarantees other.first >= data_range.first. + // Therefore, overlap occurs iff other.first < data_range.second + // (i.e., the other range starts before this range ends). + if (other.first < data_range.second) { + GGML_LOG_DEBUG("Writes overlap for %s", tensor->name); + writes_overlap = true; + break; + } + } + } + + //check if all srcs are either in branch or don't have a branch + for (int i = 0; i < GGML_MAX_SRC; ++i) { + if (!tensor->src[i]) { + continue; + } + + auto it = stream_mapping.find(tensor->src[i]); + + if (it == stream_mapping.end()) { + continue; + } + + if (it->second != stream) { + dependent_srcs = true; + break; + } + } + + if (dependent_srcs || writes_overlap) { + break; + } + } + + return !writes_overlap && !dependent_srcs; + } + + ~ggml_cuda_concurrent_event() { + if (fork_event != nullptr) { + CUDA_CHECK(cudaEventDestroy(fork_event)); + } + for (cudaEvent_t e : join_events) { + if (e != nullptr) { + CUDA_CHECK(cudaEventDestroy(e)); + } + } + } +}; + +struct ggml_cuda_stream_context { + std::vector original_nodes; + std::unordered_map concurrent_events; + + void reset() { + original_nodes.clear(); + concurrent_events.clear(); + } +}; + struct ggml_backend_cuda_context { int device; std::string name; @@ -990,11 +1140,15 @@ struct ggml_backend_cuda_context { std::unique_ptr cuda_graph; + int curr_stream_no = 0; + explicit ggml_backend_cuda_context(int device) : device(device), name(GGML_CUDA_NAME + std::to_string(device)) { } + ggml_cuda_stream_context concurrent_stream_context; + ~ggml_backend_cuda_context(); cudaStream_t stream(int device, int stream) { @@ -1005,9 +1159,9 @@ struct ggml_backend_cuda_context { return streams[device][stream]; } - cudaStream_t stream() { - return stream(device, 0); - } + cudaStream_t stream() { return stream(device, curr_stream_no); } + + ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; } cublasHandle_t cublas_handle(int device) { if (cublas_handles[device] == nullptr) { @@ -1023,15 +1177,15 @@ struct ggml_backend_cuda_context { } // pool - std::unique_ptr pools[GGML_CUDA_MAX_DEVICES]; + std::unique_ptr pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; - static std::unique_ptr new_pool_for_device(int device); + static std::unique_ptr new_pool_for_device(int device, int stream_no); ggml_cuda_pool & pool(int device) { - if (pools[device] == nullptr) { - pools[device] = new_pool_for_device(device); + if (pools[device][curr_stream_no] == nullptr) { + pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no); } - return *pools[device]; + return *pools[device][curr_stream_no]; } ggml_cuda_pool & pool() { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index a844a3d99a..fa7e1e13a7 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -522,7 +522,8 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { }; #endif // defined(GGML_USE_VMM) -std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { +std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device, + [[maybe_unused]] int stream_no) { #if defined(GGML_USE_VMM) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr(new ggml_cuda_pool_vmm(device)); @@ -3200,18 +3201,83 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx // flag used to determine whether it is an integrated_gpu const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; + ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context(); + bool is_concurrent_event_active = false; + ggml_cuda_concurrent_event * concurrent_event = nullptr; + bool should_launch_concurrent_events = false; + + const auto try_launch_concurrent_event = [&](const ggml_tensor * node) { + if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) { + concurrent_event = &stream_ctx.concurrent_events[node]; + + is_concurrent_event_active = true; + + GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name); + + cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0 + GGML_ASSERT(cuda_ctx->curr_stream_no == 0); + CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream)); + + for (int i = 1; i <= concurrent_event->n_streams; ++i) { + cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i); + CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event)); + } + } + }; + while (!graph_evaluated_or_captured) { // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. // With the use of CUDA graphs, the execution will be performed by the graph launch. if (!use_cuda_graph || cuda_graph_update_required) { - [[maybe_unused]] int prev_i = 0; + if (stream_ctx.concurrent_events.size() > 0) { + should_launch_concurrent_events = true; + for (const auto & [tensor, event] : stream_ctx.concurrent_events) { + should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid(); + } + } + if (should_launch_concurrent_events) { + //Restore the original graph to enable fusion within the streams + cgraph->nodes = const_cast(stream_ctx.original_nodes.data()); + cgraph->n_nodes = (int) stream_ctx.original_nodes.size(); + } + for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; + if (is_concurrent_event_active) { + GGML_ASSERT(concurrent_event); + + if (node == concurrent_event->join_node) { + cuda_ctx->curr_stream_no = 0; + for (int i = 1; i <= concurrent_event->n_streams; ++i) { + // Wait on join events of forked streams in the main stream + CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1], + cuda_ctx->stream(cuda_ctx->device, i))); + CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1])); + } + + is_concurrent_event_active = false; + concurrent_event = nullptr; + } else { + GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end()); + cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node]; + GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name); + } + } else if (i - prev_i > 1) { + //the previous node was fused + const ggml_tensor * prev_node = cgraph->nodes[i - 1]; + try_launch_concurrent_event(prev_node); + + if (is_concurrent_event_active) { + cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node]; + GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name); + } + } + prev_i = i; + #ifdef GGML_CUDA_DEBUG const int nodes_fused = i - prev_i - 1; - prev_i = i; if (nodes_fused > 0) { GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused); } @@ -3221,6 +3287,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } + + // start of fusion operations static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); if (!disable_fusion) { @@ -3513,13 +3581,17 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } #else GGML_UNUSED(integrated); -#endif // NDEBUG +#endif // NDEBUG bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); if (!ok) { GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); } GGML_ASSERT(ok); + + if (!is_concurrent_event_active) { + try_launch_concurrent_event(node); + } } } @@ -3659,6 +3731,235 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev } } +static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + + static bool enable_graph_optimization = [] { + const char * env = getenv("GGML_CUDA_GRAPH_OPT"); + return env != nullptr && atoi(env) == 1; + }(); + + if (!enable_graph_optimization) { + return; + } + + GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "compute graph optimization is only supported on single GPU in the CUDA backend"); + GGML_LOG_DEBUG("Optimizing CUDA graph %p with %d nodes\n", cgraph->nodes, cgraph->n_nodes); + + ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context(); + stream_context.reset(); + + // number of out-degrees for a particular node + std::unordered_map fan_out; + // reverse mapping of node to index in the cgraph + std::unordered_map node_indices; + + const auto & is_noop = [](const ggml_tensor * node) -> bool { + return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || + node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; + }; + + const auto & depends_on = [](const ggml_tensor * dst, const ggml_tensor * src) -> bool { + for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) { + if (dst->src[s] == src) { + return true; + } + } + // implicit dependency if they view the same tensor + const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst; + const ggml_tensor * src2 = src->view_src ? src->view_src : src; + if (dst2 == src2) { + return true; + } + return false; + }; + + for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) { + const ggml_tensor * node = cgraph->nodes[node_idx]; + node_indices[node] = node_idx; + + if (is_noop(node)) { + continue; + } + for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { + const ggml_tensor * src = cgraph->nodes[node_idx]->src[src_idx]; + //TODO: check why nrows > 1 fails + if (node && !is_noop(node) && ggml_nrows(node) <= 1) { + fan_out[src] += 1; + } + } + } + + // Target Q, K, V for concurrency + // this is a more general way to find nodes which can be candidates for concurrency (although it has not been tested for anything else): + // 1. find fan-out (fork) nodes where the same input is used at least N times (in QKV, it would be "attn-norm") + // 2. find the join node, where 2 or more of the outputs are required (in QKV, this would "KQ" or "flash-attn") + // 3. account for all branches from the fork to the join + // 4. To extend lifetimes of the tensors, we interleave the branches (see below for more details) + // 5. save the original cgraph and restore it in graph_compute, to enable fusion within streams + // See discussion: https://github.com/ggml-org/llama.cpp/pull/16991#issuecomment-3522620030 + + const int min_fan_out = 3; + const int max_fan_out = 3; + + // store {fork_idx, join_idx} + std::vector> concurrent_node_ranges; + + // save the original nodes + std::vector original_nodes; + original_nodes.reserve(cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; ++i) { + original_nodes.push_back(cgraph->nodes[i]); + } + cuda_ctx->stream_context().original_nodes = std::move(original_nodes); + + for (const auto & [root_node, count] : fan_out) { + if (count >= min_fan_out && count <= max_fan_out) { + const int root_node_idx = node_indices[root_node]; + + bool is_part_of_event = false; + for (const auto & [start, end] : concurrent_node_ranges) { + if (root_node_idx >= start && root_node_idx <= end) { + is_part_of_event = true; + } + } + + if (is_part_of_event) { + continue; + } + + std::vector> nodes_per_branch; + for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) { + const ggml_tensor * node = cgraph->nodes[i]; + if (!is_noop(node) && depends_on(node, root_node)) { + nodes_per_branch.push_back({ node }); + } + } + + GGML_ASSERT(nodes_per_branch.size() == (size_t) count); + + //find the join point + const ggml_tensor * join_node = nullptr; + + const auto & belongs_to_branch = [&](const ggml_tensor * node, + const std::vector & branch) -> bool { + for (const ggml_tensor * n : branch) { + if (depends_on(node, n)) { + return true; + } + } + return false; + }; + + for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) { + const ggml_tensor * curr_node = cgraph->nodes[i]; + + int num_joins = 0; + for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) { + if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) { + num_joins++; + } + } + + if (num_joins >= 2) { + join_node = curr_node; + break; + } + + bool found_branch = false; + for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) { + std::vector & branch_vec = nodes_per_branch[branch_idx]; + if (belongs_to_branch(curr_node, branch_vec)) { + //continue accumulating + if (std::find(branch_vec.begin(), branch_vec.end(), curr_node) == branch_vec.end()) { + branch_vec.push_back(curr_node); + } + found_branch = true; + } + } + + if (!found_branch && is_noop(curr_node)) { + // we can put it in any branch because it will be ignored + nodes_per_branch[0].push_back({ curr_node }); + } + } + + if (join_node) { + //Create ggml_cuda_concurrent_event + ggml_cuda_concurrent_event concurrent_event(nodes_per_branch.size()); + concurrent_event.join_node = join_node; + + for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) { + for (const ggml_tensor * n : nodes_per_branch[branch_idx]) { + concurrent_event.stream_mapping[n] = branch_idx + 1; + } + } + + int fork_node_idx = node_indices[root_node]; + int join_node_idx = node_indices[join_node]; + + int current_branch_idx = 0; + int current_node_idx = fork_node_idx + 1; + const int n_branches = nodes_per_branch.size(); + + int total_branch_nodes = 0; + for (std::vector branch_nodes : nodes_per_branch) { + total_branch_nodes += branch_nodes.size(); + } + + // there are other nodes in the middle which are unaccounted for + // usually (cpy) nodes, then ignore this fork + if (join_node_idx - fork_node_idx - 1 != total_branch_nodes) { + GGML_LOG_DEBUG( + "Skipping %s because the number of nodes in the middle is not equal to the total number of " + "branch nodes %d != %d\n", + root_node->name, join_node_idx - fork_node_idx - 1, total_branch_nodes); + continue; + } + + std::unordered_map & concurrent_events = cuda_ctx->stream_context().concurrent_events; + GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end()); + concurrent_events.emplace(root_node, std::move(concurrent_event)); + GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node); + concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx); + + // interleave tensors to extend lifetimes so that ggml graph doesn't recycle them + // example transformation: + // [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] -> + // [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn] + while (current_node_idx < join_node_idx) { + std::vector & branch_nodes = nodes_per_branch[current_branch_idx]; + + bool has_node = false; + for (std::vector branch_node : nodes_per_branch) { + has_node |= branch_node.size() > 0; + } + + GGML_ASSERT(has_node); + + if (branch_nodes.empty()) { + current_branch_idx = (current_branch_idx + 1) % n_branches; + continue; + } + + cgraph->nodes[current_node_idx] = const_cast(branch_nodes.front()); + current_node_idx++; + branch_nodes.erase(branch_nodes.begin()); + + // append all empty nodes + while (!branch_nodes.empty() && is_noop(branch_nodes.front())) { + cgraph->nodes[current_node_idx] = const_cast(branch_nodes.front()); + current_node_idx++; + branch_nodes.erase(branch_nodes.begin()); + } + + current_branch_idx = (current_branch_idx + 1) % n_branches; + } + } + } + } +} + static const ggml_backend_i ggml_backend_cuda_interface = { /* .get_name = */ ggml_backend_cuda_get_name, /* .free = */ ggml_backend_cuda_free, @@ -3673,7 +3974,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .graph_compute = */ ggml_backend_cuda_graph_compute, /* .event_record = */ ggml_backend_cuda_event_record, /* .event_wait = */ ggml_backend_cuda_event_wait, - /* .graph_optimize = */ NULL, + /* .graph_optimize = */ ggml_backend_cuda_graph_optimize, }; static ggml_guid_t ggml_backend_cuda_guid() { diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 890c103649..b7d6edf7fc 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -105,7 +105,7 @@ #define cudaStreamNonBlocking hipStreamNonBlocking #define cudaStreamPerThread hipStreamPerThread #define cudaStreamSynchronize hipStreamSynchronize -#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) +#define cudaStreamWaitEvent hipStreamWaitEvent #define cudaGraphExec_t hipGraphExec_t #define cudaGraphNode_t hipGraphNode_t #define cudaKernelNodeParams hipKernelNodeParams