feat: add multi-stream and operator fusion conflict detection

- Add operator_fusion_enabled flag to ggml_backend_cann_context
- Implement conflict detection in constructor:
  * ACL graph mode disables multi-stream (higher performance)
  * Multi-stream mode disables operator fusion (low benefit)
- Remove multi-stream fusion code (fusion disabled in multi-stream)
- Keep fusion functionality in single-stream mode
- Remove redundant multi_stream_enabled check in graph_compute
- Fix unused variable warning (sync_all_to_stream)
This commit is contained in:
hipudding 2026-02-06 02:29:02 +00:00
parent 4951a4ff7a
commit dd9e377ed8
2 changed files with 38 additions and 55 deletions

View File

@ -571,6 +571,9 @@ struct ggml_backend_cann_context {
aclrtEvent stream_events[GGML_CANN_NUM_COMPUTE_STREAMS] = { nullptr }; /**< Events for stream synchronization. */
std::vector<const ggml_tensor *> unsynced_nodes; /**< Nodes that have been executed but not synced. */
// Operator fusion support
bool operator_fusion_enabled = false; /**< Whether operator fusion is enabled. */
/**
* @brief Constructor for initializing the context with a given device.
* @param device Device ID.
@ -584,6 +587,38 @@ struct ggml_backend_cann_context {
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER",
acl_graph_mode ? "acl graph enabled" : "acl graph disabled");
#endif
// Read environment variables for multi-stream and operator fusion
bool env_multi_stream = parse_bool(get_env_as_lowercase("GGML_CANN_MULTI_STREAM").value_or(""));
bool env_operator_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
// Handle conflicts and set final values
#ifdef USE_ACL_GRAPH
if (acl_graph_mode && env_multi_stream) {
// ACL graph has higher performance, disable multi-stream
multi_stream_enabled = false;
operator_fusion_enabled = env_operator_fusion;
GGML_LOG_INFO("%s: device %d multi-stream disabled (ACL graph mode has higher performance)\n",
__func__, device);
} else
#endif
if (env_multi_stream) {
// Multi-stream enabled, disable operator fusion (fusion has low benefit with multi-stream)
multi_stream_enabled = true;
operator_fusion_enabled = false;
if (env_operator_fusion) {
GGML_LOG_INFO("%s: device %d operator fusion disabled (low benefit with multi-stream enabled)\n",
__func__, device);
}
GGML_LOG_INFO("%s: device %d multi-stream execution enabled\n", __func__, device);
} else {
// Default single-stream mode
multi_stream_enabled = false;
operator_fusion_enabled = env_operator_fusion;
if (env_operator_fusion) {
GGML_LOG_INFO("%s: device %d operator fusion enabled\n", __func__, device);
}
}
}
/**

View File

@ -2226,13 +2226,9 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
#endif // USE_ACL_GRAPH
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
// With the use of CANN graphs, the execution will be performed by the graph launch.
static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
// Check if multi-stream execution is enabled
static bool multi_stream_enabled = parse_bool(get_env_as_lowercase("GGML_CANN_MULTI_STREAM").value_or(""));
if (!use_cann_graph || cann_graph_capture_required) {
if (multi_stream_enabled) {
if (cann_ctx->multi_stream_enabled) {
// Multi-stream execution mode using memory-based dependency tracking
// Note: multi_stream_enabled implies !use_cann_graph (set in graph_compute)
// Track data pointers that have pending writes on each stream
@ -2247,23 +2243,6 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
}
}
// Helper lambda to synchronize all active streams to the target stream
auto sync_all_to_stream = [&](int target_stream) {
if (active_streams.empty()) return;
// Record events on all active streams
for (int s : active_streams) {
ACL_CHECK(aclrtRecordEvent(cann_ctx->stream_events[s], cann_ctx->stream(s)));
}
// Wait for all events on the target stream
for (int s : active_streams) {
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(target_stream), cann_ctx->stream_events[s]));
}
// Clear tracking
pending_writes.clear();
active_streams.clear();
};
// Helper lambda to wait for a specific stream on the target stream
auto wait_for_stream = [&](int src_stream, int target_stream) {
if (src_stream == target_stream) return;
@ -2273,25 +2252,6 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (opt_fusion) {
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
// Fusion ops need synchronization - execute on stream 0
sync_all_to_stream(0);
// Execute fused op on stream 0
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
// Track the output
void * out_ptr = get_data_ptr(cgraph->nodes[i + 1]);
if (out_ptr) {
pending_writes[out_ptr] = 0;
active_streams.insert(0);
}
i++;
current_stream = 1; // Next node goes to stream 1
continue;
}
}
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
@ -2373,7 +2333,7 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
// Single-stream execution mode (original behavior)
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (opt_fusion) {
if (cann_ctx->operator_fusion_enabled) {
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
i++;
@ -2438,14 +2398,6 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
#ifdef USE_ACL_GRAPH
bool use_cann_graph = true;
// Check if multi-stream execution is enabled (must check before using use_cann_graph)
static bool multi_stream_enabled = parse_bool(get_env_as_lowercase("GGML_CANN_MULTI_STREAM").value_or(""));
// Multi-stream mode is incompatible with ACL graph capture/execution
if (multi_stream_enabled) {
use_cann_graph = false;
}
if (use_cann_graph) {
static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
if (!prefill_use_graph) {
@ -2839,11 +2791,7 @@ static void ggml_backend_cann_event_wait(ggml_backend_t backend, ggml_backend_ev
*/
static void ggml_backend_cann_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * graph) {
// Check if graph optimization is disabled via environment variable
static bool disable_graph_optimize = [] {
const char * env = getenv("GGML_CANN_DISABLE_GRAPH_OPTIMIZE");
return env != nullptr;
}();
static bool disable_graph_optimize = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_GRAPH_OPTIMIZE").value_or(""));
if (disable_graph_optimize) {
return;
}