diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 8feddbe680..ca4bf35b75 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -571,6 +571,9 @@ struct ggml_backend_cann_context { aclrtEvent stream_events[GGML_CANN_NUM_COMPUTE_STREAMS] = { nullptr }; /**< Events for stream synchronization. */ std::vector unsynced_nodes; /**< Nodes that have been executed but not synced. */ + // Operator fusion support + bool operator_fusion_enabled = false; /**< Whether operator fusion is enabled. */ + /** * @brief Constructor for initializing the context with a given device. * @param device Device ID. @@ -584,6 +587,38 @@ struct ggml_backend_cann_context { GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER", acl_graph_mode ? "acl graph enabled" : "acl graph disabled"); #endif + + // Read environment variables for multi-stream and operator fusion + bool env_multi_stream = parse_bool(get_env_as_lowercase("GGML_CANN_MULTI_STREAM").value_or("")); + bool env_operator_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or("")); + + // Handle conflicts and set final values +#ifdef USE_ACL_GRAPH + if (acl_graph_mode && env_multi_stream) { + // ACL graph has higher performance, disable multi-stream + multi_stream_enabled = false; + operator_fusion_enabled = env_operator_fusion; + GGML_LOG_INFO("%s: device %d multi-stream disabled (ACL graph mode has higher performance)\n", + __func__, device); + } else +#endif + if (env_multi_stream) { + // Multi-stream enabled, disable operator fusion (fusion has low benefit with multi-stream) + multi_stream_enabled = true; + operator_fusion_enabled = false; + if (env_operator_fusion) { + GGML_LOG_INFO("%s: device %d operator fusion disabled (low benefit with multi-stream enabled)\n", + __func__, device); + } + GGML_LOG_INFO("%s: device %d multi-stream execution enabled\n", __func__, device); + } else { + // Default single-stream mode + multi_stream_enabled = false; + operator_fusion_enabled = env_operator_fusion; + if (env_operator_fusion) { + GGML_LOG_INFO("%s: device %d operator fusion enabled\n", __func__, device); + } + } } /** diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index a2b3517add..48d4eb94e3 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2226,13 +2226,9 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx #endif // USE_ACL_GRAPH // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph. // With the use of CANN graphs, the execution will be performed by the graph launch. - static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or("")); - - // Check if multi-stream execution is enabled - static bool multi_stream_enabled = parse_bool(get_env_as_lowercase("GGML_CANN_MULTI_STREAM").value_or("")); if (!use_cann_graph || cann_graph_capture_required) { - if (multi_stream_enabled) { + if (cann_ctx->multi_stream_enabled) { // Multi-stream execution mode using memory-based dependency tracking // Note: multi_stream_enabled implies !use_cann_graph (set in graph_compute) // Track data pointers that have pending writes on each stream @@ -2247,23 +2243,6 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx } } - // Helper lambda to synchronize all active streams to the target stream - auto sync_all_to_stream = [&](int target_stream) { - if (active_streams.empty()) return; - - // Record events on all active streams - for (int s : active_streams) { - ACL_CHECK(aclrtRecordEvent(cann_ctx->stream_events[s], cann_ctx->stream(s))); - } - // Wait for all events on the target stream - for (int s : active_streams) { - ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(target_stream), cann_ctx->stream_events[s])); - } - // Clear tracking - pending_writes.clear(); - active_streams.clear(); - }; - // Helper lambda to wait for a specific stream on the target stream auto wait_for_stream = [&](int src_stream, int target_stream) { if (src_stream == target_stream) return; @@ -2273,25 +2252,6 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; - if (opt_fusion) { - if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) { - // Fusion ops need synchronization - execute on stream 0 - sync_all_to_stream(0); - - // Execute fused op on stream 0 - ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]); - - // Track the output - void * out_ptr = get_data_ptr(cgraph->nodes[i + 1]); - if (out_ptr) { - pending_writes[out_ptr] = 0; - active_streams.insert(0); - } - i++; - current_stream = 1; // Next node goes to stream 1 - continue; - } - } if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { @@ -2373,7 +2333,7 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx // Single-stream execution mode (original behavior) for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; - if (opt_fusion) { + if (cann_ctx->operator_fusion_enabled) { if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) { ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]); i++; @@ -2438,14 +2398,6 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, #ifdef USE_ACL_GRAPH bool use_cann_graph = true; - // Check if multi-stream execution is enabled (must check before using use_cann_graph) - static bool multi_stream_enabled = parse_bool(get_env_as_lowercase("GGML_CANN_MULTI_STREAM").value_or("")); - - // Multi-stream mode is incompatible with ACL graph capture/execution - if (multi_stream_enabled) { - use_cann_graph = false; - } - if (use_cann_graph) { static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or("")); if (!prefill_use_graph) { @@ -2839,11 +2791,7 @@ static void ggml_backend_cann_event_wait(ggml_backend_t backend, ggml_backend_ev */ static void ggml_backend_cann_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * graph) { // Check if graph optimization is disabled via environment variable - static bool disable_graph_optimize = [] { - const char * env = getenv("GGML_CANN_DISABLE_GRAPH_OPTIMIZE"); - return env != nullptr; - }(); - + static bool disable_graph_optimize = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_GRAPH_OPTIMIZE").value_or("")); if (disable_graph_optimize) { return; }