rpc : cache and reuse compute graphs (#15405)

Store the last computed graph and reuse it when possible.
Also do not return response from GRAPH_COMPUTE and assume it always
completes successfully. If this this is not the case, the server closes
the connection. This saves us a network round trip to the server.
This commit is contained in:
Radoslav Gerganov 2025-11-28 10:33:51 +02:00 committed by GitHub
parent 6bca76ff5e
commit 15d2b46b4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 92 additions and 21 deletions

View File

@ -8,7 +8,7 @@ extern "C" {
#endif #endif
#define RPC_PROTO_MAJOR_VERSION 3 #define RPC_PROTO_MAJOR_VERSION 3
#define RPC_PROTO_MINOR_VERSION 0 #define RPC_PROTO_MINOR_VERSION 5
#define RPC_PROTO_PATCH_VERSION 0 #define RPC_PROTO_PATCH_VERSION 0
#define GGML_RPC_MAX_SERVERS 16 #define GGML_RPC_MAX_SERVERS 16

View File

@ -106,6 +106,7 @@ enum rpc_cmd {
RPC_CMD_GET_ALLOC_SIZE, RPC_CMD_GET_ALLOC_SIZE,
RPC_CMD_HELLO, RPC_CMD_HELLO,
RPC_CMD_DEVICE_COUNT, RPC_CMD_DEVICE_COUNT,
RPC_CMD_GRAPH_RECOMPUTE,
RPC_CMD_COUNT, RPC_CMD_COUNT,
}; };
@ -205,10 +206,6 @@ struct rpc_msg_copy_tensor_rsp {
uint8_t result; uint8_t result;
}; };
struct rpc_msg_graph_compute_rsp {
uint8_t result;
};
struct rpc_msg_get_device_memory_req { struct rpc_msg_get_device_memory_req {
uint32_t device; uint32_t device;
}; };
@ -217,6 +214,11 @@ struct rpc_msg_get_device_memory_rsp {
uint64_t free_mem; uint64_t free_mem;
uint64_t total_mem; uint64_t total_mem;
}; };
struct rpc_msg_graph_recompute_req {
uint32_t device;
};
#pragma pack(pop) #pragma pack(pop)
// RPC data structures // RPC data structures
@ -234,10 +236,35 @@ struct ggml_backend_rpc_buffer_type_context {
size_t max_size; size_t max_size;
}; };
struct graph_cache {
bool is_cached(const ggml_cgraph * cgraph) {
if ((int)last_graph.size() != cgraph->n_nodes) {
return false;
}
for (int i = 0; i < cgraph->n_nodes; i++) {
if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
return false;
}
}
return true;
}
void add(const ggml_cgraph * cgraph) {
last_graph.resize(cgraph->n_nodes);
for (int i = 0; i < cgraph->n_nodes; i++) {
memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
}
}
std::vector<ggml_tensor> last_graph;
};
struct ggml_backend_rpc_context { struct ggml_backend_rpc_context {
std::string endpoint; std::string endpoint;
uint32_t device; uint32_t device;
std::string name; std::string name;
graph_cache gc;
}; };
struct ggml_backend_rpc_buffer_context { struct ggml_backend_rpc_buffer_context {
@ -815,13 +842,24 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
GGML_ASSERT(cgraph->n_nodes > 0);
bool reuse = rpc_ctx->gc.is_cached(cgraph);
if (reuse) {
rpc_msg_graph_recompute_req request;
request.device = rpc_ctx->device;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
RPC_STATUS_ASSERT(status);
} else {
rpc_ctx->gc.add(cgraph);
std::vector<uint8_t> input; std::vector<uint8_t> input;
serialize_graph(rpc_ctx->device, cgraph, input); serialize_graph(rpc_ctx->device, cgraph, input);
rpc_msg_graph_compute_rsp response;
auto sock = get_socket(rpc_ctx->endpoint); auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response)); bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
RPC_STATUS_ASSERT(status); RPC_STATUS_ASSERT(status);
return (enum ggml_status)response.result; }
return GGML_STATUS_SUCCESS;
} }
static ggml_backend_i ggml_backend_rpc_interface = { static ggml_backend_i ggml_backend_rpc_interface = {
@ -880,7 +918,8 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
/* .endpoint = */ endpoint, /* .endpoint = */ endpoint,
/* .device = */ device, /* .device = */ device,
/* .name = */ dev_name /* .name = */ dev_name,
/* .gc = */ {},
}; };
auto reg = ggml_backend_rpc_add_server(endpoint); auto reg = ggml_backend_rpc_add_server(endpoint);
ggml_backend_t backend = new ggml_backend { ggml_backend_t backend = new ggml_backend {
@ -920,8 +959,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device,
class rpc_server { class rpc_server {
public: public:
rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir) rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)
: backends(std::move(backends)), cache_dir(cache_dir) { : backends(std::move(all_backends)), cache_dir(cache_dir) {
stored_graphs.resize(backends.size());
} }
~rpc_server(); ~rpc_server();
@ -936,11 +976,17 @@ public:
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response); bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response); bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response); bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response); bool graph_compute(const std::vector<uint8_t> & input);
bool graph_recompute(const rpc_msg_graph_recompute_req & request);
bool init_tensor(const rpc_msg_init_tensor_req & request); bool init_tensor(const rpc_msg_init_tensor_req & request);
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response); bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
struct stored_graph {
ggml_context_ptr ctx_ptr;
ggml_cgraph * graph;
};
private: private:
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data); bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
@ -953,6 +999,8 @@ private:
std::vector<ggml_backend_t> backends; std::vector<ggml_backend_t> backends;
const char * cache_dir; const char * cache_dir;
std::unordered_set<ggml_backend_buffer_t> buffers; std::unordered_set<ggml_backend_buffer_t> buffers;
// store the last computed graph for each backend
std::vector<stored_graph> stored_graphs;
}; };
void rpc_server::hello(rpc_msg_hello_rsp & response) { void rpc_server::hello(rpc_msg_hello_rsp & response) {
@ -1394,7 +1442,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
return result; return result;
} }
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) { bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
// serialization format: // serialization format:
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
if (input.size() < 2*sizeof(uint32_t)) { if (input.size() < 2*sizeof(uint32_t)) {
@ -1455,7 +1503,24 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
} }
} }
ggml_status status = ggml_backend_graph_compute(backends[device], graph); ggml_status status = ggml_backend_graph_compute(backends[device], graph);
response.result = status; GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
stored_graphs[device].graph = graph;
return true;
}
bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
uint32_t device = request.device;
if (device >= backends.size()) {
return false;
}
if (stored_graphs[device].graph == nullptr) {
return false;
}
ggml_cgraph * graph = stored_graphs[device].graph;
LOG_DBG("[%s] device: %u\n", __func__, device);
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
return true; return true;
} }
@ -1690,11 +1755,17 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
if (!recv_msg(sockfd, input)) { if (!recv_msg(sockfd, input)) {
return; return;
} }
rpc_msg_graph_compute_rsp response; if (!server.graph_compute(input)) {
if (!server.graph_compute(input, response)) {
return; return;
} }
if (!send_msg(sockfd, &response, sizeof(response))) { break;
}
case RPC_CMD_GRAPH_RECOMPUTE: {
rpc_msg_graph_recompute_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
if (!server.graph_recompute(request)) {
return; return;
} }
break; break;