From 05cc38ba89215b04fc55c4006456ecef833faab5 Mon Sep 17 00:00:00 2001 From: aendk Date: Mon, 15 Dec 2025 10:44:38 +0100 Subject: [PATCH 01/11] Adds CPU-to-CUDA copy capability to ggml_backend_cuda_cpy_tensor_async() --- ggml/src/ggml-cuda/ggml-cuda.cu | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ed1021469a..5843cf05f7 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -61,6 +61,7 @@ #include "ggml-cuda/cumsum.cuh" #include "ggml-cuda/fill.cuh" #include "ggml.h" +#include "ggml-cpu.h" #include #include @@ -2793,11 +2794,14 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer; ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; - if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) { + //enables async copies from CPU to CUDA, instead of only CUDA-to-CUDA + bool copy_from_cpu = ggml_backend_is_cpu(backend_src) && ggml_backend_buffer_is_host(src->buffer); + + if (!(copy_from_cpu || ggml_backend_is_cuda(backend_src)) || !ggml_backend_is_cuda(backend_dst)) { return false; } - if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) { + if (!(copy_from_cpu || ggml_backend_buffer_is_cuda(src->buffer)) || !ggml_backend_buffer_is_cuda(dst->buffer)) { return false; } @@ -2808,14 +2812,19 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; - if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) { + if (!copy_from_cpu && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__); #endif return false; } - if (backend_src != backend_dst) { + if (copy_from_cpu) { + if (!cuda_ctx_dst->stream()) { + return false; + } + CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyHostToDevice, cuda_ctx_dst->stream())); + } else if (backend_src != backend_dst) { // copy on src stream if (cuda_ctx_src->device == cuda_ctx_dst->device) { CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream())); From 377490f1e27da67c7b91ad0f8165011a921baf12 Mon Sep 17 00:00:00 2001 From: aendk Date: Tue, 16 Dec 2025 16:51:54 +0100 Subject: [PATCH 02/11] Adds function to relax sync requirements between input copies on supported backends (CUDA for now) --- ggml/src/ggml-backend.cpp | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 1b59924b8c..4d73535efc 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -11,6 +11,8 @@ #include "ggml-backend.h" #include "ggml-backend-impl.h" #include "ggml-alloc.h" +#include "ggml-cpu.h" +#include "ggml-cuda.h" // TODO add IFDEFs for CUDA-specific parts #include "ggml-impl.h" #include @@ -736,6 +738,19 @@ struct ggml_backend_sched { int debug_prev_graph_size; }; +static void ggml_backend_synchronize_if_required(ggml_backend_t current_backend) { + // TODO add env-flag check here to auto-disable this change + // CUDA backends have an implicit order between execution and memory operations via the CUDA stream. + // Multiple parallel copies are also possible. + // There is consequently no need to synchronize in between computation and subsequent memcpys + if (ggml_backend_is_cuda(current_backend)) { + return; + } + + // in all other cases, just sync. + ggml_backend_synchronize(current_backend); +} + #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) #define tensor_backend_id(tensor) sched->hv_tensor_backend_ids[hash_id(tensor)] #define tensor_id_copy(id, backend_id, copy_id) sched->hv_tensor_copies[(id) * sched->n_backends * sched->n_copies + (backend_id) * sched->n_copies + (copy_id)] @@ -1464,7 +1479,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); } else { - ggml_backend_synchronize(split_backend); + ggml_backend_synchronize_if_required(split_backend); } ggml_backend_tensor_copy(input, input_cpy); } else { @@ -1472,7 +1487,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]); } else { - ggml_backend_synchronize(split_backend); + ggml_backend_synchronize_if_required(split_backend); } // when offloading MoE weights, we can reduce the amount of data copied by copying only the experts that are used From e7ad9b382fb980148f04cc826316b390d54c70bb Mon Sep 17 00:00:00 2001 From: aendk Date: Tue, 16 Dec 2025 17:21:00 +0100 Subject: [PATCH 03/11] Exchanges synchronous copy with async copy function. --- ggml/src/ggml-backend.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 4d73535efc..07b5587821 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1481,7 +1481,8 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } else { ggml_backend_synchronize_if_required(split_backend); } - ggml_backend_tensor_copy(input, input_cpy); + ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy); + ggml_backend_synchronize_if_required(split_backend); } else { // wait for the split backend to finish using the input before overwriting it if (sched->events[split_backend_id][sched->cur_copy] != NULL) { From bc4cdcaca61985293b84deaafef8d96429830d99 Mon Sep 17 00:00:00 2001 From: aendk Date: Tue, 16 Dec 2025 17:41:40 +0100 Subject: [PATCH 04/11] Adds macro guards to allow compilation in non-CUDA builds --- ggml/src/ggml-backend.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 07b5587821..059415c9b9 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -12,7 +12,9 @@ #include "ggml-backend-impl.h" #include "ggml-alloc.h" #include "ggml-cpu.h" -#include "ggml-cuda.h" // TODO add IFDEFs for CUDA-specific parts +#ifdef GGML_CUDA +#include "ggml-cuda.h" +#endif // GGML_CUDA #include "ggml-impl.h" #include @@ -740,12 +742,15 @@ struct ggml_backend_sched { static void ggml_backend_synchronize_if_required(ggml_backend_t current_backend) { // TODO add env-flag check here to auto-disable this change + +#ifdef GGML_CUDA // CUDA backends have an implicit order between execution and memory operations via the CUDA stream. // Multiple parallel copies are also possible. // There is consequently no need to synchronize in between computation and subsequent memcpys if (ggml_backend_is_cuda(current_backend)) { return; } +#endif // GGML_CUDA // in all other cases, just sync. ggml_backend_synchronize(current_backend); From d75d64bb60c7aa51a2a6fb2309f75dcf4e49e42b Mon Sep 17 00:00:00 2001 From: aendk Date: Thu, 18 Dec 2025 10:25:14 +0100 Subject: [PATCH 05/11] Reworked backend detection in ggml-backend.cpp to avoid linking conflicts --- ggml/src/ggml-backend.cpp | 46 +++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 059415c9b9..cf5548964f 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -11,10 +11,6 @@ #include "ggml-backend.h" #include "ggml-backend-impl.h" #include "ggml-alloc.h" -#include "ggml-cpu.h" -#ifdef GGML_CUDA -#include "ggml-cuda.h" -#endif // GGML_CUDA #include "ggml-impl.h" #include @@ -740,22 +736,38 @@ struct ggml_backend_sched { int debug_prev_graph_size; }; -static void ggml_backend_synchronize_if_required(ggml_backend_t current_backend) { - // TODO add env-flag check here to auto-disable this change +static void ggml_backend_synchronize_if_required(ggml_backend_t current_backend, bool backend_implicitly_synced) { -#ifdef GGML_CUDA - // CUDA backends have an implicit order between execution and memory operations via the CUDA stream. - // Multiple parallel copies are also possible. - // There is consequently no need to synchronize in between computation and subsequent memcpys - if (ggml_backend_is_cuda(current_backend)) { + if (backend_implicitly_synced) { return; } -#endif // GGML_CUDA - // in all other cases, just sync. ggml_backend_synchronize(current_backend); } +static bool ggml_backend_implicitly_synced(ggml_backend_t current_backend) { + /* + * Some backends have implicit synchronization mechanisms, which allows several parallel asynchronous memory copies without data races. + * An example for that is the CUDA backend with the CUDA stream. + * For these backends, we can skip costly explicit synchronizations during compute split scheduling. + */ + + static bool disable_scheduler_sync_opt = (getenv("GGML_SCHED_DISABLE_SYNC_OPT") != nullptr); + + if (disable_scheduler_sync_opt) { + return false; + } + + // To not change any APIs or change what ggml-base links to, we can only detect backends by string matching + auto backend_name = ggml_backend_name(current_backend); + if (strncmp(backend_name, "CUDA", 4) == 0) { + return true; + } + + // sync other backends to ensure correctness + return false; +} + #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) #define tensor_backend_id(tensor) sched->hv_tensor_backend_ids[hash_id(tensor)] #define tensor_id_copy(id, backend_id, copy_id) sched->hv_tensor_copies[(id) * sched->n_backends * sched->n_copies + (backend_id) * sched->n_copies + (copy_id)] @@ -1472,6 +1484,8 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s struct ggml_backend_sched_split * split = &splits[split_id]; int split_backend_id = split->backend_id; ggml_backend_t split_backend = sched->backends[split_backend_id]; + // some backends can avoid costly syncs between async copies + bool backend_implicitly_synced = ggml_backend_implicitly_synced(split_backend); // copy the input tensors to the split backend for (int input_id = 0; input_id < split->n_inputs; input_id++) { @@ -1484,16 +1498,16 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); } else { - ggml_backend_synchronize_if_required(split_backend); + ggml_backend_synchronize_if_required(split_backend, backend_implicitly_synced); } ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy); - ggml_backend_synchronize_if_required(split_backend); + ggml_backend_synchronize_if_required(split_backend, backend_implicitly_synced); } else { // wait for the split backend to finish using the input before overwriting it if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]); } else { - ggml_backend_synchronize_if_required(split_backend); + ggml_backend_synchronize_if_required(split_backend, backend_implicitly_synced); } // when offloading MoE weights, we can reduce the amount of data copied by copying only the experts that are used From 58a4d043994f553ba3dd216d975283a949d3af0a Mon Sep 17 00:00:00 2001 From: aendk Date: Fri, 19 Dec 2025 11:30:03 +0100 Subject: [PATCH 06/11] Relax requirement of checks in async CUDA copies from backend and buffer type to just buffer type, to avoid linking issues --- ggml/src/ggml-cuda/ggml-cuda.cu | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 5843cf05f7..81c26d9fa8 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -61,7 +61,6 @@ #include "ggml-cuda/cumsum.cuh" #include "ggml-cuda/fill.cuh" #include "ggml.h" -#include "ggml-cpu.h" #include #include @@ -2795,13 +2794,13 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; //enables async copies from CPU to CUDA, instead of only CUDA-to-CUDA - bool copy_from_cpu = ggml_backend_is_cpu(backend_src) && ggml_backend_buffer_is_host(src->buffer); + bool copy_from_host = ggml_backend_buffer_is_host(src->buffer); - if (!(copy_from_cpu || ggml_backend_is_cuda(backend_src)) || !ggml_backend_is_cuda(backend_dst)) { + if (!(copy_from_host || ggml_backend_is_cuda(backend_src)) || !ggml_backend_is_cuda(backend_dst)) { return false; } - if (!(copy_from_cpu || ggml_backend_buffer_is_cuda(src->buffer)) || !ggml_backend_buffer_is_cuda(dst->buffer)) { + if (!(copy_from_host || ggml_backend_buffer_is_cuda(src->buffer)) || !ggml_backend_buffer_is_cuda(dst->buffer)) { return false; } @@ -2812,14 +2811,14 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; - if (!copy_from_cpu && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) { + if (!copy_from_host && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__); #endif return false; } - if (copy_from_cpu) { + if (copy_from_host) { if (!cuda_ctx_dst->stream()) { return false; } From 80b32bdadbfbe7f35e10dfb4abc214638882376e Mon Sep 17 00:00:00 2001 From: aendk Date: Fri, 9 Jan 2026 17:07:19 +0100 Subject: [PATCH 07/11] Minor cleanup --- ggml/src/ggml-backend.cpp | 6 ------ ggml/src/ggml-cuda/ggml-cuda.cu | 6 ++---- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index cf5548964f..c768337db0 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -752,12 +752,6 @@ static bool ggml_backend_implicitly_synced(ggml_backend_t current_backend) { * For these backends, we can skip costly explicit synchronizations during compute split scheduling. */ - static bool disable_scheduler_sync_opt = (getenv("GGML_SCHED_DISABLE_SYNC_OPT") != nullptr); - - if (disable_scheduler_sync_opt) { - return false; - } - // To not change any APIs or change what ggml-base links to, we can only detect backends by string matching auto backend_name = ggml_backend_name(current_backend); if (strncmp(backend_name, "CUDA", 4) == 0) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 81c26d9fa8..7445099043 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2811,7 +2811,8 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; - if (!copy_from_host && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) { + if ((copy_from_host && cuda_ctx_dst->device != buf_ctx_dst->device) || + !copy_from_host && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__); #endif @@ -2819,9 +2820,6 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ } if (copy_from_host) { - if (!cuda_ctx_dst->stream()) { - return false; - } CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyHostToDevice, cuda_ctx_dst->stream())); } else if (backend_src != backend_dst) { // copy on src stream From b039e01a1ebde9fc542c869193a0756110221b35 Mon Sep 17 00:00:00 2001 From: aendk Date: Mon, 12 Jan 2026 14:16:01 +0100 Subject: [PATCH 08/11] Makes opt-in to relax use of explicit syncs more general. Backends like vulkan which require a synchronization between HtoD copies and graph execution could also adopt this change now. --- ggml/src/ggml-backend.cpp | 42 +++++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index c768337db0..8de584c906 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -670,6 +670,12 @@ static bool ggml_is_view_op(enum ggml_op op) { #define GGML_SCHED_MAX_COPIES 4 #endif +enum ggml_backend_sync_mode { + GGML_SPLIT_SYNC_MODE_IMPLICIT = 0, // splits which can rely on implicit sync mechanisms of its backend like a queue or stream + GGML_SPLIT_SYNC_MODE_WRITE_READ_BOUNDARY = 1, // splits which require only a single explicit sync between the last write and the first read + GGML_SPLIT_SYNC_MODE_EXPLICIT = 2 // splits which require explicit synchronization throughout (default) +}; + struct ggml_backend_sched_split { int backend_id; int i_start; @@ -678,6 +684,7 @@ struct ggml_backend_sched_split { int n_inputs; // graph view of this split struct ggml_cgraph graph; + enum ggml_backend_sync_mode backend_sync_mode = GGML_SPLIT_SYNC_MODE_EXPLICIT; }; struct ggml_backend_sched { @@ -736,30 +743,40 @@ struct ggml_backend_sched { int debug_prev_graph_size; }; -static void ggml_backend_synchronize_if_required(ggml_backend_t current_backend, bool backend_implicitly_synced) { - if (backend_implicitly_synced) { +static void ggml_backend_synchronize_if_required(ggml_backend_sched_split * split, ggml_backend_t current_backend, bool is_final_write = 0) { + + if (split->backend_sync_mode == GGML_SPLIT_SYNC_MODE_IMPLICIT) { + return; + } + + if (split->backend_sync_mode == GGML_SPLIT_SYNC_MODE_WRITE_READ_BOUNDARY && !is_final_write) { return; } ggml_backend_synchronize(current_backend); } -static bool ggml_backend_implicitly_synced(ggml_backend_t current_backend) { +static void ggml_backend_implicitly_synced(ggml_backend_sched_split * split, ggml_backend_t current_backend) { /* * Some backends have implicit synchronization mechanisms, which allows several parallel asynchronous memory copies without data races. * An example for that is the CUDA backend with the CUDA stream. * For these backends, we can skip costly explicit synchronizations during compute split scheduling. */ + if (split->backend_sync_mode != GGML_SPLIT_SYNC_MODE_EXPLICIT) { + // indicates that this function has already changed the default value, no repeat check necessary + return; + } // To not change any APIs or change what ggml-base links to, we can only detect backends by string matching auto backend_name = ggml_backend_name(current_backend); if (strncmp(backend_name, "CUDA", 4) == 0) { - return true; + split->backend_sync_mode = GGML_SPLIT_SYNC_MODE_IMPLICIT; + return; } - // sync other backends to ensure correctness - return false; + // retain default explicit synchronization on other backends for correctness + return; } #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) @@ -1478,30 +1495,33 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s struct ggml_backend_sched_split * split = &splits[split_id]; int split_backend_id = split->backend_id; ggml_backend_t split_backend = sched->backends[split_backend_id]; - // some backends can avoid costly syncs between async copies - bool backend_implicitly_synced = ggml_backend_implicitly_synced(split_backend); + + // determine if backend can avoid costly syncs between HtoD async copies + ggml_backend_implicitly_synced(split, split_backend); + // copy the input tensors to the split backend for (int input_id = 0; input_id < split->n_inputs; input_id++) { ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]); struct ggml_tensor * input = split->inputs[input_id]; struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy); + bool last_input = (input_id + 1) == split->n_inputs; if (input->flags & GGML_TENSOR_FLAG_INPUT) { // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); } else { - ggml_backend_synchronize_if_required(split_backend, backend_implicitly_synced); + ggml_backend_synchronize_if_required(split, split_backend); } ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy); - ggml_backend_synchronize_if_required(split_backend, backend_implicitly_synced); + ggml_backend_synchronize_if_required(split, split_backend, last_input); } else { // wait for the split backend to finish using the input before overwriting it if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]); } else { - ggml_backend_synchronize_if_required(split_backend, backend_implicitly_synced); + ggml_backend_synchronize_if_required(split, split_backend, last_input); } // when offloading MoE weights, we can reduce the amount of data copied by copying only the experts that are used From 6ed448922c577a853ca5e1b49fc626ac18d0ed06 Mon Sep 17 00:00:00 2001 From: aendk Date: Mon, 12 Jan 2026 15:35:38 +0100 Subject: [PATCH 09/11] Reintroduces stricter check for CPU->CUDA backend async copy via GGML_DEVICE_TYPE_CPU. --- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 7445099043..edfe45f256 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2794,7 +2794,7 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; //enables async copies from CPU to CUDA, instead of only CUDA-to-CUDA - bool copy_from_host = ggml_backend_buffer_is_host(src->buffer); + bool copy_from_host = ggml_backend_buffer_is_host(src->buffer) && ggml_backend_dev_type(backend_src->device) == GGML_BACKEND_DEVICE_TYPE_CPU; if (!(copy_from_host || ggml_backend_is_cuda(backend_src)) || !ggml_backend_is_cuda(backend_dst)) { return false; From 43f6684f3d6c6e45eca72906828377f6d887b0fd Mon Sep 17 00:00:00 2001 From: aendk Date: Fri, 16 Jan 2026 10:43:56 +0100 Subject: [PATCH 10/11] Corrects initialization of ggml_backend_sync_mode in ggml_backend_sched_split initialization --- ggml/src/ggml-backend.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 8de584c906..aa922e6756 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -671,9 +671,9 @@ static bool ggml_is_view_op(enum ggml_op op) { #endif enum ggml_backend_sync_mode { - GGML_SPLIT_SYNC_MODE_IMPLICIT = 0, // splits which can rely on implicit sync mechanisms of its backend like a queue or stream + GGML_SPLIT_SYNC_MODE_EXPLICIT = 0, // splits which require explicit synchronization throughout (default) GGML_SPLIT_SYNC_MODE_WRITE_READ_BOUNDARY = 1, // splits which require only a single explicit sync between the last write and the first read - GGML_SPLIT_SYNC_MODE_EXPLICIT = 2 // splits which require explicit synchronization throughout (default) + GGML_SPLIT_SYNC_MODE_IMPLICIT = 2 // splits which can rely on implicit sync mechanisms of its backend like a queue or stream }; struct ggml_backend_sched_split { @@ -684,7 +684,7 @@ struct ggml_backend_sched_split { int n_inputs; // graph view of this split struct ggml_cgraph graph; - enum ggml_backend_sync_mode backend_sync_mode = GGML_SPLIT_SYNC_MODE_EXPLICIT; + enum ggml_backend_sync_mode backend_sync_mode; }; struct ggml_backend_sched { From 50344142f4be66630639abe43ca233df37d887df Mon Sep 17 00:00:00 2001 From: aendk Date: Mon, 19 Jan 2026 17:45:33 +0100 Subject: [PATCH 11/11] Simplifies synchronizations to adhere to `saaasg` pattern. --- ggml/src/ggml-backend.cpp | 59 +++++---------------------------------- 1 file changed, 7 insertions(+), 52 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index aa922e6756..2519526826 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -670,12 +670,6 @@ static bool ggml_is_view_op(enum ggml_op op) { #define GGML_SCHED_MAX_COPIES 4 #endif -enum ggml_backend_sync_mode { - GGML_SPLIT_SYNC_MODE_EXPLICIT = 0, // splits which require explicit synchronization throughout (default) - GGML_SPLIT_SYNC_MODE_WRITE_READ_BOUNDARY = 1, // splits which require only a single explicit sync between the last write and the first read - GGML_SPLIT_SYNC_MODE_IMPLICIT = 2 // splits which can rely on implicit sync mechanisms of its backend like a queue or stream -}; - struct ggml_backend_sched_split { int backend_id; int i_start; @@ -684,7 +678,6 @@ struct ggml_backend_sched_split { int n_inputs; // graph view of this split struct ggml_cgraph graph; - enum ggml_backend_sync_mode backend_sync_mode; }; struct ggml_backend_sched { @@ -743,42 +736,6 @@ struct ggml_backend_sched { int debug_prev_graph_size; }; - -static void ggml_backend_synchronize_if_required(ggml_backend_sched_split * split, ggml_backend_t current_backend, bool is_final_write = 0) { - - if (split->backend_sync_mode == GGML_SPLIT_SYNC_MODE_IMPLICIT) { - return; - } - - if (split->backend_sync_mode == GGML_SPLIT_SYNC_MODE_WRITE_READ_BOUNDARY && !is_final_write) { - return; - } - - ggml_backend_synchronize(current_backend); -} - -static void ggml_backend_implicitly_synced(ggml_backend_sched_split * split, ggml_backend_t current_backend) { - /* - * Some backends have implicit synchronization mechanisms, which allows several parallel asynchronous memory copies without data races. - * An example for that is the CUDA backend with the CUDA stream. - * For these backends, we can skip costly explicit synchronizations during compute split scheduling. - */ - if (split->backend_sync_mode != GGML_SPLIT_SYNC_MODE_EXPLICIT) { - // indicates that this function has already changed the default value, no repeat check necessary - return; - } - - // To not change any APIs or change what ggml-base links to, we can only detect backends by string matching - auto backend_name = ggml_backend_name(current_backend); - if (strncmp(backend_name, "CUDA", 4) == 0) { - split->backend_sync_mode = GGML_SPLIT_SYNC_MODE_IMPLICIT; - return; - } - - // retain default explicit synchronization on other backends for correctness - return; -} - #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) #define tensor_backend_id(tensor) sched->hv_tensor_backend_ids[hash_id(tensor)] #define tensor_id_copy(id, backend_id, copy_id) sched->hv_tensor_copies[(id) * sched->n_backends * sched->n_copies + (backend_id) * sched->n_copies + (copy_id)] @@ -1496,32 +1453,26 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s int split_backend_id = split->backend_id; ggml_backend_t split_backend = sched->backends[split_backend_id]; - // determine if backend can avoid costly syncs between HtoD async copies - ggml_backend_implicitly_synced(split, split_backend); - + if (sched->events[split_backend_id][sched->cur_copy] == NULL) { + ggml_backend_synchronize(split_backend); + } // copy the input tensors to the split backend for (int input_id = 0; input_id < split->n_inputs; input_id++) { ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]); struct ggml_tensor * input = split->inputs[input_id]; struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy); - bool last_input = (input_id + 1) == split->n_inputs; if (input->flags & GGML_TENSOR_FLAG_INPUT) { // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); - } else { - ggml_backend_synchronize_if_required(split, split_backend); } ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy); - ggml_backend_synchronize_if_required(split, split_backend, last_input); } else { // wait for the split backend to finish using the input before overwriting it if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]); - } else { - ggml_backend_synchronize_if_required(split, split_backend, last_input); } // when offloading MoE weights, we can reduce the amount of data copied by copying only the experts that are used @@ -1625,6 +1576,10 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } } + if (sched->events[split_backend_id][sched->cur_copy] == NULL) { + ggml_backend_synchronize(split_backend); + } + if (!sched->callback_eval) { enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph); if (ec != GGML_STATUS_SUCCESS) {