add support for 4/8 GPUs

This commit is contained in:
Johannes Gäßler 2026-02-07 19:18:36 +01:00
parent 4b8aa26650
commit 2ffa49decc
1 changed files with 25 additions and 15 deletions

View File

@ -180,7 +180,7 @@ ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev,
ggml_backend_dev_t ggml_backend_meta_device(
ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) {
GGML_ASSERT(n_devs <= 2);
GGML_ASSERT(n_devs == 1 || n_devs == 2 || n_devs == 4 || n_devs == 8);
static std::vector<std::unique_ptr<ggml_backend_meta_device_context>> ctxs;
static std::map<ggml_backend_meta_device_context, struct ggml_backend_device> meta_devs;
@ -383,7 +383,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer
nb[k] = tensor->nb[k];
}
if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) {
GGML_ASSERT(ne[split_dim] % n_simple_bufs == 0);
GGML_ASSERT(ne[split_dim] % (n_simple_bufs*ggml_blck_size(tensor->type)) == 0);
ne[split_dim] /= n_simple_bufs;
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (tensor->nb[i] > tensor->nb[split_dim]) {
@ -652,16 +652,17 @@ struct ggml_backend_meta_context {
struct backend_config {
ggml_backend_t backend;
std::vector<cgraph_config> cgraphs;
std::vector<ggml_tensor *> nodes;
ggml_context * ctx = nullptr;
ggml_backend_buffer_t bufs[2] = {nullptr, nullptr}; // Double-buffered to reduce synchronizations.
std::vector<cgraph_config> cgraphs;
std::vector<ggml_tensor *> nodes;
ggml_context * ctx = nullptr;
std::vector<ggml_backend_buffer_t> bufs; // Multiple buffers to reduce synchronizations.
backend_config(ggml_backend_t backend) : backend(backend) {}
~backend_config() {
ggml_backend_buffer_free(bufs[1]);
ggml_backend_buffer_free(bufs[0]);
for (ggml_backend_buffer_t buf : bufs) {
ggml_backend_buffer_free(buf);
}
ggml_free(ctx);
}
};
@ -688,6 +689,10 @@ struct ggml_backend_meta_context {
ggml_backend_free(bc.backend);
}
}
size_t n_reduce_steps() const {
return std::ceil(std::log2(backend_configs.size()));
}
};
static const char * ggml_backend_meta_get_name(ggml_backend_t backend) {
@ -730,7 +735,7 @@ static void ggml_backend_meta_synchronize(ggml_backend_t backend) {
static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
const size_t n_backends = ggml_backend_meta_n_backends(backend);
ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context;
const size_t n_reduce_steps = std::ceilf(std::log2(n_backends));
const size_t n_reduce_steps = backend_ctx->n_reduce_steps();
for (size_t j = 0; j < n_backends; j++) {
auto & bcj = backend_ctx->backend_configs[j];
@ -772,25 +777,28 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
}
ggml_init_params params = {
/*.mem_size =*/ n_subgraphs*2*ggml_tensor_overhead(),
/*.mem_size =*/ n_subgraphs*n_reduce_steps*2*ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
size_t i_buf = 0; // Alternate between tmp buffers per simple backend to reduce synchronizations.
// Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable:
bool tmp_buffers_initialized = false;
auto allreduce_fallback = [&](size_t i) -> ggml_status {
size_t i_buf = i % 2; // Alternate between the two tmp buffers per simple backends to reduce synchronizations.
if (!tmp_buffers_initialized) {
for (size_t j = 0; j < n_backends; j++) {
auto & bcj = backend_ctx->backend_configs[j];
ggml_backend_buffer_free(bcj.bufs[1]);
ggml_backend_buffer_free(bcj.bufs[0]);
for (ggml_backend_buffer_t buf : bcj.bufs) {
ggml_backend_buffer_free(buf);
}
bcj.bufs.clear();
ggml_free(bcj.ctx);
bcj.ctx = ggml_init(params);
bcj.bufs[0] = ggml_backend_alloc_buffer(bcj.backend, max_tmp_size);
bcj.bufs[1] = ggml_backend_alloc_buffer(bcj.backend, max_tmp_size);
for (size_t k = 0; k < n_reduce_steps + 1; k++) {
bcj.bufs.push_back(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size));
}
}
tmp_buffers_initialized = true;
}
@ -844,6 +852,8 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
bcj2.cgraphs[i].cgraphs_aux.back().nodes = &bcj2.cgraphs[i].nodes_aux.back();
bcj1.cgraphs[i].cgraphs_aux.back().n_nodes = 1;
bcj2.cgraphs[i].cgraphs_aux.back().n_nodes = 1;
i_buf = (i_buf + 1) % (n_reduce_steps + 1);
}
for (size_t j = 0; j < n_backends; j++) {