From 87e12c60cde6515282f0872beb42dea6f50b877b Mon Sep 17 00:00:00 2001 From: hipudding Date: Tue, 3 Feb 2026 06:32:50 +0000 Subject: [PATCH] cann: fix multi-stream execution with memory-based dependency tracking - Replace tensor-pointer-based dependency tracking with memory-address-based tracking - Use std::map to track pending writes per stream - Implement smart stream selection: - No dependencies: round-robin distribution - Single dependency: execute on same stream (avoid sync overhead) - Multiple dependencies: sync all streams - Add WAW (Write-After-Write) hazard detection - Fix output corruption issue when using multi-stream execution Enable with: GGML_CANN_MULTI_STREAM=1 --- ggml/src/ggml-cann/common.h | 13 ++ ggml/src/ggml-cann/ggml-cann.cpp | 272 +++++++++++++++++++++++++++++-- 2 files changed, 269 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 0120f0dfd1..8feddbe680 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -46,6 +46,7 @@ #define MATRIX_ROW_PADDING 512 #define GGML_CANN_MAX_STREAMS 8 +#define GGML_CANN_NUM_COMPUTE_STREAMS 4 // Number of streams for parallel compute /** * @brief Handles CANN-related errors by printing an error message and @@ -564,6 +565,12 @@ struct ggml_backend_cann_context { aclrtStream streams[GGML_CANN_MAX_STREAMS] = { nullptr }; /**< Array of streams for the device. */ + // Multi-stream parallel execution support + bool multi_stream_enabled = false; /**< Whether multi-stream execution is enabled. */ + int current_stream_idx = 0; /**< Current stream index for round-robin scheduling. */ + aclrtEvent stream_events[GGML_CANN_NUM_COMPUTE_STREAMS] = { nullptr }; /**< Events for stream synchronization. */ + std::vector unsynced_nodes; /**< Nodes that have been executed but not synced. */ + /** * @brief Constructor for initializing the context with a given device. * @param device Device ID. @@ -592,6 +599,12 @@ struct ggml_backend_cann_context { ACL_CHECK(aclrtDestroyStream(streams[i])); } } + // Clean up multi-stream events + for (int i = 0; i < GGML_CANN_NUM_COMPUTE_STREAMS; ++i) { + if (stream_events[i] != nullptr) { + ACL_CHECK(aclrtDestroyEvent(stream_events[i])); + } + } } /** diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 8f202d83f5..d32ebe9b1e 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -37,6 +37,7 @@ #include #include #include +#include #include #include #include @@ -2107,6 +2108,98 @@ static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph, return false; } +/** + * @brief Check if a tensor depends on any node in the unsynced list. + * + * This function checks whether the given tensor depends on any of the unsynced + * nodes by examining the source tensors recursively. + * + * @param tensor The tensor to check dependencies for. + * @param unsynced_nodes Set of nodes that haven't been synchronized yet. + * @return true if the tensor depends on any unsynced node. + */ +static bool depends_on_unsynced(const ggml_tensor * tensor, + const std::set & unsynced_nodes) { + if (tensor == nullptr) { + return false; + } + + // Check if this tensor itself is unsynced + if (unsynced_nodes.count(tensor) > 0) { + return true; + } + + // Check view source + if (tensor->view_src != nullptr && unsynced_nodes.count(tensor->view_src) > 0) { + return true; + } + + return false; +} + +/** + * @brief Check if a node depends on any unsynced nodes through its sources. + * + * @param node The node to check. + * @param unsynced_nodes Set of nodes that haven't been synchronized yet. + * @return true if the node depends on any unsynced node. + */ +static bool node_depends_on_unsynced(const ggml_tensor * node, + const std::set & unsynced_nodes) { + for (int s = 0; s < GGML_MAX_SRC; ++s) { + if (depends_on_unsynced(node->src[s], unsynced_nodes)) { + return true; + } + } + return false; +} + +/** + * @brief Get the underlying data pointer for a tensor. + * + * Returns the data pointer of the view source if the tensor is a view, + * otherwise returns the tensor's own data pointer. + * + * @param tensor The tensor to get the data pointer for. + * @return The underlying data pointer. + */ +static inline void * get_data_ptr(const ggml_tensor * tensor) { + if (tensor == nullptr) { + return nullptr; + } + return tensor->data; +} + +/** + * @brief Check if a node has memory dependencies on pending writes. + * + * This function checks if any of the node's input tensors read from memory + * locations that have pending writes, or if the node's output memory overlaps + * with pending writes. + * + * @param node The node to check. + * @param pending_write_ptrs Set of data pointers with pending writes. + * @return true if there are memory dependencies. + */ +static bool has_memory_dependency(const ggml_tensor * node, + const std::set & pending_write_ptrs) { + // Check if any source reads from a pending write location + for (int s = 0; s < GGML_MAX_SRC; ++s) { + void * src_ptr = get_data_ptr(node->src[s]); + if (src_ptr != nullptr && pending_write_ptrs.count(src_ptr) > 0) { + return true; + } + } + + // Check if output location has pending writes (WAW hazard) + void * dst_ptr = get_data_ptr(node); + if (dst_ptr != nullptr && pending_write_ptrs.count(dst_ptr) > 0) { + return true; + } + + return false; +} + /** * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API. * @@ -2114,6 +2207,8 @@ static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph, * graph capture, runs the graph, ends capture, and stores the captured graph. * * Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher. + * When multi-stream execution is enabled, nodes are distributed across multiple streams + * for parallel execution. * * @param cann_ctx The CANN backend context. * @param cgraph The ggml computation graph. @@ -2133,31 +2228,176 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx // With the use of CANN graphs, the execution will be performed by the graph launch. static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or("")); + // Check if multi-stream execution is enabled + static bool multi_stream_enabled = [] { + const char * env = getenv("GGML_CANN_MULTI_STREAM"); + return env != nullptr && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0); + }(); + if (!use_cann_graph || cann_graph_capture_required) { - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; - if (opt_fusion) { - if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) { - ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]); - i++; - continue; + if (multi_stream_enabled && !use_cann_graph) { + // Multi-stream execution mode using memory-based dependency tracking + // Track data pointers that have pending writes on each stream + std::map pending_writes; // data_ptr -> stream_id + std::set active_streams; // streams with pending work + int current_stream = 0; + + // Ensure stream events are created + for (int s = 0; s < GGML_CANN_NUM_COMPUTE_STREAMS; ++s) { + if (cann_ctx->stream_events[s] == nullptr) { + ACL_CHECK(aclrtCreateEvent(&cann_ctx->stream_events[s])); } } - if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || - node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { - continue; + // Helper lambda to synchronize all active streams to the target stream + auto sync_all_to_stream = [&](int target_stream) { + if (active_streams.empty()) return; + + // Record events on all active streams + for (int s : active_streams) { + ACL_CHECK(aclrtRecordEvent(cann_ctx->stream_events[s], cann_ctx->stream(s))); + } + // Wait for all events on the target stream + for (int s : active_streams) { + ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(target_stream), cann_ctx->stream_events[s])); + } + // Clear tracking + pending_writes.clear(); + active_streams.clear(); + }; + + // Helper lambda to wait for a specific stream on the target stream + auto wait_for_stream = [&](int src_stream, int target_stream) { + if (src_stream == target_stream) return; + ACL_CHECK(aclrtRecordEvent(cann_ctx->stream_events[src_stream], cann_ctx->stream(src_stream))); + ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(target_stream), cann_ctx->stream_events[src_stream])); + }; + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (opt_fusion) { + if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) { + // Fusion ops need synchronization - execute on stream 0 + sync_all_to_stream(0); + + // Execute fused op on stream 0 + ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]); + + // Track the output + void * out_ptr = get_data_ptr(cgraph->nodes[i + 1]); + if (out_ptr) { + pending_writes[out_ptr] = 0; + active_streams.insert(0); + } + i++; + current_stream = 1; // Next node goes to stream 1 + continue; + } + } + + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || + node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + + // Find which streams we depend on based on input memory locations + std::set dependent_streams; + for (int s = 0; s < GGML_MAX_SRC; ++s) { + void * src_ptr = get_data_ptr(node->src[s]); + if (src_ptr != nullptr) { + auto it = pending_writes.find(src_ptr); + if (it != pending_writes.end()) { + dependent_streams.insert(it->second); + } + } + } + + // Check for WAW hazard (output location has pending write) + void * dst_ptr = get_data_ptr(node); + if (dst_ptr != nullptr) { + auto it = pending_writes.find(dst_ptr); + if (it != pending_writes.end()) { + dependent_streams.insert(it->second); + } + } + + // Choose which stream to execute on + int exec_stream; + if (dependent_streams.empty()) { + // No dependencies - use round-robin + exec_stream = current_stream; + current_stream = (current_stream + 1) % GGML_CANN_NUM_COMPUTE_STREAMS; + } else if (dependent_streams.size() == 1) { + // Single dependency - execute on the same stream to avoid sync overhead + exec_stream = *dependent_streams.begin(); + } else { + // Multiple dependencies - sync all to stream 0 and execute there + sync_all_to_stream(0); + exec_stream = 0; + current_stream = 1; + } + + // If we depend on a different stream, wait for it + if (dependent_streams.size() == 1 && *dependent_streams.begin() != exec_stream) { + wait_for_stream(*dependent_streams.begin(), exec_stream); + } + + // Execute the node on the chosen stream + // Temporarily swap the default stream + aclrtStream original_stream = cann_ctx->streams[0]; + cann_ctx->streams[0] = cann_ctx->stream(exec_stream); + + bool ok = ggml_cann_compute_forward(*cann_ctx, node); + if (!ok) { + GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); + + // Restore the original stream + cann_ctx->streams[0] = original_stream; + + // Track the output location + if (dst_ptr != nullptr) { + pending_writes[dst_ptr] = exec_stream; + active_streams.insert(exec_stream); + } } - if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { - continue; + // Final synchronization - wait for all streams to complete + for (int s : active_streams) { + ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream(s))); } + } else { + // Single-stream execution mode (original behavior) + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (opt_fusion) { + if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) { + ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]); + i++; + continue; + } + } - bool ok = ggml_cann_compute_forward(*cann_ctx, node); - if (!ok) { - GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || + node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + + bool ok = ggml_cann_compute_forward(*cann_ctx, node); + if (!ok) { + GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); } - GGML_ASSERT(ok); } }