feat: add multi-stream and operator fusion conflict detection
- Add operator_fusion_enabled flag to ggml_backend_cann_context - Implement conflict detection in constructor: * ACL graph mode disables multi-stream (higher performance) * Multi-stream mode disables operator fusion (low benefit) - Remove multi-stream fusion code (fusion disabled in multi-stream) - Keep fusion functionality in single-stream mode - Remove redundant multi_stream_enabled check in graph_compute - Fix unused variable warning (sync_all_to_stream)
This commit is contained in:
parent
4951a4ff7a
commit
dd9e377ed8
|
|
@ -571,6 +571,9 @@ struct ggml_backend_cann_context {
|
|||
aclrtEvent stream_events[GGML_CANN_NUM_COMPUTE_STREAMS] = { nullptr }; /**< Events for stream synchronization. */
|
||||
std::vector<const ggml_tensor *> unsynced_nodes; /**< Nodes that have been executed but not synced. */
|
||||
|
||||
// Operator fusion support
|
||||
bool operator_fusion_enabled = false; /**< Whether operator fusion is enabled. */
|
||||
|
||||
/**
|
||||
* @brief Constructor for initializing the context with a given device.
|
||||
* @param device Device ID.
|
||||
|
|
@ -584,6 +587,38 @@ struct ggml_backend_cann_context {
|
|||
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER",
|
||||
acl_graph_mode ? "acl graph enabled" : "acl graph disabled");
|
||||
#endif
|
||||
|
||||
// Read environment variables for multi-stream and operator fusion
|
||||
bool env_multi_stream = parse_bool(get_env_as_lowercase("GGML_CANN_MULTI_STREAM").value_or(""));
|
||||
bool env_operator_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
|
||||
|
||||
// Handle conflicts and set final values
|
||||
#ifdef USE_ACL_GRAPH
|
||||
if (acl_graph_mode && env_multi_stream) {
|
||||
// ACL graph has higher performance, disable multi-stream
|
||||
multi_stream_enabled = false;
|
||||
operator_fusion_enabled = env_operator_fusion;
|
||||
GGML_LOG_INFO("%s: device %d multi-stream disabled (ACL graph mode has higher performance)\n",
|
||||
__func__, device);
|
||||
} else
|
||||
#endif
|
||||
if (env_multi_stream) {
|
||||
// Multi-stream enabled, disable operator fusion (fusion has low benefit with multi-stream)
|
||||
multi_stream_enabled = true;
|
||||
operator_fusion_enabled = false;
|
||||
if (env_operator_fusion) {
|
||||
GGML_LOG_INFO("%s: device %d operator fusion disabled (low benefit with multi-stream enabled)\n",
|
||||
__func__, device);
|
||||
}
|
||||
GGML_LOG_INFO("%s: device %d multi-stream execution enabled\n", __func__, device);
|
||||
} else {
|
||||
// Default single-stream mode
|
||||
multi_stream_enabled = false;
|
||||
operator_fusion_enabled = env_operator_fusion;
|
||||
if (env_operator_fusion) {
|
||||
GGML_LOG_INFO("%s: device %d operator fusion enabled\n", __func__, device);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -2226,13 +2226,9 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
|||
#endif // USE_ACL_GRAPH
|
||||
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
|
||||
// With the use of CANN graphs, the execution will be performed by the graph launch.
|
||||
static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
|
||||
|
||||
// Check if multi-stream execution is enabled
|
||||
static bool multi_stream_enabled = parse_bool(get_env_as_lowercase("GGML_CANN_MULTI_STREAM").value_or(""));
|
||||
|
||||
if (!use_cann_graph || cann_graph_capture_required) {
|
||||
if (multi_stream_enabled) {
|
||||
if (cann_ctx->multi_stream_enabled) {
|
||||
// Multi-stream execution mode using memory-based dependency tracking
|
||||
// Note: multi_stream_enabled implies !use_cann_graph (set in graph_compute)
|
||||
// Track data pointers that have pending writes on each stream
|
||||
|
|
@ -2247,23 +2243,6 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
|||
}
|
||||
}
|
||||
|
||||
// Helper lambda to synchronize all active streams to the target stream
|
||||
auto sync_all_to_stream = [&](int target_stream) {
|
||||
if (active_streams.empty()) return;
|
||||
|
||||
// Record events on all active streams
|
||||
for (int s : active_streams) {
|
||||
ACL_CHECK(aclrtRecordEvent(cann_ctx->stream_events[s], cann_ctx->stream(s)));
|
||||
}
|
||||
// Wait for all events on the target stream
|
||||
for (int s : active_streams) {
|
||||
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(target_stream), cann_ctx->stream_events[s]));
|
||||
}
|
||||
// Clear tracking
|
||||
pending_writes.clear();
|
||||
active_streams.clear();
|
||||
};
|
||||
|
||||
// Helper lambda to wait for a specific stream on the target stream
|
||||
auto wait_for_stream = [&](int src_stream, int target_stream) {
|
||||
if (src_stream == target_stream) return;
|
||||
|
|
@ -2273,25 +2252,6 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
|||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
if (opt_fusion) {
|
||||
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
|
||||
// Fusion ops need synchronization - execute on stream 0
|
||||
sync_all_to_stream(0);
|
||||
|
||||
// Execute fused op on stream 0
|
||||
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
|
||||
|
||||
// Track the output
|
||||
void * out_ptr = get_data_ptr(cgraph->nodes[i + 1]);
|
||||
if (out_ptr) {
|
||||
pending_writes[out_ptr] = 0;
|
||||
active_streams.insert(0);
|
||||
}
|
||||
i++;
|
||||
current_stream = 1; // Next node goes to stream 1
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
|
||||
node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||
|
|
@ -2373,7 +2333,7 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
|||
// Single-stream execution mode (original behavior)
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
if (opt_fusion) {
|
||||
if (cann_ctx->operator_fusion_enabled) {
|
||||
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
|
||||
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
|
||||
i++;
|
||||
|
|
@ -2438,14 +2398,6 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
|
|||
#ifdef USE_ACL_GRAPH
|
||||
bool use_cann_graph = true;
|
||||
|
||||
// Check if multi-stream execution is enabled (must check before using use_cann_graph)
|
||||
static bool multi_stream_enabled = parse_bool(get_env_as_lowercase("GGML_CANN_MULTI_STREAM").value_or(""));
|
||||
|
||||
// Multi-stream mode is incompatible with ACL graph capture/execution
|
||||
if (multi_stream_enabled) {
|
||||
use_cann_graph = false;
|
||||
}
|
||||
|
||||
if (use_cann_graph) {
|
||||
static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
|
||||
if (!prefill_use_graph) {
|
||||
|
|
@ -2839,11 +2791,7 @@ static void ggml_backend_cann_event_wait(ggml_backend_t backend, ggml_backend_ev
|
|||
*/
|
||||
static void ggml_backend_cann_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * graph) {
|
||||
// Check if graph optimization is disabled via environment variable
|
||||
static bool disable_graph_optimize = [] {
|
||||
const char * env = getenv("GGML_CANN_DISABLE_GRAPH_OPTIMIZE");
|
||||
return env != nullptr;
|
||||
}();
|
||||
|
||||
static bool disable_graph_optimize = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_GRAPH_OPTIMIZE").value_or(""));
|
||||
if (disable_graph_optimize) {
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue