This commit is contained in:
Matthias Fahsold 2026-03-16 12:39:21 +11:00 committed by GitHub
commit f03bc3dd4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 39 additions and 11 deletions

View File

@ -220,6 +220,15 @@ struct rpc_msg_graph_recompute_req {
uint32_t device;
};
// Response structs for graph compute operations
struct rpc_msg_graph_compute_rsp {
int32_t status; // ggml_status
};
struct rpc_msg_graph_recompute_rsp {
int32_t status; // ggml_status
};
#pragma pack(pop)
// RPC data structures
@ -874,17 +883,20 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
rpc_msg_graph_recompute_req request;
request.device = rpc_ctx->device;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
rpc_msg_graph_recompute_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return (enum ggml_status)response.status;
} else {
rpc_ctx->gc.add(cgraph);
std::vector<uint8_t> input;
serialize_graph(rpc_ctx->device, cgraph, input);
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
rpc_msg_graph_compute_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return (enum ggml_status)response.status;
}
return GGML_STATUS_SUCCESS;
}
static ggml_backend_i ggml_backend_rpc_interface = {
@ -1001,8 +1013,8 @@ public:
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
bool graph_compute(const std::vector<uint8_t> & input);
bool graph_recompute(const rpc_msg_graph_recompute_req & request);
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
bool graph_recompute(const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response);
bool init_tensor(const rpc_msg_init_tensor_req & request);
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
@ -1474,7 +1486,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
return result;
}
bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
// serialization format:
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
if (input.size() < 2*sizeof(uint32_t)) {
@ -1537,24 +1549,32 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
}
}
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
if (status != GGML_STATUS_SUCCESS) {
GGML_LOG_ERROR("[%s] graph compute failed with status %d\n", __func__, (int)status);
}
response.status = (int32_t)status;
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
stored_graphs[device].graph = graph;
return true;
}
bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response) {
uint32_t device = request.device;
if (device >= backends.size()) {
response.status = (int32_t)GGML_STATUS_FAILED;
return false;
}
if (stored_graphs[device].graph == nullptr) {
response.status = (int32_t)GGML_STATUS_FAILED;
return false;
}
ggml_cgraph * graph = stored_graphs[device].graph;
LOG_DBG("[%s] device: %u\n", __func__, device);
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
if (status != GGML_STATUS_SUCCESS) {
GGML_LOG_ERROR("[%s] graph recompute failed with status %d\n", __func__, (int)status);
}
response.status = (int32_t)status;
return true;
}
@ -1789,7 +1809,11 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
if (!recv_msg(sockfd, input)) {
return;
}
if (!server.graph_compute(input)) {
rpc_msg_graph_compute_rsp response;
if (!server.graph_compute(input, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
@ -1799,7 +1823,11 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
if (!server.graph_recompute(request)) {
rpc_msg_graph_recompute_rsp response;
if (!server.graph_recompute(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;