This commit is contained in:
Andreas Kieslinger 2025-12-17 05:51:07 +02:00 committed by GitHub
commit ec8f17cd7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 94 additions and 68 deletions

View File

@ -46,15 +46,16 @@ extern "C" {
// (optional) initialize a tensor in the buffer (eg. add tensor extras) // (optional) initialize a tensor in the buffer (eg. add tensor extras)
enum ggml_status (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); enum ggml_status (*init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
// tensor data access // tensor data access
void (*memset_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); void (*memset_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); void (*set_tensor_async) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
// (optional) tensor copy: dst is in the buffer, src may be in any buffer, including buffers from a different backend (return false if not supported) // (optional) tensor copy: dst is in the buffer, src may be in any buffer, including buffers from a different backend (return false if not supported)
bool (*cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); bool (*cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst);
// clear the entire buffer // clear the entire buffer
void (*clear) (ggml_backend_buffer_t buffer, uint8_t value); void (*clear) (ggml_backend_buffer_t buffer, uint8_t value);
// (optional) reset any internal state due to tensor initialization, such as tensor extras // (optional) reset any internal state due to tensor initialization, such as tensor extras
void (*reset) (ggml_backend_buffer_t buffer); void (*reset) (ggml_backend_buffer_t buffer);
}; };
struct ggml_backend_buffer { struct ggml_backend_buffer {

View File

@ -21,6 +21,7 @@
#include <string.h> #include <string.h>
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include <unordered_map>
#ifdef __APPLE__ #ifdef __APPLE__
#include <sys/types.h> #include <sys/types.h>
@ -289,7 +290,15 @@ void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, siz
GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
buf->iface.set_tensor(buf, tensor, data, offset, size); // do not synchronize directly after dispatching async tensor copies
static bool disable_sync_optimization = (getenv("GGML_CUDA_DISABLE_SYNC_OPTIMIZATION") != nullptr);
if (!disable_sync_optimization && buf->iface.set_tensor_async != NULL) {
buf->iface.set_tensor_async(buf, tensor, data, offset, size);
} else {
buf->iface.set_tensor(buf, tensor, data, offset, size);
}
} }
void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
@ -602,15 +611,16 @@ static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_
} }
static const struct ggml_backend_buffer_i ggml_backend_multi_buffer_i = { static const struct ggml_backend_buffer_i ggml_backend_multi_buffer_i = {
/* .free_buffer = */ ggml_backend_multi_buffer_free_buffer, /* .free_buffer = */ ggml_backend_multi_buffer_free_buffer,
/* .get_base = */ NULL, /* .get_base = */ NULL,
/* .init_tensor = */ NULL, /* .init_tensor = */ NULL,
/* .memset_tensor = */ NULL, /* .memset_tensor = */ NULL,
/* .set_tensor = */ NULL, /* .set_tensor = */ NULL,
/* .get_tensor = */ NULL, /* .set_tensor_async = */ NULL,
/* .cpy_tensor = */ NULL, /* .get_tensor = */ NULL,
/* .clear = */ ggml_backend_multi_buffer_clear, /* .cpy_tensor = */ NULL,
/* .reset = */ NULL, /* .clear = */ ggml_backend_multi_buffer_clear,
/* .reset = */ NULL,
}; };
ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers) { ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers) {
@ -1453,6 +1463,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
int split_backend_id = split->backend_id; int split_backend_id = split->backend_id;
ggml_backend_t split_backend = sched->backends[split_backend_id]; ggml_backend_t split_backend = sched->backends[split_backend_id];
std::unordered_map<ggml_backend_t, bool> backends_to_sync;
// copy the input tensors to the split backend // copy the input tensors to the split backend
for (int input_id = 0; input_id < split->n_inputs; input_id++) { for (int input_id = 0; input_id < split->n_inputs; input_id++) {
ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]); ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]);
@ -1464,7 +1475,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
if (sched->events[split_backend_id][sched->cur_copy] != NULL) { if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
} else { } else {
ggml_backend_synchronize(split_backend); backends_to_sync[split_backend] = true;
} }
ggml_backend_tensor_copy(input, input_cpy); ggml_backend_tensor_copy(input, input_cpy);
} else { } else {
@ -1472,7 +1483,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
if (sched->events[split_backend_id][sched->cur_copy] != NULL) { if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]); ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]);
} else { } else {
ggml_backend_synchronize(split_backend); backends_to_sync[split_backend] = true;
} }
// when offloading MoE weights, we can reduce the amount of data copied by copying only the experts that are used // when offloading MoE weights, we can reduce the amount of data copied by copying only the experts that are used
@ -1487,7 +1498,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
const int64_t n_expert = node->op == GGML_OP_MUL_MAT_ID ? input->ne[2] : input->ne[1]; const int64_t n_expert = node->op == GGML_OP_MUL_MAT_ID ? input->ne[2] : input->ne[1];
const size_t expert_size = node->op == GGML_OP_MUL_MAT_ID ? input->nb[2] : input->nb[1]; const size_t expert_size = node->op == GGML_OP_MUL_MAT_ID ? input->nb[2] : input->nb[1];
ggml_backend_synchronize(input_backend); backends_to_sync[input_backend] = true;
// get the ids // get the ids
ggml_tensor * ids_tensor = node->src[2]; ggml_tensor * ids_tensor = node->src[2];
@ -1506,7 +1517,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
if (ids_tensor != prev_ids_tensor) { if (ids_tensor != prev_ids_tensor) {
ids.resize(ggml_nbytes(ids_tensor) / sizeof(int32_t)); ids.resize(ggml_nbytes(ids_tensor) / sizeof(int32_t));
ggml_backend_tensor_get_async(ids_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor)); ggml_backend_tensor_get_async(ids_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor));
ggml_backend_synchronize(ids_backend); backends_to_sync[ids_backend] = true;
// find the used experts // find the used experts
used_ids.clear(); used_ids.clear();
@ -1564,11 +1575,11 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
// try async copy, but if not possible, we can still use a sync copy without synchronizing the dst backend, since we handle the synchronization here with multiple copies and events // try async copy, but if not possible, we can still use a sync copy without synchronizing the dst backend, since we handle the synchronization here with multiple copies and events
// TODO: add public function to facilitate this, since applications do not have direct access to the backend interface // TODO: add public function to facilitate this, since applications do not have direct access to the backend interface
if (!split_backend->iface.cpy_tensor_async || !split_backend->iface.cpy_tensor_async(input_backend, split_backend, input, input_cpy)) { if (!split_backend->iface.cpy_tensor_async || !split_backend->iface.cpy_tensor_async(input_backend, split_backend, input, input_cpy)) {
ggml_backend_synchronize(input_backend); backends_to_sync[input_backend] = true;
if (sched->events[split_backend_id][sched->cur_copy] != NULL) { if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
} else { } else {
ggml_backend_synchronize(split_backend); backends_to_sync[split_backend] = true;
} }
ggml_backend_tensor_copy(input, input_cpy); ggml_backend_tensor_copy(input, input_cpy);
} }
@ -1576,6 +1587,11 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
} }
} }
// sync in bulk instead of between async copies
for (auto& elem : backends_to_sync) {
ggml_backend_synchronize(elem.first);
}
if (!sched->callback_eval) { if (!sched->callback_eval) {
enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph); enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
if (ec != GGML_STATUS_SUCCESS) { if (ec != GGML_STATUS_SUCCESS) {
@ -2165,27 +2181,29 @@ static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
} }
static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = { static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = {
/* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer, /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
/* .get_base = */ ggml_backend_cpu_buffer_get_base, /* .get_base = */ ggml_backend_cpu_buffer_get_base,
/* .init_tensor = */ NULL, // no initialization required /* .init_tensor = */ NULL, // no initialization required
/* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, /* .set_tensor_async = */ NULL,
/* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
/* .clear = */ ggml_backend_cpu_buffer_clear, /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
/* .reset = */ NULL, /* .clear = */ ggml_backend_cpu_buffer_clear,
/* .reset = */ NULL,
}; };
static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = { static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = {
/* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
/* .get_base = */ ggml_backend_cpu_buffer_get_base, /* .get_base = */ ggml_backend_cpu_buffer_get_base,
/* .init_tensor = */ NULL, // no initialization required /* .init_tensor = */ NULL, // no initialization required
/* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, /* .set_tensor_async = */ NULL,
/* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
/* .clear = */ ggml_backend_cpu_buffer_clear, /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
/* .reset = */ NULL, /* .clear = */ ggml_backend_cpu_buffer_clear,
/* .reset = */ NULL,
}; };
// CPU backend buffer type // CPU backend buffer type

View File

@ -105,15 +105,16 @@ static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
} }
static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = { static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
/* .free_buffer = */ ggml_backend_amx_buffer_free_buffer, /* .free_buffer = */ ggml_backend_amx_buffer_free_buffer,
/* .get_base = */ ggml_backend_amx_buffer_get_base, /* .get_base = */ ggml_backend_amx_buffer_get_base,
/* .init_tensor = */ ggml_backend_amx_buffer_init_tensor, /* .init_tensor = */ ggml_backend_amx_buffer_init_tensor,
/* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor, /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_amx_buffer_set_tensor, /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor,
/* .get_tensor = */ nullptr, /* .set_tensor_async = */ nullptr,
/* .cpy_tensor = */ nullptr, /* .get_tensor = */ nullptr,
/* .clear = */ ggml_backend_amx_buffer_clear, /* .cpy_tensor = */ nullptr,
/* .reset = */ nullptr, /* .clear = */ ggml_backend_amx_buffer_clear,
/* .reset = */ nullptr,
}; };
static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) { static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {

View File

@ -623,11 +623,15 @@ static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer,
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
} }
static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { static void ggml_backend_cuda_buffer_set_tensor_async(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
ggml_cuda_set_device(ctx->device); ggml_cuda_set_device(ctx->device);
CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
}
static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
ggml_backend_cuda_buffer_set_tensor_async(buffer, tensor, data, offset, size);
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
} }
@ -669,15 +673,16 @@ static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
} }
static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = { static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
/* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer, /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
/* .get_base = */ ggml_backend_cuda_buffer_get_base, /* .get_base = */ ggml_backend_cuda_buffer_get_base,
/* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor, /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
/* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor, /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor, /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor, /* .set_tensor_async = */ ggml_backend_cuda_buffer_set_tensor_async,
/* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor, /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
/* .clear = */ ggml_backend_cuda_buffer_clear, /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor,
/* .reset = */ NULL, /* .clear = */ ggml_backend_cuda_buffer_clear,
/* .reset = */ NULL,
}; };
// cuda buffer type // cuda buffer type
@ -981,15 +986,16 @@ static void ggml_backend_cuda_split_buffer_clear(ggml_backend_buffer_t buffer, u
} }
static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = { static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
/* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer, /* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer,
/* .get_base = */ ggml_backend_cuda_split_buffer_get_base, /* .get_base = */ ggml_backend_cuda_split_buffer_get_base,
/* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor, /* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor,
/* .memset_tensor = */ NULL, /* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor, /* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor, /* .set_tensor_async = */ NULL,
/* .cpy_tensor = */ NULL, /* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor,
/* .clear = */ ggml_backend_cuda_split_buffer_clear, /* .cpy_tensor = */ NULL,
/* .reset = */ NULL, /* .clear = */ ggml_backend_cuda_split_buffer_clear,
/* .reset = */ NULL,
}; };
// cuda split buffer type // cuda split buffer type