llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp

2631 lines
93 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "ggml-rpc.h"
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
#include "ggml-cpp.h"
#include <cinttypes>
#include <string>
#include <vector>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <unordered_set>
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
# include <winsock2.h>
#else
# include <arpa/inet.h>
# include <sys/socket.h>
# include <sys/types.h>
# include <netinet/in.h>
# include <netinet/tcp.h>
# include <netdb.h>
# include <unistd.h>
#endif
#include <cstring>
#include <fstream>
#include <filesystem>
#include <algorithm>
#ifdef GGML_RPC_RDMA
# include <infiniband/verbs.h>
# include <time.h>
#endif
static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG");
#define LOG_DBG(...) \
do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0)
namespace fs = std::filesystem;
static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB
#ifdef GGML_RPC_RDMA
static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock)
static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB
#endif
#ifdef _WIN32
typedef SOCKET sockfd_t;
using ssize_t = __int64;
#else
typedef int sockfd_t;
#endif
// cross-platform socket
#ifdef GGML_RPC_RDMA
struct rdma_conn {
struct ibv_context * ctx = nullptr;
struct ibv_pd * pd = nullptr;
struct ibv_cq * scq = nullptr; // send completions
struct ibv_cq * rcq = nullptr; // recv completions
struct ibv_qp * qp = nullptr;
void * tx_buf = nullptr;
struct ibv_mr * tx_mr = nullptr;
void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous
struct ibv_mr * rx_mr = nullptr;
int rx_head = 0;
uint32_t max_inline = 0;
void * rx_slot(int i) { return (char *)rx_buf + (size_t)i * RDMA_CHUNK; }
bool post_rx(int i) {
struct ibv_sge sge = {};
sge.addr = (uintptr_t)rx_slot(i);
sge.length = RDMA_CHUNK;
sge.lkey = rx_mr->lkey;
struct ibv_recv_wr wr = {}, * bad = nullptr;
wr.wr_id = (uint64_t)i;
wr.sg_list = &sge;
wr.num_sge = 1;
return ibv_post_recv(qp, &wr, &bad) == 0;
}
~rdma_conn() {
if (tx_mr) ibv_dereg_mr(tx_mr);
if (rx_mr) ibv_dereg_mr(rx_mr);
free(tx_buf);
free(rx_buf);
if (qp) ibv_destroy_qp(qp);
if (scq) ibv_destroy_cq(scq);
if (rcq) ibv_destroy_cq(rcq);
if (pd) ibv_dealloc_pd(pd);
if (ctx) ibv_close_device(ctx);
}
};
#endif
// Forward declarations for transport function pointers
struct socket_t;
static bool tcp_send_impl(socket_t * sock, const void * data, size_t size);
static bool tcp_recv_impl(socket_t * sock, void * data, size_t size);
struct socket_t {
sockfd_t fd;
bool (*fn_send)(socket_t *, const void *, size_t) = tcp_send_impl;
bool (*fn_recv)(socket_t *, void *, size_t) = tcp_recv_impl;
#ifdef GGML_RPC_RDMA
rdma_conn * rdma = nullptr;
#endif
socket_t(sockfd_t fd) : fd(fd) {}
~socket_t() {
#ifdef GGML_RPC_RDMA
if (rdma) { delete rdma; rdma = nullptr; }
#endif
LOG_DBG("[%s] closing socket %d\n", __func__, this->fd);
#ifdef _WIN32
if (fd != INVALID_SOCKET) closesocket(this->fd);
#else
if (fd >= 0) close(this->fd);
#endif
}
};
// macro for nicer error messages on server crash
#define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response")
// all RPC structures must be packed
#pragma pack(push, 1)
// ggml_tensor is serialized into rpc_tensor
struct rpc_tensor {
uint64_t id;
uint32_t type;
uint64_t buffer;
uint32_t ne[GGML_MAX_DIMS];
uint32_t nb[GGML_MAX_DIMS];
uint32_t op;
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
int32_t flags;
uint64_t src[GGML_MAX_SRC];
uint64_t view_src;
uint64_t view_offs;
uint64_t data;
char name[GGML_MAX_NAME];
char padding[4];
};
static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
// RPC commands
enum rpc_cmd {
RPC_CMD_ALLOC_BUFFER = 0,
RPC_CMD_GET_ALIGNMENT,
RPC_CMD_GET_MAX_SIZE,
RPC_CMD_BUFFER_GET_BASE,
RPC_CMD_FREE_BUFFER,
RPC_CMD_BUFFER_CLEAR,
RPC_CMD_SET_TENSOR,
RPC_CMD_SET_TENSOR_HASH,
RPC_CMD_GET_TENSOR,
RPC_CMD_COPY_TENSOR,
RPC_CMD_GRAPH_COMPUTE,
RPC_CMD_GET_DEVICE_MEMORY,
RPC_CMD_INIT_TENSOR,
RPC_CMD_GET_ALLOC_SIZE,
RPC_CMD_HELLO,
RPC_CMD_DEVICE_COUNT,
RPC_CMD_GRAPH_RECOMPUTE,
RPC_CMD_COUNT,
};
static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
struct rpc_msg_hello_rsp {
uint8_t major;
uint8_t minor;
uint8_t patch;
};
struct rpc_msg_hello_rdma_req {
uint32_t rdma_qpn; // 0 = no RDMA capability
uint32_t rdma_psn;
uint8_t rdma_gid[16];
};
struct rpc_msg_hello_rdma_rsp {
uint8_t major;
uint8_t minor;
uint8_t patch;
uint32_t rdma_qpn; // 0 = no RDMA capability
uint32_t rdma_psn;
uint8_t rdma_gid[16];
};
struct rpc_msg_device_count_rsp {
uint32_t device_count;
};
struct rpc_msg_get_alloc_size_req {
uint32_t device;
rpc_tensor tensor;
rpc_tensor srcs[GGML_MAX_SRC];
};
struct rpc_msg_get_alloc_size_rsp {
uint64_t alloc_size;
};
struct rpc_msg_init_tensor_req {
rpc_tensor tensor;
};
struct rpc_msg_alloc_buffer_req {
uint32_t device;
uint64_t size;
};
struct rpc_msg_alloc_buffer_rsp {
uint64_t remote_ptr;
uint64_t remote_size;
};
struct rpc_msg_get_alignment_req {
uint32_t device;
};
struct rpc_msg_get_alignment_rsp {
uint64_t alignment;
};
struct rpc_msg_get_max_size_req {
uint32_t device;
};
struct rpc_msg_get_max_size_rsp {
uint64_t max_size;
};
struct rpc_msg_buffer_get_base_req {
uint64_t remote_ptr;
};
struct rpc_msg_buffer_get_base_rsp {
uint64_t base_ptr;
};
struct rpc_msg_free_buffer_req {
uint64_t remote_ptr;
};
struct rpc_msg_buffer_clear_req {
uint64_t remote_ptr;
uint8_t value;
};
struct rpc_msg_set_tensor_hash_req {
rpc_tensor tensor;
uint64_t offset;
uint64_t hash;
};
struct rpc_msg_set_tensor_hash_rsp {
uint8_t result;
};
struct rpc_msg_get_tensor_req {
rpc_tensor tensor;
uint64_t offset;
uint64_t size;
};
struct rpc_msg_copy_tensor_req {
rpc_tensor src;
rpc_tensor dst;
};
struct rpc_msg_copy_tensor_rsp {
uint8_t result;
};
struct rpc_msg_get_device_memory_req {
uint32_t device;
};
struct rpc_msg_get_device_memory_rsp {
uint64_t free_mem;
uint64_t total_mem;
};
struct rpc_msg_graph_recompute_req {
uint32_t device;
};
#pragma pack(pop)
// RPC data structures
static ggml_guid_t ggml_backend_rpc_guid() {
static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
return &guid;
}
struct ggml_backend_rpc_buffer_type_context {
std::string endpoint;
uint32_t device;
std::string name;
size_t alignment;
size_t max_size;
};
struct graph_cache {
bool is_cached(const ggml_cgraph * cgraph) {
if ((int)last_graph.size() != cgraph->n_nodes) {
return false;
}
for (int i = 0; i < cgraph->n_nodes; i++) {
if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
return false;
}
}
return true;
}
void add(const ggml_cgraph * cgraph) {
last_graph.resize(cgraph->n_nodes);
for (int i = 0; i < cgraph->n_nodes; i++) {
memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
}
}
std::vector<ggml_tensor> last_graph;
};
struct ggml_backend_rpc_context {
std::string endpoint;
uint32_t device;
std::string name;
graph_cache gc;
};
struct ggml_backend_rpc_buffer_context {
std::shared_ptr<socket_t> sock;
void * base_ptr;
uint64_t remote_ptr;
};
// RPC helper functions
// Computes FNV-1a hash of the data
static uint64_t fnv_hash(const uint8_t * data, size_t len) {
const uint64_t fnv_prime = 0x100000001b3ULL;
uint64_t hash = 0xcbf29ce484222325ULL;
for (size_t i = 0; i < len; ++i) {
hash ^= data[i];
hash *= fnv_prime;
}
return hash;
}
static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
#ifdef _WIN32
if (fd == INVALID_SOCKET) {
return nullptr;
}
#else
if (fd < 0) {
return nullptr;
}
#endif
return std::make_shared<socket_t>(fd);
}
static bool set_no_delay(sockfd_t sockfd) {
int flag = 1;
// set TCP_NODELAY to disable Nagle's algorithm
int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
return ret == 0;
}
static bool set_reuse_addr(sockfd_t sockfd) {
int flag = 1;
int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));
return ret == 0;
}
static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
struct sockaddr_in addr;
auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
auto sock_ptr = make_socket(sockfd);
if (sock_ptr == nullptr) {
return nullptr;
}
if (!set_no_delay(sockfd)) {
GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
return nullptr;
}
addr.sin_family = AF_INET;
addr.sin_port = htons(port);
struct hostent * server = gethostbyname(host);
if (server == NULL) {
GGML_LOG_ERROR("Cannot resolve host '%s'\n", host);
return nullptr;
}
memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
return nullptr;
}
return sock_ptr;
}
static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
auto client_socket_fd = accept(srv_sockfd, NULL, NULL);
auto client_socket = make_socket(client_socket_fd);
if (client_socket == nullptr) {
return nullptr;
}
if (!set_no_delay(client_socket_fd)) {
GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
return nullptr;
}
return client_socket;
}
static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
auto sock = make_socket(sockfd);
if (sock == nullptr) {
return nullptr;
}
if (!set_reuse_addr(sockfd)) {
GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n");
return nullptr;
}
if (inet_addr(host) == INADDR_NONE) {
GGML_LOG_ERROR("Invalid host address: %s\n", host);
return nullptr;
}
struct sockaddr_in serv_addr;
serv_addr.sin_family = AF_INET;
serv_addr.sin_addr.s_addr = inet_addr(host);
serv_addr.sin_port = htons(port);
if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
return nullptr;
}
if (listen(sockfd, 1) < 0) {
return nullptr;
}
return sock;
}
static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
size_t bytes_sent = 0;
while (bytes_sent < size) {
size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE);
ssize_t n = send(sockfd, (const char *)data + bytes_sent, size_to_send, 0);
if (n < 0) {
GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n",
bytes_sent, size_to_send);
return false;
}
bytes_sent += (size_t)n;
}
return true;
}
static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
size_t bytes_recv = 0;
while (bytes_recv < size) {
size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE);
ssize_t n = recv(sockfd, (char *)data + bytes_recv, size_to_recv, 0);
if (n < 0) {
GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n",
bytes_recv, size_to_recv);
return false;
}
if (n == 0) {
LOG_DBG("recv returned 0 (peer closed?)\n");
return false;
}
bytes_recv += (size_t)n;
}
return true;
}
// ── TCP transport implementations (for function-pointer dispatch) ────────────
static bool tcp_send_impl(socket_t * sock, const void * data, size_t size) {
return send_data(sock->fd, data, size);
}
static bool tcp_recv_impl(socket_t * sock, void * data, size_t size) {
return recv_data(sock->fd, data, size);
}
// ── RDMA transport (performance-optimized, auto-negotiated) ─────────────────
#ifdef GGML_RPC_RDMA
static bool rdma_send_impl(socket_t * sock, const void * data, size_t size);
static bool rdma_recv_impl(socket_t * sock, void * data, size_t size);
static inline bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc) {
struct timespec t0;
clock_gettime(CLOCK_MONOTONIC_COARSE, &t0);
for (uint64_t s = 0; ; s++) {
int n = ibv_poll_cq(cq, 1, wc);
if (n > 0) return wc->status == IBV_WC_SUCCESS;
if (n < 0) return false;
if ((s & 0xFFFFF) == 0 && s > 0) {
struct timespec now;
clock_gettime(CLOCK_MONOTONIC_COARSE, &now);
if (now.tv_sec - t0.tv_sec >= 30) {
GGML_LOG_ERROR("RDMA CQ poll timeout\n");
return false;
}
}
}
}
static bool rdma_send(rdma_conn * c, const void * data, size_t size) {
const uint8_t * src = (const uint8_t *)data;
size_t rem = size;
while (rem > 0) {
size_t chunk = std::min(rem, RDMA_CHUNK);
struct ibv_sge sge = {};
struct ibv_send_wr wr = {}, * bad = nullptr;
wr.opcode = IBV_WR_SEND;
wr.sg_list = &sge;
wr.num_sge = 1;
if (chunk <= c->max_inline) {
sge.addr = (uintptr_t)src;
sge.length = chunk;
wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE;
} else {
memcpy(c->tx_buf, src, chunk);
sge.addr = (uintptr_t)c->tx_buf;
sge.length = chunk;
sge.lkey = c->tx_mr->lkey;
wr.send_flags = IBV_SEND_SIGNALED;
}
if (ibv_post_send(c->qp, &wr, &bad) != 0) return false;
struct ibv_wc wc;
if (!rdma_poll(c->scq, &wc)) return false;
src += chunk;
rem -= chunk;
}
return true;
}
static bool rdma_recv(rdma_conn * c, void * data, size_t size) {
uint8_t * dst = (uint8_t *)data;
size_t rem = size;
while (rem > 0) {
struct ibv_wc wc;
if (!rdma_poll(c->rcq, &wc)) return false;
int slot = (int)wc.wr_id;
size_t got = wc.byte_len;
memcpy(dst, c->rx_slot(slot), got);
if (!c->post_rx(slot)) return false;
dst += got;
rem -= got;
}
return true;
}
static bool rdma_send_impl(socket_t * sock, const void * data, size_t size) {
return rdma_send(sock->rdma, data, size);
}
static bool rdma_recv_impl(socket_t * sock, void * data, size_t size) {
return rdma_recv(sock->rdma, data, size);
}
// Phase 1: Probe for RDMA device, create QP (in RESET state), return local info.
// Returns rdma_conn with QP created but NOT connected. Caller gets local QPN/PSN/GID
// to send to the remote side via the HELLO exchange.
// If RDMA is not available, returns nullptr (caller stays on TCP).
struct rdma_local_info {
uint32_t qpn;
uint32_t psn;
uint8_t gid[16];
uint8_t ib_port;
int gid_idx;
enum ibv_mtu path_mtu;
};
static rdma_conn * rdma_probe(sockfd_t tcp_fd, rdma_local_info * out) {
const char * dev_env = std::getenv("GGML_RDMA_DEV");
const char * gid_env = std::getenv("GGML_RDMA_GID");
struct sockaddr_in local_addr;
socklen_t addr_len = sizeof(local_addr);
if (getsockname(tcp_fd, (struct sockaddr *)&local_addr, &addr_len) != 0) {
return nullptr;
}
const uint8_t ib_port = 1;
int num_devs = 0;
struct ibv_device ** devs = ibv_get_device_list(&num_devs);
if (!devs || num_devs == 0) return nullptr;
struct ibv_context * ibctx = nullptr;
const char * matched_dev = nullptr;
int gid_idx = gid_env ? atoi(gid_env) : -1;
for (int d = 0; d < num_devs; d++) {
const char * dn = ibv_get_device_name(devs[d]);
if (dev_env && strcmp(dev_env, dn) != 0) continue;
struct ibv_context * ctx = ibv_open_device(devs[d]);
if (!ctx) continue;
struct ibv_port_attr pa;
if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; }
int found_gid = gid_idx;
if (found_gid < 0) {
for (int i = 0; i < pa.gid_tbl_len; i++) {
union ibv_gid g;
if (ibv_query_gid(ctx, ib_port, i, &g) != 0) continue;
if (g.raw[10] != 0xff || g.raw[11] != 0xff) continue;
uint32_t ip;
memcpy(&ip, &g.raw[12], 4);
if (dev_env) {
if (ip != 0) { found_gid = i; break; }
} else {
if (ip == local_addr.sin_addr.s_addr) { found_gid = i; break; }
}
}
}
if (found_gid >= 0) {
ibctx = ctx;
gid_idx = found_gid;
matched_dev = dn;
out->path_mtu = pa.active_mtu;
break;
}
ibv_close_device(ctx);
}
ibv_free_device_list(devs);
if (!ibctx) return nullptr;
out->ib_port = ib_port;
out->gid_idx = gid_idx;
auto * c = new rdma_conn();
c->ctx = ibctx;
c->pd = ibv_alloc_pd(ibctx);
if (!c->pd) { delete c; return nullptr; }
c->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0);
c->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0);
if (!c->scq || !c->rcq) { delete c; return nullptr; }
struct ibv_qp_init_attr qia = {};
qia.send_cq = c->scq;
qia.recv_cq = c->rcq;
qia.qp_type = IBV_QPT_RC;
qia.cap.max_send_wr = 4;
qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4;
qia.cap.max_send_sge = 1;
qia.cap.max_recv_sge = 1;
qia.cap.max_inline_data = 256;
c->qp = ibv_create_qp(c->pd, &qia);
if (!c->qp) { delete c; return nullptr; }
c->max_inline = qia.cap.max_inline_data;
c->tx_buf = aligned_alloc(4096, RDMA_CHUNK);
c->rx_buf = aligned_alloc(4096, (size_t)RDMA_RX_DEPTH * RDMA_CHUNK);
if (!c->tx_buf || !c->rx_buf) { delete c; return nullptr; }
c->tx_mr = ibv_reg_mr(c->pd, c->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE);
c->rx_mr = ibv_reg_mr(c->pd, c->rx_buf, (size_t)RDMA_RX_DEPTH * RDMA_CHUNK,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
if (!c->tx_mr || !c->rx_mr) { delete c; return nullptr; }
union ibv_gid local_gid;
ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid);
out->qpn = c->qp->qp_num;
out->psn = c->qp->qp_num & 0xffffff;
memcpy(out->gid, &local_gid, 16);
GGML_LOG_INFO("RDMA probed: dev=%s gid=%d qpn=%u inline=%u\n",
matched_dev, gid_idx, out->qpn, c->max_inline);
return c;
}
// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET→INIT→pre-post→RTR→RTS.
// On success, the connection is live and ready for rdma_send/rdma_recv.
static bool rdma_activate(rdma_conn * c, const rdma_local_info * local,
uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) {
// RESET → INIT
{
struct ibv_qp_attr a = {};
a.qp_state = IBV_QPS_INIT;
a.port_num = local->ib_port;
a.pkey_index = 0;
a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE;
if (ibv_modify_qp(c->qp, &a,
IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) {
return false;
}
}
for (int i = 0; i < RDMA_RX_DEPTH; i++) {
if (!c->post_rx(i)) return false;
}
// INIT → RTR
{
struct ibv_qp_attr a = {};
a.qp_state = IBV_QPS_RTR;
a.path_mtu = local->path_mtu;
a.dest_qp_num = remote_qpn;
a.rq_psn = remote_psn;
a.max_dest_rd_atomic = 1;
a.min_rnr_timer = 1;
a.ah_attr.is_global = 1;
memcpy(&a.ah_attr.grh.dgid, remote_gid, 16);
a.ah_attr.grh.hop_limit = 1;
a.ah_attr.grh.sgid_index = local->gid_idx;
a.ah_attr.dlid = 0;
a.ah_attr.port_num = local->ib_port;
if (ibv_modify_qp(c->qp, &a,
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN |
IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) {
return false;
}
}
// RTR → RTS
{
struct ibv_qp_attr a = {};
a.qp_state = IBV_QPS_RTS;
a.timeout = 14;
a.retry_cnt = 7;
a.rnr_retry = 7;
a.sq_psn = local->psn;
a.max_rd_atomic = 1;
if (ibv_modify_qp(c->qp, &a,
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY |
IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) {
return false;
}
}
GGML_LOG_INFO("RDMA activated: qpn=%u→%u mtu=%d rx_depth=%d\n",
local->qpn, remote_qpn, 128 << local->path_mtu, RDMA_RX_DEPTH);
return true;
}
#endif // GGML_RPC_RDMA
// ── Unified transport dispatch (via function pointers) ──────────────────────
static bool send_data(socket_t * sock, const void * data, size_t size) {
return sock->fn_send(sock, data, size);
}
static bool recv_data(socket_t * sock, void * data, size_t size) {
return sock->fn_recv(sock, data, size);
}
static bool send_msg(socket_t * sock, const void * msg, size_t msg_size) {
if (!send_data(sock, &msg_size, sizeof(msg_size))) {
return false;
}
return send_data(sock, msg, msg_size);
}
static bool recv_msg(socket_t * sock, void * msg, size_t msg_size) {
uint64_t size;
if (!recv_data(sock, &size, sizeof(size))) {
return false;
}
if (size != msg_size) {
return false;
}
return recv_data(sock, msg, msg_size);
}
static bool recv_msg(socket_t * sock, std::vector<uint8_t> & input) {
uint64_t size;
if (!recv_data(sock, &size, sizeof(size))) {
return false;
}
try {
input.resize(size);
} catch (const std::bad_alloc & e) {
GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size);
return false;
}
return recv_data(sock, input.data(), size);
}
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
size_t pos = endpoint.find(':');
if (pos == std::string::npos) {
return false;
}
host = endpoint.substr(0, pos);
try {
port = std::stoi(endpoint.substr(pos + 1));
} catch (...) {
return false;
}
return true;
}
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
// No response
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {
uint8_t cmd_byte = cmd;
if (!send_data(sock.get(), &cmd_byte, sizeof(cmd_byte))) {
return false;
}
if (!send_data(sock.get(), &input_size, sizeof(input_size))) {
return false;
}
if (!send_data(sock.get(), input, input_size)) {
return false;
}
return true;
}
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
if (!send_rpc_cmd(sock, cmd, input, input_size)) {
return false;
}
uint64_t out_size;
if (!recv_data(sock.get(), &out_size, sizeof(out_size))) {
return false;
}
if (out_size != output_size) {
return false;
}
if (!recv_data(sock.get(), output, output_size)) {
return false;
}
return true;
}
// RPC client-side implementation
// Performs HELLO handshake with optional RDMA auto-negotiation.
// If both sides have RDMA, the socket is upgraded transparently.
static bool negotiate_hello(const std::shared_ptr<socket_t> & sock) {
#ifdef GGML_RPC_RDMA
rdma_local_info local_info = {};
rdma_conn * probe = rdma_probe(sock->fd, &local_info);
if (probe) {
rpc_msg_hello_rdma_req req = {};
req.rdma_qpn = local_info.qpn;
req.rdma_psn = local_info.psn;
memcpy(req.rdma_gid, local_info.gid, 16);
// Send extended HELLO: cmd + input_size + input_data
if (!send_rpc_cmd(sock, RPC_CMD_HELLO, &req, sizeof(req))) {
delete probe;
return false;
}
// Read response size -- server may respond with legacy or extended size
uint64_t out_size = 0;
if (!recv_data(sock.get(), &out_size, sizeof(out_size))) {
delete probe;
return false;
}
if (out_size == sizeof(rpc_msg_hello_rdma_rsp)) {
rpc_msg_hello_rdma_rsp rsp = {};
if (!recv_data(sock.get(), &rsp, sizeof(rsp))) { delete probe; return false; }
if (rsp.major != RPC_PROTO_MAJOR_VERSION || rsp.minor > RPC_PROTO_MINOR_VERSION) {
GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", rsp.major, rsp.minor, rsp.patch);
delete probe;
return false;
}
if (rsp.rdma_qpn != 0) {
if (rdma_activate(probe, &local_info, rsp.rdma_qpn, rsp.rdma_psn, rsp.rdma_gid)) {
sock->rdma = probe;
sock->fn_send = rdma_send_impl;
sock->fn_recv = rdma_recv_impl;
return true;
}
GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n");
}
delete probe;
return true;
} else if (out_size == sizeof(rpc_msg_hello_rsp)) {
// Legacy server responded with standard HELLO (ignored our RDMA req)
rpc_msg_hello_rsp rsp = {};
if (!recv_data(sock.get(), &rsp, sizeof(rsp))) { delete probe; return false; }
delete probe;
if (rsp.major != RPC_PROTO_MAJOR_VERSION || rsp.minor > RPC_PROTO_MINOR_VERSION) {
GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", rsp.major, rsp.minor, rsp.patch);
return false;
}
return true;
} else {
delete probe;
return false;
}
}
#endif
rpc_msg_hello_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
RPC_STATUS_ASSERT(status);
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
return false;
}
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
GGML_LOG_INFO("WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
}
return true;
}
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
static bool initialized = false;
auto it = sockets.find(endpoint);
if (it != sockets.end()) {
if (auto sock = it->second.lock()) {
return sock;
}
}
std::string host;
int port;
if (!parse_endpoint(endpoint, host, port)) {
GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str());
return nullptr;
}
#ifdef _WIN32
if (!initialized) {
WSADATA wsaData;
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
if (res != 0) {
return nullptr;
}
initialized = true;
}
#else
GGML_UNUSED(initialized);
#endif
auto sock = socket_connect(host.c_str(), port);
if (sock == nullptr) {
return nullptr;
}
if (!negotiate_hello(sock)) {
return nullptr;
}
LOG_DBG("[%s] connected to %s\n", __func__, endpoint.c_str());
sockets[endpoint] = sock;
return sock;
}
static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
rpc_msg_free_buffer_req request = {ctx->remote_ptr};
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
RPC_STATUS_ASSERT(status);
delete ctx;
}
static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
if (ctx->base_ptr != nullptr) {
return ctx->base_ptr;
}
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
rpc_msg_buffer_get_base_rsp response;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
return ctx->base_ptr;
}
static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
}
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
rpc_tensor result;
if (!tensor) {
memset(&result, 0, sizeof(result));
return result;
}
result.id = reinterpret_cast<uint64_t>(tensor);
result.type = tensor->type;
if (tensor->buffer && ggml_backend_buffer_is_rpc(tensor->buffer)) {
ggml_backend_buffer_t buffer = tensor->buffer;
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
result.buffer = ctx != nullptr ? ctx->remote_ptr : 0;
} else {
result.buffer = 0;
}
for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
result.ne[i] = tensor->ne[i];
result.nb[i] = tensor->nb[i];
}
result.op = tensor->op;
for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
result.op_params[i] = tensor->op_params[i];
}
result.flags = tensor->flags;
for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {
result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);
}
result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
result.view_offs = tensor->view_offs;
result.data = reinterpret_cast<uint64_t>(tensor->data);
// Avoid sending uninitialized data over the wire
memset(result.name, 0, sizeof(result.name));
memset(result.padding, 0, sizeof(result.padding));
snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
return result;
}
static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
// CUDA backend on the server pads everything to 512 due to CUDA limitations.
// Due to bandwidth constraints, we only call the server init tensor functions if necessary.
// In particular, only quantized tensors need padding
if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
rpc_msg_init_tensor_req request;
request.tensor = serialize_tensor(tensor);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
RPC_STATUS_ASSERT(status);
}
return GGML_STATUS_SUCCESS;
}
static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
rpc_tensor rpc_tensor = serialize_tensor(tensor);
if (size > HASH_THRESHOLD) {
rpc_msg_set_tensor_hash_req request;
request.tensor = rpc_tensor;
request.offset = offset;
request.hash = fnv_hash((const uint8_t*)data, size);
rpc_msg_set_tensor_hash_rsp response;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
if (response.result) {
// the server has the same data, no need to send it
return;
}
}
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
std::vector<uint8_t> input(input_size, 0);
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());
RPC_STATUS_ASSERT(status);
}
static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
rpc_msg_get_tensor_req request;
request.tensor = serialize_tensor(tensor);
request.offset = offset;
request.size = size;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
RPC_STATUS_ASSERT(status);
}
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
if (ggml_backend_buffer_is_rpc(src->buffer)) {
// check if src and dst are on the same server
ggml_backend_buffer_t src_buffer = src->buffer;
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
ggml_backend_buffer_t dst_buffer = dst->buffer;
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
if (src_ctx->sock != dst_ctx->sock) {
return false;
}
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
rpc_msg_copy_tensor_req request;
request.src = serialize_tensor(src);
request.dst = serialize_tensor(dst);
rpc_msg_copy_tensor_rsp response;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return response.result;
}
return false;
}
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
RPC_STATUS_ASSERT(status);
}
static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
/* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
/* .get_base = */ ggml_backend_rpc_buffer_get_base,
/* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
/* .memset_tensor = */ NULL,
/* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
/* .clear = */ ggml_backend_rpc_buffer_clear,
/* .reset = */ NULL,
};
static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
return buft_ctx->name.c_str();
}
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
rpc_msg_alloc_buffer_req request = {buft_ctx->device, size};
rpc_msg_alloc_buffer_rsp response;
auto sock = get_socket(buft_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
if (response.remote_ptr != 0) {
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
ggml_backend_rpc_buffer_interface,
new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
response.remote_size);
return buffer;
} else {
return nullptr;
}
}
static size_t get_alignment(const std::shared_ptr<socket_t> & sock, uint32_t device) {
rpc_msg_get_alignment_req request = {device};
rpc_msg_get_alignment_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return response.alignment;
}
static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
return buft_ctx->alignment;
}
static size_t get_max_size(const std::shared_ptr<socket_t> & sock, uint32_t device) {
rpc_msg_get_max_size_req request = {device};
rpc_msg_get_max_size_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return response.max_size;
}
static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
return buft_ctx->max_size;
}
static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
// should we query the remote server for the actual size
bool rpc_get = false;
// See comments in init_tensor.
rpc_get |= ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr);
// ops that require additional memory for fleeting data on certain backends
// ref: https://github.com/ggml-org/llama.cpp/pull/15966
rpc_get |= tensor->op == GGML_OP_FLASH_ATTN_EXT;
rpc_get |= tensor->op == GGML_OP_MUL_MAT_ID;
if (rpc_get) {
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
auto sock = get_socket(buft_ctx->endpoint);
rpc_msg_get_alloc_size_req request = {
/*.device =*/ buft_ctx->device,
/*.tensor =*/ serialize_tensor(tensor),
/*.srcs =*/ {},
};
// .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well
for (int i = 0; i < GGML_MAX_SRC; i++) {
request.srcs[i] = serialize_tensor(tensor->src[i]);
}
// TODO: cache the alloc responses to avoid extra RPC calls?
rpc_msg_get_alloc_size_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return response.alloc_size;
}
return ggml_nbytes(tensor);
}
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
/* .get_name = */ ggml_backend_rpc_buffer_type_name,
/* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
/* .get_max_size = */ ggml_backend_rpc_get_max_size,
/* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
/* .is_host = */ NULL,
};
static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
return rpc_ctx->name.c_str();
}
static void ggml_backend_rpc_free(ggml_backend_t backend) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
delete rpc_ctx;
delete backend;
}
static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
GGML_UNUSED(backend);
// this is no-op because we don't have any async operations
}
static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) {
if (tensor == nullptr) {
return;
}
if (visited.find(tensor) != visited.end()) {
return;
}
visited.insert(tensor);
for (int i = 0; i < GGML_MAX_SRC; i++) {
add_tensor(tensor->src[i], tensors, visited);
}
add_tensor(tensor->view_src, tensors, visited);
tensors.push_back(serialize_tensor(tensor));
}
static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
uint32_t n_nodes = cgraph->n_nodes;
std::vector<rpc_tensor> tensors;
std::unordered_set<ggml_tensor*> visited;
for (uint32_t i = 0; i < n_nodes; i++) {
add_tensor(cgraph->nodes[i], tensors, visited);
}
// serialization format:
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
uint32_t n_tensors = tensors.size();
int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
output.resize(output_size, 0);
uint8_t * dest = output.data();
memcpy(dest, &device, sizeof(device));
dest += sizeof(device);
memcpy(dest, &n_nodes, sizeof(n_nodes));
dest += sizeof(n_nodes);
for (uint32_t i = 0; i < n_nodes; i++) {
memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
}
dest += n_nodes * sizeof(uint64_t);
memcpy(dest, &n_tensors, sizeof(n_tensors));
dest += sizeof(n_tensors);
rpc_tensor * out_tensors = (rpc_tensor *)dest;
memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
}
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
GGML_ASSERT(cgraph->n_nodes > 0);
bool reuse = rpc_ctx->gc.is_cached(cgraph);
if (reuse) {
rpc_msg_graph_recompute_req request;
request.device = rpc_ctx->device;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
RPC_STATUS_ASSERT(status);
} else {
rpc_ctx->gc.add(cgraph);
std::vector<uint8_t> input;
serialize_graph(rpc_ctx->device, cgraph, input);
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
RPC_STATUS_ASSERT(status);
}
return GGML_STATUS_SUCCESS;
}
static ggml_backend_i ggml_backend_rpc_interface = {
/* .get_name = */ ggml_backend_rpc_name,
/* .free = */ ggml_backend_rpc_free,
/* .set_tensor_async = */ NULL,
/* .get_tensor_async = */ NULL,
/* .cpy_tensor_async = */ NULL,
/* .synchronize = */ ggml_backend_rpc_synchronize,
/* .graph_plan_create = */ NULL,
/* .graph_plan_free = */ NULL,
/* .graph_plan_update = */ NULL,
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .graph_optimize = */ NULL,
};
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
// NOTE: buffer types are allocated and never freed; this is by design
static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
auto it = buft_map.find(buft_name);
if (it != buft_map.end()) {
return it->second;
}
auto sock = get_socket(endpoint);
if (sock == nullptr) {
GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
return nullptr;
}
size_t alignment = get_alignment(sock, device);
size_t max_size = get_max_size(sock, device);
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
/* .endpoint = */ endpoint,
/* .device = */ device,
/* .name = */ buft_name,
/* .alignment = */ alignment,
/* .max_size = */ max_size
};
auto reg = ggml_backend_rpc_add_server(endpoint);
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
/* .device = */ ggml_backend_reg_dev_get(reg, device),
/* .context = */ buft_ctx
};
buft_map[buft_name] = buft;
return buft;
}
ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
/* .endpoint = */ endpoint,
/* .device = */ device,
/* .name = */ dev_name,
/* .gc = */ {},
};
auto reg = ggml_backend_rpc_add_server(endpoint);
ggml_backend_t backend = new ggml_backend {
/* .guid = */ ggml_backend_rpc_guid(),
/* .iface = */ ggml_backend_rpc_interface,
/* .device = */ ggml_backend_reg_dev_get(reg, device),
/* .context = */ ctx
};
return backend;
}
bool ggml_backend_is_rpc(ggml_backend_t backend) {
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
}
static void get_device_memory(const std::shared_ptr<socket_t> & sock, uint32_t device, size_t * free, size_t * total) {
rpc_msg_get_device_memory_req request;
request.device = device;
rpc_msg_get_device_memory_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
*free = response.free_mem;
*total = response.total_mem;
}
void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
auto sock = get_socket(endpoint);
if (sock == nullptr) {
*free = 0;
*total = 0;
return;
}
get_device_memory(sock, device, free, total);
}
// RPC server-side implementation
class rpc_server {
public:
rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)
: backends(std::move(all_backends)), cache_dir(cache_dir) {
stored_graphs.resize(backends.size());
}
~rpc_server();
void hello(rpc_msg_hello_rsp & response);
bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response);
bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response);
bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
bool free_buffer(const rpc_msg_free_buffer_req & request);
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
bool set_tensor(const std::vector<uint8_t> & input);
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
bool graph_compute(const std::vector<uint8_t> & input);
bool graph_recompute(const rpc_msg_graph_recompute_req & request);
bool init_tensor(const rpc_msg_init_tensor_req & request);
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
struct stored_graph {
ggml_context_ptr ctx_ptr;
ggml_cgraph * graph;
};
private:
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
ggml_tensor * create_node(uint64_t id,
struct ggml_context * ctx,
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
std::vector<ggml_backend_t> backends;
const char * cache_dir;
std::unordered_set<ggml_backend_buffer_t> buffers;
// store the last computed graph for each backend
std::vector<stored_graph> stored_graphs;
};
void rpc_server::hello(rpc_msg_hello_rsp & response) {
response.major = RPC_PROTO_MAJOR_VERSION;
response.minor = RPC_PROTO_MINOR_VERSION;
response.patch = RPC_PROTO_PATCH_VERSION;
LOG_DBG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
}
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
uint32_t dev_id = request.device;
if (dev_id >= backends.size()) {
return false;
}
ggml_backend_buffer_type_t buft;
struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead()*(1 + GGML_MAX_SRC),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx_ptr { ggml_init(params) };
GGML_ASSERT(ctx_ptr != nullptr);
ggml_context * ctx = ctx_ptr.get();
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
if (tensor == nullptr) {
GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
return false;
}
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (request.srcs[i].id != 0) {
tensor->src[i] = deserialize_tensor(ctx, &request.srcs[i]);
}
}
LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
if (tensor->buffer == nullptr) {
//No buffer allocated.
buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
} else {
buft = tensor->buffer->buft;
}
response.alloc_size = ggml_backend_buft_get_alloc_size(buft, tensor);
return true;
}
bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
uint32_t dev_id = request.device;
if (dev_id >= backends.size()) {
return false;
}
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
response.remote_ptr = 0;
response.remote_size = 0;
if (buffer != nullptr) {
response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
response.remote_size = buffer->size;
LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n",
__func__, dev_id, request.size, response.remote_ptr, response.remote_size);
buffers.insert(buffer);
} else {
LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size);
}
return true;
}
bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) {
uint32_t dev_id = request.device;
if (dev_id >= backends.size()) {
return false;
}
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
size_t alignment = ggml_backend_buft_get_alignment(buft);
LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment);
response.alignment = alignment;
return true;
}
bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) {
uint32_t dev_id = request.device;
if (dev_id >= backends.size()) {
return false;
}
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
size_t max_size = ggml_backend_buft_get_max_size(buft);
LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size);
response.max_size = max_size;
return true;
}
bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
if (buffers.find(buffer) == buffers.end()) {
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
return false;
}
void * base = ggml_backend_buffer_get_base(buffer);
response.base_ptr = reinterpret_cast<uint64_t>(base);
return true;
}
bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
if (buffers.find(buffer) == buffers.end()) {
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
return false;
}
ggml_backend_buffer_free(buffer);
buffers.erase(buffer);
return true;
}
bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
LOG_DBG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
if (buffers.find(buffer) == buffers.end()) {
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
return false;
}
ggml_backend_buffer_clear(buffer, request.value);
return true;
}
ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
// Validate tensor type before using it
if (tensor->type >= GGML_TYPE_COUNT) {
GGML_LOG_ERROR("[%s] invalid tensor type received: %u\n", __func__, tensor->type);
return nullptr;
}
ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
// ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type
if (result == nullptr) {
GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type);
return nullptr;
}
for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
result->nb[i] = tensor->nb[i];
}
result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
result->buffer = nullptr;
}
if (result->buffer) {
// require that the tensor data does not go beyond the buffer end
uint64_t tensor_size = (uint64_t) ggml_nbytes(result);
uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow
GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size);
}
result->op = (ggml_op) tensor->op;
for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
result->op_params[i] = tensor->op_params[i];
}
result->flags = tensor->flags;
result->data = reinterpret_cast<void *>(tensor->data);
ggml_set_name(result, tensor->name);
return result;
}
bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
// serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) {
return false;
}
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
uint64_t offset;
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx_ptr { ggml_init(params) };
GGML_ASSERT(ctx_ptr != nullptr);
ggml_context * ctx = ctx_ptr.get();
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
if (tensor == nullptr || tensor->buffer == nullptr) {
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
return false;
}
LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
// sanitize tensor->data
{
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\n",
__func__, in_tensor->data, offset, size, p0, p1);
return false;
}
}
const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
if (cache_dir && size > HASH_THRESHOLD) {
uint64_t hash = fnv_hash((const uint8_t*)data, size);
char hash_str[17];
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
// save to cache_dir/hash_str
fs::path cache_file = fs::path(cache_dir) / hash_str;
std::ofstream ofs(cache_file, std::ios::binary);
ofs.write((const char *)data, size);
GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.c_str());
}
ggml_backend_tensor_set(tensor, data, offset, size);
return true;
}
bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
if (!cache_dir) {
return false;
}
char hash_str[17];
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
fs::path cache_file = fs::path(cache_dir) / hash_str;
std::error_code ec;
if (!fs::exists(cache_file, ec)) {
return false;
}
std::ifstream ifs(cache_file, std::ios::binary);
ifs.seekg(0, std::ios::end);
size_t size = ifs.tellg();
ifs.seekg(0, std::ios::beg);
data.resize(size);
ifs.read((char *)data.data(), size);
return true;
}
bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response)
{
std::vector<uint8_t> cached_file;
if (!get_cached_file(request.hash, cached_file)) {
response.result = 0;
return true;
}
size_t size = cached_file.size();
struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx_ptr { ggml_init(params) };
GGML_ASSERT(ctx_ptr != nullptr);
ggml_context * ctx = ctx_ptr.get();
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
if (tensor == nullptr || tensor->buffer == nullptr) {
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
return false;
}
LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
__func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);
// sanitize tensor->data
{
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
if (request.tensor.data + request.offset < p0
|| request.tensor.data + request.offset >= p1
|| size > (p1 - request.tensor.data - request.offset)) {
GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
__func__, request.tensor.data, request.offset, size, request.hash, p0, p1);
return false;
}
}
ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size);
response.result = 1;
return true;
}
bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx_ptr { ggml_init(params) };
GGML_ASSERT(ctx_ptr != nullptr);
ggml_context * ctx = ctx_ptr.get();
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
if (tensor == nullptr) {
GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n");
return false;
}
LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data);
// Call the backend's buffer_init_tensor function
ggml_backend_buffer_t buffer = tensor->buffer;
if (buffer && buffer->iface.init_tensor) {
buffer->iface.init_tensor(buffer, tensor);
} else {
GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n");
}
if (tensor->extra != nullptr) {
// This pointer can either be passed around client/server, or probably better stored server-side and kept track of.
// Currently unimplemented.
GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n");
return false;
}
return true;
}
bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx_ptr { ggml_init(params) };
GGML_ASSERT(ctx_ptr != nullptr);
ggml_context * ctx = ctx_ptr.get();
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
if (tensor == nullptr || tensor->buffer == nullptr) {
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
return false;
}
LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
// sanitize tensor->data
{
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
if (request.tensor.data + request.offset < p0 ||
request.tensor.data + request.offset >= p1 ||
request.size > (p1 - request.tensor.data - request.offset)) {
GGML_LOG_ERROR("[%s] requested tensor region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%" PRIu64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
__func__, request.tensor.data, request.offset, request.size, p0, p1);
return false;
}
}
response.resize(request.size, 0);
ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
return true;
}
bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
struct ggml_init_params params {
/*.mem_size =*/ 2*ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx_ptr { ggml_init(params) };
GGML_ASSERT(ctx_ptr != nullptr);
ggml_context * ctx = ctx_ptr.get();
ggml_tensor * src = deserialize_tensor(ctx, &request.src);
ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
if (src == nullptr || dst == nullptr || src->buffer == nullptr || dst->buffer == nullptr) {
GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
return false;
}
uint64_t src_size = (uint64_t) ggml_nbytes(src);
uint64_t dst_data = (uint64_t) dst->data;
uint64_t dst_base = (uint64_t) ggml_backend_buffer_get_base(dst->buffer);
uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer);
if (dst_data + src_size > dst_base + dst_buf_sz) {
GGML_LOG_ERROR("[%s] out-of-bounds write in rpc_server::copy_tensor:\n"
" write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n"
" buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n",
__func__,
dst_data,
dst_data + src_size,
dst_base,
dst_base + dst_buf_sz);
return false;
}
LOG_DBG("[%s] src->buffer: %p, dst->buffer: %p\n",
__func__, (void*) src->buffer, (void*) dst->buffer);
response.result = ggml_backend_buffer_copy_tensor(src, dst);
return true;
}
ggml_tensor * rpc_server::create_node(uint64_t id,
struct ggml_context * ctx,
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
if (tensor_map.find(id) != tensor_map.end()) {
return tensor_map[id];
}
// Safely find the tensor pointer
auto it_ptr = tensor_ptrs.find(id);
if (it_ptr == tensor_ptrs.end()) {
return nullptr;
}
const rpc_tensor * tensor = it_ptr->second;
struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
if (result == nullptr) {
return nullptr;
}
tensor_map[id] = result;
for (int i = 0; i < GGML_MAX_SRC; i++) {
// Check if the source ID is 0 before calling create_node recursively
if (tensor->src[i] == 0) {
result->src[i] = nullptr;
} else {
result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
// If the recursive call failed for a non-zero ID, propagate the error
if (result->src[i] == nullptr) {
GGML_LOG_ERROR("[%s] failed to create source node %d (src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
__func__, i, tensor->src[i], id);
// Must return nullptr to signal failure up the call stack
return nullptr;
}
}
}
// Handle view_src similarly
if (tensor->view_src == 0) {
result->view_src = nullptr;
} else {
result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
// If the recursive call failed for a non-zero ID, propagate the error
if (result->view_src == nullptr) {
GGML_LOG_ERROR("[%s] failed to create view_src node (view_src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
__func__, tensor->view_src, id);
// Must return nullptr to signal failure up the call stack
return nullptr;
}
}
result->view_offs = tensor->view_offs;
return result;
}
bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
// serialization format:
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
if (input.size() < 2*sizeof(uint32_t)) {
return false;
}
const uint8_t * src = input.data();
uint32_t device;
memcpy(&device, src, sizeof(device));
src += sizeof(device);
if (device >= backends.size()) {
return false;
}
uint32_t n_nodes;
memcpy(&n_nodes, src, sizeof(n_nodes));
src += sizeof(n_nodes);
if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
return false;
}
const uint64_t * nodes = (const uint64_t *)src;
src += n_nodes*sizeof(uint64_t);
uint32_t n_tensors;
memcpy(&n_tensors, src, sizeof(n_tensors));
src += sizeof(n_tensors);
if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
return false;
}
const rpc_tensor * tensors = (const rpc_tensor *)src;
LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
struct ggml_init_params params = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx_ptr { ggml_init(params) };
GGML_ASSERT(ctx_ptr != nullptr);
ggml_context * ctx = ctx_ptr.get();
struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
graph->n_nodes = n_nodes;
std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
tensor_ptrs.reserve(n_tensors);
for (uint32_t i = 0; i < n_tensors; i++) {
tensor_ptrs.emplace(tensors[i].id, &tensors[i]);
}
std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
tensor_map.reserve(n_nodes);
for (uint32_t i = 0; i < n_nodes; i++) {
int64_t id;
memcpy(&id, &nodes[i], sizeof(id));
graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
// Check if create_node failed for a *non-zero* ID.
// If id was 0, create_node returning nullptr is expected.
// If id was non-zero and create_node returned nullptr, it indicates a deserialization error.
if (graph->nodes[i] == nullptr && id != 0) {
GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id);
return false;
}
}
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
stored_graphs[device].graph = graph;
return true;
}
bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
uint32_t device = request.device;
if (device >= backends.size()) {
return false;
}
if (stored_graphs[device].graph == nullptr) {
return false;
}
ggml_cgraph * graph = stored_graphs[device].graph;
LOG_DBG("[%s] device: %u\n", __func__, device);
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
return true;
}
bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
uint32_t dev_id = request.device;
if (dev_id >= backends.size()) {
return false;
}
size_t free, total;
ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
ggml_backend_dev_memory(dev, &free, &total);
response.free_mem = free;
response.total_mem = total;
LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
return true;
}
rpc_server::~rpc_server() {
for (auto buffer : buffers) {
ggml_backend_buffer_free(buffer);
}
}
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
socket_t * sockfd) {
rpc_server server(backends, cache_dir);
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
return;
}
if (cmd != RPC_CMD_HELLO) {
GGML_LOG_ERROR("Expected HELLO command, update client\n");
return;
}
// Read input_size to determine legacy vs extended HELLO
uint64_t hello_input_size;
if (!recv_data(sockfd, &hello_input_size, sizeof(hello_input_size))) {
return;
}
if (hello_input_size == sizeof(rpc_msg_hello_rdma_req)) {
// Extended HELLO with RDMA fields
rpc_msg_hello_rdma_req req = {};
if (!recv_data(sockfd, &req, sizeof(req))) {
return;
}
#ifdef GGML_RPC_RDMA
rdma_local_info local_info = {};
rdma_conn * probe = rdma_probe(sockfd->fd, &local_info);
rpc_msg_hello_rdma_rsp rsp = {};
rpc_msg_hello_rsp base_rsp;
server.hello(base_rsp);
rsp.major = base_rsp.major;
rsp.minor = base_rsp.minor;
rsp.patch = base_rsp.patch;
if (probe && req.rdma_qpn != 0) {
rsp.rdma_qpn = local_info.qpn;
rsp.rdma_psn = local_info.psn;
memcpy(rsp.rdma_gid, local_info.gid, 16);
}
if (!send_msg(sockfd, &rsp, sizeof(rsp))) {
delete probe;
return;
}
if (probe && req.rdma_qpn != 0 && rsp.rdma_qpn != 0) {
if (rdma_activate(probe, &local_info, req.rdma_qpn, req.rdma_psn, req.rdma_gid)) {
sockfd->rdma = probe;
sockfd->fn_send = rdma_send_impl;
sockfd->fn_recv = rdma_recv_impl;
} else {
GGML_LOG_ERROR("RDMA activate failed on server, staying on TCP\n");
delete probe;
}
} else {
delete probe;
}
#else
// Not compiled with RDMA -- respond with zeros for RDMA fields
rpc_msg_hello_rdma_rsp rsp = {};
rpc_msg_hello_rsp base_rsp;
server.hello(base_rsp);
rsp.major = base_rsp.major;
rsp.minor = base_rsp.minor;
rsp.patch = base_rsp.patch;
if (!send_msg(sockfd, &rsp, sizeof(rsp))) {
return;
}
#endif
} else if (hello_input_size == 0) {
// Legacy HELLO (no RDMA)
rpc_msg_hello_rsp response;
server.hello(response);
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
} else {
// Unknown HELLO size -- consume and respond with legacy
std::vector<uint8_t> discard(hello_input_size);
if (!recv_data(sockfd, discard.data(), hello_input_size)) {
return;
}
rpc_msg_hello_rsp response;
server.hello(response);
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
}
while (true) {
if (!recv_data(sockfd, &cmd, 1)) {
break;
}
if (cmd >= RPC_CMD_COUNT) {
// fail fast if the command is invalid
GGML_LOG_ERROR("Unknown command: %d\n", cmd);
break;
}
switch (cmd) {
case RPC_CMD_HELLO: {
// HELLO command is handled above
return;
}
case RPC_CMD_DEVICE_COUNT: {
if (!recv_msg(sockfd, nullptr, 0)) {
return;
}
rpc_msg_device_count_rsp response;
response.device_count = backends.size();
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
}
case RPC_CMD_ALLOC_BUFFER: {
rpc_msg_alloc_buffer_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_alloc_buffer_rsp response;
if (!server.alloc_buffer(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
}
case RPC_CMD_GET_ALLOC_SIZE: {
rpc_msg_get_alloc_size_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_get_alloc_size_rsp response;
if (!server.get_alloc_size(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
}
case RPC_CMD_GET_ALIGNMENT: {
rpc_msg_get_alignment_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_get_alignment_rsp response;
if (!server.get_alignment(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
}
case RPC_CMD_GET_MAX_SIZE: {
rpc_msg_get_max_size_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_get_max_size_rsp response;
if (!server.get_max_size(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
}
case RPC_CMD_BUFFER_GET_BASE: {
rpc_msg_buffer_get_base_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_buffer_get_base_rsp response;
if (!server.buffer_get_base(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
}
case RPC_CMD_FREE_BUFFER: {
rpc_msg_free_buffer_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
if (!server.free_buffer(request)) {
return;
}
if (!send_msg(sockfd, nullptr, 0)) {
return;
}
break;
}
case RPC_CMD_BUFFER_CLEAR: {
rpc_msg_buffer_clear_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
if (!server.buffer_clear(request)) {
return;
}
if (!send_msg(sockfd, nullptr, 0)) {
return;
}
break;
}
case RPC_CMD_SET_TENSOR: {
std::vector<uint8_t> input;
if (!recv_msg(sockfd, input)) {
return;
}
if (!server.set_tensor(input)) {
return;
}
break;
}
case RPC_CMD_SET_TENSOR_HASH: {
rpc_msg_set_tensor_hash_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_set_tensor_hash_rsp response;
if (!server.set_tensor_hash(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
}
case RPC_CMD_INIT_TENSOR: {
rpc_msg_init_tensor_req request;
if (!recv_msg(sockfd, &request,sizeof(request))) {
return;
}
if (!server.init_tensor(request)) {
return;
}
if (!send_msg(sockfd, nullptr, 0)) {
return;
}
break;
}
case RPC_CMD_GET_TENSOR: {
rpc_msg_get_tensor_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
std::vector<uint8_t> response;
if (!server.get_tensor(request, response)) {
return;
}
if (!send_msg(sockfd, response.data(), response.size())) {
return;
}
break;
}
case RPC_CMD_COPY_TENSOR: {
rpc_msg_copy_tensor_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_copy_tensor_rsp response;
if (!server.copy_tensor(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
}
case RPC_CMD_GRAPH_COMPUTE: {
std::vector<uint8_t> input;
if (!recv_msg(sockfd, input)) {
return;
}
if (!server.graph_compute(input)) {
return;
}
break;
}
case RPC_CMD_GRAPH_RECOMPUTE: {
rpc_msg_graph_recompute_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
if (!server.graph_recompute(request)) {
return;
}
break;
}
case RPC_CMD_GET_DEVICE_MEMORY: {
rpc_msg_get_device_memory_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_get_device_memory_rsp response;
if (!server.get_device_memory(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
}
default: {
GGML_LOG_ERROR("Unknown command: %d\n", cmd);
return;
}
}
}
}
void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
if (n_devices == 0 || devices == nullptr) {
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
return;
}
std::vector<ggml_backend_t> backends;
printf("Starting RPC server v%d.%d.%d\n",
RPC_PROTO_MAJOR_VERSION,
RPC_PROTO_MINOR_VERSION,
RPC_PROTO_PATCH_VERSION);
printf(" endpoint : %s\n", endpoint);
printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
printf("Devices:\n");
for (size_t i = 0; i < n_devices; i++) {
auto dev = devices[i];
size_t free, total;
ggml_backend_dev_memory(dev, &free, &total);
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
total / 1024 / 1024, free / 1024 / 1024);
auto backend = ggml_backend_dev_init(dev, nullptr);
if (!backend) {
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
return;
}
backends.push_back(backend);
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
if (reg) {
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
if (ggml_backend_set_n_threads_fn) {
ggml_backend_set_n_threads_fn(backend, n_threads);
}
}
}
std::string host;
int port;
if (!parse_endpoint(endpoint, host, port)) {
return;
}
#ifdef GGML_RPC_RDMA
printf(" transport : TCP (RDMA auto-negotiate enabled)\n");
#else
printf(" transport : TCP\n");
#endif
#ifdef _WIN32
{
WSADATA wsaData;
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
if (res != 0) {
fprintf(stderr, "WSAStartup failed: %d\n", res);
return;
}
}
#endif
auto server_socket = create_server_socket(host.c_str(), port);
if (server_socket == nullptr) {
fprintf(stderr, "Failed to create server socket\n");
return;
}
while (true) {
auto client_socket = socket_accept(server_socket->fd);
if (client_socket == nullptr) {
fprintf(stderr, "Failed to accept client connection\n");
return;
}
printf("Accepted client connection\n");
fflush(stdout);
rpc_serve_client(backends, cache_dir, client_socket.get());
printf("Client connection closed\n");
fflush(stdout);
}
#ifdef _WIN32
WSACleanup();
#endif
for (auto backend : backends) {
ggml_backend_free(backend);
}
}
// device interface
struct ggml_backend_rpc_device_context {
std::string endpoint;
uint32_t device;
std::string name;
std::string description;
};
static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
return ctx->name.c_str();
}
static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
return ctx->description.c_str();
}
static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total);
}
static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
// TODO: obtain value from the server
return GGML_BACKEND_DEVICE_TYPE_GPU;
GGML_UNUSED(dev);
}
static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
props->name = ggml_backend_rpc_device_get_name(dev);
props->description = ggml_backend_rpc_device_get_description(dev);
props->type = ggml_backend_rpc_device_get_type(dev);
ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = {
/* .async = */ false,
/* .host_buffer = */ false,
/* .buffer_from_host_ptr = */ false,
/* .events = */ false,
};
}
static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device);
GGML_UNUSED(params);
}
static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device);
GGML_UNUSED(dev);
}
static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
GGML_UNUSED(dev);
GGML_UNUSED(op);
//TODO: call the remote backend and cache the results
return true;
}
static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
return false;
}
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device;
}
static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
/* .get_name = */ ggml_backend_rpc_device_get_name,
/* .get_description = */ ggml_backend_rpc_device_get_description,
/* .get_memory = */ ggml_backend_rpc_device_get_memory,
/* .get_type = */ ggml_backend_rpc_device_get_type,
/* .get_props = */ ggml_backend_rpc_device_get_props,
/* .init_backend = */ ggml_backend_rpc_device_init,
/* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
/* .get_host_buffer_type = */ NULL,
/* .buffer_from_host_ptr = */ NULL,
/* .supports_op = */ ggml_backend_rpc_device_supports_op,
/* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
/* .offload_op = */ NULL,
/* .event_new = */ NULL,
/* .event_free = */ NULL,
/* .event_synchronize = */ NULL,
};
// backend reg interface
struct ggml_backend_rpc_reg_context {
std::string name;
std::vector<ggml_backend_dev_t> devices;
};
static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
return ctx ? ctx->name.c_str() : "RPC";
}
static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
return ctx ? ctx->devices.size() : 0;
}
static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
if (ctx == nullptr) {
GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead");
} else {
GGML_ASSERT(index < ctx->devices.size());
return ctx->devices[index];
}
}
static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) {
return (void *)ggml_backend_rpc_add_server;
}
if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
return (void *)ggml_backend_rpc_start_server;
}
return NULL;
GGML_UNUSED(reg);
}
static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
/* .get_name = */ ggml_backend_rpc_reg_get_name,
/* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
/* .get_device = */ ggml_backend_rpc_reg_get_device,
/* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
};
ggml_backend_reg_t ggml_backend_rpc_reg(void) {
static struct ggml_backend_reg ggml_backend_rpc_reg = {
/* .api_version = */ GGML_BACKEND_API_VERSION,
/* .iface = */ ggml_backend_rpc_reg_i,
/* .context = */ NULL,
};
return &ggml_backend_rpc_reg;
}
static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {
auto sock = get_socket(endpoint);
if (sock == nullptr) {
GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
return 0;
}
rpc_msg_device_count_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return response.device_count;
}
static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = {
/* .get_name = */ ggml_backend_rpc_reg_get_name,
/* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
/* .get_device = */ ggml_backend_rpc_reg_get_device,
/* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
};
ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {
static std::unordered_map<std::string, ggml_backend_reg_t> reg_map;
static std::mutex mutex;
static uint32_t dev_id = 0;
std::lock_guard<std::mutex> lock(mutex);
if (reg_map.find(endpoint) != reg_map.end()) {
return reg_map[endpoint];
}
uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint);
if (dev_count == 0) {
return nullptr;
}
ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context;
ctx->name = "RPC[" + std::string(endpoint) + "]";
for (uint32_t ind = 0; ind < dev_count; ind++) {
std::string dev_name = "RPC" + std::to_string(dev_id);
std::string dev_desc = std::string(endpoint);
ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {
/* .endpoint = */ endpoint,
/* .device = */ ind,
/* .name = */ dev_name,
/* .description = */ dev_desc
};
ggml_backend_dev_t dev = new ggml_backend_device {
/* .iface = */ ggml_backend_rpc_device_i,
/* .reg = */ ggml_backend_rpc_reg(),
/* .context = */ dev_ctx,
};
ctx->devices.push_back(dev);
dev_id++;
}
ggml_backend_reg_t reg = new ggml_backend_reg {
/* .api_version = */ GGML_BACKEND_API_VERSION,
/* .iface = */ ggml_backend_rpc_reg_interface,
/* .context = */ ctx
};
reg_map[endpoint] = reg;
return reg;
}
GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)