diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index 1fccfdd17f..38db7816a9 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -16,6 +16,8 @@ set(TARGET_SRCS utils.hpp server-http.cpp server-http.h + server-models.cpp + server-models.h ) set(PUBLIC_ASSETS index.html.gz diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp new file mode 100644 index 0000000000..f8a10deb5e --- /dev/null +++ b/tools/server/server-models.cpp @@ -0,0 +1,441 @@ +#include "utils.hpp" +#include "server-models.h" + +#include "download.h" + +#include + +#include +#include +#include +#include +#include +#include // for kill() + +#if defined(__APPLE__) && defined(__MACH__) +// macOS: use _NSGetExecutablePath to get the executable path +#include +#include +#endif + +static std::filesystem::path get_server_exec_path() { +#if defined(_MSC_VER) + wchar_t path[FILENAME_MAX] = { 0 }; + GetModuleFileNameW(nullptr, path, FILENAME_MAX); + return std::filesystem::path(path); +#elif defined(__APPLE__) && defined(__MACH__) + char small_path[PATH_MAX]; + uint32_t size = sizeof(small_path); + + if (_NSGetExecutablePath(small_path, &size) == 0) { + // resolve any symlinks to get absolute path + try { + return std::filesystem::canonical(std::filesystem::path(small_path)); + } catch (...) { + return std::filesystem::path(small_path); + } + } else { + // buffer was too small, allocate required size and call again + std::vector buf(size); + if (_NSGetExecutablePath(buf.data(), &size) == 0) { + try { + return std::filesystem::canonical(std::filesystem::path(buf.data())); + } catch (...) { + return std::filesystem::path(buf.data()); + } + } + return std::filesystem::path(std::string(buf.data(), (size > 0) ? size : 0)); + } +#else + char path[FILENAME_MAX]; + ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX); + return std::filesystem::path(std::string(path, (count > 0) ? count: 0)); +#endif +} + +// +// server_models +// + +server_models::server_models( + const common_params & params, + int argc, + char ** argv, + char ** envp) : base_params(params) { + for (int i = 0; i < argc; i++) { + base_args.push_back(std::string(argv[i])); + } + for (char ** env = envp; *env != nullptr; env++) { + base_env.push_back(std::string(*env)); + } + // TODO: allow refreshing cached model list + auto cached_models = common_list_cached_models(); + for (const auto & model : cached_models) { + server_model_meta meta{ + /* name */ model.to_string(), + /* path */ model.manifest_path, + /* path_mmproj */ "", + /* in_cache */ true, + /* port */ 0, + /* status */ SERVER_MODEL_STATUS_UNLOADED + }; + mapping[meta.name] = instance_t{0, std::thread(), meta}; + } +} + +void server_models::update_meta(const std::string & name, const server_model_meta & meta) { + std::lock_guard lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + it->second.meta = meta; + } + cv.notify_all(); // notify wait_until_loaded +} + +bool server_models::has_model(const std::string & name) { + std::lock_guard lk(mutex); + return mapping.find(name) != mapping.end(); +} + +std::optional server_models::get_meta(const std::string & name) { + std::lock_guard lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + return it->second.meta; + } + return std::nullopt; +} + +static int get_free_port(std::string host) { + httplib::Server s; + int port = s.bind_to_any_port(host.c_str()); + s.stop(); + return port; +} + +// helper to convert vector to char ** +// pointers are only valid as long as the original vector is valid +static std::vector to_char_ptr_array(const std::vector & vec) { + std::vector result; + result.reserve(vec.size() + 1); + for (const auto & s : vec) { + result.push_back(const_cast(s.c_str())); + } + result.push_back(nullptr); + return result; +} + +std::vector server_models::get_all_meta() { + std::lock_guard lk(mutex); + std::vector result; + for (const auto & [name, inst] : mapping) { + result.push_back(inst.meta); + } + return result; +} + +void server_models::load(const std::string & name) { + auto meta = get_meta(name); + if (!meta.has_value()) { + throw std::runtime_error("model name=" + name + " is not found"); + } + + std::lock_guard lk(mutex); + if (meta->status != SERVER_MODEL_STATUS_FAILED && meta->status != SERVER_MODEL_STATUS_UNLOADED) { + SRV_INF("model %s is not ready\n", name.c_str()); + return; + } + + instance_t inst; + inst.meta = meta.value(); + inst.meta.port = get_free_port(base_params.hostname); + inst.meta.status = SERVER_MODEL_STATUS_LOADING; + + pid_t pid = 0; + { + std::string exec_path = get_server_exec_path().string(); + SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port); + + std::vector child_args = base_args; // copy + if (inst.meta.in_cache) { + child_args.push_back("-hf"); + child_args.push_back(inst.meta.name); + } else { + child_args.push_back("-m"); + child_args.push_back(inst.meta.path); + if (!inst.meta.path_mmproj.empty()) { + child_args.push_back("--mmproj"); + child_args.push_back(inst.meta.path_mmproj); + } + } + child_args.push_back("--alias"); + child_args.push_back(inst.meta.name); + child_args.push_back("--port"); + child_args.push_back(std::to_string(inst.meta.port)); + + std::vector child_env = base_env; // copy + child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port)); + + // TODO: add logging + SRV_INF("%s", "spawning server instance with args:\n"); + for (const auto & arg : child_args) { + SRV_INF(" %s\n", arg.c_str()); + } + + std::vector argv = to_char_ptr_array(child_args); + std::vector envp = to_char_ptr_array(child_env); + + if (posix_spawn(&pid, exec_path.c_str(), NULL, NULL, argv.data(), envp.data()) != 0) { + perror("posix_spawn"); + throw std::runtime_error("failed to spawn server instance"); + } else { + inst.pid = pid; + SRV_INF("spawned instance with pid %d\n", pid); + } + } + + inst.th = std::thread([this, name, pid]() { + int exit_code = 0; + waitpid(pid, &exit_code, 0); + SRV_INF("instance with pid %d exited with status %d\n", pid, exit_code); + // note: if this is reached before std::move(inst) happens, + // this will be blocked until lock_guard is released (no race condition) + this->update_status(name, exit_code == 0 ? SERVER_MODEL_STATUS_UNLOADED : SERVER_MODEL_STATUS_FAILED); + }); + if (inst.th.joinable()) { + inst.th.detach(); + } + + mapping[name] = std::move(inst); + cv.notify_all(); +} + +void server_models::unload(const std::string & name) { + std::lock_guard lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + if (it->second.pid != 0) { + SRV_INF("killing instance %s with pid %d\n", name.c_str(), it->second.pid); + kill(it->second.pid, SIGTERM); + } + it->second.meta.status = SERVER_MODEL_STATUS_UNLOADED; + cv.notify_all(); // notify status change + } +} + +void server_models::unload_all() { + auto all_meta = get_all_meta(); + for (const auto & meta : all_meta) { + unload(meta.name); + } +} + +void server_models::update_status(const std::string & name, server_model_status status) { + auto meta = get_meta(name); + if (meta.has_value()) { + meta->status = status; + update_meta(name, meta.value()); + } +} + +void server_models::wait_until_loaded(const std::string & name) { + std::unique_lock lk(mutex); + cv.wait(lk, [this, &name]() { + auto it = mapping.find(name); + if (it != mapping.end()) { + return it->second.meta.status == SERVER_MODEL_STATUS_LOADED || + it->second.meta.status == SERVER_MODEL_STATUS_FAILED; + } + return false; + }); +} + +void server_models::ensure_model_loaded(const std::string & name) { + auto meta = get_meta(name); + if (!meta.has_value()) { + throw std::runtime_error("model name=" + name + " is not found"); + } + if (meta->status == SERVER_MODEL_STATUS_LOADED) { + return; // already loaded + } + load(name); + wait_until_loaded(name); +} + +server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name) { + auto meta = get_meta(name); + if (!meta.has_value()) { + throw std::runtime_error("model name=" + name + " is not found"); + } + ensure_model_loaded(name); // TODO: handle failure case + SRV_INF("proxying request to model %s at port %d\n", name.c_str(), meta->port); + auto proxy = std::make_unique( + method, + base_params.hostname, + meta->port, + req.path, + req.headers, + req.body, + req.should_stop); + return proxy; +} + +void server_models::notify_router_server_ready(const std::string & name) { + // send a notification to the router server that a model instance is ready + const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT"); + if (router_port == nullptr) { + // no router server to notify, this is a standalone server + return; + } + + httplib::Client cli("localhost", std::atoi(router_port)); + cli.set_connection_timeout(0, 200000); // 200 milliseconds + + httplib::Request req; + req.method = "POST"; + req.path = "/models/status"; + req.set_header("Content-Type", "application/json"); + + json body; + body["model"] = name; + body["value"] = server_model_status_to_string(SERVER_MODEL_STATUS_LOADED); + req.body = body.dump(); + + SRV_INF("notifying router server that model %s is ready\n", name.c_str()); + cli.send(std::move(req)); + // discard response +} + + +// +// server_http_proxy +// + +// simple implementation of a pipe +// used for streaming data between threads +template +struct pipe_t { + std::mutex mutex; + std::condition_variable cv; + std::queue queue; + std::atomic writer_closed{false}; + std::atomic reader_closed{false}; + void close_write() { + writer_closed.store(true); + cv.notify_all(); + } + void close_read() { + reader_closed.store(true); + cv.notify_all(); + } + bool read(T & output, const std::function & should_stop) { + std::unique_lock lk(mutex); + constexpr auto poll_interval = std::chrono::milliseconds(500); + while (true) { + if (!queue.empty()) { + output = std::move(queue.front()); + queue.pop(); + return true; + } + if (writer_closed.load()) { + return false; // clean EOF + } + if (should_stop()) { + close_read(); // signal broken pipe to writer + return false; // cancelled / reader no longer alive + } + cv.wait_for(lk, poll_interval); + } + } + bool write(T && data) { + std::lock_guard lk(mutex); + if (reader_closed.load()) { + return false; // broken pipe + } + queue.push(std::move(data)); + cv.notify_one(); + return true; + } +}; + +server_http_proxy::server_http_proxy( + const std::string & method, + const std::string & host, + int port, + const std::string & path, + const std::map & headers, + const std::string & body, + const std::function should_stop) { + // shared between reader and writer threads + auto cli = std::make_shared(host, port); + auto pipe = std::make_shared>(); + + // setup Client + cli->set_connection_timeout(0, 200000); // 200 milliseconds + this->status = 500; // to be overwritten upon response + this->cleanup = [pipe]() { + pipe->close_read(); + pipe->close_write(); + }; + + // wire up the receive end of the pipe + this->next = [pipe, should_stop](std::string & out) -> bool { + msg_t msg; + bool has_next = pipe->read(msg, should_stop); + if (!msg.data.empty()) { + out = std::move(msg.data); + } + return has_next; // false if EOF or pipe broken + }; + + // wire up the HTTP client + // note: do NOT capture `this` pointer, as it may be destroyed before the thread ends + httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) { + msg_t msg; + msg.status = response.status; + for (const auto & [key, value] : response.headers) { + msg.headers[key] = value; + } + return pipe->write(std::move(msg)); // send headers first + }; + httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) { + // send data chunks + // returns false if pipe is closed / broken (signal to stop receiving) + return pipe->write({{}, 0, std::string(data, data_length)}); + }; + + // prepare the request to destination server + httplib::Request req; + { + req.method = method; + req.path = path; + for (const auto & [key, value] : headers) { + req.set_header(key, value); + } + req.body = body; + req.response_handler = response_handler; + req.content_receiver = content_receiver; + } + + // start the proxy thread + SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str()); + this->thread = std::thread([cli, pipe, req]() { + auto result = cli->send(std::move(req)); + if (result.error() != httplib::Error::Success) { + auto err_str = httplib::to_string(result.error()); + SRV_ERR("http client error: %s\n", err_str.c_str()); + pipe->write({{}, 500, ""}); // header + pipe->write({{}, 0, "proxy error: " + err_str}); // body + } + pipe->close_write(); // signal EOF to reader + SRV_DBG("%s", "client request thread ended\n"); + }); + this->thread.detach(); + + // wait for the first chunk (headers) + msg_t header; + pipe->read(header, should_stop); + SRV_DBG("%s", "received response headers\n"); + this->status = header.status; + this->headers = header.headers; +} diff --git a/tools/server/server-models.h b/tools/server/server-models.h new file mode 100644 index 0000000000..094b287b56 --- /dev/null +++ b/tools/server/server-models.h @@ -0,0 +1,126 @@ +#pragma once + +#include "common.h" +#include "server-http.h" + +#include +#include +#include + +enum server_model_status { + SERVER_MODEL_STATUS_UNLOADED, + SERVER_MODEL_STATUS_LOADING, + SERVER_MODEL_STATUS_LOADED, + SERVER_MODEL_STATUS_FAILED +}; + +static server_model_status server_model_status_from_string(const std::string & status_str) { + if (status_str == "unloaded") { + return SERVER_MODEL_STATUS_UNLOADED; + } else if (status_str == "loading") { + return SERVER_MODEL_STATUS_LOADING; + } else if (status_str == "loaded") { + return SERVER_MODEL_STATUS_LOADED; + } else if (status_str == "failed") { + return SERVER_MODEL_STATUS_FAILED; + } else { + throw std::runtime_error("invalid server model status"); + } +} + +static std::string server_model_status_to_string(server_model_status status) { + switch (status) { + case SERVER_MODEL_STATUS_UNLOADED: return "unloaded"; + case SERVER_MODEL_STATUS_LOADING: return "loading"; + case SERVER_MODEL_STATUS_LOADED: return "loaded"; + case SERVER_MODEL_STATUS_FAILED: return "failed"; + } +} + +struct server_model_meta { + std::string name; + std::string path; + std::string path_mmproj; // only available if in_cache=false + bool in_cache = false; // if true, use -hf; use -m otherwise + int port = 0; + server_model_status status = SERVER_MODEL_STATUS_UNLOADED; +}; + +struct server_models { +private: + struct instance_t { + pid_t pid = 0; + std::thread th; + server_model_meta meta; + }; + + std::mutex mutex; + std::condition_variable cv; + std::map mapping; + + common_params base_params; + std::vector base_args; + std::vector base_env; + + void update_meta(const std::string & name, const server_model_meta & meta); + +public: + server_models(const common_params & params, int argc, char ** argv, char ** envp); + + // check if a model instance exists + bool has_model(const std::string & name); + + // return a copy of model metadata + std::optional get_meta(const std::string & name); + + // return a copy of all model metadata + std::vector get_all_meta(); + + void load(const std::string & name); + void unload(const std::string & name); + void unload_all(); + + // update the status of a model instance + void update_status(const std::string & name, server_model_status status); + + // wait until the model instance is fully loaded + // return when the model is loaded or failed to load + void wait_until_loaded(const std::string & name); + + // load the model if not loaded, otherwise do nothing + void ensure_model_loaded(const std::string & name); + + // proxy an HTTP request to the model instance + server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name); + + // notify the router server that a model instance is ready + static void notify_router_server_ready(const std::string & name); +}; + +/** + * A simple HTTP proxy that forwards requests to another server + * and relays the responses back. + */ +struct server_http_proxy : server_http_res { + std::function cleanup = nullptr; +public: + server_http_proxy(const std::string & method, + const std::string & host, + int port, + const std::string & path, + const std::map & headers, + const std::string & body, + const std::function should_stop); + ~server_http_proxy() { + if (cleanup) { + cleanup(); + } + } +private: + std::thread thread; + struct msg_t { + std::map headers; + int status = 0; + std::string data; + }; +}; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 3750c8fdb6..5377d79812 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1,6 +1,7 @@ #include "chat.h" #include "utils.hpp" #include "server-http.h" +#include "server-models.h" #include "arg.h" #include "common.h" @@ -4452,6 +4453,8 @@ struct server_routes { const common_params & params; server_context & ctx_server; server_http_context & ctx_http; // for reading is_ready + std::unique_ptr models = nullptr; + server_routes(const common_params & params, server_context & ctx_server, server_http_context & ctx_http) : params(params), ctx_server(ctx_server), ctx_http(ctx_http) {} @@ -5109,6 +5112,115 @@ public: return res; }; + + // + // endpoints for model management (aka router server) + // + + server_http_context::handler_t get_router_props = [this](const server_http_req & req) { + std::string name = req.get_param("model"); + if (name.empty()) { + // main instance + auto res = std::make_unique(ctx_server); + res->ok({ + // TODO: add support for this on web UI + {"role", "router"}, + {"max_instances", 4}, // dummy value for testing + // this is a dummy response to make sure webui doesn't break + {"model_alias", "llama-server"}, + {"model_path", "none"}, + {"default_generation_settings", { + {"params", json{}}, + {"n_ctx", 0}, + }}, + }); + return std::unique_ptr(std::move(res)); + } + return proxy_get(req); + }; + + server_http_context::handler_t proxy_get = [this](const server_http_req & req) { + std::string method = "GET"; + std::string name = req.get_param("model"); + models->ensure_model_loaded(name); + return models->proxy_request(req, method, name); + }; + + server_http_context::handler_t proxy_post = [this](const server_http_req & req) { + std::string method = "POST"; + json body = json::parse(req.body); + std::string name = json_value(body, "model", std::string()); + models->ensure_model_loaded(name); + return models->proxy_request(req, method, name); + }; + + server_http_context::handler_t post_router_models_load = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + json body = json::parse(req.body); + std::string name = json_value(body, "model", std::string()); + auto model = models->get_meta(name); + if (!model.has_value()) { + res->error(format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + if (model->status == SERVER_MODEL_STATUS_LOADED) { + res->error(format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + models->load(name); + res->ok({{"success", true}}); + return res; + }; + + // used by child process to notify the router about status change + server_http_context::handler_t post_router_models_status = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + json body = json::parse(req.body); + std::string model = json_value(body, "model", std::string()); + std::string value = json_value(body, "value", std::string()); + models->update_status(model, server_model_status_from_string(value)); + res->ok({{"success", true}}); + return res; + }; + + server_http_context::handler_t get_router_models = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + json models_json = json::array(); + auto all_models = models->get_all_meta(); + for (const auto & model : all_models) { + models_json.push_back(json { + {"model", model.name}, + {"name", model.name}, + {"id", model.name}, + // TODO: other fields... + {"status", { + {"value", server_model_status_to_string(model.status)} + }}, + }); + } + res->ok({{"data", models_json}}); + return res; + }; + + server_http_context::handler_t post_router_models_unload = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + json body = json::parse(req.body); + std::string name = json_value(body, "model", std::string()); + auto model = models->get_meta(name); + if (!model.has_value()) { + res->error(format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + if (model->status != SERVER_MODEL_STATUS_LOADED) { + res->error(format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + models->unload(name); + res->ok({{"success", true}}); + return res; + }; + + private: std::unique_ptr handle_completions_impl( server_task_type type, @@ -5502,7 +5614,7 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t }; } -int main(int argc, char ** argv) { +int main(int argc, char ** argv, char ** envp) { // own arguments required by this example common_params params; @@ -5550,6 +5662,36 @@ int main(int argc, char ** argv) { // register API routes server_routes routes(params, ctx_server, ctx_http); + bool is_router_server = params.model.path == DEFAULT_MODEL_PATH; + if (is_router_server) { + // setup server instances manager + routes.models.reset(new server_models(params, argc, argv, envp)); + + // proxy handlers + routes.post_props = routes.proxy_post; + routes.post_completions = routes.proxy_post; + routes.post_completions_oai = routes.proxy_post; + routes.post_chat_completions = routes.proxy_post; + routes.post_infill = routes.proxy_post; + routes.post_embeddings = routes.proxy_post; + routes.post_embeddings_oai = routes.proxy_post; + routes.post_rerank = routes.proxy_post; + routes.post_tokenize = routes.proxy_post; + routes.post_detokenize = routes.proxy_post; + routes.post_apply_template = routes.proxy_post; + routes.get_lora_adapters = routes.proxy_get; + routes.post_lora_adapters = routes.proxy_post; + routes.get_slots = routes.proxy_get; + routes.post_slots = routes.proxy_post; + + // custom routes for router + routes.get_props = routes.get_router_props; + routes.get_models = routes.get_router_models; + ctx_http.post("/models/load", ex_wrapper(routes.post_router_models_load)); + ctx_http.post("/models/unload", ex_wrapper(routes.post_router_models_unload)); + ctx_http.post("/models/status", ex_wrapper(routes.post_router_models_status)); + } + ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics)); @@ -5587,51 +5729,74 @@ int main(int argc, char ** argv) { // Start the server // - // setup clean up function, to be called before exit - auto clean_up = [&ctx_http, &ctx_server]() { - SRV_INF("%s: cleaning up before exit...\n", __func__); - ctx_http.stop(); - ctx_server.queue_results.terminate(); - llama_backend_free(); - }; + std::function clean_up; - // start the HTTP server before loading the model to be able to serve /health requests - if (!ctx_http.start()) { - clean_up(); - LOG_ERR("%s: exiting due to HTTP server error\n", __func__); - return 1; - } + if (is_router_server) { + LOG_INF("%s: starting router server, no model will be loaded in this process\n", __func__); + ctx_http.is_ready.store(true); - // load the model - LOG_INF("%s: loading model\n", __func__); + clean_up = []() { + SRV_INF("%s: cleaning up before exit...\n", __func__); + llama_backend_free(); + }; - if (!ctx_server.load_model(params)) { - clean_up(); - if (ctx_http.thread.joinable()) { - ctx_http.thread.join(); + if (!ctx_http.start()) { + clean_up(); + LOG_ERR("%s: exiting due to HTTP server error\n", __func__); + return 1; } - LOG_ERR("%s: exiting due to model loading error\n", __func__); - return 1; + + shutdown_handler = [&](int) { + ctx_http.stop(); + }; + + } else { + // setup clean up function, to be called before exit + clean_up = [&ctx_http, &ctx_server]() { + SRV_INF("%s: cleaning up before exit...\n", __func__); + ctx_http.stop(); + ctx_server.queue_results.terminate(); + llama_backend_free(); + }; + + // start the HTTP server before loading the model to be able to serve /health requests + if (!ctx_http.start()) { + clean_up(); + LOG_ERR("%s: exiting due to HTTP server error\n", __func__); + return 1; + } + + // load the model + LOG_INF("%s: loading model\n", __func__); + + if (!ctx_server.load_model(params)) { + clean_up(); + if (ctx_http.thread.joinable()) { + ctx_http.thread.join(); + } + LOG_ERR("%s: exiting due to model loading error\n", __func__); + return 1; + } + + ctx_server.init(); + ctx_http.is_ready.store(true); + + LOG_INF("%s: model loaded\n", __func__); + + ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { + ctx_server.process_single_task(std::move(task)); + }); + + ctx_server.queue_tasks.on_update_slots([&ctx_server]() { + ctx_server.update_slots(); + }); + + shutdown_handler = [&](int) { + // this will unblock start_loop() + ctx_server.queue_tasks.terminate(); + }; } - ctx_server.init(); - ctx_http.is_ready.store(true); - - LOG_INF("%s: model loaded\n", __func__); - - ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { - ctx_server.process_single_task(std::move(task)); - }); - - ctx_server.queue_tasks.on_update_slots([&ctx_server]() { - ctx_server.update_slots(); - }); - - shutdown_handler = [&](int) { - // this will unblock start_loop() - ctx_server.queue_tasks.terminate(); - }; - #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; sigint_action.sa_handler = signal_handler; @@ -5646,16 +5811,32 @@ int main(int argc, char ** argv) { SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif - LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); - LOG_INF("%s: starting the main loop...\n", __func__); - // this call blocks the main thread until queue_tasks.terminate() is called - ctx_server.queue_tasks.start_loop(); + if (is_router_server) { + LOG_INF("%s: router server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); + ctx_http.is_ready.store(true); + ctx_http.thread.join(); // keep the main thread alive - clean_up(); - if (ctx_http.thread.joinable()) { - ctx_http.thread.join(); + // when the HTTP server stops, clean up and exit + clean_up(); + + // TODO @ngxson : why the models are already unloaded without this line? + // routes.models->unload_all(); + } else { + LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); + LOG_INF("%s: starting the main loop...\n", __func__); + + // optionally, notify router server that this instance is ready + server_models::notify_router_server_ready(params.model_alias); + + // this call blocks the main thread until queue_tasks.terminate() is called + ctx_server.queue_tasks.start_loop(); + + clean_up(); + if (ctx_http.thread.joinable()) { + ctx_http.thread.join(); + } + llama_memory_breakdown_print(ctx_server.ctx); } - llama_memory_breakdown_print(ctx_server.ctx); return 0; }