diff --git a/tools/server/README.md b/tools/server/README.md index f22b57fee2..13ac89617f 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1436,7 +1436,8 @@ Listing all models in cache. The model metadata will also include a field to ind "in_cache": true, "path": "/Users/REDACTED/Library/Caches/llama.cpp/ggml-org_gemma-3-4b-it-GGUF_gemma-3-4b-it-Q4_K_M.gguf", "status": { - "value": "loaded" + "value": "loaded", + "args": ["llama-server", "-ctx", "4096"] }, ... }] @@ -1477,14 +1478,16 @@ The `status` object can be: ### POST `/models/load`: Load a model - Load a model Payload: +- `model`: name of the model to be loaded +- `extra_args`: (optional) an array of additional arguments to be passed to the model instance ```json { - "model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M" + "model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M", + "extra_args": ["-n", "128", "--top-k", "4"] } ``` @@ -1498,7 +1501,6 @@ Response: ### POST `/models/unload`: Unload a model - Unload a model Payload: diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 691ce746e6..285e1e7f7c 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -322,7 +322,7 @@ void server_models::unload_lru() { } } -void server_models::load(const std::string & name) { +void server_models::load(const std::string & name, const std::vector & extra_args) { if (!has_model(name)) { throw std::runtime_error("model name=" + name + " is not found"); } @@ -369,6 +369,11 @@ void server_models::load(const std::string & name) { child_args.push_back("--port"); child_args.push_back(std::to_string(inst.meta.port)); + // append extra args + for (const auto & arg : extra_args) { + child_args.push_back(arg); + } + std::vector child_env = base_env; // copy child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port)); @@ -465,6 +470,10 @@ void server_models::unload_all() { } void server_models::update_status(const std::string & name, server_model_status status) { + // for now, we only allow updating to LOADED status + if (status != SERVER_MODEL_STATUS_LOADED) { + throw std::runtime_error("invalid status value"); + } auto meta = get_meta(name); if (meta.has_value()) { meta->status = status; @@ -493,7 +502,7 @@ bool server_models::ensure_model_loaded(const std::string & name) { return false; // already loaded } SRV_INF("model name=%s is not loaded, loading...\n", name.c_str()); - load(name); + load(name, {}); wait_until_loaded(name); { // check final status @@ -529,15 +538,18 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co return proxy; } -void server_models::setup_child_server(const std::string & host, int router_port, const std::string & name, std::function & shutdown_handler) { +void server_models::setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function & shutdown_handler) { // send a notification to the router server that a model instance is ready - httplib::Client cli(host, router_port); + httplib::Client cli(base_params.hostname, router_port); cli.set_connection_timeout(0, 200000); // 200 milliseconds httplib::Request req; req.method = "POST"; req.path = "/models/status"; req.set_header("Content-Type", "application/json"); + if (!base_params.api_keys.empty()) { + req.set_header("Authorization", "Bearer " + base_params.api_keys[0]); + } json body; body["model"] = name; diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 222f31645e..e192d3dd6e 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -100,7 +100,7 @@ public: // return a copy of all model metadata std::vector get_all_meta(); - void load(const std::string & name); + void load(const std::string & name, const std::vector & extra_args); void unload(const std::string & name); void unload_all(); @@ -119,7 +119,7 @@ public: server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used); // notify the router server that a model instance is ready - static void setup_child_server(const std::string & host, int router_port, const std::string & name, std::function & shutdown_handler); + static void setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function & shutdown_handler); }; /** diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 4ec8aa879c..ab825e24ba 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -5158,6 +5158,7 @@ public: auto res = std::make_unique(ctx_server); json body = json::parse(req.body); std::string name = json_value(body, "model", std::string()); + std::vector extra_args = json_value(body, "extra_args", std::vector()); auto model = models->get_meta(name); if (!model.has_value()) { res->error(format_error_response("model is not found", ERROR_TYPE_NOT_FOUND)); @@ -5167,12 +5168,13 @@ public: res->error(format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST)); return res; } - models->load(name); + models->load(name, extra_args); res->ok({{"success", true}}); return res; }; // used by child process to notify the router about status change + // TODO @ngxson : maybe implement authentication for this endpoint in the future server_http_context::handler_t post_router_models_status = [this](const server_http_req & req) { auto res = std::make_unique(ctx_server); json body = json::parse(req.body); @@ -5836,7 +5838,7 @@ int main(int argc, char ** argv, char ** envp) { // optionally, notify router server that this instance is ready const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT"); if (router_port != nullptr) { - server_models::setup_child_server(params.hostname, std::atoi(router_port), params.model_alias, shutdown_handler); + server_models::setup_child_server(params, std::atoi(router_port), params.model_alias, shutdown_handler); } // this call blocks the main thread until queue_tasks.terminate() is called