cann: fix multi-stream to properly disable ACL graph mode
When GGML_CANN_MULTI_STREAM=1 is set, ACL graph capture/execution must be disabled since they are incompatible. The previous code had a bug where the prefill_use_graph check would overwrite use_cann_graph after it was set to false for multi-stream mode. Fix by wrapping the prefill_use_graph check inside if (use_cann_graph) to ensure it only runs when ACL graph is not already disabled.
This commit is contained in:
parent
87e12c60cd
commit
906bfed0ca
|
|
@ -2235,8 +2235,9 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
|||
}();
|
||||
|
||||
if (!use_cann_graph || cann_graph_capture_required) {
|
||||
if (multi_stream_enabled && !use_cann_graph) {
|
||||
if (multi_stream_enabled) {
|
||||
// Multi-stream execution mode using memory-based dependency tracking
|
||||
// Note: multi_stream_enabled implies !use_cann_graph (set in graph_compute)
|
||||
// Track data pointers that have pending writes on each stream
|
||||
std::map<void *, int> pending_writes; // data_ptr -> stream_id
|
||||
std::set<int> active_streams; // streams with pending work
|
||||
|
|
@ -2440,17 +2441,30 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
|
|||
#ifdef USE_ACL_GRAPH
|
||||
bool use_cann_graph = true;
|
||||
|
||||
static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
|
||||
if (!prefill_use_graph) {
|
||||
// Do not use acl_graph for prefill.
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
// TODO: Optimize here. Currently, we can only
|
||||
// get seq_len by FA's input.
|
||||
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||
// Q -> src[0], shape: [B, S, N, D]
|
||||
use_cann_graph = (node->src[0]->ne[1] == 1);
|
||||
break;
|
||||
// Check if multi-stream execution is enabled (must check before using use_cann_graph)
|
||||
static bool multi_stream_enabled = [] {
|
||||
const char * env = getenv("GGML_CANN_MULTI_STREAM");
|
||||
return env != nullptr && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0);
|
||||
}();
|
||||
|
||||
// Multi-stream mode is incompatible with ACL graph capture/execution
|
||||
if (multi_stream_enabled) {
|
||||
use_cann_graph = false;
|
||||
}
|
||||
|
||||
if (use_cann_graph) {
|
||||
static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
|
||||
if (!prefill_use_graph) {
|
||||
// Do not use acl_graph for prefill.
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
// TODO: Optimize here. Currently, we can only
|
||||
// get seq_len by FA's input.
|
||||
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||
// Q -> src[0], shape: [B, S, N, D]
|
||||
use_cann_graph = (node->src[0]->ne[1] == 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue