diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 36f388ba48..381b5d8664 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -180,7 +180,7 @@ ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev, ggml_backend_dev_t ggml_backend_meta_device( ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) { - GGML_ASSERT(n_devs <= 2); + GGML_ASSERT(n_devs == 1 || n_devs == 2 || n_devs == 4 || n_devs == 8); static std::vector> ctxs; static std::map meta_devs; @@ -383,7 +383,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer nb[k] = tensor->nb[k]; } if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) { - GGML_ASSERT(ne[split_dim] % n_simple_bufs == 0); + GGML_ASSERT(ne[split_dim] % (n_simple_bufs*ggml_blck_size(tensor->type)) == 0); ne[split_dim] /= n_simple_bufs; for (int i = 0; i < GGML_MAX_DIMS; i++) { if (tensor->nb[i] > tensor->nb[split_dim]) { @@ -652,16 +652,17 @@ struct ggml_backend_meta_context { struct backend_config { ggml_backend_t backend; - std::vector cgraphs; - std::vector nodes; - ggml_context * ctx = nullptr; - ggml_backend_buffer_t bufs[2] = {nullptr, nullptr}; // Double-buffered to reduce synchronizations. + std::vector cgraphs; + std::vector nodes; + ggml_context * ctx = nullptr; + std::vector bufs; // Multiple buffers to reduce synchronizations. backend_config(ggml_backend_t backend) : backend(backend) {} ~backend_config() { - ggml_backend_buffer_free(bufs[1]); - ggml_backend_buffer_free(bufs[0]); + for (ggml_backend_buffer_t buf : bufs) { + ggml_backend_buffer_free(buf); + } ggml_free(ctx); } }; @@ -688,6 +689,10 @@ struct ggml_backend_meta_context { ggml_backend_free(bc.backend); } } + + size_t n_reduce_steps() const { + return std::ceil(std::log2(backend_configs.size())); + } }; static const char * ggml_backend_meta_get_name(ggml_backend_t backend) { @@ -730,7 +735,7 @@ static void ggml_backend_meta_synchronize(ggml_backend_t backend) { static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { const size_t n_backends = ggml_backend_meta_n_backends(backend); ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context; - const size_t n_reduce_steps = std::ceilf(std::log2(n_backends)); + const size_t n_reduce_steps = backend_ctx->n_reduce_steps(); for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; @@ -772,25 +777,28 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, } ggml_init_params params = { - /*.mem_size =*/ n_subgraphs*2*ggml_tensor_overhead(), + /*.mem_size =*/ n_subgraphs*n_reduce_steps*2*ggml_tensor_overhead(), /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; + size_t i_buf = 0; // Alternate between tmp buffers per simple backend to reduce synchronizations. // Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable: bool tmp_buffers_initialized = false; auto allreduce_fallback = [&](size_t i) -> ggml_status { - size_t i_buf = i % 2; // Alternate between the two tmp buffers per simple backends to reduce synchronizations. if (!tmp_buffers_initialized) { for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; - ggml_backend_buffer_free(bcj.bufs[1]); - ggml_backend_buffer_free(bcj.bufs[0]); + for (ggml_backend_buffer_t buf : bcj.bufs) { + ggml_backend_buffer_free(buf); + } + bcj.bufs.clear(); ggml_free(bcj.ctx); bcj.ctx = ggml_init(params); - bcj.bufs[0] = ggml_backend_alloc_buffer(bcj.backend, max_tmp_size); - bcj.bufs[1] = ggml_backend_alloc_buffer(bcj.backend, max_tmp_size); + for (size_t k = 0; k < n_reduce_steps + 1; k++) { + bcj.bufs.push_back(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); + } } tmp_buffers_initialized = true; } @@ -844,6 +852,8 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, bcj2.cgraphs[i].cgraphs_aux.back().nodes = &bcj2.cgraphs[i].nodes_aux.back(); bcj1.cgraphs[i].cgraphs_aux.back().n_nodes = 1; bcj2.cgraphs[i].cgraphs_aux.back().n_nodes = 1; + + i_buf = (i_buf + 1) % (n_reduce_steps + 1); } for (size_t j = 0; j < n_backends; j++) {