diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 381b5d8664..2304310bf0 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -668,6 +669,8 @@ struct ggml_backend_meta_context { }; std::string name; std::vector backend_configs; + size_t max_tmp_size = 0; + size_t max_subgraphs = 0; ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) { const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev); @@ -693,6 +696,23 @@ struct ggml_backend_meta_context { size_t n_reduce_steps() const { return std::ceil(std::log2(backend_configs.size())); } + + ggml_tensor * get_next_tensor(size_t j, std::vector & tensors, ggml_tensor * node) { + ggml_tensor * next = tensors[j] == nullptr ? ggml_get_first_tensor(backend_configs[j].ctx) + : ggml_get_next_tensor(backend_configs[j].ctx, tensors[j]); + if (next == nullptr) { + next = ggml_new_tensor_1d(backend_configs[j].ctx, GGML_TYPE_F32, 1); + } + memset(next, 0, sizeof(ggml_tensor)); + next->op = GGML_OP_NONE; + next->type = node->type; + for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { + next->ne[dim] = node->ne[dim]; + next->nb[dim] = node->nb[dim]; + } + tensors[j] = next; + return next; + } }; static const char * ggml_backend_meta_get_name(ggml_backend_t backend) { @@ -776,32 +796,38 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, GGML_ASSERT(i_start == cgraph->n_nodes); } - ggml_init_params params = { - /*.mem_size =*/ n_subgraphs*n_reduce_steps*2*ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; + if (max_tmp_size > backend_ctx->max_tmp_size) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + for (ggml_backend_buffer_t buf : bcj.bufs) { + ggml_backend_buffer_free(buf); + } + bcj.bufs.clear(); + for (size_t k = 0; k < n_reduce_steps + 1; k++) { + bcj.bufs.push_back(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); + } + } + backend_ctx->max_tmp_size = max_tmp_size; + } + if (n_subgraphs > backend_ctx->max_subgraphs) { + ggml_init_params params = { + /*.mem_size =*/ n_subgraphs*n_reduce_steps*2*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + ggml_free(bcj.ctx); + bcj.ctx = ggml_init(params); + } + backend_ctx->max_subgraphs = n_subgraphs; + } size_t i_buf = 0; // Alternate between tmp buffers per simple backend to reduce synchronizations. + std::vector tensors(n_backends, nullptr); // Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable: - bool tmp_buffers_initialized = false; auto allreduce_fallback = [&](size_t i) -> ggml_status { - if (!tmp_buffers_initialized) { - for (size_t j = 0; j < n_backends; j++) { - auto & bcj = backend_ctx->backend_configs[j]; - for (ggml_backend_buffer_t buf : bcj.bufs) { - ggml_backend_buffer_free(buf); - } - bcj.bufs.clear(); - ggml_free(bcj.ctx); - bcj.ctx = ggml_init(params); - for (size_t k = 0; k < n_reduce_steps + 1; k++) { - bcj.bufs.push_back(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); - } - } - tmp_buffers_initialized = true; - } for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; @@ -826,8 +852,8 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, GGML_ASSERT(ggml_is_contiguous(node1)); GGML_ASSERT(ggml_is_contiguous(node2)); - ggml_tensor * node_tmp_1 = ggml_dup_tensor(bcj1.ctx, node1); - ggml_tensor * node_tmp_2 = ggml_dup_tensor(bcj2.ctx, node2); + ggml_tensor * node_tmp_1 = backend_ctx->get_next_tensor(j, tensors, node1); + ggml_tensor * node_tmp_2 = backend_ctx->get_next_tensor(j_other, tensors, node2); node_tmp_1->buffer = bcj1.bufs[i_buf]; node_tmp_2->buffer = bcj2.bufs[i_buf]; node_tmp_1->data = ggml_backend_buffer_get_base(bcj1.bufs[i_buf]); @@ -837,8 +863,18 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, ggml_backend_tensor_shfl_async(bcj1.backend, bcj2.backend, node1, node2, node_tmp_1, node_tmp_2); - ggml_tensor * node_red_1 = ggml_add_inplace(bcj1.ctx, node1, node_tmp_1); - ggml_tensor * node_red_2 = ggml_add_inplace(bcj2.ctx, node2, node_tmp_2); + ggml_tensor * node_red_1 = backend_ctx->get_next_tensor(j, tensors, node1); + ggml_tensor * node_red_2 = backend_ctx->get_next_tensor(j_other, tensors, node2); + node_red_1->view_src = node1->view_src == nullptr ? node1 : node1->view_src; + node_red_2->view_src = node2->view_src == nullptr ? node2 : node2->view_src; + node_red_1->view_offs = node1->view_offs; + node_red_2->view_offs = node2->view_offs; + node_red_1->op = GGML_OP_ADD; + node_red_2->op = GGML_OP_ADD; + node_red_1->src[0] = node1; + node_red_2->src[0] = node2; + node_red_1->src[1] = node_tmp_1; + node_red_2->src[1] = node_tmp_2; node_red_1->flags |= GGML_TENSOR_FLAG_COMPUTE; node_red_2->flags |= GGML_TENSOR_FLAG_COMPUTE; ggml_backend_view_init(node_red_1);