ggml-cuda: reorder only relevant nodes (#17639)
This commit is contained in:
parent
7b6d745364
commit
ed32089927
|
|
@ -989,6 +989,10 @@ struct ggml_cuda_concurrent_event {
|
|||
int n_streams = 0;
|
||||
std::unordered_map<const ggml_tensor *, int> stream_mapping;
|
||||
|
||||
// Original order of nodes in this concurrent region (before interleaving)
|
||||
// Used to restore grouping for fusion within streams
|
||||
std::vector<const ggml_tensor *> original_order;
|
||||
|
||||
const ggml_tensor * join_node;
|
||||
|
||||
ggml_cuda_concurrent_event() = default;
|
||||
|
|
@ -1011,6 +1015,7 @@ struct ggml_cuda_concurrent_event {
|
|||
, fork_event(other.fork_event)
|
||||
, n_streams(other.n_streams)
|
||||
, stream_mapping(std::move(other.stream_mapping))
|
||||
, original_order(std::move(other.original_order))
|
||||
, join_node(other.join_node) {
|
||||
other.fork_event = nullptr;
|
||||
}
|
||||
|
|
@ -1121,11 +1126,9 @@ struct ggml_cuda_concurrent_event {
|
|||
};
|
||||
|
||||
struct ggml_cuda_stream_context {
|
||||
std::vector<const ggml_tensor *> original_nodes;
|
||||
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
|
||||
|
||||
void reset() {
|
||||
original_nodes.clear();
|
||||
concurrent_events.clear();
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -3238,9 +3238,56 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||
}
|
||||
}
|
||||
if (should_launch_concurrent_events) {
|
||||
//Restore the original graph to enable fusion within the streams
|
||||
cgraph->nodes = const_cast<ggml_tensor **>(stream_ctx.original_nodes.data());
|
||||
cgraph->n_nodes = (int) stream_ctx.original_nodes.size();
|
||||
// Restore original node order within each concurrent region to enable fusion within streams
|
||||
|
||||
std::unordered_map<const ggml_tensor *, int> node_to_idx;
|
||||
node_to_idx.reserve(cgraph->n_nodes);
|
||||
for (int i = 0; i < cgraph->n_nodes; ++i) {
|
||||
node_to_idx[cgraph->nodes[i]] = i;
|
||||
}
|
||||
|
||||
for (auto & [fork_node, event] : stream_ctx.concurrent_events) {
|
||||
// Find positions of all nodes from this event in the current graph
|
||||
std::vector<int> positions;
|
||||
positions.reserve(event.original_order.size());
|
||||
|
||||
bool all_found = true;
|
||||
for (const ggml_tensor * orig_node : event.original_order) {
|
||||
auto it = node_to_idx.find(orig_node);
|
||||
if (it != node_to_idx.end()) {
|
||||
positions.push_back(it->second);
|
||||
} else {
|
||||
all_found = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!all_found || positions.size() != event.original_order.size()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Sort positions to get contiguous range
|
||||
std::vector<int> sorted_positions = positions;
|
||||
std::sort(sorted_positions.begin(), sorted_positions.end());
|
||||
|
||||
bool is_contiguous = true;
|
||||
for (size_t i = 1; i < sorted_positions.size(); ++i) {
|
||||
if (sorted_positions[i] != sorted_positions[i-1] + 1) {
|
||||
is_contiguous = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_contiguous) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Restore original order at the sorted positions
|
||||
int start_pos = sorted_positions[0];
|
||||
for (size_t i = 0; i < event.original_order.size(); ++i) {
|
||||
cgraph->nodes[start_pos + i] = const_cast<ggml_tensor *>(event.original_order[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
|
|
@ -3805,14 +3852,6 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
|
|||
// store {fork_idx, join_idx}
|
||||
std::vector<std::pair<int, int>> concurrent_node_ranges;
|
||||
|
||||
// save the original nodes
|
||||
std::vector<const ggml_tensor *> original_nodes;
|
||||
original_nodes.reserve(cgraph->n_nodes);
|
||||
for (int i = 0; i < cgraph->n_nodes; ++i) {
|
||||
original_nodes.push_back(cgraph->nodes[i]);
|
||||
}
|
||||
cuda_ctx->stream_context().original_nodes = std::move(original_nodes);
|
||||
|
||||
for (const auto & [root_node, count] : fan_out) {
|
||||
if (count >= min_fan_out && count <= max_fan_out) {
|
||||
const int root_node_idx = node_indices[root_node];
|
||||
|
|
@ -3917,6 +3956,13 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
|
|||
continue;
|
||||
}
|
||||
|
||||
// Save the original order of nodes in this region before interleaving
|
||||
// This is used later to restore grouping for fusion within streams
|
||||
concurrent_event.original_order.reserve(total_branch_nodes);
|
||||
for (int i = fork_node_idx + 1; i < join_node_idx; ++i) {
|
||||
concurrent_event.original_order.push_back(cgraph->nodes[i]);
|
||||
}
|
||||
|
||||
std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;
|
||||
GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());
|
||||
concurrent_events.emplace(root_node, std::move(concurrent_event));
|
||||
|
|
|
|||
Loading…
Reference in New Issue