diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index d32ebe9b1e..691eb0397a 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2235,8 +2235,9 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx }(); if (!use_cann_graph || cann_graph_capture_required) { - if (multi_stream_enabled && !use_cann_graph) { + if (multi_stream_enabled) { // Multi-stream execution mode using memory-based dependency tracking + // Note: multi_stream_enabled implies !use_cann_graph (set in graph_compute) // Track data pointers that have pending writes on each stream std::map pending_writes; // data_ptr -> stream_id std::set active_streams; // streams with pending work @@ -2440,17 +2441,30 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, #ifdef USE_ACL_GRAPH bool use_cann_graph = true; - static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or("")); - if (!prefill_use_graph) { - // Do not use acl_graph for prefill. - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; - // TODO: Optimize here. Currently, we can only - // get seq_len by FA's input. - if (node->op == GGML_OP_FLASH_ATTN_EXT) { - // Q -> src[0], shape: [B, S, N, D] - use_cann_graph = (node->src[0]->ne[1] == 1); - break; + // Check if multi-stream execution is enabled (must check before using use_cann_graph) + static bool multi_stream_enabled = [] { + const char * env = getenv("GGML_CANN_MULTI_STREAM"); + return env != nullptr && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0); + }(); + + // Multi-stream mode is incompatible with ACL graph capture/execution + if (multi_stream_enabled) { + use_cann_graph = false; + } + + if (use_cann_graph) { + static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or("")); + if (!prefill_use_graph) { + // Do not use acl_graph for prefill. + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + // TODO: Optimize here. Currently, we can only + // get seq_len by FA's input. + if (node->op == GGML_OP_FLASH_ATTN_EXT) { + // Q -> src[0], shape: [B, S, N, D] + use_cann_graph = (node->src[0]->ne[1] == 1); + break; + } } } }