cann: fix multi-stream execution with memory-based dependency tracking
- Replace tensor-pointer-based dependency tracking with memory-address-based tracking - Use std::map<void*, int> to track pending writes per stream - Implement smart stream selection: - No dependencies: round-robin distribution - Single dependency: execute on same stream (avoid sync overhead) - Multiple dependencies: sync all streams - Add WAW (Write-After-Write) hazard detection - Fix output corruption issue when using multi-stream execution Enable with: GGML_CANN_MULTI_STREAM=1
This commit is contained in:
parent
c1792d58b5
commit
87e12c60cd
|
|
@ -46,6 +46,7 @@
|
|||
|
||||
#define MATRIX_ROW_PADDING 512
|
||||
#define GGML_CANN_MAX_STREAMS 8
|
||||
#define GGML_CANN_NUM_COMPUTE_STREAMS 4 // Number of streams for parallel compute
|
||||
|
||||
/**
|
||||
* @brief Handles CANN-related errors by printing an error message and
|
||||
|
|
@ -564,6 +565,12 @@ struct ggml_backend_cann_context {
|
|||
|
||||
aclrtStream streams[GGML_CANN_MAX_STREAMS] = { nullptr }; /**< Array of streams for the device. */
|
||||
|
||||
// Multi-stream parallel execution support
|
||||
bool multi_stream_enabled = false; /**< Whether multi-stream execution is enabled. */
|
||||
int current_stream_idx = 0; /**< Current stream index for round-robin scheduling. */
|
||||
aclrtEvent stream_events[GGML_CANN_NUM_COMPUTE_STREAMS] = { nullptr }; /**< Events for stream synchronization. */
|
||||
std::vector<const ggml_tensor *> unsynced_nodes; /**< Nodes that have been executed but not synced. */
|
||||
|
||||
/**
|
||||
* @brief Constructor for initializing the context with a given device.
|
||||
* @param device Device ID.
|
||||
|
|
@ -592,6 +599,12 @@ struct ggml_backend_cann_context {
|
|||
ACL_CHECK(aclrtDestroyStream(streams[i]));
|
||||
}
|
||||
}
|
||||
// Clean up multi-stream events
|
||||
for (int i = 0; i < GGML_CANN_NUM_COMPUTE_STREAMS; ++i) {
|
||||
if (stream_events[i] != nullptr) {
|
||||
ACL_CHECK(aclrtDestroyEvent(stream_events[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@
|
|||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <queue>
|
||||
|
|
@ -2107,6 +2108,98 @@ static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph,
|
|||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check if a tensor depends on any node in the unsynced list.
|
||||
*
|
||||
* This function checks whether the given tensor depends on any of the unsynced
|
||||
* nodes by examining the source tensors recursively.
|
||||
*
|
||||
* @param tensor The tensor to check dependencies for.
|
||||
* @param unsynced_nodes Set of nodes that haven't been synchronized yet.
|
||||
* @return true if the tensor depends on any unsynced node.
|
||||
*/
|
||||
static bool depends_on_unsynced(const ggml_tensor * tensor,
|
||||
const std::set<const ggml_tensor *> & unsynced_nodes) {
|
||||
if (tensor == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if this tensor itself is unsynced
|
||||
if (unsynced_nodes.count(tensor) > 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check view source
|
||||
if (tensor->view_src != nullptr && unsynced_nodes.count(tensor->view_src) > 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check if a node depends on any unsynced nodes through its sources.
|
||||
*
|
||||
* @param node The node to check.
|
||||
* @param unsynced_nodes Set of nodes that haven't been synchronized yet.
|
||||
* @return true if the node depends on any unsynced node.
|
||||
*/
|
||||
static bool node_depends_on_unsynced(const ggml_tensor * node,
|
||||
const std::set<const ggml_tensor *> & unsynced_nodes) {
|
||||
for (int s = 0; s < GGML_MAX_SRC; ++s) {
|
||||
if (depends_on_unsynced(node->src[s], unsynced_nodes)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the underlying data pointer for a tensor.
|
||||
*
|
||||
* Returns the data pointer of the view source if the tensor is a view,
|
||||
* otherwise returns the tensor's own data pointer.
|
||||
*
|
||||
* @param tensor The tensor to get the data pointer for.
|
||||
* @return The underlying data pointer.
|
||||
*/
|
||||
static inline void * get_data_ptr(const ggml_tensor * tensor) {
|
||||
if (tensor == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return tensor->data;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check if a node has memory dependencies on pending writes.
|
||||
*
|
||||
* This function checks if any of the node's input tensors read from memory
|
||||
* locations that have pending writes, or if the node's output memory overlaps
|
||||
* with pending writes.
|
||||
*
|
||||
* @param node The node to check.
|
||||
* @param pending_write_ptrs Set of data pointers with pending writes.
|
||||
* @return true if there are memory dependencies.
|
||||
*/
|
||||
static bool has_memory_dependency(const ggml_tensor * node,
|
||||
const std::set<void *> & pending_write_ptrs) {
|
||||
// Check if any source reads from a pending write location
|
||||
for (int s = 0; s < GGML_MAX_SRC; ++s) {
|
||||
void * src_ptr = get_data_ptr(node->src[s]);
|
||||
if (src_ptr != nullptr && pending_write_ptrs.count(src_ptr) > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if output location has pending writes (WAW hazard)
|
||||
void * dst_ptr = get_data_ptr(node);
|
||||
if (dst_ptr != nullptr && pending_write_ptrs.count(dst_ptr) > 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
|
||||
*
|
||||
|
|
@ -2114,6 +2207,8 @@ static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph,
|
|||
* graph capture, runs the graph, ends capture, and stores the captured graph.
|
||||
*
|
||||
* Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
|
||||
* When multi-stream execution is enabled, nodes are distributed across multiple streams
|
||||
* for parallel execution.
|
||||
*
|
||||
* @param cann_ctx The CANN backend context.
|
||||
* @param cgraph The ggml computation graph.
|
||||
|
|
@ -2133,31 +2228,176 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
|||
// With the use of CANN graphs, the execution will be performed by the graph launch.
|
||||
static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
|
||||
|
||||
// Check if multi-stream execution is enabled
|
||||
static bool multi_stream_enabled = [] {
|
||||
const char * env = getenv("GGML_CANN_MULTI_STREAM");
|
||||
return env != nullptr && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0);
|
||||
}();
|
||||
|
||||
if (!use_cann_graph || cann_graph_capture_required) {
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
if (opt_fusion) {
|
||||
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
|
||||
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
|
||||
i++;
|
||||
continue;
|
||||
if (multi_stream_enabled && !use_cann_graph) {
|
||||
// Multi-stream execution mode using memory-based dependency tracking
|
||||
// Track data pointers that have pending writes on each stream
|
||||
std::map<void *, int> pending_writes; // data_ptr -> stream_id
|
||||
std::set<int> active_streams; // streams with pending work
|
||||
int current_stream = 0;
|
||||
|
||||
// Ensure stream events are created
|
||||
for (int s = 0; s < GGML_CANN_NUM_COMPUTE_STREAMS; ++s) {
|
||||
if (cann_ctx->stream_events[s] == nullptr) {
|
||||
ACL_CHECK(aclrtCreateEvent(&cann_ctx->stream_events[s]));
|
||||
}
|
||||
}
|
||||
|
||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
|
||||
node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||
continue;
|
||||
// Helper lambda to synchronize all active streams to the target stream
|
||||
auto sync_all_to_stream = [&](int target_stream) {
|
||||
if (active_streams.empty()) return;
|
||||
|
||||
// Record events on all active streams
|
||||
for (int s : active_streams) {
|
||||
ACL_CHECK(aclrtRecordEvent(cann_ctx->stream_events[s], cann_ctx->stream(s)));
|
||||
}
|
||||
// Wait for all events on the target stream
|
||||
for (int s : active_streams) {
|
||||
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(target_stream), cann_ctx->stream_events[s]));
|
||||
}
|
||||
// Clear tracking
|
||||
pending_writes.clear();
|
||||
active_streams.clear();
|
||||
};
|
||||
|
||||
// Helper lambda to wait for a specific stream on the target stream
|
||||
auto wait_for_stream = [&](int src_stream, int target_stream) {
|
||||
if (src_stream == target_stream) return;
|
||||
ACL_CHECK(aclrtRecordEvent(cann_ctx->stream_events[src_stream], cann_ctx->stream(src_stream)));
|
||||
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(target_stream), cann_ctx->stream_events[src_stream]));
|
||||
};
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
if (opt_fusion) {
|
||||
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
|
||||
// Fusion ops need synchronization - execute on stream 0
|
||||
sync_all_to_stream(0);
|
||||
|
||||
// Execute fused op on stream 0
|
||||
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
|
||||
|
||||
// Track the output
|
||||
void * out_ptr = get_data_ptr(cgraph->nodes[i + 1]);
|
||||
if (out_ptr) {
|
||||
pending_writes[out_ptr] = 0;
|
||||
active_streams.insert(0);
|
||||
}
|
||||
i++;
|
||||
current_stream = 1; // Next node goes to stream 1
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
|
||||
node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Find which streams we depend on based on input memory locations
|
||||
std::set<int> dependent_streams;
|
||||
for (int s = 0; s < GGML_MAX_SRC; ++s) {
|
||||
void * src_ptr = get_data_ptr(node->src[s]);
|
||||
if (src_ptr != nullptr) {
|
||||
auto it = pending_writes.find(src_ptr);
|
||||
if (it != pending_writes.end()) {
|
||||
dependent_streams.insert(it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for WAW hazard (output location has pending write)
|
||||
void * dst_ptr = get_data_ptr(node);
|
||||
if (dst_ptr != nullptr) {
|
||||
auto it = pending_writes.find(dst_ptr);
|
||||
if (it != pending_writes.end()) {
|
||||
dependent_streams.insert(it->second);
|
||||
}
|
||||
}
|
||||
|
||||
// Choose which stream to execute on
|
||||
int exec_stream;
|
||||
if (dependent_streams.empty()) {
|
||||
// No dependencies - use round-robin
|
||||
exec_stream = current_stream;
|
||||
current_stream = (current_stream + 1) % GGML_CANN_NUM_COMPUTE_STREAMS;
|
||||
} else if (dependent_streams.size() == 1) {
|
||||
// Single dependency - execute on the same stream to avoid sync overhead
|
||||
exec_stream = *dependent_streams.begin();
|
||||
} else {
|
||||
// Multiple dependencies - sync all to stream 0 and execute there
|
||||
sync_all_to_stream(0);
|
||||
exec_stream = 0;
|
||||
current_stream = 1;
|
||||
}
|
||||
|
||||
// If we depend on a different stream, wait for it
|
||||
if (dependent_streams.size() == 1 && *dependent_streams.begin() != exec_stream) {
|
||||
wait_for_stream(*dependent_streams.begin(), exec_stream);
|
||||
}
|
||||
|
||||
// Execute the node on the chosen stream
|
||||
// Temporarily swap the default stream
|
||||
aclrtStream original_stream = cann_ctx->streams[0];
|
||||
cann_ctx->streams[0] = cann_ctx->stream(exec_stream);
|
||||
|
||||
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
|
||||
if (!ok) {
|
||||
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||
}
|
||||
GGML_ASSERT(ok);
|
||||
|
||||
// Restore the original stream
|
||||
cann_ctx->streams[0] = original_stream;
|
||||
|
||||
// Track the output location
|
||||
if (dst_ptr != nullptr) {
|
||||
pending_writes[dst_ptr] = exec_stream;
|
||||
active_streams.insert(exec_stream);
|
||||
}
|
||||
}
|
||||
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
continue;
|
||||
// Final synchronization - wait for all streams to complete
|
||||
for (int s : active_streams) {
|
||||
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream(s)));
|
||||
}
|
||||
} else {
|
||||
// Single-stream execution mode (original behavior)
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
if (opt_fusion) {
|
||||
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
|
||||
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
|
||||
if (!ok) {
|
||||
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
|
||||
node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
|
||||
if (!ok) {
|
||||
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||
}
|
||||
GGML_ASSERT(ok);
|
||||
}
|
||||
GGML_ASSERT(ok);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue