From 4ac643ee742bb7e7a83bfbf183af2fb5c243dd62 Mon Sep 17 00:00:00 2001 From: struct Date: Thu, 1 Jan 2026 09:22:48 -0500 Subject: [PATCH] first commit of unix socket support for ggml rpc-server --- common/arg.cpp | 2 +- ggml/src/ggml-rpc/ggml-rpc.cpp | 243 ++++++++++++++++++++++++++++++--- tools/rpc/rpc-server.cpp | 14 +- 3 files changed, 238 insertions(+), 21 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 62d31393c4..ad756b0df0 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2017,7 +2017,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex if (llama_supports_rpc()) { add_opt(common_arg( {"--rpc"}, "SERVERS", - "comma separated list of RPC servers (host:port)", + "comma separated list of RPC servers (host:port or path ending in .sock)", [](common_params & params, const std::string & value) { add_rpc_devices(value); GGML_UNUSED(params); diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 164b39d01e..42d4752a04 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -25,7 +25,10 @@ # include # include # include +# include +# include #endif +#include #include #include #include @@ -314,7 +317,7 @@ static bool set_reuse_addr(sockfd_t sockfd) { return ret == 0; } -static std::shared_ptr socket_connect(const char * host, int port) { +static std::shared_ptr socket_connect_tcp(const char * host, int port) { struct sockaddr_in addr; auto sockfd = socket(AF_INET, SOCK_STREAM, 0); auto sock_ptr = make_socket(sockfd); @@ -334,25 +337,75 @@ static std::shared_ptr socket_connect(const char * host, int port) { } memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + GGML_LOG_ERROR("Failed to connect to host '%s'\n", host); return nullptr; } return sock_ptr; } -static std::shared_ptr socket_accept(sockfd_t srv_sockfd) { +#ifndef _WIN32 +static bool unlink_old_socket_path(const char * path) { + struct stat st; + if (lstat(path, &st) != 0) { + if (errno == ENOENT) { + return true; + } + GGML_LOG_ERROR("lstat('%s') failed\n", path); + return false; + } + + if (!S_ISSOCK(st.st_mode)) { + GGML_LOG_ERROR("Refusing to unlink '%s': exists but is not a unix socket\n", path); + return false; + } + + if (unlink(path) != 0) { + GGML_LOG_ERROR("unlink('%s') failed\n", path); + return false; + } + + return true; +} + +static std::shared_ptr socket_connect_unix(const char * path) { + auto sockfd = socket(AF_UNIX, SOCK_STREAM, 0); + auto sock_ptr = make_socket(sockfd); + if (sock_ptr == nullptr) { + return nullptr; + } + + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + + if (strlen(path) >= sizeof(addr.sun_path)) { + GGML_LOG_ERROR("Unix socket path too long: %s, max is %d\n", path, (int)sizeof(addr.sun_path)); + return nullptr; + } + strncpy(addr.sun_path, path, sizeof(addr.sun_path) - 1); + + if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + GGML_LOG_ERROR("Failed to create socket at '%s'\n", path); + return nullptr; + } + return sock_ptr; +} +#endif + +static std::shared_ptr socket_accept(sockfd_t srv_sockfd, bool is_tcp) { auto client_socket_fd = accept(srv_sockfd, NULL, NULL); auto client_socket = make_socket(client_socket_fd); if (client_socket == nullptr) { return nullptr; } - if (!set_no_delay(client_socket_fd)) { + if (is_tcp && !set_no_delay(client_socket_fd)) { GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); return nullptr; } return client_socket; } -static std::shared_ptr create_server_socket(const char * host, int port) { +static std::shared_ptr create_server_socket_tcp(const char * host, int port) { auto sockfd = socket(AF_INET, SOCK_STREAM, 0); auto sock = make_socket(sockfd); if (sock == nullptr) { @@ -380,6 +433,44 @@ static std::shared_ptr create_server_socket(const char * host, int por return sock; } +#ifndef _WIN32 +static std::shared_ptr create_server_socket_unix(const char * path) { + auto sockfd = socket(AF_UNIX, SOCK_STREAM, 0); + auto sock = make_socket(sockfd); + if (sock == nullptr) { + return nullptr; + } + + if (!unlink_old_socket_path(path)) { + return nullptr; + } + + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + + if (strlen(path) >= sizeof(addr.sun_path)) { + GGML_LOG_ERROR("Unix socket path too long: %s, max is %d\n", path, (int)sizeof(addr.sun_path)); + return nullptr; + } + strncpy(addr.sun_path, path, sizeof(addr.sun_path) - 1); + + if (bind(sockfd, (struct sockaddr *) &addr, sizeof(addr)) < 0) { + return nullptr; + } + + if (chmod(path, 0770) != 0) { + GGML_LOG_ERROR("chmod('%s') failed\n", path); + } + + if (listen(sockfd, 1) < 0) { + return nullptr; + } + + return sock; +} +#endif + static bool send_data(sockfd_t sockfd, const void * data, size_t size) { size_t bytes_sent = 0; while (bytes_sent < size) { @@ -446,16 +537,103 @@ static bool recv_msg(sockfd_t sockfd, std::vector & input) { return recv_data(sockfd, input.data(), size); } -static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { - size_t pos = endpoint.find(':'); +enum class endpoint_type { + TCP, + UNIX, + INVALID +}; + +struct endpoint_info { + endpoint_type type = endpoint_type::INVALID; + std::string host; // For TCP: hostname/IP, For Unix: socket path + int port = 0; // Only used for TCP +}; + +static bool parse_tcp_endpoint(const std::string & endpoint, endpoint_info & info) { + size_t pos = endpoint.rfind(':'); if (pos == std::string::npos) { return false; } - host = endpoint.substr(0, pos); - port = std::stoi(endpoint.substr(pos + 1)); + std::string host = endpoint.substr(0, pos); + std::string port_str = endpoint.substr(pos + 1); + + if (host.empty() || host.find('/') != std::string::npos) { + return false; + } + + if (port_str.empty()) { + return false; + } + for (char c : port_str) { + if (!std::isdigit(static_cast(c))) { + return false; + } + } + + int port; + try { + port = std::stoi(port_str); + } catch (...) { + return false; + } + if (port <= 0 || port > 65535) { + return false; + } + + info.type = endpoint_type::TCP; + info.host = host; + info.port = port; return true; } +static bool parse_unix_endpoint(const std::string & endpoint, endpoint_info & info) { + std::string path; + + const std::string prefix = "unix://"; + if (endpoint.rfind(prefix, 0) == 0) { + path = endpoint.substr(prefix.length()); + } else { + path = endpoint; + } + + if (path.empty()) { + return false; + } + + info.type = endpoint_type::UNIX; + info.host = path; + info.port = 0; + return true; +} + +static bool parse_endpoint(const std::string & endpoint, endpoint_info & info) { + if (endpoint.rfind("unix://", 0) == 0) { + return parse_unix_endpoint(endpoint, info); + } + + if (endpoint.size() >= 5 && endpoint.rfind(".sock") == endpoint.size() - 5) { +#ifdef _WIN32 + return false; +#else + return parse_unix_endpoint(endpoint, info); +#endif + } + + if (parse_tcp_endpoint(endpoint, info)) { + return true; + } + +#ifdef _WIN32 + // On Windows, we don't support Unix sockets, so if TCP parsing failed, it's invalid + return false; +#else + if (endpoint.find('/') != std::string::npos) { + return parse_unix_endpoint(endpoint, info); + } + return false; +#endif +} + // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // No response static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size) { @@ -521,13 +699,16 @@ static std::shared_ptr get_socket(const std::string & endpoint) { return sock; } } - std::string host; - int port; - if (!parse_endpoint(endpoint, host, port)) { + endpoint_info info; + if (!parse_endpoint(endpoint, info)) { GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str()); return nullptr; } #ifdef _WIN32 + if (info.type == endpoint_type::UNIX) { + GGML_LOG_ERROR("Unix socket endpoints are not supported on Windows\n"); + return nullptr; + } if (!initialized) { WSADATA wsaData; int res = WSAStartup(MAKEWORD(2, 2), &wsaData); @@ -539,7 +720,17 @@ static std::shared_ptr get_socket(const std::string & endpoint) { #else GGML_UNUSED(initialized); #endif - auto sock = socket_connect(host.c_str(), port); + std::shared_ptr sock; +#ifndef _WIN32 + if (info.type == endpoint_type::UNIX) { + sock = socket_connect_unix(info.host.c_str()); + } else { + sock = socket_connect_tcp(info.host.c_str(), info.port); + } +#else + sock = socket_connect_tcp(info.host.c_str(), info.port); +#endif + if (sock == nullptr) { return nullptr; } @@ -1859,12 +2050,17 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir } } - std::string host; - int port; - if (!parse_endpoint(endpoint, host, port)) { + endpoint_info info; + if (!parse_endpoint(endpoint, info)) { + fprintf(stderr, "Failed to parse endpoint: %s\n", endpoint); return; } + #ifdef _WIN32 + if (info.type == endpoint_type::UNIX) { + fprintf(stderr, "Unix socket endpoints are not supported on Windows\n"); + return; + } { WSADATA wsaData; int res = WSAStartup(MAKEWORD(2, 2), &wsaData); @@ -1874,13 +2070,26 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir } } #endif - auto server_socket = create_server_socket(host.c_str(), port); + + std::shared_ptr server_socket; +#ifndef _WIN32 + if (info.type == endpoint_type::UNIX) { + server_socket = create_server_socket_unix(info.host.c_str()); + } else { + server_socket = create_server_socket_tcp(info.host.c_str(), info.port); + } +#else + server_socket = create_server_socket_tcp(info.host.c_str(), info.port); +#endif + if (server_socket == nullptr) { fprintf(stderr, "Failed to create server socket\n"); return; } + + const bool is_tcp = (info.type == endpoint_type::TCP); while (true) { - auto client_socket = socket_accept(server_socket->fd); + auto client_socket = socket_accept(server_socket->fd, is_tcp); if (client_socket == nullptr) { fprintf(stderr, "Failed to accept client connection\n"); return; diff --git a/tools/rpc/rpc-server.cpp b/tools/rpc/rpc-server.cpp index 58b93c7468..b1f106b82c 100644 --- a/tools/rpc/rpc-server.cpp +++ b/tools/rpc/rpc-server.cpp @@ -145,7 +145,7 @@ static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) { fprintf(stderr, " -h, --help show this help message and exit\n"); fprintf(stderr, " -t, --threads N number of threads for the CPU device (default: %d)\n", params.n_threads); fprintf(stderr, " -d, --device comma-separated list of devices\n"); - fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str()); + fprintf(stderr, " -H, --host HOST host to bind to, or path to unix socket ending in .sock (default: %s)\n", params.host.c_str()); fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port); fprintf(stderr, " -c, --cache enable local file cache\n"); fprintf(stderr, "\n"); @@ -258,7 +258,10 @@ int main(int argc, char * argv[]) { return 1; } - if (params.host != "127.0.0.1") { + bool is_unix_socket = params.host.size() >= 5 && + params.host.rfind(".sock") == params.host.size() - 5; + + if (!is_unix_socket && params.host != "127.0.0.1") { fprintf(stderr, "\n"); fprintf(stderr, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"); fprintf(stderr, "WARNING: Host ('%s') is != '127.0.0.1'\n", params.host.c_str()); @@ -273,7 +276,12 @@ int main(int argc, char * argv[]) { fprintf(stderr, "No devices found\n"); return 1; } - std::string endpoint = params.host + ":" + std::to_string(params.port); + std::string endpoint; + if (is_unix_socket) { + endpoint = params.host; + } else { + endpoint = params.host + ":" + std::to_string(params.port); + } const char * cache_dir = nullptr; std::string cache_dir_str; if (params.use_cache) {