Merge 3b9b01a5b6 into 9e2e2198b0
This commit is contained in:
commit
f03bc3dd4b
|
|
@ -220,6 +220,15 @@ struct rpc_msg_graph_recompute_req {
|
|||
uint32_t device;
|
||||
};
|
||||
|
||||
// Response structs for graph compute operations
|
||||
struct rpc_msg_graph_compute_rsp {
|
||||
int32_t status; // ggml_status
|
||||
};
|
||||
|
||||
struct rpc_msg_graph_recompute_rsp {
|
||||
int32_t status; // ggml_status
|
||||
};
|
||||
|
||||
#pragma pack(pop)
|
||||
|
||||
// RPC data structures
|
||||
|
|
@ -874,17 +883,20 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
|
|||
rpc_msg_graph_recompute_req request;
|
||||
request.device = rpc_ctx->device;
|
||||
auto sock = get_socket(rpc_ctx->endpoint);
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
|
||||
rpc_msg_graph_recompute_rsp response;
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request), &response, sizeof(response));
|
||||
RPC_STATUS_ASSERT(status);
|
||||
return (enum ggml_status)response.status;
|
||||
} else {
|
||||
rpc_ctx->gc.add(cgraph);
|
||||
std::vector<uint8_t> input;
|
||||
serialize_graph(rpc_ctx->device, cgraph, input);
|
||||
auto sock = get_socket(rpc_ctx->endpoint);
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
|
||||
rpc_msg_graph_compute_rsp response;
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
|
||||
RPC_STATUS_ASSERT(status);
|
||||
return (enum ggml_status)response.status;
|
||||
}
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
static ggml_backend_i ggml_backend_rpc_interface = {
|
||||
|
|
@ -1001,8 +1013,8 @@ public:
|
|||
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
|
||||
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
||||
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
||||
bool graph_compute(const std::vector<uint8_t> & input);
|
||||
bool graph_recompute(const rpc_msg_graph_recompute_req & request);
|
||||
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
||||
bool graph_recompute(const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response);
|
||||
bool init_tensor(const rpc_msg_init_tensor_req & request);
|
||||
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
|
||||
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
|
||||
|
|
@ -1474,7 +1486,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
|
|||
return result;
|
||||
}
|
||||
|
||||
bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
|
||||
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
|
||||
// serialization format:
|
||||
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
||||
if (input.size() < 2*sizeof(uint32_t)) {
|
||||
|
|
@ -1537,24 +1549,32 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
|
|||
}
|
||||
}
|
||||
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
|
||||
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
|
||||
if (status != GGML_STATUS_SUCCESS) {
|
||||
GGML_LOG_ERROR("[%s] graph compute failed with status %d\n", __func__, (int)status);
|
||||
}
|
||||
response.status = (int32_t)status;
|
||||
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
|
||||
stored_graphs[device].graph = graph;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
|
||||
bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response) {
|
||||
uint32_t device = request.device;
|
||||
if (device >= backends.size()) {
|
||||
response.status = (int32_t)GGML_STATUS_FAILED;
|
||||
return false;
|
||||
}
|
||||
if (stored_graphs[device].graph == nullptr) {
|
||||
response.status = (int32_t)GGML_STATUS_FAILED;
|
||||
return false;
|
||||
}
|
||||
ggml_cgraph * graph = stored_graphs[device].graph;
|
||||
LOG_DBG("[%s] device: %u\n", __func__, device);
|
||||
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
|
||||
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
|
||||
if (status != GGML_STATUS_SUCCESS) {
|
||||
GGML_LOG_ERROR("[%s] graph recompute failed with status %d\n", __func__, (int)status);
|
||||
}
|
||||
response.status = (int32_t)status;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -1789,7 +1809,11 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
|
|||
if (!recv_msg(sockfd, input)) {
|
||||
return;
|
||||
}
|
||||
if (!server.graph_compute(input)) {
|
||||
rpc_msg_graph_compute_rsp response;
|
||||
if (!server.graph_compute(input, response)) {
|
||||
return;
|
||||
}
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
|
|
@ -1799,7 +1823,11 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
|
|||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||
return;
|
||||
}
|
||||
if (!server.graph_recompute(request)) {
|
||||
rpc_msg_graph_recompute_rsp response;
|
||||
if (!server.graph_recompute(request, response)) {
|
||||
return;
|
||||
}
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
|
|
|
|||
Loading…
Reference in New Issue