This commit is contained in:
Chris Rohlf 2026-02-01 08:59:31 -05:00 committed by GitHub
commit 5d525178bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 238 additions and 21 deletions

View File

@ -2183,7 +2183,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
if (llama_supports_rpc()) {
add_opt(common_arg(
{"--rpc"}, "SERVERS",
"comma separated list of RPC servers (host:port)",
"comma separated list of RPC servers (host:port or path ending in .sock)",
[](common_params & params, const std::string & value) {
add_rpc_devices(value);
GGML_UNUSED(params);

View File

@ -25,7 +25,10 @@
# include <netinet/tcp.h>
# include <netdb.h>
# include <unistd.h>
# include <sys/un.h>
# include <sys/stat.h>
#endif
#include <cctype>
#include <cstring>
#include <fstream>
#include <filesystem>
@ -314,7 +317,7 @@ static bool set_reuse_addr(sockfd_t sockfd) {
return ret == 0;
}
static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
static std::shared_ptr<socket_t> socket_connect_tcp(const char * host, int port) {
struct sockaddr_in addr;
auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
auto sock_ptr = make_socket(sockfd);
@ -334,25 +337,75 @@ static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
}
memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
GGML_LOG_ERROR("Failed to connect to host '%s'\n", host);
return nullptr;
}
return sock_ptr;
}
static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
#ifndef _WIN32
static bool unlink_old_socket_path(const char * path) {
struct stat st;
if (lstat(path, &st) != 0) {
if (errno == ENOENT) {
return true;
}
GGML_LOG_ERROR("lstat('%s') failed\n", path);
return false;
}
if (!S_ISSOCK(st.st_mode)) {
GGML_LOG_ERROR("Refusing to unlink '%s': exists but is not a unix socket\n", path);
return false;
}
if (unlink(path) != 0) {
GGML_LOG_ERROR("unlink('%s') failed\n", path);
return false;
}
return true;
}
static std::shared_ptr<socket_t> socket_connect_unix(const char * path) {
auto sockfd = socket(AF_UNIX, SOCK_STREAM, 0);
auto sock_ptr = make_socket(sockfd);
if (sock_ptr == nullptr) {
return nullptr;
}
struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
if (strlen(path) >= sizeof(addr.sun_path)) {
GGML_LOG_ERROR("Unix socket path too long: %s, max is %d\n", path, (int)sizeof(addr.sun_path));
return nullptr;
}
strncpy(addr.sun_path, path, sizeof(addr.sun_path) - 1);
if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
GGML_LOG_ERROR("Failed to create socket at '%s'\n", path);
return nullptr;
}
return sock_ptr;
}
#endif
static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd, bool is_tcp) {
auto client_socket_fd = accept(srv_sockfd, NULL, NULL);
auto client_socket = make_socket(client_socket_fd);
if (client_socket == nullptr) {
return nullptr;
}
if (!set_no_delay(client_socket_fd)) {
if (is_tcp && !set_no_delay(client_socket_fd)) {
GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
return nullptr;
}
return client_socket;
}
static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
static std::shared_ptr<socket_t> create_server_socket_tcp(const char * host, int port) {
auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
auto sock = make_socket(sockfd);
if (sock == nullptr) {
@ -380,6 +433,44 @@ static std::shared_ptr<socket_t> create_server_socket(const char * host, int por
return sock;
}
#ifndef _WIN32
static std::shared_ptr<socket_t> create_server_socket_unix(const char * path) {
auto sockfd = socket(AF_UNIX, SOCK_STREAM, 0);
auto sock = make_socket(sockfd);
if (sock == nullptr) {
return nullptr;
}
if (!unlink_old_socket_path(path)) {
return nullptr;
}
struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
if (strlen(path) >= sizeof(addr.sun_path)) {
GGML_LOG_ERROR("Unix socket path too long: %s, max is %d\n", path, (int)sizeof(addr.sun_path));
return nullptr;
}
strncpy(addr.sun_path, path, sizeof(addr.sun_path) - 1);
if (bind(sockfd, (struct sockaddr *) &addr, sizeof(addr)) < 0) {
return nullptr;
}
if (chmod(path, 0770) != 0) {
GGML_LOG_ERROR("chmod('%s') failed\n", path);
}
if (listen(sockfd, 1) < 0) {
return nullptr;
}
return sock;
}
#endif
static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
size_t bytes_sent = 0;
while (bytes_sent < size) {
@ -446,16 +537,103 @@ static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
return recv_data(sockfd, input.data(), size);
}
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
size_t pos = endpoint.find(':');
enum class endpoint_type {
TCP,
UNIX,
INVALID
};
struct endpoint_info {
endpoint_type type = endpoint_type::INVALID;
std::string host; // For TCP: hostname/IP, For Unix: socket path
int port = 0; // Only used for TCP
};
static bool parse_tcp_endpoint(const std::string & endpoint, endpoint_info & info) {
size_t pos = endpoint.rfind(':');
if (pos == std::string::npos) {
return false;
}
host = endpoint.substr(0, pos);
port = std::stoi(endpoint.substr(pos + 1));
std::string host = endpoint.substr(0, pos);
std::string port_str = endpoint.substr(pos + 1);
if (host.empty() || host.find('/') != std::string::npos) {
return false;
}
if (port_str.empty()) {
return false;
}
for (char c : port_str) {
if (!std::isdigit(static_cast<unsigned char>(c))) {
return false;
}
}
int port;
try {
port = std::stoi(port_str);
} catch (...) {
return false;
}
if (port <= 0 || port > 65535) {
return false;
}
info.type = endpoint_type::TCP;
info.host = host;
info.port = port;
return true;
}
static bool parse_unix_endpoint(const std::string & endpoint, endpoint_info & info) {
std::string path;
const std::string prefix = "unix://";
if (endpoint.rfind(prefix, 0) == 0) {
path = endpoint.substr(prefix.length());
} else {
path = endpoint;
}
if (path.empty()) {
return false;
}
info.type = endpoint_type::UNIX;
info.host = path;
info.port = 0;
return true;
}
static bool parse_endpoint(const std::string & endpoint, endpoint_info & info) {
if (endpoint.rfind("unix://", 0) == 0) {
return parse_unix_endpoint(endpoint, info);
}
if (endpoint.size() >= 5 && endpoint.rfind(".sock") == endpoint.size() - 5) {
#ifdef _WIN32
return false;
#else
return parse_unix_endpoint(endpoint, info);
#endif
}
if (parse_tcp_endpoint(endpoint, info)) {
return true;
}
#ifdef _WIN32
// On Windows, we don't support Unix sockets, so if TCP parsing failed, it's invalid
return false;
#else
if (endpoint.find('/') != std::string::npos) {
return parse_unix_endpoint(endpoint, info);
}
return false;
#endif
}
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
// No response
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {
@ -521,13 +699,16 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
return sock;
}
}
std::string host;
int port;
if (!parse_endpoint(endpoint, host, port)) {
endpoint_info info;
if (!parse_endpoint(endpoint, info)) {
GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str());
return nullptr;
}
#ifdef _WIN32
if (info.type == endpoint_type::UNIX) {
GGML_LOG_ERROR("Unix socket endpoints are not supported on Windows\n");
return nullptr;
}
if (!initialized) {
WSADATA wsaData;
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
@ -539,7 +720,17 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
#else
GGML_UNUSED(initialized);
#endif
auto sock = socket_connect(host.c_str(), port);
std::shared_ptr<socket_t> sock;
#ifndef _WIN32
if (info.type == endpoint_type::UNIX) {
sock = socket_connect_unix(info.host.c_str());
} else {
sock = socket_connect_tcp(info.host.c_str(), info.port);
}
#else
sock = socket_connect_tcp(info.host.c_str(), info.port);
#endif
if (sock == nullptr) {
return nullptr;
}
@ -1861,12 +2052,17 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
}
}
std::string host;
int port;
if (!parse_endpoint(endpoint, host, port)) {
endpoint_info info;
if (!parse_endpoint(endpoint, info)) {
fprintf(stderr, "Failed to parse endpoint: %s\n", endpoint);
return;
}
#ifdef _WIN32
if (info.type == endpoint_type::UNIX) {
fprintf(stderr, "Unix socket endpoints are not supported on Windows\n");
return;
}
{
WSADATA wsaData;
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
@ -1876,13 +2072,26 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
}
}
#endif
auto server_socket = create_server_socket(host.c_str(), port);
std::shared_ptr<socket_t> server_socket;
#ifndef _WIN32
if (info.type == endpoint_type::UNIX) {
server_socket = create_server_socket_unix(info.host.c_str());
} else {
server_socket = create_server_socket_tcp(info.host.c_str(), info.port);
}
#else
server_socket = create_server_socket_tcp(info.host.c_str(), info.port);
#endif
if (server_socket == nullptr) {
fprintf(stderr, "Failed to create server socket\n");
return;
}
const bool is_tcp = (info.type == endpoint_type::TCP);
while (true) {
auto client_socket = socket_accept(server_socket->fd);
auto client_socket = socket_accept(server_socket->fd, is_tcp);
if (client_socket == nullptr) {
fprintf(stderr, "Failed to accept client connection\n");
return;

View File

@ -145,7 +145,7 @@ static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -t, --threads N number of threads for the CPU device (default: %d)\n", params.n_threads);
fprintf(stderr, " -d, --device <dev1,dev2,...> comma-separated list of devices\n");
fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str());
fprintf(stderr, " -H, --host HOST host to bind to, or path to unix socket ending in .sock (default: %s)\n", params.host.c_str());
fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port);
fprintf(stderr, " -c, --cache enable local file cache\n");
fprintf(stderr, "\n");
@ -258,7 +258,10 @@ int main(int argc, char * argv[]) {
return 1;
}
if (params.host != "127.0.0.1") {
bool is_unix_socket = params.host.size() >= 5 &&
params.host.rfind(".sock") == params.host.size() - 5;
if (!is_unix_socket && params.host != "127.0.0.1") {
fprintf(stderr, "\n");
fprintf(stderr, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n");
fprintf(stderr, "WARNING: Host ('%s') is != '127.0.0.1'\n", params.host.c_str());
@ -273,7 +276,12 @@ int main(int argc, char * argv[]) {
fprintf(stderr, "No devices found\n");
return 1;
}
std::string endpoint = params.host + ":" + std::to_string(params.port);
std::string endpoint;
if (is_unix_socket) {
endpoint = params.host;
} else {
endpoint = params.host + ":" + std::to_string(params.port);
}
const char * cache_dir = nullptr;
std::string cache_dir_str;
if (params.use_cache) {