diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index effa20e230..74baed9c60 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -9,6 +9,16 @@ #include #include #include +#include + +#ifdef _WIN32 +#include +#else +#include +#include +#include +#include +#endif #if defined(__APPLE__) && defined(__MACH__) // macOS: use _NSGetExecutablePath to get the executable path @@ -112,10 +122,63 @@ std::optional server_models::get_meta(const std::string & nam return std::nullopt; } -static int get_free_port(std::string host) { - httplib::Server s; - int port = s.bind_to_any_port(host.c_str()); - s.stop(); +static int get_free_port() { +#ifdef _WIN32 + WSADATA wsaData; + if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) { + return -1; + } + typedef SOCKET native_socket_t; +#define INVALID_SOCKET_VAL INVALID_SOCKET +#define CLOSE_SOCKET(s) closesocket(s) +#else + typedef int native_socket_t; +#define INVALID_SOCKET_VAL -1 +#define CLOSE_SOCKET(s) close(s) +#endif + + native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock == INVALID_SOCKET_VAL) { +#ifdef _WIN32 + WSACleanup(); +#endif + return -1; + } + + struct sockaddr_in serv_addr; + std::memset(&serv_addr, 0, sizeof(serv_addr)); + serv_addr.sin_family = AF_INET; + serv_addr.sin_addr.s_addr = htonl(INADDR_ANY); + serv_addr.sin_port = htons(0); + + if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) { + CLOSE_SOCKET(sock); +#ifdef _WIN32 + WSACleanup(); +#endif + return -1; + } + +#ifdef _WIN32 + int namelen = sizeof(serv_addr); +#else + socklen_t namelen = sizeof(serv_addr); +#endif + if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) { + CLOSE_SOCKET(sock); +#ifdef _WIN32 + WSACleanup(); +#endif + return -1; + } + + int port = ntohs(serv_addr.sin_port); + + CLOSE_SOCKET(sock); +#ifdef _WIN32 + WSACleanup(); +#endif + return port; } @@ -154,7 +217,7 @@ void server_models::load(const std::string & name) { instance_t inst; inst.meta = meta.value(); - inst.meta.port = get_free_port(base_params.hostname); + inst.meta.port = get_free_port(); inst.meta.status = SERVER_MODEL_STATUS_LOADING; subprocess_s child_proc; @@ -218,7 +281,7 @@ void server_models::load(const std::string & name) { SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code); }); if (inst.th.joinable()) { - inst.th.detach(); + inst.th.detach(); // TODO: remove this because it makes joining impossible } // start a logging thread to read stdout/stderr @@ -324,7 +387,7 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co return proxy; } -void server_models::notify_router_server_ready(const std::string & name) { +void server_models::notify_router_server_ready(const std::string & host, const std::string & name) { // send a notification to the router server that a model instance is ready const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT"); if (router_port == nullptr) { @@ -332,7 +395,7 @@ void server_models::notify_router_server_ready(const std::string & name) { return; } - httplib::Client cli("localhost", std::atoi(router_port)); + httplib::Client cli(host, std::atoi(router_port)); cli.set_connection_timeout(0, 200000); // 200 milliseconds httplib::Request req; @@ -345,9 +408,13 @@ void server_models::notify_router_server_ready(const std::string & name) { body["value"] = server_model_status_to_string(SERVER_MODEL_STATUS_LOADED); req.body = body.dump(); - SRV_INF("notifying router server that model %s is ready\n", name.c_str()); - cli.send(std::move(req)); - // discard response + SRV_INF("notifying router server (port=%s) that model %s is ready\n", router_port, name.c_str()); + auto result = cli.send(std::move(req)); + if (result.error() != httplib::Error::Success) { + auto err_str = httplib::to_string(result.error()); + SRV_ERR("failed to notify router server: %s\n", err_str.c_str()); + // TODO: maybe force shutdown here? + } } diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 664193f434..3e159c18e3 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -98,7 +98,7 @@ public: server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name); // notify the router server that a model instance is ready - static void notify_router_server_ready(const std::string & name); + static void notify_router_server_ready(const std::string & host, const std::string & name); }; /** diff --git a/tools/server/server.cpp b/tools/server/server.cpp index dfa0d07ce9..9aa0cb9f50 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -5825,7 +5825,7 @@ int main(int argc, char ** argv, char ** envp) { LOG_INF("%s: starting the main loop...\n", __func__); // optionally, notify router server that this instance is ready - server_models::notify_router_server_ready(params.model_alias); + server_models::notify_router_server_ready(params.hostname, params.model_alias); // this call blocks the main thread until queue_tasks.terminate() is called ctx_server.queue_tasks.start_loop(); diff --git a/tools/server/tests/unit/test_router.py b/tools/server/tests/unit/test_router.py index 3b1a811c3e..e6f3c6485c 100644 --- a/tools/server/tests/unit/test_router.py +++ b/tools/server/tests/unit/test_router.py @@ -16,7 +16,7 @@ def create_server(): ("non-existent/model", False), ] ) -def test_chat_completion_stream(model: str, success: bool): +def test_router_chat_completion_stream(model: str, success: bool): # TODO: make sure the model is in cache (ie. ServerProcess.load_all()) before starting the router server global server server.start() diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index a9eec74822..3eeb830fdd 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -528,7 +528,7 @@ class ServerPreset: server.n_predict = 4 server.seed = 42 return server - + @staticmethod def router() -> ServerProcess: server = ServerProcess()