diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index d7c8ad8c16..a9bb87eca6 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -220,6 +220,15 @@ struct rpc_msg_graph_recompute_req { uint32_t device; }; +// Response structs for graph compute operations +struct rpc_msg_graph_compute_rsp { + int32_t status; // ggml_status +}; + +struct rpc_msg_graph_recompute_rsp { + int32_t status; // ggml_status +}; + #pragma pack(pop) // RPC data structures @@ -874,17 +883,20 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g rpc_msg_graph_recompute_req request; request.device = rpc_ctx->device; auto sock = get_socket(rpc_ctx->endpoint); - bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request)); + rpc_msg_graph_recompute_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request), &response, sizeof(response)); RPC_STATUS_ASSERT(status); + return (enum ggml_status)response.status; } else { rpc_ctx->gc.add(cgraph); std::vector input; serialize_graph(rpc_ctx->device, cgraph, input); auto sock = get_socket(rpc_ctx->endpoint); - bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size()); + rpc_msg_graph_compute_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response)); RPC_STATUS_ASSERT(status); + return (enum ggml_status)response.status; } - return GGML_STATUS_SUCCESS; } static ggml_backend_i ggml_backend_rpc_interface = { @@ -1001,8 +1013,8 @@ public: bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response); bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response); bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response); - bool graph_compute(const std::vector & input); - bool graph_recompute(const rpc_msg_graph_recompute_req & request); + bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response); + bool graph_recompute(const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response); bool init_tensor(const rpc_msg_init_tensor_req & request); bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response); @@ -1474,7 +1486,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id, return result; } -bool rpc_server::graph_compute(const std::vector & input) { +bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response) { // serialization format: // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | if (input.size() < 2*sizeof(uint32_t)) { @@ -1537,24 +1549,32 @@ bool rpc_server::graph_compute(const std::vector & input) { } } ggml_status status = ggml_backend_graph_compute(backends[device], graph); - GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC"); + if (status != GGML_STATUS_SUCCESS) { + GGML_LOG_ERROR("[%s] graph compute failed with status %d\n", __func__, (int)status); + } + response.status = (int32_t)status; stored_graphs[device].ctx_ptr.swap(ctx_ptr); stored_graphs[device].graph = graph; return true; } -bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) { +bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response) { uint32_t device = request.device; if (device >= backends.size()) { + response.status = (int32_t)GGML_STATUS_FAILED; return false; } if (stored_graphs[device].graph == nullptr) { + response.status = (int32_t)GGML_STATUS_FAILED; return false; } ggml_cgraph * graph = stored_graphs[device].graph; LOG_DBG("[%s] device: %u\n", __func__, device); ggml_status status = ggml_backend_graph_compute(backends[device], graph); - GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC"); + if (status != GGML_STATUS_SUCCESS) { + GGML_LOG_ERROR("[%s] graph recompute failed with status %d\n", __func__, (int)status); + } + response.status = (int32_t)status; return true; } @@ -1789,7 +1809,11 @@ static void rpc_serve_client(const std::vector & backends, const if (!recv_msg(sockfd, input)) { return; } - if (!server.graph_compute(input)) { + rpc_msg_graph_compute_rsp response; + if (!server.graph_compute(input, response)) { + return; + } + if (!send_msg(sockfd, &response, sizeof(response))) { return; } break; @@ -1799,7 +1823,11 @@ static void rpc_serve_client(const std::vector & backends, const if (!recv_msg(sockfd, &request, sizeof(request))) { return; } - if (!server.graph_recompute(request)) { + rpc_msg_graph_recompute_rsp response; + if (!server.graph_recompute(request, response)) { + return; + } + if (!send_msg(sockfd, &response, sizeof(response))) { return; } break;