From 7c59ff01f283fbcbf0e6af1d19029b9294a4f845 Mon Sep 17 00:00:00 2001 From: aendk Date: Mon, 15 Dec 2025 10:44:38 +0100 Subject: [PATCH 1/6] Adds CPU-to-CUDA copy capability to ggml_backend_cuda_cpy_tensor_async() --- ggml/src/ggml-cuda/ggml-cuda.cu | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ab0f6fe9ce..88005c969a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -59,6 +59,7 @@ #include "ggml-cuda/cumsum.cuh" #include "ggml-cuda/fill.cuh" #include "ggml.h" +#include "ggml-cpu.h" #include #include @@ -2790,11 +2791,14 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer; ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; - if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) { + //enables async copies from CPU to CUDA, instead of only CUDA-to-CUDA + bool copy_from_cpu = ggml_backend_is_cpu(backend_src) && ggml_backend_buffer_is_host(src->buffer); + + if (!(copy_from_cpu || ggml_backend_is_cuda(backend_src)) || !ggml_backend_is_cuda(backend_dst)) { return false; } - if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) { + if (!(copy_from_cpu || ggml_backend_buffer_is_cuda(src->buffer)) || !ggml_backend_buffer_is_cuda(dst->buffer)) { return false; } @@ -2805,14 +2809,19 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; - if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) { + if (!copy_from_cpu && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__); #endif return false; } - if (backend_src != backend_dst) { + if (copy_from_cpu) { + if (!cuda_ctx_dst->stream()) { + return false; + } + CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyHostToDevice, cuda_ctx_dst->stream())); + } else if (backend_src != backend_dst) { // copy on src stream if (cuda_ctx_src->device == cuda_ctx_dst->device) { CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream())); From 3f59431df7defb7c29d0c0e6c734fe469b62f29f Mon Sep 17 00:00:00 2001 From: aendk Date: Tue, 16 Dec 2025 16:51:54 +0100 Subject: [PATCH 2/6] Adds function to relax sync requirements between input copies on supported backends (CUDA for now) --- ggml/src/ggml-backend.cpp | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 8547ecc849..aa2c29f5f0 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -11,6 +11,8 @@ #include "ggml-backend.h" #include "ggml-backend-impl.h" #include "ggml-alloc.h" +#include "ggml-cpu.h" +#include "ggml-cuda.h" // TODO add IFDEFs for CUDA-specific parts #include "ggml-impl.h" #include @@ -736,6 +738,19 @@ struct ggml_backend_sched { int debug_prev_graph_size; }; +static void ggml_backend_synchronize_if_required(ggml_backend_t current_backend) { + // TODO add env-flag check here to auto-disable this change + // CUDA backends have an implicit order between execution and memory operations via the CUDA stream. + // Multiple parallel copies are also possible. + // There is consequently no need to synchronize in between computation and subsequent memcpys + if (ggml_backend_is_cuda(current_backend)) { + return; + } + + // in all other cases, just sync. + ggml_backend_synchronize(current_backend); +} + #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) #define tensor_backend_id(tensor) sched->hv_tensor_backend_ids[hash_id(tensor)] #define tensor_id_copy(id, backend_id, copy_id) sched->hv_tensor_copies[(id) * sched->n_backends * sched->n_copies + (backend_id) * sched->n_copies + (copy_id)] @@ -1464,7 +1479,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); } else { - ggml_backend_synchronize(split_backend); + ggml_backend_synchronize_if_required(split_backend); } ggml_backend_tensor_copy(input, input_cpy); } else { @@ -1472,7 +1487,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]); } else { - ggml_backend_synchronize(split_backend); + ggml_backend_synchronize_if_required(split_backend); } // when offloading MoE weights, we can reduce the amount of data copied by copying only the experts that are used From ef93b6e6a007854b2b1204a9f1e95a51624e60a0 Mon Sep 17 00:00:00 2001 From: aendk Date: Tue, 16 Dec 2025 17:21:00 +0100 Subject: [PATCH 3/6] Exchanges synchronous copy with async copy function. --- ggml/src/ggml-backend.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index aa2c29f5f0..ffd69ad74b 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1481,7 +1481,8 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } else { ggml_backend_synchronize_if_required(split_backend); } - ggml_backend_tensor_copy(input, input_cpy); + ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy); + ggml_backend_synchronize_if_required(split_backend); } else { // wait for the split backend to finish using the input before overwriting it if (sched->events[split_backend_id][sched->cur_copy] != NULL) { From 0c765efe7247258717f9e94be4d44f9d0972087a Mon Sep 17 00:00:00 2001 From: aendk Date: Tue, 16 Dec 2025 17:41:40 +0100 Subject: [PATCH 4/6] Adds macro guards to allow compilation in non-CUDA builds --- ggml/src/ggml-backend.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index ffd69ad74b..a3574e914f 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -12,7 +12,9 @@ #include "ggml-backend-impl.h" #include "ggml-alloc.h" #include "ggml-cpu.h" -#include "ggml-cuda.h" // TODO add IFDEFs for CUDA-specific parts +#ifdef GGML_CUDA +#include "ggml-cuda.h" +#endif // GGML_CUDA #include "ggml-impl.h" #include @@ -740,12 +742,15 @@ struct ggml_backend_sched { static void ggml_backend_synchronize_if_required(ggml_backend_t current_backend) { // TODO add env-flag check here to auto-disable this change + +#ifdef GGML_CUDA // CUDA backends have an implicit order between execution and memory operations via the CUDA stream. // Multiple parallel copies are also possible. // There is consequently no need to synchronize in between computation and subsequent memcpys if (ggml_backend_is_cuda(current_backend)) { return; } +#endif // GGML_CUDA // in all other cases, just sync. ggml_backend_synchronize(current_backend); From 9171aca78250638694d6872ce5c0b6475fcd16da Mon Sep 17 00:00:00 2001 From: aendk Date: Thu, 18 Dec 2025 10:25:14 +0100 Subject: [PATCH 5/6] Reworked backend detection in ggml-backend.cpp to avoid linking conflicts --- ggml/src/ggml-backend.cpp | 46 +++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index a3574e914f..26580720dd 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -11,10 +11,6 @@ #include "ggml-backend.h" #include "ggml-backend-impl.h" #include "ggml-alloc.h" -#include "ggml-cpu.h" -#ifdef GGML_CUDA -#include "ggml-cuda.h" -#endif // GGML_CUDA #include "ggml-impl.h" #include @@ -740,22 +736,38 @@ struct ggml_backend_sched { int debug_prev_graph_size; }; -static void ggml_backend_synchronize_if_required(ggml_backend_t current_backend) { - // TODO add env-flag check here to auto-disable this change +static void ggml_backend_synchronize_if_required(ggml_backend_t current_backend, bool backend_implicitly_synced) { -#ifdef GGML_CUDA - // CUDA backends have an implicit order between execution and memory operations via the CUDA stream. - // Multiple parallel copies are also possible. - // There is consequently no need to synchronize in between computation and subsequent memcpys - if (ggml_backend_is_cuda(current_backend)) { + if (backend_implicitly_synced) { return; } -#endif // GGML_CUDA - // in all other cases, just sync. ggml_backend_synchronize(current_backend); } +static bool ggml_backend_implicitly_synced(ggml_backend_t current_backend) { + /* + * Some backends have implicit synchronization mechanisms, which allows several parallel asynchronous memory copies without data races. + * An example for that is the CUDA backend with the CUDA stream. + * For these backends, we can skip costly explicit synchronizations during compute split scheduling. + */ + + static bool disable_scheduler_sync_opt = (getenv("GGML_SCHED_DISABLE_SYNC_OPT") != nullptr); + + if (disable_scheduler_sync_opt) { + return false; + } + + // To not change any APIs or change what ggml-base links to, we can only detect backends by string matching + auto backend_name = ggml_backend_name(current_backend); + if (strncmp(backend_name, "CUDA", 4) == 0) { + return true; + } + + // sync other backends to ensure correctness + return false; +} + #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) #define tensor_backend_id(tensor) sched->hv_tensor_backend_ids[hash_id(tensor)] #define tensor_id_copy(id, backend_id, copy_id) sched->hv_tensor_copies[(id) * sched->n_backends * sched->n_copies + (backend_id) * sched->n_copies + (copy_id)] @@ -1472,6 +1484,8 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s struct ggml_backend_sched_split * split = &splits[split_id]; int split_backend_id = split->backend_id; ggml_backend_t split_backend = sched->backends[split_backend_id]; + // some backends can avoid costly syncs between async copies + bool backend_implicitly_synced = ggml_backend_implicitly_synced(split_backend); // copy the input tensors to the split backend for (int input_id = 0; input_id < split->n_inputs; input_id++) { @@ -1484,16 +1498,16 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); } else { - ggml_backend_synchronize_if_required(split_backend); + ggml_backend_synchronize_if_required(split_backend, backend_implicitly_synced); } ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy); - ggml_backend_synchronize_if_required(split_backend); + ggml_backend_synchronize_if_required(split_backend, backend_implicitly_synced); } else { // wait for the split backend to finish using the input before overwriting it if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]); } else { - ggml_backend_synchronize_if_required(split_backend); + ggml_backend_synchronize_if_required(split_backend, backend_implicitly_synced); } // when offloading MoE weights, we can reduce the amount of data copied by copying only the experts that are used From 1233fdda5fb915254bc1b68c326e8a50ad6ba76a Mon Sep 17 00:00:00 2001 From: aendk Date: Fri, 19 Dec 2025 11:30:03 +0100 Subject: [PATCH 6/6] Relax requirement of checks in async CUDA copies from backend and buffer type to just buffer type, to avoid linking issues --- ggml/src/ggml-cuda/ggml-cuda.cu | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 88005c969a..b0a6782d90 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -59,7 +59,6 @@ #include "ggml-cuda/cumsum.cuh" #include "ggml-cuda/fill.cuh" #include "ggml.h" -#include "ggml-cpu.h" #include #include @@ -2792,13 +2791,13 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; //enables async copies from CPU to CUDA, instead of only CUDA-to-CUDA - bool copy_from_cpu = ggml_backend_is_cpu(backend_src) && ggml_backend_buffer_is_host(src->buffer); + bool copy_from_host = ggml_backend_buffer_is_host(src->buffer); - if (!(copy_from_cpu || ggml_backend_is_cuda(backend_src)) || !ggml_backend_is_cuda(backend_dst)) { + if (!(copy_from_host || ggml_backend_is_cuda(backend_src)) || !ggml_backend_is_cuda(backend_dst)) { return false; } - if (!(copy_from_cpu || ggml_backend_buffer_is_cuda(src->buffer)) || !ggml_backend_buffer_is_cuda(dst->buffer)) { + if (!(copy_from_host || ggml_backend_buffer_is_cuda(src->buffer)) || !ggml_backend_buffer_is_cuda(dst->buffer)) { return false; } @@ -2809,14 +2808,14 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; - if (!copy_from_cpu && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) { + if (!copy_from_host && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__); #endif return false; } - if (copy_from_cpu) { + if (copy_from_host) { if (!cuda_ctx_dst->stream()) { return false; }