diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 1c11495b66..0b60d39f39 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -7,7 +7,7 @@ extern "C" { #endif #define RPC_PROTO_MAJOR_VERSION 3 -#define RPC_PROTO_MINOR_VERSION 6 +#define RPC_PROTO_MINOR_VERSION 7 #define RPC_PROTO_PATCH_VERSION 1 #ifdef __cplusplus diff --git a/ggml/src/ggml-rpc/CMakeLists.txt b/ggml/src/ggml-rpc/CMakeLists.txt index f5acb8ec2c..f03af6888a 100644 --- a/ggml/src/ggml-rpc/CMakeLists.txt +++ b/ggml/src/ggml-rpc/CMakeLists.txt @@ -7,3 +7,11 @@ ggml_add_backend_library(ggml-rpc if (WIN32) target_link_libraries(ggml-rpc PRIVATE ws2_32) endif() + +option(GGML_RPC_RDMA "ggml: enable RDMA transport for RPC (requires libibverbs)" OFF) +if (GGML_RPC_RDMA) + find_library(IBVERBS_LIB ibverbs REQUIRED) + target_compile_definitions(ggml-rpc PRIVATE GGML_RPC_RDMA) + target_link_libraries(ggml-rpc PRIVATE ${IBVERBS_LIB}) + message(STATUS " RDMA transport enabled") +endif() diff --git a/ggml/src/ggml-rpc/feature-readme.md b/ggml/src/ggml-rpc/feature-readme.md new file mode 100644 index 0000000000..cb92ffbd1f --- /dev/null +++ b/ggml/src/ggml-rpc/feature-readme.md @@ -0,0 +1,125 @@ +# Native RDMA Transport for llama.cpp RPC + +## Overview + +This patch adds native RDMA (Remote Direct Memory Access) transport to llama.cpp's RPC backend, enabling two-node GPU inference clusters to communicate over RoCEv2 (RDMA over Converged Ethernet v2) instead of TCP. RDMA bypasses the kernel network stack entirely, delivering lower latency and higher throughput for the frequent small messages that dominate token generation. + +The transport auto-negotiates during the existing RPC HELLO handshake. No special URI scheme or command-line flags are needed -- if both client and server have RDMA-capable NICs, the connection upgrades transparently. If either side lacks RDMA, it falls back to TCP silently. + +## Performance + +Tested on a two-node cluster with AMD Radeon 8060S (gfx1151) iGPUs connected via Mellanox ConnectX 25GbE NICs running RoCEv2. Model: Qwen3-Coder-Next 80B Q8_K_XL. + +| Metric | TCP | RDMA | Improvement | +| ------ | --: | ---: | ----------: | +| Prompt processing (pp2048) | 651.48 t/s | 678.42 t/s | **+4.1%** | +| Token generation (tg256) | 30.19 t/s | 32.16 t/s | **+6.5%** | + +The token generation improvement is particularly significant because tg is latency-sensitive -- each token requires a full round-trip between the two nodes. RDMA's 1.5μs latency (vs TCP's 31μs) directly reduces per-token overhead. + +## Architecture + +### Connection lifecycle + +``` +Client Server + | | + |---- TCP connect (host:port) ------>| + |<--- TCP accept --------------------| + | | + | [rdma_probe: find RDMA device, | [waiting for HELLO] + | create QP, get local QPN/GID] | + | | + |---- RPC_CMD_HELLO + RDMA req ----->| + | (QPN, PSN, GID or zeros) | + | | [rdma_probe: find RDMA device, + | | create QP, get local QPN/GID] + |<--- HELLO rsp + RDMA rsp ---------| + | (version + QPN, PSN, GID) | + | | + | [rdma_activate: INIT→RTR→RTS] | [rdma_activate: INIT→RTR→RTS] + | [swap fn_send/fn_recv to RDMA] | [swap fn_send/fn_recv to RDMA] + | | + |==== All subsequent data via RDMA ==| + | (TCP socket stays open but | + | idle for lifetime mgmt) | +``` + +### Key design decisions + +**1. HELLO-embedded negotiation** + +RDMA parameters (QPN, PSN, GID) are exchanged inside the standard RPC HELLO handshake rather than using a separate pre-HELLO TCP exchange. This means: +- No new wire protocol messages +- The server can distinguish extended HELLO (`input_size == 24`) from legacy (`input_size == 0`) +- Legacy clients/servers work unchanged -- they simply don't send/receive RDMA fields + +**2. Function-pointer transport dispatch** + +```cpp +struct socket_t { + sockfd_t fd; + bool (*fn_send)(socket_t *, const void *, size_t) = tcp_send_impl; + bool (*fn_recv)(socket_t *, void *, size_t) = tcp_recv_impl; + rdma_conn * rdma = nullptr; // only when compiled with GGML_RPC_RDMA +}; +``` + +Transport is selected once at connection time by swapping function pointers. All call sites use `sock->fn_send(sock, data, size)` -- zero `#ifdef` guards or `if (sock->rdma)` checks on the hot path. + +**3. Two-phase RDMA setup** + +- `rdma_probe()` -- Before HELLO: opens RDMA device, creates QP (stays in RESET state), allocates and registers buffers. Returns local QPN/GID for the HELLO exchange. +- `rdma_activate()` -- After HELLO: given remote QPN/GID, transitions QP through INIT→RTR→RTS and pre-posts the receive ring. + +This split is necessary because both sides need the other's QPN before they can complete the QP state machine. + +**4. Auto-detect with override** + +`rdma_probe()` uses `getsockname()` on the TCP socket to find the local IP, then scans all RDMA devices' GID tables for a matching IPv4-mapped entry. This provides zero-config operation on most setups. The `GGML_RDMA_DEV` and `GGML_RDMA_GID` environment variables override auto-detection when the network topology is complex (e.g., Linux bridges where the IP may not appear in the expected GID slot). + +### RDMA transport internals + +The data transport layer uses several optimizations developed through iterative benchmarking: + +| Optimization | What it does | Why it matters | +| ------------ | ------------ | -------------- | +| Pre-posted receive ring | 24 receive buffers (256 KiB each) are posted before any data flows | Eliminates RNR (Receiver Not Ready) retries. Without pre-posted buffers, the sender fires faster than the receiver can post, causing 640μs RNR retry delays per event. | +| Separate send/recv CQs | Send completions go to `scq`, receive completions go to `rcq` | Simplifies polling -- `rdma_send` only polls `scq`, `rdma_recv` only polls `rcq`. No need to filter completion types. | +| Inline sends | Messages ≤316 bytes are sent inline (no DMA from registered memory) | RPC command headers and small responses bypass the memcpy-to-registered-buffer step. Most token generation messages are <100 bytes. | +| min_rnr_timer=1 | Sets the RNR retry delay to 0.01ms (the minimum) | Even if an RNR does occur, the retry happens in 10μs instead of the default 640μs. This was the single largest tg improvement. | +| 256 KiB chunk size | Data is sent/received in 256 KiB chunks | Fits within the default Linux locked memory limit (8 MiB). The 24-slot × 256 KiB = 6 MiB receive ring stays under `ulimit -l`. | +| Time-based CQ poll timeout | Uses `clock_gettime(CLOCK_MONOTONIC_COARSE)` with 30s timeout | Replaces spin-loop iteration counting which was inaccurate on fast hardware (Mellanox `ibv_poll_cq` returns in ~10ns). | + +### Backwards compatibility + +| Scenario | Behavior | +| -------- | -------- | +| New client → old server | Client sends extended HELLO (24 bytes input). Old server treats extra bytes as unknown input, responds with standard 3-byte HELLO. Client sees no RDMA fields, stays on TCP. | +| Old client → new server | Client sends standard HELLO (0 bytes input). Server detects `input_size == 0`, responds with standard HELLO. No RDMA negotiation attempted. | +| New client → new server, no RDMA hardware | `rdma_probe()` returns nullptr. Client sends standard HELLO. Normal TCP operation. | +| New client → new server, RDMA available | Full auto-negotiation. RDMA transport activated after HELLO. | + +## Files changed + +| File | Changes | +| ---- | ------- | +| `ggml/include/ggml-rpc.h` | Bumped `RPC_PROTO_MINOR_VERSION` from 6 to 7 | +| `ggml/src/ggml-rpc/CMakeLists.txt` | Added `GGML_RPC_RDMA` option with `libibverbs` linking | +| `ggml/src/ggml-rpc/ggml-rpc.cpp` | All transport logic (~550 lines added) | + +## Building + +```bash +cmake -B build \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_HIP=ON \ + -DGGML_RPC_RDMA=ON \ + -DAMDGPU_TARGETS=gfx1151 + +cmake --build build --target rpc-server llama-bench -j$(nproc) +``` + +Requires `libibverbs-dev` (Ubuntu: `apt install libibverbs-dev`). + +Without `-DGGML_RPC_RDMA=ON`, the build produces a standard TCP-only binary with no RDMA code compiled in. diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index d7c8ad8c16..6feb37379d 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -31,6 +31,11 @@ #include #include +#ifdef GGML_RPC_RDMA +# include +# include +#endif + static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); #define LOG_DBG(...) \ @@ -41,6 +46,11 @@ namespace fs = std::filesystem; static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB +#ifdef GGML_RPC_RDMA +static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock) +static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB +#endif + #ifdef _WIN32 typedef SOCKET sockfd_t; using ssize_t = __int64; @@ -49,15 +59,74 @@ typedef int sockfd_t; #endif // cross-platform socket + +#ifdef GGML_RPC_RDMA +struct rdma_conn { + struct ibv_context * ctx = nullptr; + struct ibv_pd * pd = nullptr; + struct ibv_cq * scq = nullptr; // send completions + struct ibv_cq * rcq = nullptr; // recv completions + struct ibv_qp * qp = nullptr; + + void * tx_buf = nullptr; + struct ibv_mr * tx_mr = nullptr; + + void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous + struct ibv_mr * rx_mr = nullptr; + int rx_head = 0; + + uint32_t max_inline = 0; + + void * rx_slot(int i) { return (char *)rx_buf + (size_t)i * RDMA_CHUNK; } + + bool post_rx(int i) { + struct ibv_sge sge = {}; + sge.addr = (uintptr_t)rx_slot(i); + sge.length = RDMA_CHUNK; + sge.lkey = rx_mr->lkey; + struct ibv_recv_wr wr = {}, * bad = nullptr; + wr.wr_id = (uint64_t)i; + wr.sg_list = &sge; + wr.num_sge = 1; + return ibv_post_recv(qp, &wr, &bad) == 0; + } + + ~rdma_conn() { + if (tx_mr) ibv_dereg_mr(tx_mr); + if (rx_mr) ibv_dereg_mr(rx_mr); + free(tx_buf); + free(rx_buf); + if (qp) ibv_destroy_qp(qp); + if (scq) ibv_destroy_cq(scq); + if (rcq) ibv_destroy_cq(rcq); + if (pd) ibv_dealloc_pd(pd); + if (ctx) ibv_close_device(ctx); + } +}; +#endif + +// Forward declarations for transport function pointers +struct socket_t; +static bool tcp_send_impl(socket_t * sock, const void * data, size_t size); +static bool tcp_recv_impl(socket_t * sock, void * data, size_t size); + struct socket_t { sockfd_t fd; + bool (*fn_send)(socket_t *, const void *, size_t) = tcp_send_impl; + bool (*fn_recv)(socket_t *, void *, size_t) = tcp_recv_impl; +#ifdef GGML_RPC_RDMA + rdma_conn * rdma = nullptr; +#endif socket_t(sockfd_t fd) : fd(fd) {} ~socket_t() { +#ifdef GGML_RPC_RDMA + if (rdma) { delete rdma; rdma = nullptr; } +#endif LOG_DBG("[%s] closing socket %d\n", __func__, this->fd); #ifdef _WIN32 - closesocket(this->fd); + if (fd != INVALID_SOCKET) closesocket(this->fd); #else - close(this->fd); + if (fd >= 0) close(this->fd); #endif } }; @@ -121,6 +190,21 @@ struct rpc_msg_hello_rsp { uint8_t patch; }; +struct rpc_msg_hello_rdma_req { + uint32_t rdma_qpn; // 0 = no RDMA capability + uint32_t rdma_psn; + uint8_t rdma_gid[16]; +}; + +struct rpc_msg_hello_rdma_rsp { + uint8_t major; + uint8_t minor; + uint8_t patch; + uint32_t rdma_qpn; // 0 = no RDMA capability + uint32_t rdma_psn; + uint8_t rdma_gid[16]; +}; + struct rpc_msg_device_count_rsp { uint32_t device_count; }; @@ -414,27 +498,316 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) { return true; } -static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) { - if (!send_data(sockfd, &msg_size, sizeof(msg_size))) { - return false; - } - return send_data(sockfd, msg, msg_size); +// ── TCP transport implementations (for function-pointer dispatch) ──────────── + +static bool tcp_send_impl(socket_t * sock, const void * data, size_t size) { + return send_data(sock->fd, data, size); } -static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) { +static bool tcp_recv_impl(socket_t * sock, void * data, size_t size) { + return recv_data(sock->fd, data, size); +} + +// ── RDMA transport (performance-optimized, auto-negotiated) ───────────────── + +#ifdef GGML_RPC_RDMA + +static bool rdma_send_impl(socket_t * sock, const void * data, size_t size); +static bool rdma_recv_impl(socket_t * sock, void * data, size_t size); + +static inline bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc) { + struct timespec t0; + clock_gettime(CLOCK_MONOTONIC_COARSE, &t0); + for (uint64_t s = 0; ; s++) { + int n = ibv_poll_cq(cq, 1, wc); + if (n > 0) return wc->status == IBV_WC_SUCCESS; + if (n < 0) return false; + if ((s & 0xFFFFF) == 0 && s > 0) { + struct timespec now; + clock_gettime(CLOCK_MONOTONIC_COARSE, &now); + if (now.tv_sec - t0.tv_sec >= 30) { + GGML_LOG_ERROR("RDMA CQ poll timeout\n"); + return false; + } + } + } +} + +static bool rdma_send(rdma_conn * c, const void * data, size_t size) { + const uint8_t * src = (const uint8_t *)data; + size_t rem = size; + while (rem > 0) { + size_t chunk = std::min(rem, RDMA_CHUNK); + + struct ibv_sge sge = {}; + struct ibv_send_wr wr = {}, * bad = nullptr; + wr.opcode = IBV_WR_SEND; + wr.sg_list = &sge; + wr.num_sge = 1; + + if (chunk <= c->max_inline) { + sge.addr = (uintptr_t)src; + sge.length = chunk; + wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE; + } else { + memcpy(c->tx_buf, src, chunk); + sge.addr = (uintptr_t)c->tx_buf; + sge.length = chunk; + sge.lkey = c->tx_mr->lkey; + wr.send_flags = IBV_SEND_SIGNALED; + } + + if (ibv_post_send(c->qp, &wr, &bad) != 0) return false; + struct ibv_wc wc; + if (!rdma_poll(c->scq, &wc)) return false; + + src += chunk; + rem -= chunk; + } + return true; +} + +static bool rdma_recv(rdma_conn * c, void * data, size_t size) { + uint8_t * dst = (uint8_t *)data; + size_t rem = size; + while (rem > 0) { + struct ibv_wc wc; + if (!rdma_poll(c->rcq, &wc)) return false; + + int slot = (int)wc.wr_id; + size_t got = wc.byte_len; + memcpy(dst, c->rx_slot(slot), got); + + if (!c->post_rx(slot)) return false; + + dst += got; + rem -= got; + } + return true; +} + +static bool rdma_send_impl(socket_t * sock, const void * data, size_t size) { + return rdma_send(sock->rdma, data, size); +} + +static bool rdma_recv_impl(socket_t * sock, void * data, size_t size) { + return rdma_recv(sock->rdma, data, size); +} + +// Phase 1: Probe for RDMA device, create QP (in RESET state), return local info. +// Returns rdma_conn with QP created but NOT connected. Caller gets local QPN/PSN/GID +// to send to the remote side via the HELLO exchange. +// If RDMA is not available, returns nullptr (caller stays on TCP). +struct rdma_local_info { + uint32_t qpn; + uint32_t psn; + uint8_t gid[16]; + uint8_t ib_port; + int gid_idx; + enum ibv_mtu path_mtu; +}; + +static rdma_conn * rdma_probe(sockfd_t tcp_fd, rdma_local_info * out) { + const char * dev_env = std::getenv("GGML_RDMA_DEV"); + const char * gid_env = std::getenv("GGML_RDMA_GID"); + + struct sockaddr_in local_addr; + socklen_t addr_len = sizeof(local_addr); + if (getsockname(tcp_fd, (struct sockaddr *)&local_addr, &addr_len) != 0) { + return nullptr; + } + + const uint8_t ib_port = 1; + int num_devs = 0; + struct ibv_device ** devs = ibv_get_device_list(&num_devs); + if (!devs || num_devs == 0) return nullptr; + + struct ibv_context * ibctx = nullptr; + const char * matched_dev = nullptr; + int gid_idx = gid_env ? atoi(gid_env) : -1; + + for (int d = 0; d < num_devs; d++) { + const char * dn = ibv_get_device_name(devs[d]); + if (dev_env && strcmp(dev_env, dn) != 0) continue; + + struct ibv_context * ctx = ibv_open_device(devs[d]); + if (!ctx) continue; + + struct ibv_port_attr pa; + if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; } + + int found_gid = gid_idx; + if (found_gid < 0) { + for (int i = 0; i < pa.gid_tbl_len; i++) { + union ibv_gid g; + if (ibv_query_gid(ctx, ib_port, i, &g) != 0) continue; + if (g.raw[10] != 0xff || g.raw[11] != 0xff) continue; + uint32_t ip; + memcpy(&ip, &g.raw[12], 4); + if (dev_env) { + if (ip != 0) { found_gid = i; break; } + } else { + if (ip == local_addr.sin_addr.s_addr) { found_gid = i; break; } + } + } + } + if (found_gid >= 0) { + ibctx = ctx; + gid_idx = found_gid; + matched_dev = dn; + out->path_mtu = pa.active_mtu; + break; + } + ibv_close_device(ctx); + } + ibv_free_device_list(devs); + if (!ibctx) return nullptr; + + out->ib_port = ib_port; + out->gid_idx = gid_idx; + + auto * c = new rdma_conn(); + c->ctx = ibctx; + + c->pd = ibv_alloc_pd(ibctx); + if (!c->pd) { delete c; return nullptr; } + + c->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0); + c->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0); + if (!c->scq || !c->rcq) { delete c; return nullptr; } + + struct ibv_qp_init_attr qia = {}; + qia.send_cq = c->scq; + qia.recv_cq = c->rcq; + qia.qp_type = IBV_QPT_RC; + qia.cap.max_send_wr = 4; + qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4; + qia.cap.max_send_sge = 1; + qia.cap.max_recv_sge = 1; + qia.cap.max_inline_data = 256; + + c->qp = ibv_create_qp(c->pd, &qia); + if (!c->qp) { delete c; return nullptr; } + c->max_inline = qia.cap.max_inline_data; + + c->tx_buf = aligned_alloc(4096, RDMA_CHUNK); + c->rx_buf = aligned_alloc(4096, (size_t)RDMA_RX_DEPTH * RDMA_CHUNK); + if (!c->tx_buf || !c->rx_buf) { delete c; return nullptr; } + + c->tx_mr = ibv_reg_mr(c->pd, c->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE); + c->rx_mr = ibv_reg_mr(c->pd, c->rx_buf, (size_t)RDMA_RX_DEPTH * RDMA_CHUNK, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + if (!c->tx_mr || !c->rx_mr) { delete c; return nullptr; } + + union ibv_gid local_gid; + ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid); + + out->qpn = c->qp->qp_num; + out->psn = c->qp->qp_num & 0xffffff; + memcpy(out->gid, &local_gid, 16); + + GGML_LOG_INFO("RDMA probed: dev=%s gid=%d qpn=%u inline=%u\n", + matched_dev, gid_idx, out->qpn, c->max_inline); + return c; +} + +// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET→INIT→pre-post→RTR→RTS. +// On success, the connection is live and ready for rdma_send/rdma_recv. +static bool rdma_activate(rdma_conn * c, const rdma_local_info * local, + uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) { + // RESET → INIT + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_INIT; + a.port_num = local->ib_port; + a.pkey_index = 0; + a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE; + if (ibv_modify_qp(c->qp, &a, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { + return false; + } + } + + for (int i = 0; i < RDMA_RX_DEPTH; i++) { + if (!c->post_rx(i)) return false; + } + + // INIT → RTR + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTR; + a.path_mtu = local->path_mtu; + a.dest_qp_num = remote_qpn; + a.rq_psn = remote_psn; + a.max_dest_rd_atomic = 1; + a.min_rnr_timer = 1; + a.ah_attr.is_global = 1; + memcpy(&a.ah_attr.grh.dgid, remote_gid, 16); + a.ah_attr.grh.hop_limit = 1; + a.ah_attr.grh.sgid_index = local->gid_idx; + a.ah_attr.dlid = 0; + a.ah_attr.port_num = local->ib_port; + if (ibv_modify_qp(c->qp, &a, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | + IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) { + return false; + } + } + + // RTR → RTS + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTS; + a.timeout = 14; + a.retry_cnt = 7; + a.rnr_retry = 7; + a.sq_psn = local->psn; + a.max_rd_atomic = 1; + if (ibv_modify_qp(c->qp, &a, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | + IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) { + return false; + } + } + + GGML_LOG_INFO("RDMA activated: qpn=%u→%u mtu=%d rx_depth=%d\n", + local->qpn, remote_qpn, 128 << local->path_mtu, RDMA_RX_DEPTH); + return true; +} + +#endif // GGML_RPC_RDMA + +// ── Unified transport dispatch (via function pointers) ────────────────────── + +static bool send_data(socket_t * sock, const void * data, size_t size) { + return sock->fn_send(sock, data, size); +} + +static bool recv_data(socket_t * sock, void * data, size_t size) { + return sock->fn_recv(sock, data, size); +} + +static bool send_msg(socket_t * sock, const void * msg, size_t msg_size) { + if (!send_data(sock, &msg_size, sizeof(msg_size))) { + return false; + } + return send_data(sock, msg, msg_size); +} + +static bool recv_msg(socket_t * sock, void * msg, size_t msg_size) { uint64_t size; - if (!recv_data(sockfd, &size, sizeof(size))) { + if (!recv_data(sock, &size, sizeof(size))) { return false; } if (size != msg_size) { return false; } - return recv_data(sockfd, msg, msg_size); + return recv_data(sock, msg, msg_size); } -static bool recv_msg(sockfd_t sockfd, std::vector & input) { +static bool recv_msg(socket_t * sock, std::vector & input) { uint64_t size; - if (!recv_data(sockfd, &size, sizeof(size))) { + if (!recv_data(sock, &size, sizeof(size))) { return false; } try { @@ -443,7 +816,7 @@ static bool recv_msg(sockfd_t sockfd, std::vector & input) { GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size); return false; } - return recv_data(sockfd, input.data(), size); + return recv_data(sock, input.data(), size); } static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { @@ -452,7 +825,11 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int return false; } host = endpoint.substr(0, pos); - port = std::stoi(endpoint.substr(pos + 1)); + try { + port = std::stoi(endpoint.substr(pos + 1)); + } catch (...) { + return false; + } return true; } @@ -460,13 +837,13 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int // No response static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size) { uint8_t cmd_byte = cmd; - if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) { + if (!send_data(sock.get(), &cmd_byte, sizeof(cmd_byte))) { return false; } - if (!send_data(sock->fd, &input_size, sizeof(input_size))) { + if (!send_data(sock.get(), &input_size, sizeof(input_size))) { return false; } - if (!send_data(sock->fd, input, input_size)) { + if (!send_data(sock.get(), input, input_size)) { return false; } return true; @@ -478,16 +855,14 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm if (!send_rpc_cmd(sock, cmd, input, input_size)) { return false; } - // TODO: currently the output_size is always known, do we need support for commands with variable output size? - // even if we do, we can skip sending output_size from the server for commands with known output size uint64_t out_size; - if (!recv_data(sock->fd, &out_size, sizeof(out_size))) { + if (!recv_data(sock.get(), &out_size, sizeof(out_size))) { return false; } if (out_size != output_size) { return false; } - if (!recv_data(sock->fd, output, output_size)) { + if (!recv_data(sock.get(), output, output_size)) { return false; } return true; @@ -495,7 +870,67 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm // RPC client-side implementation -static bool check_server_version(const std::shared_ptr & sock) { +// Performs HELLO handshake with optional RDMA auto-negotiation. +// If both sides have RDMA, the socket is upgraded transparently. +static bool negotiate_hello(const std::shared_ptr & sock) { +#ifdef GGML_RPC_RDMA + rdma_local_info local_info = {}; + rdma_conn * probe = rdma_probe(sock->fd, &local_info); + + if (probe) { + rpc_msg_hello_rdma_req req = {}; + req.rdma_qpn = local_info.qpn; + req.rdma_psn = local_info.psn; + memcpy(req.rdma_gid, local_info.gid, 16); + + // Send extended HELLO: cmd + input_size + input_data + if (!send_rpc_cmd(sock, RPC_CMD_HELLO, &req, sizeof(req))) { + delete probe; + return false; + } + + // Read response size -- server may respond with legacy or extended size + uint64_t out_size = 0; + if (!recv_data(sock.get(), &out_size, sizeof(out_size))) { + delete probe; + return false; + } + + if (out_size == sizeof(rpc_msg_hello_rdma_rsp)) { + rpc_msg_hello_rdma_rsp rsp = {}; + if (!recv_data(sock.get(), &rsp, sizeof(rsp))) { delete probe; return false; } + if (rsp.major != RPC_PROTO_MAJOR_VERSION || rsp.minor > RPC_PROTO_MINOR_VERSION) { + GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", rsp.major, rsp.minor, rsp.patch); + delete probe; + return false; + } + if (rsp.rdma_qpn != 0) { + if (rdma_activate(probe, &local_info, rsp.rdma_qpn, rsp.rdma_psn, rsp.rdma_gid)) { + sock->rdma = probe; + sock->fn_send = rdma_send_impl; + sock->fn_recv = rdma_recv_impl; + return true; + } + GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n"); + } + delete probe; + return true; + } else if (out_size == sizeof(rpc_msg_hello_rsp)) { + // Legacy server responded with standard HELLO (ignored our RDMA req) + rpc_msg_hello_rsp rsp = {}; + if (!recv_data(sock.get(), &rsp, sizeof(rsp))) { delete probe; return false; } + delete probe; + if (rsp.major != RPC_PROTO_MAJOR_VERSION || rsp.minor > RPC_PROTO_MINOR_VERSION) { + GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", rsp.major, rsp.minor, rsp.patch); + return false; + } + return true; + } else { + delete probe; + return false; + } + } +#endif rpc_msg_hello_rsp response; bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response)); RPC_STATUS_ASSERT(status); @@ -527,6 +962,7 @@ static std::shared_ptr get_socket(const std::string & endpoint) { GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str()); return nullptr; } + #ifdef _WIN32 if (!initialized) { WSADATA wsaData; @@ -543,10 +979,10 @@ static std::shared_ptr get_socket(const std::string & endpoint) { if (sock == nullptr) { return nullptr; } - if (!check_server_version(sock)) { + if (!negotiate_hello(sock)) { return nullptr; } - LOG_DBG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd); + LOG_DBG("[%s] connected to %s\n", __func__, endpoint.c_str()); sockets[endpoint] = sock; return sock; } @@ -1579,24 +2015,94 @@ rpc_server::~rpc_server() { } static void rpc_serve_client(const std::vector & backends, const char * cache_dir, - sockfd_t sockfd) { + socket_t * sockfd) { rpc_server server(backends, cache_dir); uint8_t cmd; if (!recv_data(sockfd, &cmd, 1)) { return; } - // the first command sent by the client must be HELLO if (cmd != RPC_CMD_HELLO) { GGML_LOG_ERROR("Expected HELLO command, update client\n"); return; } - if (!recv_msg(sockfd, nullptr, 0)) { + + // Read input_size to determine legacy vs extended HELLO + uint64_t hello_input_size; + if (!recv_data(sockfd, &hello_input_size, sizeof(hello_input_size))) { return; } - rpc_msg_hello_rsp response; - server.hello(response); - if (!send_msg(sockfd, &response, sizeof(response))) { - return; + + if (hello_input_size == sizeof(rpc_msg_hello_rdma_req)) { + // Extended HELLO with RDMA fields + rpc_msg_hello_rdma_req req = {}; + if (!recv_data(sockfd, &req, sizeof(req))) { + return; + } + +#ifdef GGML_RPC_RDMA + rdma_local_info local_info = {}; + rdma_conn * probe = rdma_probe(sockfd->fd, &local_info); + + rpc_msg_hello_rdma_rsp rsp = {}; + rpc_msg_hello_rsp base_rsp; + server.hello(base_rsp); + rsp.major = base_rsp.major; + rsp.minor = base_rsp.minor; + rsp.patch = base_rsp.patch; + + if (probe && req.rdma_qpn != 0) { + rsp.rdma_qpn = local_info.qpn; + rsp.rdma_psn = local_info.psn; + memcpy(rsp.rdma_gid, local_info.gid, 16); + } + + if (!send_msg(sockfd, &rsp, sizeof(rsp))) { + delete probe; + return; + } + + if (probe && req.rdma_qpn != 0 && rsp.rdma_qpn != 0) { + if (rdma_activate(probe, &local_info, req.rdma_qpn, req.rdma_psn, req.rdma_gid)) { + sockfd->rdma = probe; + sockfd->fn_send = rdma_send_impl; + sockfd->fn_recv = rdma_recv_impl; + } else { + GGML_LOG_ERROR("RDMA activate failed on server, staying on TCP\n"); + delete probe; + } + } else { + delete probe; + } +#else + // Not compiled with RDMA -- respond with zeros for RDMA fields + rpc_msg_hello_rdma_rsp rsp = {}; + rpc_msg_hello_rsp base_rsp; + server.hello(base_rsp); + rsp.major = base_rsp.major; + rsp.minor = base_rsp.minor; + rsp.patch = base_rsp.patch; + if (!send_msg(sockfd, &rsp, sizeof(rsp))) { + return; + } +#endif + } else if (hello_input_size == 0) { + // Legacy HELLO (no RDMA) + rpc_msg_hello_rsp response; + server.hello(response); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + } else { + // Unknown HELLO size -- consume and respond with legacy + std::vector discard(hello_input_size); + if (!recv_data(sockfd, discard.data(), hello_input_size)) { + return; + } + rpc_msg_hello_rsp response; + server.hello(response); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } } while (true) { if (!recv_data(sockfd, &cmd, 1)) { @@ -1866,6 +2372,12 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir if (!parse_endpoint(endpoint, host, port)) { return; } + +#ifdef GGML_RPC_RDMA + printf(" transport : TCP (RDMA auto-negotiate enabled)\n"); +#else + printf(" transport : TCP\n"); +#endif #ifdef _WIN32 { WSADATA wsaData; @@ -1889,7 +2401,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir } printf("Accepted client connection\n"); fflush(stdout); - rpc_serve_client(backends, cache_dir, client_socket->fd); + rpc_serve_client(backends, cache_dir, client_socket.get()); printf("Client connection closed\n"); fflush(stdout); }