diff --git a/common/arg.cpp b/common/arg.cpp index eab26b67f2..062046c0d0 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2482,12 +2482,19 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_DIR")); add_opt(common_arg( - {"--max-models"}, "N", - string_format("for router server, maximum number of models to load simultaneously (default: %d, 0 = unlimited)", params.max_models), + {"--models-max"}, "N", + string_format("for router server, maximum number of models to load simultaneously (default: %d, 0 = unlimited)", params.models_max), [](common_params & params, int value) { - params.max_models = value; + params.models_max = value; } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MAX_MODELS")); + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_MAX")); + add_opt(common_arg( + {"--no-models-autoload"}, + "disables automatic loading of models (default: enabled)", + [](common_params & params) { + params.models_autoload = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_MODELS_AUTOLOAD")); add_opt(common_arg( {"--jinja"}, "use jinja template for chat (default: disabled)", diff --git a/common/common.h b/common/common.h index 20ba209ce4..4ac9700d7b 100644 --- a/common/common.h +++ b/common/common.h @@ -460,7 +460,8 @@ struct common_params { // router server configs std::string models_dir = ""; // directory containing models for the router server - int max_models = 4; // maximum number of models to load simultaneously + int models_max = 4; // maximum number of models to load simultaneously + bool models_autoload = true; // automatically load models when requested via the router server bool log_json = false; diff --git a/tools/server/README.md b/tools/server/README.md index 62ae83a1f0..bc1a4f8f7a 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1463,8 +1463,9 @@ The `status` object can be: ```json "status": { - "value": "failed", + "value": "unloaded", "args": ["llama-server", "-ctx", "4096"], + "failed": true, "exit_code": 1 } ``` diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index cf81540f5a..67f84a508f 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -29,6 +29,8 @@ #include #endif +#define CMD_EXIT "exit" + static std::filesystem::path get_server_exec_path() { #if defined(_WIN32) wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths @@ -297,10 +299,10 @@ std::vector server_models::get_all_meta() { } void server_models::unload_lru() { - if (base_params.max_models <= 0) { + if (base_params.models_max <= 0) { return; // no limit } - // remove one of the servers if we passed the max_models (least recently used - LRU) + // remove one of the servers if we passed the models_max (least recently used - LRU) std::string lru_model_name = ""; int64_t lru_last_used = ggml_time_ms(); size_t count_active = 0; @@ -316,8 +318,8 @@ void server_models::unload_lru() { } } } - if (!lru_model_name.empty() && count_active >= (size_t)base_params.max_models) { - SRV_INF("max_models limit reached, removing LRU name=%s\n", lru_model_name.c_str()); + if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) { + SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str()); unload(lru_model_name); } } @@ -331,7 +333,7 @@ void server_models::load(const std::string & name, const std::vector lk(mutex); auto meta = mapping[name].meta; - if (meta.status != SERVER_MODEL_STATUS_FAILED && meta.status != SERVER_MODEL_STATUS_UNLOADED) { + if (meta.status != SERVER_MODEL_STATUS_UNLOADED) { SRV_INF("model %s is not ready\n", name.c_str()); return; } @@ -428,9 +430,7 @@ void server_models::load(const std::string & name, const std::vectorsecond.meta; meta.exit_code = exit_code; - meta.status = exit_code == 0 - ? SERVER_MODEL_STATUS_UNLOADED - : SERVER_MODEL_STATUS_FAILED; + meta.status = SERVER_MODEL_STATUS_UNLOADED; } cv.notify_all(); } @@ -446,13 +446,23 @@ void server_models::load(const std::string & name, const std::vector lk(mutex); auto it = mapping.find(name); if (it != mapping.end()) { if (it->second.meta.is_active()) { SRV_INF("unloading model instance name=%s\n", name.c_str()); - subprocess_terminate(it->second.subproc.get()); + interrupt_subprocess(it->second.subproc.get()); // status change will be handled by the managing thread } else { SRV_WRN("model instance name=%s is not loaded\n", name.c_str()); @@ -467,7 +477,7 @@ void server_models::unload_all() { for (auto & [name, inst] : mapping) { if (inst.meta.is_active()) { SRV_INF("unloading model instance name=%s\n", name.c_str()); - subprocess_terminate(inst.subproc.get()); + interrupt_subprocess(inst.subproc.get()); // status change will be handled by the managing thread } // moving the thread to join list to avoid deadlock @@ -498,8 +508,7 @@ void server_models::wait_until_loaded(const std::string & name) { cv.wait(lk, [this, &name]() { auto it = mapping.find(name); if (it != mapping.end()) { - return it->second.meta.status == SERVER_MODEL_STATUS_LOADED || - it->second.meta.status == SERVER_MODEL_STATUS_FAILED; + return it->second.meta.status != SERVER_MODEL_STATUS_LOADING; } return false; }); @@ -510,19 +519,23 @@ bool server_models::ensure_model_loaded(const std::string & name) { if (!meta.has_value()) { throw std::runtime_error("model name=" + name + " is not found"); } - if (meta->is_active()) { + if (meta->status == SERVER_MODEL_STATUS_LOADED) { return false; // already loaded } - SRV_INF("model name=%s is not loaded, loading...\n", name.c_str()); - load(name, {}, true); - wait_until_loaded(name); - { - // check final status - meta = get_meta(name); - if (!meta.has_value() || meta->status == SERVER_MODEL_STATUS_FAILED) { - throw std::runtime_error("model name=" + name + " failed to load"); - } + if (meta->status == SERVER_MODEL_STATUS_UNLOADED) { + SRV_INF("model name=%s is not loaded, loading...\n", name.c_str()); + load(name, {}, true); } + + SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str()); + wait_until_loaded(name); + + // check final status + meta = get_meta(name); + if (!meta.has_value() || meta->is_failed()) { + throw std::runtime_error("model name=" + name + " failed to load"); + } + return true; } @@ -582,13 +595,18 @@ void server_models::setup_child_server(const common_params & base_params, int ro // wait for EOF on stdin SRV_INF("%s", "child server monitoring thread started, waiting for EOF on stdin...\n"); while (true) { - int c = getchar(); - if (c == EOF) { - break; + std::string line; + if (!std::getline(std::cin, line)) { + break; // EOF detected + } + if (line.find(CMD_EXIT) != std::string::npos) { + SRV_INF("%s", "exit command received, exiting...\n"); + shutdown_handler(0); } } - SRV_INF("%s", "EOF on stdin detected, invoking shutdown handler...\n"); - shutdown_handler(0); // invoke shutdown handler + // EOF meaning router server is unexpectedly exit or killed + SRV_INF("%s", "EOF on stdin detected, forcing shutdown...\n"); + exit(1); }).detach(); } diff --git a/tools/server/server-models.h b/tools/server/server-models.h index ed08c5023e..c49cb7c62c 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -15,16 +15,16 @@ * state diagram: * * UNLOADED ──► LOADING ──► LOADED - * ▲ │ - * │ │ - * FAILED ◄───────┘ + * ▲ │ │ + * └───failed───┘ │ + * ▲ │ + * └────────unloaded─────────┘ */ enum server_model_status { - // TODO: also add downloading state + // TODO: also add downloading state when the logic is added SERVER_MODEL_STATUS_UNLOADED, SERVER_MODEL_STATUS_LOADING, - SERVER_MODEL_STATUS_LOADED, - SERVER_MODEL_STATUS_FAILED + SERVER_MODEL_STATUS_LOADED }; static server_model_status server_model_status_from_string(const std::string & status_str) { @@ -34,8 +34,6 @@ static server_model_status server_model_status_from_string(const std::string & s return SERVER_MODEL_STATUS_LOADING; } else if (status_str == "loaded") { return SERVER_MODEL_STATUS_LOADED; - } else if (status_str == "failed") { - return SERVER_MODEL_STATUS_FAILED; } else { throw std::runtime_error("invalid server model status"); } @@ -46,7 +44,6 @@ static std::string server_model_status_to_string(server_model_status status) { case SERVER_MODEL_STATUS_UNLOADED: return "unloaded"; case SERVER_MODEL_STATUS_LOADING: return "loading"; case SERVER_MODEL_STATUS_LOADED: return "loaded"; - case SERVER_MODEL_STATUS_FAILED: return "failed"; default: return "unknown"; } } @@ -65,6 +62,10 @@ struct server_model_meta { bool is_active() const { return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING; } + + bool is_failed() const { + return status == SERVER_MODEL_STATUS_UNLOADED && exit_code != 0; + } }; struct server_models { diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 6a499d9577..efd9aee83f 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -5142,10 +5142,9 @@ public: server_http_context::handler_t proxy_get = [this](const server_http_req & req) { std::string method = "GET"; std::string name = req.get_param("model"); - if (name.empty()) { - auto res = std::make_unique(ctx_server); - res->error(format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST)); - return std::unique_ptr(std::move(res)); + auto error_res = std::make_unique(ctx_server); + if (!router_validate_model(name, error_res)) { + return std::unique_ptr(std::move(error_res)); } models->ensure_model_loaded(name); return models->proxy_request(req, method, name, false); @@ -5155,10 +5154,9 @@ public: std::string method = "POST"; json body = json::parse(req.body); std::string name = json_value(body, "model", std::string()); - if (name.empty()) { - auto res = std::make_unique(ctx_server); - res->error(format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST)); - return std::unique_ptr(std::move(res)); + auto error_res = std::make_unique(ctx_server); + if (!router_validate_model(name, error_res)) { + return std::unique_ptr(std::move(error_res)); } models->ensure_model_loaded(name); return models->proxy_request(req, method, name, true); // update last usage for POST request only @@ -5200,22 +5198,23 @@ public: json models_json = json::array(); auto all_models = models->get_all_meta(); std::time_t t = std::time(0); - for (const auto & model : all_models) { + for (const auto & meta : all_models) { json status { - {"value", server_model_status_to_string(model.status)}, - {"args", model.args}, + {"value", server_model_status_to_string(meta.status)}, + {"args", meta.args}, }; - if (model.status == SERVER_MODEL_STATUS_FAILED) { - status["exit_code"] = model.exit_code; + if (meta.is_failed()) { + status["exit_code"] = meta.exit_code; + status["failed"] = true; } models_json.push_back(json { - {"id", model.name}, - {"name", model.name}, + {"id", meta.name}, + {"name", meta.name}, {"object", "model"}, // for OAI-compat {"owned_by", "llamacpp"}, // for OAI-compat {"created", t}, // for OAI-compat - {"in_cache", model.in_cache}, - {"path", model.path}, + {"in_cache", meta.in_cache}, + {"path", meta.path}, {"status", status}, // TODO: add other fields, may require reading GGUF metadata }); @@ -5595,6 +5594,27 @@ private: res->ok(root); return res; } + + bool router_validate_model(const std::string & name, std::unique_ptr & res) { + if (name.empty()) { + res->error(format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST)); + return false; + } + auto meta = models->get_meta(name); + if (!meta.has_value()) { + res->error(format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST)); + return false; + } + if (params.models_autoload) { + models->ensure_model_loaded(name); + } else { + if (meta->status != SERVER_MODEL_STATUS_LOADED) { + res->error(format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); + return false; + } + } + return true; + } }; std::function shutdown_handler;