CUDA: add stream-based concurrency (#16991)

* CUDA: add stream-based concurrency

* HIP: fix hipStreamWaitEvent define and nodiscard warnings

* ggml-cuda: fix fusion inside stream

* ggml-cuda: fix bug w.r.t first stream launch

* ggml-cuda: format

* ggml-cuda: improve assert message

* ggml-cuda: use lambda instead of duplicating code

* ggml-cuda: add some more comments

* ggml-cuda: add more detailed comments about concurrency

* ggml-cuda: rename + remove unused var

* ggml-cuda: fix condition for stream launch

* ggml-cuda: address review comments, add destructor

* common.cuh: add is_valid for concurrent events

* common.cuh: make comment better

* update comment

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* update comment

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* common.cuh: fix lower_bound condition + remove join_node data from write_ranges

* ggml-cuda: fix overlap condition + shadowing parameter

---------

Co-authored-by: Carl Philipp Klemm <carl@uvos.xyz>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
Aman Gupta 2025-11-30 08:17:55 +08:00 committed by GitHub
parent 00425e2ed1
commit c7af376c29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 469 additions and 14 deletions

View File

@ -21,10 +21,12 @@
#include "ggml-common.h"
#include <array>
#include <algorithm>
#include <cassert>
#include <cfloat>
#include <cstdio>
#include <string>
#include <unordered_map>
#include <vector>
#if defined(GGML_USE_HIP)
@ -980,6 +982,154 @@ struct ggml_cuda_graph {
#endif
};
struct ggml_cuda_concurrent_event {
std::vector<cudaEvent_t> join_events;
cudaEvent_t fork_event = nullptr;
int n_streams = 0;
std::unordered_map<const ggml_tensor *, int> stream_mapping;
const ggml_tensor * join_node;
ggml_cuda_concurrent_event() = default;
ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete;
ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete;
explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
join_events.resize(n_streams);
for (size_t i = 0; i < join_events.size(); ++i) {
CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));
}
CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));
}
ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept
: join_events(std::move(other.join_events))
, fork_event(other.fork_event)
, n_streams(other.n_streams)
, stream_mapping(std::move(other.stream_mapping))
, join_node(other.join_node) {
other.fork_event = nullptr;
}
// 1. check if any branches write to overlapping memory ranges (except the join node)
// 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event
// we assume all nodes have the same buffer
bool is_valid() const {
std::vector<std::vector<std::pair<int64_t, int64_t>>> write_ranges;
write_ranges.resize(n_streams);
// get join_node's memory range to exclude from overlap checking.
// multiple nodes can use join_node's buffer; we synchronize on the join node.
const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node;
const int64_t join_start = (int64_t) join_t->data;
const int64_t join_end = join_start + ggml_nbytes(join_t);
for (const auto & [tensor, stream] : stream_mapping) {
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
const int64_t t_start = (int64_t) t->data;
const int64_t t_end = t_start + ggml_nbytes(t);
// skip tensors that overlap with join_node's buffer.
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
continue;
}
// concurrent streams begin from 1
write_ranges[stream - 1].emplace_back(t_start, t_end);
}
for (int i = 0; i < n_streams; ++i) {
// sorts first by start then by end of write range
std::sort(write_ranges[i].begin(), write_ranges[i].end());
}
bool writes_overlap = false;
bool dependent_srcs = false;
for (const auto & [tensor, stream] : stream_mapping) {
const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
const int64_t t_start = (int64_t) t->data;
const int64_t t_end = t_start + ggml_nbytes(t);
// skip tensors that overlap with join_node's buffer
if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
continue;
}
// check if this buffer's write data overlaps with another stream's
std::pair<int64_t, int64_t> data_range = std::make_pair(t_start, t_end);
for (int i = 0; i < n_streams; ++i) {
if (i == stream - 1) {
continue;
}
auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range);
if (it != write_ranges[i].end()) {
const std::pair<int64_t, int64_t> & other = *it;
// std::lower_bound returns the first element where other >= data_range (lexicographically).
// This guarantees other.first >= data_range.first.
// Therefore, overlap occurs iff other.first < data_range.second
// (i.e., the other range starts before this range ends).
if (other.first < data_range.second) {
GGML_LOG_DEBUG("Writes overlap for %s", tensor->name);
writes_overlap = true;
break;
}
}
}
//check if all srcs are either in branch or don't have a branch
for (int i = 0; i < GGML_MAX_SRC; ++i) {
if (!tensor->src[i]) {
continue;
}
auto it = stream_mapping.find(tensor->src[i]);
if (it == stream_mapping.end()) {
continue;
}
if (it->second != stream) {
dependent_srcs = true;
break;
}
}
if (dependent_srcs || writes_overlap) {
break;
}
}
return !writes_overlap && !dependent_srcs;
}
~ggml_cuda_concurrent_event() {
if (fork_event != nullptr) {
CUDA_CHECK(cudaEventDestroy(fork_event));
}
for (cudaEvent_t e : join_events) {
if (e != nullptr) {
CUDA_CHECK(cudaEventDestroy(e));
}
}
}
};
struct ggml_cuda_stream_context {
std::vector<const ggml_tensor *> original_nodes;
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
void reset() {
original_nodes.clear();
concurrent_events.clear();
}
};
struct ggml_backend_cuda_context {
int device;
std::string name;
@ -990,11 +1140,15 @@ struct ggml_backend_cuda_context {
std::unique_ptr<ggml_cuda_graph> cuda_graph;
int curr_stream_no = 0;
explicit ggml_backend_cuda_context(int device) :
device(device),
name(GGML_CUDA_NAME + std::to_string(device)) {
}
ggml_cuda_stream_context concurrent_stream_context;
~ggml_backend_cuda_context();
cudaStream_t stream(int device, int stream) {
@ -1005,9 +1159,9 @@ struct ggml_backend_cuda_context {
return streams[device][stream];
}
cudaStream_t stream() {
return stream(device, 0);
}
cudaStream_t stream() { return stream(device, curr_stream_no); }
ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }
cublasHandle_t cublas_handle(int device) {
if (cublas_handles[device] == nullptr) {
@ -1023,15 +1177,15 @@ struct ggml_backend_cuda_context {
}
// pool
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);
ggml_cuda_pool & pool(int device) {
if (pools[device] == nullptr) {
pools[device] = new_pool_for_device(device);
if (pools[device][curr_stream_no] == nullptr) {
pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
}
return *pools[device];
return *pools[device][curr_stream_no];
}
ggml_cuda_pool & pool() {

View File

@ -522,7 +522,8 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
};
#endif // defined(GGML_USE_VMM)
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device,
[[maybe_unused]] int stream_no) {
#if defined(GGML_USE_VMM)
if (ggml_cuda_info().devices[device].vmm) {
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
@ -3200,18 +3201,83 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
// flag used to determine whether it is an integrated_gpu
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
bool is_concurrent_event_active = false;
ggml_cuda_concurrent_event * concurrent_event = nullptr;
bool should_launch_concurrent_events = false;
const auto try_launch_concurrent_event = [&](const ggml_tensor * node) {
if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {
concurrent_event = &stream_ctx.concurrent_events[node];
is_concurrent_event_active = true;
GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name);
cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0
GGML_ASSERT(cuda_ctx->curr_stream_no == 0);
CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream));
for (int i = 1; i <= concurrent_event->n_streams; ++i) {
cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);
CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));
}
}
};
while (!graph_evaluated_or_captured) {
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
// With the use of CUDA graphs, the execution will be performed by the graph launch.
if (!use_cuda_graph || cuda_graph_update_required) {
[[maybe_unused]] int prev_i = 0;
if (stream_ctx.concurrent_events.size() > 0) {
should_launch_concurrent_events = true;
for (const auto & [tensor, event] : stream_ctx.concurrent_events) {
should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid();
}
}
if (should_launch_concurrent_events) {
//Restore the original graph to enable fusion within the streams
cgraph->nodes = const_cast<ggml_tensor **>(stream_ctx.original_nodes.data());
cgraph->n_nodes = (int) stream_ctx.original_nodes.size();
}
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (is_concurrent_event_active) {
GGML_ASSERT(concurrent_event);
if (node == concurrent_event->join_node) {
cuda_ctx->curr_stream_no = 0;
for (int i = 1; i <= concurrent_event->n_streams; ++i) {
// Wait on join events of forked streams in the main stream
CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1],
cuda_ctx->stream(cuda_ctx->device, i)));
CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1]));
}
is_concurrent_event_active = false;
concurrent_event = nullptr;
} else {
GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end());
cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
}
} else if (i - prev_i > 1) {
//the previous node was fused
const ggml_tensor * prev_node = cgraph->nodes[i - 1];
try_launch_concurrent_event(prev_node);
if (is_concurrent_event_active) {
cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
}
}
prev_i = i;
#ifdef GGML_CUDA_DEBUG
const int nodes_fused = i - prev_i - 1;
prev_i = i;
if (nodes_fused > 0) {
GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused);
}
@ -3221,6 +3287,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
continue;
}
// start of fusion operations
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
@ -3520,6 +3588,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
}
GGML_ASSERT(ok);
if (!is_concurrent_event_active) {
try_launch_concurrent_event(node);
}
}
}
@ -3659,6 +3731,235 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
}
}
static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
static bool enable_graph_optimization = [] {
const char * env = getenv("GGML_CUDA_GRAPH_OPT");
return env != nullptr && atoi(env) == 1;
}();
if (!enable_graph_optimization) {
return;
}
GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "compute graph optimization is only supported on single GPU in the CUDA backend");
GGML_LOG_DEBUG("Optimizing CUDA graph %p with %d nodes\n", cgraph->nodes, cgraph->n_nodes);
ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();
stream_context.reset();
// number of out-degrees for a particular node
std::unordered_map<const ggml_tensor *, int> fan_out;
// reverse mapping of node to index in the cgraph
std::unordered_map<const ggml_tensor *, int> node_indices;
const auto & is_noop = [](const ggml_tensor * node) -> bool {
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE ||
node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
};
const auto & depends_on = [](const ggml_tensor * dst, const ggml_tensor * src) -> bool {
for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
if (dst->src[s] == src) {
return true;
}
}
// implicit dependency if they view the same tensor
const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst;
const ggml_tensor * src2 = src->view_src ? src->view_src : src;
if (dst2 == src2) {
return true;
}
return false;
};
for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
const ggml_tensor * node = cgraph->nodes[node_idx];
node_indices[node] = node_idx;
if (is_noop(node)) {
continue;
}
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
const ggml_tensor * src = cgraph->nodes[node_idx]->src[src_idx];
//TODO: check why nrows > 1 fails
if (node && !is_noop(node) && ggml_nrows(node) <= 1) {
fan_out[src] += 1;
}
}
}
// Target Q, K, V for concurrency
// this is a more general way to find nodes which can be candidates for concurrency (although it has not been tested for anything else):
// 1. find fan-out (fork) nodes where the same input is used at least N times (in QKV, it would be "attn-norm")
// 2. find the join node, where 2 or more of the outputs are required (in QKV, this would "KQ" or "flash-attn")
// 3. account for all branches from the fork to the join
// 4. To extend lifetimes of the tensors, we interleave the branches (see below for more details)
// 5. save the original cgraph and restore it in graph_compute, to enable fusion within streams
// See discussion: https://github.com/ggml-org/llama.cpp/pull/16991#issuecomment-3522620030
const int min_fan_out = 3;
const int max_fan_out = 3;
// store {fork_idx, join_idx}
std::vector<std::pair<int, int>> concurrent_node_ranges;
// save the original nodes
std::vector<const ggml_tensor *> original_nodes;
original_nodes.reserve(cgraph->n_nodes);
for (int i = 0; i < cgraph->n_nodes; ++i) {
original_nodes.push_back(cgraph->nodes[i]);
}
cuda_ctx->stream_context().original_nodes = std::move(original_nodes);
for (const auto & [root_node, count] : fan_out) {
if (count >= min_fan_out && count <= max_fan_out) {
const int root_node_idx = node_indices[root_node];
bool is_part_of_event = false;
for (const auto & [start, end] : concurrent_node_ranges) {
if (root_node_idx >= start && root_node_idx <= end) {
is_part_of_event = true;
}
}
if (is_part_of_event) {
continue;
}
std::vector<std::vector<const ggml_tensor *>> nodes_per_branch;
for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
const ggml_tensor * node = cgraph->nodes[i];
if (!is_noop(node) && depends_on(node, root_node)) {
nodes_per_branch.push_back({ node });
}
}
GGML_ASSERT(nodes_per_branch.size() == (size_t) count);
//find the join point
const ggml_tensor * join_node = nullptr;
const auto & belongs_to_branch = [&](const ggml_tensor * node,
const std::vector<const ggml_tensor *> & branch) -> bool {
for (const ggml_tensor * n : branch) {
if (depends_on(node, n)) {
return true;
}
}
return false;
};
for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
const ggml_tensor * curr_node = cgraph->nodes[i];
int num_joins = 0;
for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) {
num_joins++;
}
}
if (num_joins >= 2) {
join_node = curr_node;
break;
}
bool found_branch = false;
for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
std::vector<const ggml_tensor *> & branch_vec = nodes_per_branch[branch_idx];
if (belongs_to_branch(curr_node, branch_vec)) {
//continue accumulating
if (std::find(branch_vec.begin(), branch_vec.end(), curr_node) == branch_vec.end()) {
branch_vec.push_back(curr_node);
}
found_branch = true;
}
}
if (!found_branch && is_noop(curr_node)) {
// we can put it in any branch because it will be ignored
nodes_per_branch[0].push_back({ curr_node });
}
}
if (join_node) {
//Create ggml_cuda_concurrent_event
ggml_cuda_concurrent_event concurrent_event(nodes_per_branch.size());
concurrent_event.join_node = join_node;
for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
for (const ggml_tensor * n : nodes_per_branch[branch_idx]) {
concurrent_event.stream_mapping[n] = branch_idx + 1;
}
}
int fork_node_idx = node_indices[root_node];
int join_node_idx = node_indices[join_node];
int current_branch_idx = 0;
int current_node_idx = fork_node_idx + 1;
const int n_branches = nodes_per_branch.size();
int total_branch_nodes = 0;
for (std::vector<const ggml_tensor *> branch_nodes : nodes_per_branch) {
total_branch_nodes += branch_nodes.size();
}
// there are other nodes in the middle which are unaccounted for
// usually (cpy) nodes, then ignore this fork
if (join_node_idx - fork_node_idx - 1 != total_branch_nodes) {
GGML_LOG_DEBUG(
"Skipping %s because the number of nodes in the middle is not equal to the total number of "
"branch nodes %d != %d\n",
root_node->name, join_node_idx - fork_node_idx - 1, total_branch_nodes);
continue;
}
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;
GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());
concurrent_events.emplace(root_node, std::move(concurrent_event));
GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node);
concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx);
// interleave tensors to extend lifetimes so that ggml graph doesn't recycle them
// example transformation:
// [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] ->
// [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn]
while (current_node_idx < join_node_idx) {
std::vector<const ggml_tensor *> & branch_nodes = nodes_per_branch[current_branch_idx];
bool has_node = false;
for (std::vector<const ggml_tensor *> branch_node : nodes_per_branch) {
has_node |= branch_node.size() > 0;
}
GGML_ASSERT(has_node);
if (branch_nodes.empty()) {
current_branch_idx = (current_branch_idx + 1) % n_branches;
continue;
}
cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
current_node_idx++;
branch_nodes.erase(branch_nodes.begin());
// append all empty nodes
while (!branch_nodes.empty() && is_noop(branch_nodes.front())) {
cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
current_node_idx++;
branch_nodes.erase(branch_nodes.begin());
}
current_branch_idx = (current_branch_idx + 1) % n_branches;
}
}
}
}
}
static const ggml_backend_i ggml_backend_cuda_interface = {
/* .get_name = */ ggml_backend_cuda_get_name,
/* .free = */ ggml_backend_cuda_free,
@ -3673,7 +3974,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
/* .event_record = */ ggml_backend_cuda_event_record,
/* .event_wait = */ ggml_backend_cuda_event_wait,
/* .graph_optimize = */ NULL,
/* .graph_optimize = */ ggml_backend_cuda_graph_optimize,
};
static ggml_guid_t ggml_backend_cuda_guid() {

View File

@ -105,7 +105,7 @@
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamPerThread hipStreamPerThread
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaStreamWaitEvent hipStreamWaitEvent
#define cudaGraphExec_t hipGraphExec_t
#define cudaGraphNode_t hipGraphNode_t
#define cudaKernelNodeParams hipKernelNodeParams