cann: fix multi-stream execution with memory-based dependency tracking
- Replace tensor-pointer-based dependency tracking with memory-address-based tracking - Use std::map<void*, int> to track pending writes per stream - Implement smart stream selection: - No dependencies: round-robin distribution - Single dependency: execute on same stream (avoid sync overhead) - Multiple dependencies: sync all streams - Add WAW (Write-After-Write) hazard detection - Fix output corruption issue when using multi-stream execution Enable with: GGML_CANN_MULTI_STREAM=1
This commit is contained in:
parent
c1792d58b5
commit
87e12c60cd
|
|
@ -46,6 +46,7 @@
|
||||||
|
|
||||||
#define MATRIX_ROW_PADDING 512
|
#define MATRIX_ROW_PADDING 512
|
||||||
#define GGML_CANN_MAX_STREAMS 8
|
#define GGML_CANN_MAX_STREAMS 8
|
||||||
|
#define GGML_CANN_NUM_COMPUTE_STREAMS 4 // Number of streams for parallel compute
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Handles CANN-related errors by printing an error message and
|
* @brief Handles CANN-related errors by printing an error message and
|
||||||
|
|
@ -564,6 +565,12 @@ struct ggml_backend_cann_context {
|
||||||
|
|
||||||
aclrtStream streams[GGML_CANN_MAX_STREAMS] = { nullptr }; /**< Array of streams for the device. */
|
aclrtStream streams[GGML_CANN_MAX_STREAMS] = { nullptr }; /**< Array of streams for the device. */
|
||||||
|
|
||||||
|
// Multi-stream parallel execution support
|
||||||
|
bool multi_stream_enabled = false; /**< Whether multi-stream execution is enabled. */
|
||||||
|
int current_stream_idx = 0; /**< Current stream index for round-robin scheduling. */
|
||||||
|
aclrtEvent stream_events[GGML_CANN_NUM_COMPUTE_STREAMS] = { nullptr }; /**< Events for stream synchronization. */
|
||||||
|
std::vector<const ggml_tensor *> unsynced_nodes; /**< Nodes that have been executed but not synced. */
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Constructor for initializing the context with a given device.
|
* @brief Constructor for initializing the context with a given device.
|
||||||
* @param device Device ID.
|
* @param device Device ID.
|
||||||
|
|
@ -592,6 +599,12 @@ struct ggml_backend_cann_context {
|
||||||
ACL_CHECK(aclrtDestroyStream(streams[i]));
|
ACL_CHECK(aclrtDestroyStream(streams[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Clean up multi-stream events
|
||||||
|
for (int i = 0; i < GGML_CANN_NUM_COMPUTE_STREAMS; ++i) {
|
||||||
|
if (stream_events[i] != nullptr) {
|
||||||
|
ACL_CHECK(aclrtDestroyEvent(stream_events[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#include <map>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
|
|
@ -2107,6 +2108,98 @@ static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph,
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Check if a tensor depends on any node in the unsynced list.
|
||||||
|
*
|
||||||
|
* This function checks whether the given tensor depends on any of the unsynced
|
||||||
|
* nodes by examining the source tensors recursively.
|
||||||
|
*
|
||||||
|
* @param tensor The tensor to check dependencies for.
|
||||||
|
* @param unsynced_nodes Set of nodes that haven't been synchronized yet.
|
||||||
|
* @return true if the tensor depends on any unsynced node.
|
||||||
|
*/
|
||||||
|
static bool depends_on_unsynced(const ggml_tensor * tensor,
|
||||||
|
const std::set<const ggml_tensor *> & unsynced_nodes) {
|
||||||
|
if (tensor == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this tensor itself is unsynced
|
||||||
|
if (unsynced_nodes.count(tensor) > 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check view source
|
||||||
|
if (tensor->view_src != nullptr && unsynced_nodes.count(tensor->view_src) > 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Check if a node depends on any unsynced nodes through its sources.
|
||||||
|
*
|
||||||
|
* @param node The node to check.
|
||||||
|
* @param unsynced_nodes Set of nodes that haven't been synchronized yet.
|
||||||
|
* @return true if the node depends on any unsynced node.
|
||||||
|
*/
|
||||||
|
static bool node_depends_on_unsynced(const ggml_tensor * node,
|
||||||
|
const std::set<const ggml_tensor *> & unsynced_nodes) {
|
||||||
|
for (int s = 0; s < GGML_MAX_SRC; ++s) {
|
||||||
|
if (depends_on_unsynced(node->src[s], unsynced_nodes)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get the underlying data pointer for a tensor.
|
||||||
|
*
|
||||||
|
* Returns the data pointer of the view source if the tensor is a view,
|
||||||
|
* otherwise returns the tensor's own data pointer.
|
||||||
|
*
|
||||||
|
* @param tensor The tensor to get the data pointer for.
|
||||||
|
* @return The underlying data pointer.
|
||||||
|
*/
|
||||||
|
static inline void * get_data_ptr(const ggml_tensor * tensor) {
|
||||||
|
if (tensor == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return tensor->data;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Check if a node has memory dependencies on pending writes.
|
||||||
|
*
|
||||||
|
* This function checks if any of the node's input tensors read from memory
|
||||||
|
* locations that have pending writes, or if the node's output memory overlaps
|
||||||
|
* with pending writes.
|
||||||
|
*
|
||||||
|
* @param node The node to check.
|
||||||
|
* @param pending_write_ptrs Set of data pointers with pending writes.
|
||||||
|
* @return true if there are memory dependencies.
|
||||||
|
*/
|
||||||
|
static bool has_memory_dependency(const ggml_tensor * node,
|
||||||
|
const std::set<void *> & pending_write_ptrs) {
|
||||||
|
// Check if any source reads from a pending write location
|
||||||
|
for (int s = 0; s < GGML_MAX_SRC; ++s) {
|
||||||
|
void * src_ptr = get_data_ptr(node->src[s]);
|
||||||
|
if (src_ptr != nullptr && pending_write_ptrs.count(src_ptr) > 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if output location has pending writes (WAW hazard)
|
||||||
|
void * dst_ptr = get_data_ptr(node);
|
||||||
|
if (dst_ptr != nullptr && pending_write_ptrs.count(dst_ptr) > 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
|
* @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
|
||||||
*
|
*
|
||||||
|
|
@ -2114,6 +2207,8 @@ static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph,
|
||||||
* graph capture, runs the graph, ends capture, and stores the captured graph.
|
* graph capture, runs the graph, ends capture, and stores the captured graph.
|
||||||
*
|
*
|
||||||
* Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
|
* Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
|
||||||
|
* When multi-stream execution is enabled, nodes are distributed across multiple streams
|
||||||
|
* for parallel execution.
|
||||||
*
|
*
|
||||||
* @param cann_ctx The CANN backend context.
|
* @param cann_ctx The CANN backend context.
|
||||||
* @param cgraph The ggml computation graph.
|
* @param cgraph The ggml computation graph.
|
||||||
|
|
@ -2133,31 +2228,176 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
||||||
// With the use of CANN graphs, the execution will be performed by the graph launch.
|
// With the use of CANN graphs, the execution will be performed by the graph launch.
|
||||||
static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
|
static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
|
||||||
|
|
||||||
|
// Check if multi-stream execution is enabled
|
||||||
|
static bool multi_stream_enabled = [] {
|
||||||
|
const char * env = getenv("GGML_CANN_MULTI_STREAM");
|
||||||
|
return env != nullptr && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0);
|
||||||
|
}();
|
||||||
|
|
||||||
if (!use_cann_graph || cann_graph_capture_required) {
|
if (!use_cann_graph || cann_graph_capture_required) {
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
if (multi_stream_enabled && !use_cann_graph) {
|
||||||
ggml_tensor * node = cgraph->nodes[i];
|
// Multi-stream execution mode using memory-based dependency tracking
|
||||||
if (opt_fusion) {
|
// Track data pointers that have pending writes on each stream
|
||||||
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
|
std::map<void *, int> pending_writes; // data_ptr -> stream_id
|
||||||
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
|
std::set<int> active_streams; // streams with pending work
|
||||||
i++;
|
int current_stream = 0;
|
||||||
continue;
|
|
||||||
|
// Ensure stream events are created
|
||||||
|
for (int s = 0; s < GGML_CANN_NUM_COMPUTE_STREAMS; ++s) {
|
||||||
|
if (cann_ctx->stream_events[s] == nullptr) {
|
||||||
|
ACL_CHECK(aclrtCreateEvent(&cann_ctx->stream_events[s]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
|
// Helper lambda to synchronize all active streams to the target stream
|
||||||
node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
auto sync_all_to_stream = [&](int target_stream) {
|
||||||
continue;
|
if (active_streams.empty()) return;
|
||||||
|
|
||||||
|
// Record events on all active streams
|
||||||
|
for (int s : active_streams) {
|
||||||
|
ACL_CHECK(aclrtRecordEvent(cann_ctx->stream_events[s], cann_ctx->stream(s)));
|
||||||
|
}
|
||||||
|
// Wait for all events on the target stream
|
||||||
|
for (int s : active_streams) {
|
||||||
|
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(target_stream), cann_ctx->stream_events[s]));
|
||||||
|
}
|
||||||
|
// Clear tracking
|
||||||
|
pending_writes.clear();
|
||||||
|
active_streams.clear();
|
||||||
|
};
|
||||||
|
|
||||||
|
// Helper lambda to wait for a specific stream on the target stream
|
||||||
|
auto wait_for_stream = [&](int src_stream, int target_stream) {
|
||||||
|
if (src_stream == target_stream) return;
|
||||||
|
ACL_CHECK(aclrtRecordEvent(cann_ctx->stream_events[src_stream], cann_ctx->stream(src_stream)));
|
||||||
|
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(target_stream), cann_ctx->stream_events[src_stream]));
|
||||||
|
};
|
||||||
|
|
||||||
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
|
ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
if (opt_fusion) {
|
||||||
|
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
|
||||||
|
// Fusion ops need synchronization - execute on stream 0
|
||||||
|
sync_all_to_stream(0);
|
||||||
|
|
||||||
|
// Execute fused op on stream 0
|
||||||
|
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
|
||||||
|
|
||||||
|
// Track the output
|
||||||
|
void * out_ptr = get_data_ptr(cgraph->nodes[i + 1]);
|
||||||
|
if (out_ptr) {
|
||||||
|
pending_writes[out_ptr] = 0;
|
||||||
|
active_streams.insert(0);
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
current_stream = 1; // Next node goes to stream 1
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
|
||||||
|
node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find which streams we depend on based on input memory locations
|
||||||
|
std::set<int> dependent_streams;
|
||||||
|
for (int s = 0; s < GGML_MAX_SRC; ++s) {
|
||||||
|
void * src_ptr = get_data_ptr(node->src[s]);
|
||||||
|
if (src_ptr != nullptr) {
|
||||||
|
auto it = pending_writes.find(src_ptr);
|
||||||
|
if (it != pending_writes.end()) {
|
||||||
|
dependent_streams.insert(it->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for WAW hazard (output location has pending write)
|
||||||
|
void * dst_ptr = get_data_ptr(node);
|
||||||
|
if (dst_ptr != nullptr) {
|
||||||
|
auto it = pending_writes.find(dst_ptr);
|
||||||
|
if (it != pending_writes.end()) {
|
||||||
|
dependent_streams.insert(it->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Choose which stream to execute on
|
||||||
|
int exec_stream;
|
||||||
|
if (dependent_streams.empty()) {
|
||||||
|
// No dependencies - use round-robin
|
||||||
|
exec_stream = current_stream;
|
||||||
|
current_stream = (current_stream + 1) % GGML_CANN_NUM_COMPUTE_STREAMS;
|
||||||
|
} else if (dependent_streams.size() == 1) {
|
||||||
|
// Single dependency - execute on the same stream to avoid sync overhead
|
||||||
|
exec_stream = *dependent_streams.begin();
|
||||||
|
} else {
|
||||||
|
// Multiple dependencies - sync all to stream 0 and execute there
|
||||||
|
sync_all_to_stream(0);
|
||||||
|
exec_stream = 0;
|
||||||
|
current_stream = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we depend on a different stream, wait for it
|
||||||
|
if (dependent_streams.size() == 1 && *dependent_streams.begin() != exec_stream) {
|
||||||
|
wait_for_stream(*dependent_streams.begin(), exec_stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the node on the chosen stream
|
||||||
|
// Temporarily swap the default stream
|
||||||
|
aclrtStream original_stream = cann_ctx->streams[0];
|
||||||
|
cann_ctx->streams[0] = cann_ctx->stream(exec_stream);
|
||||||
|
|
||||||
|
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
|
||||||
|
if (!ok) {
|
||||||
|
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||||
|
}
|
||||||
|
GGML_ASSERT(ok);
|
||||||
|
|
||||||
|
// Restore the original stream
|
||||||
|
cann_ctx->streams[0] = original_stream;
|
||||||
|
|
||||||
|
// Track the output location
|
||||||
|
if (dst_ptr != nullptr) {
|
||||||
|
pending_writes[dst_ptr] = exec_stream;
|
||||||
|
active_streams.insert(exec_stream);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
// Final synchronization - wait for all streams to complete
|
||||||
continue;
|
for (int s : active_streams) {
|
||||||
|
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream(s)));
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// Single-stream execution mode (original behavior)
|
||||||
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
|
ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
if (opt_fusion) {
|
||||||
|
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
|
||||||
|
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
|
||||||
|
i++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
|
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
|
||||||
if (!ok) {
|
node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||||
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
|
||||||
|
if (!ok) {
|
||||||
|
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||||
|
}
|
||||||
|
GGML_ASSERT(ok);
|
||||||
}
|
}
|
||||||
GGML_ASSERT(ok);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue