cann: fix multi-stream execution with memory-based dependency tracking

- Replace tensor-pointer-based dependency tracking with memory-address-based tracking
- Use std::map<void*, int> to track pending writes per stream
- Implement smart stream selection:
  - No dependencies: round-robin distribution
  - Single dependency: execute on same stream (avoid sync overhead)
  - Multiple dependencies: sync all streams
- Add WAW (Write-After-Write) hazard detection
- Fix output corruption issue when using multi-stream execution

Enable with: GGML_CANN_MULTI_STREAM=1
This commit is contained in:
hipudding 2026-02-03 06:32:50 +00:00
parent c1792d58b5
commit 87e12c60cd
2 changed files with 269 additions and 16 deletions

View File

@ -46,6 +46,7 @@
#define MATRIX_ROW_PADDING 512
#define GGML_CANN_MAX_STREAMS 8
#define GGML_CANN_NUM_COMPUTE_STREAMS 4 // Number of streams for parallel compute
/**
* @brief Handles CANN-related errors by printing an error message and
@ -564,6 +565,12 @@ struct ggml_backend_cann_context {
aclrtStream streams[GGML_CANN_MAX_STREAMS] = { nullptr }; /**< Array of streams for the device. */
// Multi-stream parallel execution support
bool multi_stream_enabled = false; /**< Whether multi-stream execution is enabled. */
int current_stream_idx = 0; /**< Current stream index for round-robin scheduling. */
aclrtEvent stream_events[GGML_CANN_NUM_COMPUTE_STREAMS] = { nullptr }; /**< Events for stream synchronization. */
std::vector<const ggml_tensor *> unsynced_nodes; /**< Nodes that have been executed but not synced. */
/**
* @brief Constructor for initializing the context with a given device.
* @param device Device ID.
@ -592,6 +599,12 @@ struct ggml_backend_cann_context {
ACL_CHECK(aclrtDestroyStream(streams[i]));
}
}
// Clean up multi-stream events
for (int i = 0; i < GGML_CANN_NUM_COMPUTE_STREAMS; ++i) {
if (stream_events[i] != nullptr) {
ACL_CHECK(aclrtDestroyEvent(stream_events[i]));
}
}
}
/**

View File

@ -37,6 +37,7 @@
#include <cmath>
#include <cstdio>
#include <cstring>
#include <map>
#include <mutex>
#include <optional>
#include <queue>
@ -2107,6 +2108,98 @@ static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph,
return false;
}
/**
* @brief Check if a tensor depends on any node in the unsynced list.
*
* This function checks whether the given tensor depends on any of the unsynced
* nodes by examining the source tensors recursively.
*
* @param tensor The tensor to check dependencies for.
* @param unsynced_nodes Set of nodes that haven't been synchronized yet.
* @return true if the tensor depends on any unsynced node.
*/
static bool depends_on_unsynced(const ggml_tensor * tensor,
const std::set<const ggml_tensor *> & unsynced_nodes) {
if (tensor == nullptr) {
return false;
}
// Check if this tensor itself is unsynced
if (unsynced_nodes.count(tensor) > 0) {
return true;
}
// Check view source
if (tensor->view_src != nullptr && unsynced_nodes.count(tensor->view_src) > 0) {
return true;
}
return false;
}
/**
* @brief Check if a node depends on any unsynced nodes through its sources.
*
* @param node The node to check.
* @param unsynced_nodes Set of nodes that haven't been synchronized yet.
* @return true if the node depends on any unsynced node.
*/
static bool node_depends_on_unsynced(const ggml_tensor * node,
const std::set<const ggml_tensor *> & unsynced_nodes) {
for (int s = 0; s < GGML_MAX_SRC; ++s) {
if (depends_on_unsynced(node->src[s], unsynced_nodes)) {
return true;
}
}
return false;
}
/**
* @brief Get the underlying data pointer for a tensor.
*
* Returns the data pointer of the view source if the tensor is a view,
* otherwise returns the tensor's own data pointer.
*
* @param tensor The tensor to get the data pointer for.
* @return The underlying data pointer.
*/
static inline void * get_data_ptr(const ggml_tensor * tensor) {
if (tensor == nullptr) {
return nullptr;
}
return tensor->data;
}
/**
* @brief Check if a node has memory dependencies on pending writes.
*
* This function checks if any of the node's input tensors read from memory
* locations that have pending writes, or if the node's output memory overlaps
* with pending writes.
*
* @param node The node to check.
* @param pending_write_ptrs Set of data pointers with pending writes.
* @return true if there are memory dependencies.
*/
static bool has_memory_dependency(const ggml_tensor * node,
const std::set<void *> & pending_write_ptrs) {
// Check if any source reads from a pending write location
for (int s = 0; s < GGML_MAX_SRC; ++s) {
void * src_ptr = get_data_ptr(node->src[s]);
if (src_ptr != nullptr && pending_write_ptrs.count(src_ptr) > 0) {
return true;
}
}
// Check if output location has pending writes (WAW hazard)
void * dst_ptr = get_data_ptr(node);
if (dst_ptr != nullptr && pending_write_ptrs.count(dst_ptr) > 0) {
return true;
}
return false;
}
/**
* @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
*
@ -2114,6 +2207,8 @@ static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph,
* graph capture, runs the graph, ends capture, and stores the captured graph.
*
* Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
* When multi-stream execution is enabled, nodes are distributed across multiple streams
* for parallel execution.
*
* @param cann_ctx The CANN backend context.
* @param cgraph The ggml computation graph.
@ -2133,31 +2228,176 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
// With the use of CANN graphs, the execution will be performed by the graph launch.
static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
// Check if multi-stream execution is enabled
static bool multi_stream_enabled = [] {
const char * env = getenv("GGML_CANN_MULTI_STREAM");
return env != nullptr && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0);
}();
if (!use_cann_graph || cann_graph_capture_required) {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (opt_fusion) {
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
i++;
continue;
if (multi_stream_enabled && !use_cann_graph) {
// Multi-stream execution mode using memory-based dependency tracking
// Track data pointers that have pending writes on each stream
std::map<void *, int> pending_writes; // data_ptr -> stream_id
std::set<int> active_streams; // streams with pending work
int current_stream = 0;
// Ensure stream events are created
for (int s = 0; s < GGML_CANN_NUM_COMPUTE_STREAMS; ++s) {
if (cann_ctx->stream_events[s] == nullptr) {
ACL_CHECK(aclrtCreateEvent(&cann_ctx->stream_events[s]));
}
}
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
// Helper lambda to synchronize all active streams to the target stream
auto sync_all_to_stream = [&](int target_stream) {
if (active_streams.empty()) return;
// Record events on all active streams
for (int s : active_streams) {
ACL_CHECK(aclrtRecordEvent(cann_ctx->stream_events[s], cann_ctx->stream(s)));
}
// Wait for all events on the target stream
for (int s : active_streams) {
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(target_stream), cann_ctx->stream_events[s]));
}
// Clear tracking
pending_writes.clear();
active_streams.clear();
};
// Helper lambda to wait for a specific stream on the target stream
auto wait_for_stream = [&](int src_stream, int target_stream) {
if (src_stream == target_stream) return;
ACL_CHECK(aclrtRecordEvent(cann_ctx->stream_events[src_stream], cann_ctx->stream(src_stream)));
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(target_stream), cann_ctx->stream_events[src_stream]));
};
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (opt_fusion) {
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
// Fusion ops need synchronization - execute on stream 0
sync_all_to_stream(0);
// Execute fused op on stream 0
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
// Track the output
void * out_ptr = get_data_ptr(cgraph->nodes[i + 1]);
if (out_ptr) {
pending_writes[out_ptr] = 0;
active_streams.insert(0);
}
i++;
current_stream = 1; // Next node goes to stream 1
continue;
}
}
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
// Find which streams we depend on based on input memory locations
std::set<int> dependent_streams;
for (int s = 0; s < GGML_MAX_SRC; ++s) {
void * src_ptr = get_data_ptr(node->src[s]);
if (src_ptr != nullptr) {
auto it = pending_writes.find(src_ptr);
if (it != pending_writes.end()) {
dependent_streams.insert(it->second);
}
}
}
// Check for WAW hazard (output location has pending write)
void * dst_ptr = get_data_ptr(node);
if (dst_ptr != nullptr) {
auto it = pending_writes.find(dst_ptr);
if (it != pending_writes.end()) {
dependent_streams.insert(it->second);
}
}
// Choose which stream to execute on
int exec_stream;
if (dependent_streams.empty()) {
// No dependencies - use round-robin
exec_stream = current_stream;
current_stream = (current_stream + 1) % GGML_CANN_NUM_COMPUTE_STREAMS;
} else if (dependent_streams.size() == 1) {
// Single dependency - execute on the same stream to avoid sync overhead
exec_stream = *dependent_streams.begin();
} else {
// Multiple dependencies - sync all to stream 0 and execute there
sync_all_to_stream(0);
exec_stream = 0;
current_stream = 1;
}
// If we depend on a different stream, wait for it
if (dependent_streams.size() == 1 && *dependent_streams.begin() != exec_stream) {
wait_for_stream(*dependent_streams.begin(), exec_stream);
}
// Execute the node on the chosen stream
// Temporarily swap the default stream
aclrtStream original_stream = cann_ctx->streams[0];
cann_ctx->streams[0] = cann_ctx->stream(exec_stream);
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
if (!ok) {
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
}
GGML_ASSERT(ok);
// Restore the original stream
cann_ctx->streams[0] = original_stream;
// Track the output location
if (dst_ptr != nullptr) {
pending_writes[dst_ptr] = exec_stream;
active_streams.insert(exec_stream);
}
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
// Final synchronization - wait for all streams to complete
for (int s : active_streams) {
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream(s)));
}
} else {
// Single-stream execution mode (original behavior)
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (opt_fusion) {
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
i++;
continue;
}
}
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
if (!ok) {
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
if (!ok) {
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
}
GGML_ASSERT(ok);
}
GGML_ASSERT(ok);
}
}