diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 1378ba9f5b..4e2f1ab0f2 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1009,8 +1009,8 @@ public: bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response); struct stored_graph { - ggml_context_ptr ctx_ptr; - ggml_cgraph * graph; + std::vector buffer; + ggml_cgraph * graph; }; private: @@ -1518,10 +1518,12 @@ bool rpc_server::graph_compute(const std::vector & input) { LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors); size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); - + if (stored_graphs[device].buffer.size() < buf_size) { + stored_graphs[device].buffer.resize(buf_size); + } struct ggml_init_params params = { /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ NULL, + /*.mem_buffer =*/ stored_graphs[device].buffer.data(), /*.no_alloc =*/ true, }; ggml_context_ptr ctx_ptr { ggml_init(params) }; @@ -1551,7 +1553,6 @@ bool rpc_server::graph_compute(const std::vector & input) { } ggml_status status = ggml_backend_graph_compute(backends[device], graph); GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC"); - stored_graphs[device].ctx_ptr.swap(ctx_ptr); stored_graphs[device].graph = graph; return true; }