Merge 4ac643ee74 into 2634ed207a
This commit is contained in:
commit
5d525178bd
|
|
@ -2183,7 +2183,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
if (llama_supports_rpc()) {
|
||||
add_opt(common_arg(
|
||||
{"--rpc"}, "SERVERS",
|
||||
"comma separated list of RPC servers (host:port)",
|
||||
"comma separated list of RPC servers (host:port or path ending in .sock)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
add_rpc_devices(value);
|
||||
GGML_UNUSED(params);
|
||||
|
|
|
|||
|
|
@ -25,7 +25,10 @@
|
|||
# include <netinet/tcp.h>
|
||||
# include <netdb.h>
|
||||
# include <unistd.h>
|
||||
# include <sys/un.h>
|
||||
# include <sys/stat.h>
|
||||
#endif
|
||||
#include <cctype>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <filesystem>
|
||||
|
|
@ -314,7 +317,7 @@ static bool set_reuse_addr(sockfd_t sockfd) {
|
|||
return ret == 0;
|
||||
}
|
||||
|
||||
static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
|
||||
static std::shared_ptr<socket_t> socket_connect_tcp(const char * host, int port) {
|
||||
struct sockaddr_in addr;
|
||||
auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
|
||||
auto sock_ptr = make_socket(sockfd);
|
||||
|
|
@ -334,25 +337,75 @@ static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
|
|||
}
|
||||
memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
|
||||
if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
|
||||
GGML_LOG_ERROR("Failed to connect to host '%s'\n", host);
|
||||
return nullptr;
|
||||
}
|
||||
return sock_ptr;
|
||||
}
|
||||
|
||||
static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
|
||||
#ifndef _WIN32
|
||||
static bool unlink_old_socket_path(const char * path) {
|
||||
struct stat st;
|
||||
if (lstat(path, &st) != 0) {
|
||||
if (errno == ENOENT) {
|
||||
return true;
|
||||
}
|
||||
GGML_LOG_ERROR("lstat('%s') failed\n", path);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!S_ISSOCK(st.st_mode)) {
|
||||
GGML_LOG_ERROR("Refusing to unlink '%s': exists but is not a unix socket\n", path);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (unlink(path) != 0) {
|
||||
GGML_LOG_ERROR("unlink('%s') failed\n", path);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::shared_ptr<socket_t> socket_connect_unix(const char * path) {
|
||||
auto sockfd = socket(AF_UNIX, SOCK_STREAM, 0);
|
||||
auto sock_ptr = make_socket(sockfd);
|
||||
if (sock_ptr == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
struct sockaddr_un addr;
|
||||
memset(&addr, 0, sizeof(addr));
|
||||
addr.sun_family = AF_UNIX;
|
||||
|
||||
if (strlen(path) >= sizeof(addr.sun_path)) {
|
||||
GGML_LOG_ERROR("Unix socket path too long: %s, max is %d\n", path, (int)sizeof(addr.sun_path));
|
||||
return nullptr;
|
||||
}
|
||||
strncpy(addr.sun_path, path, sizeof(addr.sun_path) - 1);
|
||||
|
||||
if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
|
||||
GGML_LOG_ERROR("Failed to create socket at '%s'\n", path);
|
||||
return nullptr;
|
||||
}
|
||||
return sock_ptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd, bool is_tcp) {
|
||||
auto client_socket_fd = accept(srv_sockfd, NULL, NULL);
|
||||
auto client_socket = make_socket(client_socket_fd);
|
||||
if (client_socket == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!set_no_delay(client_socket_fd)) {
|
||||
if (is_tcp && !set_no_delay(client_socket_fd)) {
|
||||
GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
|
||||
return nullptr;
|
||||
}
|
||||
return client_socket;
|
||||
}
|
||||
|
||||
static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
|
||||
static std::shared_ptr<socket_t> create_server_socket_tcp(const char * host, int port) {
|
||||
auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
|
||||
auto sock = make_socket(sockfd);
|
||||
if (sock == nullptr) {
|
||||
|
|
@ -380,6 +433,44 @@ static std::shared_ptr<socket_t> create_server_socket(const char * host, int por
|
|||
return sock;
|
||||
}
|
||||
|
||||
#ifndef _WIN32
|
||||
static std::shared_ptr<socket_t> create_server_socket_unix(const char * path) {
|
||||
auto sockfd = socket(AF_UNIX, SOCK_STREAM, 0);
|
||||
auto sock = make_socket(sockfd);
|
||||
if (sock == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!unlink_old_socket_path(path)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
struct sockaddr_un addr;
|
||||
memset(&addr, 0, sizeof(addr));
|
||||
addr.sun_family = AF_UNIX;
|
||||
|
||||
if (strlen(path) >= sizeof(addr.sun_path)) {
|
||||
GGML_LOG_ERROR("Unix socket path too long: %s, max is %d\n", path, (int)sizeof(addr.sun_path));
|
||||
return nullptr;
|
||||
}
|
||||
strncpy(addr.sun_path, path, sizeof(addr.sun_path) - 1);
|
||||
|
||||
if (bind(sockfd, (struct sockaddr *) &addr, sizeof(addr)) < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (chmod(path, 0770) != 0) {
|
||||
GGML_LOG_ERROR("chmod('%s') failed\n", path);
|
||||
}
|
||||
|
||||
if (listen(sockfd, 1) < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return sock;
|
||||
}
|
||||
#endif
|
||||
|
||||
static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
|
||||
size_t bytes_sent = 0;
|
||||
while (bytes_sent < size) {
|
||||
|
|
@ -446,16 +537,103 @@ static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
|
|||
return recv_data(sockfd, input.data(), size);
|
||||
}
|
||||
|
||||
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
|
||||
size_t pos = endpoint.find(':');
|
||||
enum class endpoint_type {
|
||||
TCP,
|
||||
UNIX,
|
||||
INVALID
|
||||
};
|
||||
|
||||
struct endpoint_info {
|
||||
endpoint_type type = endpoint_type::INVALID;
|
||||
std::string host; // For TCP: hostname/IP, For Unix: socket path
|
||||
int port = 0; // Only used for TCP
|
||||
};
|
||||
|
||||
static bool parse_tcp_endpoint(const std::string & endpoint, endpoint_info & info) {
|
||||
size_t pos = endpoint.rfind(':');
|
||||
if (pos == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
host = endpoint.substr(0, pos);
|
||||
port = std::stoi(endpoint.substr(pos + 1));
|
||||
std::string host = endpoint.substr(0, pos);
|
||||
std::string port_str = endpoint.substr(pos + 1);
|
||||
|
||||
if (host.empty() || host.find('/') != std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (port_str.empty()) {
|
||||
return false;
|
||||
}
|
||||
for (char c : port_str) {
|
||||
if (!std::isdigit(static_cast<unsigned char>(c))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
int port;
|
||||
try {
|
||||
port = std::stoi(port_str);
|
||||
} catch (...) {
|
||||
return false;
|
||||
}
|
||||
if (port <= 0 || port > 65535) {
|
||||
return false;
|
||||
}
|
||||
|
||||
info.type = endpoint_type::TCP;
|
||||
info.host = host;
|
||||
info.port = port;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool parse_unix_endpoint(const std::string & endpoint, endpoint_info & info) {
|
||||
std::string path;
|
||||
|
||||
const std::string prefix = "unix://";
|
||||
if (endpoint.rfind(prefix, 0) == 0) {
|
||||
path = endpoint.substr(prefix.length());
|
||||
} else {
|
||||
path = endpoint;
|
||||
}
|
||||
|
||||
if (path.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
info.type = endpoint_type::UNIX;
|
||||
info.host = path;
|
||||
info.port = 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool parse_endpoint(const std::string & endpoint, endpoint_info & info) {
|
||||
if (endpoint.rfind("unix://", 0) == 0) {
|
||||
return parse_unix_endpoint(endpoint, info);
|
||||
}
|
||||
|
||||
if (endpoint.size() >= 5 && endpoint.rfind(".sock") == endpoint.size() - 5) {
|
||||
#ifdef _WIN32
|
||||
return false;
|
||||
#else
|
||||
return parse_unix_endpoint(endpoint, info);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (parse_tcp_endpoint(endpoint, info)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
// On Windows, we don't support Unix sockets, so if TCP parsing failed, it's invalid
|
||||
return false;
|
||||
#else
|
||||
if (endpoint.find('/') != std::string::npos) {
|
||||
return parse_unix_endpoint(endpoint, info);
|
||||
}
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
|
||||
// No response
|
||||
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {
|
||||
|
|
@ -521,13 +699,16 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|||
return sock;
|
||||
}
|
||||
}
|
||||
std::string host;
|
||||
int port;
|
||||
if (!parse_endpoint(endpoint, host, port)) {
|
||||
endpoint_info info;
|
||||
if (!parse_endpoint(endpoint, info)) {
|
||||
GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
#ifdef _WIN32
|
||||
if (info.type == endpoint_type::UNIX) {
|
||||
GGML_LOG_ERROR("Unix socket endpoints are not supported on Windows\n");
|
||||
return nullptr;
|
||||
}
|
||||
if (!initialized) {
|
||||
WSADATA wsaData;
|
||||
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
||||
|
|
@ -539,7 +720,17 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|||
#else
|
||||
GGML_UNUSED(initialized);
|
||||
#endif
|
||||
auto sock = socket_connect(host.c_str(), port);
|
||||
std::shared_ptr<socket_t> sock;
|
||||
#ifndef _WIN32
|
||||
if (info.type == endpoint_type::UNIX) {
|
||||
sock = socket_connect_unix(info.host.c_str());
|
||||
} else {
|
||||
sock = socket_connect_tcp(info.host.c_str(), info.port);
|
||||
}
|
||||
#else
|
||||
sock = socket_connect_tcp(info.host.c_str(), info.port);
|
||||
#endif
|
||||
|
||||
if (sock == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
@ -1861,12 +2052,17 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
|
|||
}
|
||||
}
|
||||
|
||||
std::string host;
|
||||
int port;
|
||||
if (!parse_endpoint(endpoint, host, port)) {
|
||||
endpoint_info info;
|
||||
if (!parse_endpoint(endpoint, info)) {
|
||||
fprintf(stderr, "Failed to parse endpoint: %s\n", endpoint);
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
if (info.type == endpoint_type::UNIX) {
|
||||
fprintf(stderr, "Unix socket endpoints are not supported on Windows\n");
|
||||
return;
|
||||
}
|
||||
{
|
||||
WSADATA wsaData;
|
||||
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
||||
|
|
@ -1876,13 +2072,26 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
|
|||
}
|
||||
}
|
||||
#endif
|
||||
auto server_socket = create_server_socket(host.c_str(), port);
|
||||
|
||||
std::shared_ptr<socket_t> server_socket;
|
||||
#ifndef _WIN32
|
||||
if (info.type == endpoint_type::UNIX) {
|
||||
server_socket = create_server_socket_unix(info.host.c_str());
|
||||
} else {
|
||||
server_socket = create_server_socket_tcp(info.host.c_str(), info.port);
|
||||
}
|
||||
#else
|
||||
server_socket = create_server_socket_tcp(info.host.c_str(), info.port);
|
||||
#endif
|
||||
|
||||
if (server_socket == nullptr) {
|
||||
fprintf(stderr, "Failed to create server socket\n");
|
||||
return;
|
||||
}
|
||||
|
||||
const bool is_tcp = (info.type == endpoint_type::TCP);
|
||||
while (true) {
|
||||
auto client_socket = socket_accept(server_socket->fd);
|
||||
auto client_socket = socket_accept(server_socket->fd, is_tcp);
|
||||
if (client_socket == nullptr) {
|
||||
fprintf(stderr, "Failed to accept client connection\n");
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
|
|||
fprintf(stderr, " -h, --help show this help message and exit\n");
|
||||
fprintf(stderr, " -t, --threads N number of threads for the CPU device (default: %d)\n", params.n_threads);
|
||||
fprintf(stderr, " -d, --device <dev1,dev2,...> comma-separated list of devices\n");
|
||||
fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str());
|
||||
fprintf(stderr, " -H, --host HOST host to bind to, or path to unix socket ending in .sock (default: %s)\n", params.host.c_str());
|
||||
fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port);
|
||||
fprintf(stderr, " -c, --cache enable local file cache\n");
|
||||
fprintf(stderr, "\n");
|
||||
|
|
@ -258,7 +258,10 @@ int main(int argc, char * argv[]) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
if (params.host != "127.0.0.1") {
|
||||
bool is_unix_socket = params.host.size() >= 5 &&
|
||||
params.host.rfind(".sock") == params.host.size() - 5;
|
||||
|
||||
if (!is_unix_socket && params.host != "127.0.0.1") {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n");
|
||||
fprintf(stderr, "WARNING: Host ('%s') is != '127.0.0.1'\n", params.host.c_str());
|
||||
|
|
@ -273,7 +276,12 @@ int main(int argc, char * argv[]) {
|
|||
fprintf(stderr, "No devices found\n");
|
||||
return 1;
|
||||
}
|
||||
std::string endpoint = params.host + ":" + std::to_string(params.port);
|
||||
std::string endpoint;
|
||||
if (is_unix_socket) {
|
||||
endpoint = params.host;
|
||||
} else {
|
||||
endpoint = params.host + ":" + std::to_string(params.port);
|
||||
}
|
||||
const char * cache_dir = nullptr;
|
||||
std::string cache_dir_str;
|
||||
if (params.use_cache) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue