support extra_args on loading model

This commit is contained in:
Xuan Son Nguyen 2025-11-23 15:39:03 +01:00
parent 7ef6312f85
commit f927e21ffc
4 changed files with 28 additions and 12 deletions

View File

@ -1436,7 +1436,8 @@ Listing all models in cache. The model metadata will also include a field to ind
"in_cache": true, "in_cache": true,
"path": "/Users/REDACTED/Library/Caches/llama.cpp/ggml-org_gemma-3-4b-it-GGUF_gemma-3-4b-it-Q4_K_M.gguf", "path": "/Users/REDACTED/Library/Caches/llama.cpp/ggml-org_gemma-3-4b-it-GGUF_gemma-3-4b-it-Q4_K_M.gguf",
"status": { "status": {
"value": "loaded" "value": "loaded",
"args": ["llama-server", "-ctx", "4096"]
}, },
... ...
}] }]
@ -1477,14 +1478,16 @@ The `status` object can be:
### POST `/models/load`: Load a model ### POST `/models/load`: Load a model
Load a model Load a model
Payload: Payload:
- `model`: name of the model to be loaded
- `extra_args`: (optional) an array of additional arguments to be passed to the model instance
```json ```json
{ {
"model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M" "model": "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M",
"extra_args": ["-n", "128", "--top-k", "4"]
} }
``` ```
@ -1498,7 +1501,6 @@ Response:
### POST `/models/unload`: Unload a model ### POST `/models/unload`: Unload a model
Unload a model Unload a model
Payload: Payload:

View File

@ -322,7 +322,7 @@ void server_models::unload_lru() {
} }
} }
void server_models::load(const std::string & name) { void server_models::load(const std::string & name, const std::vector<std::string> & extra_args) {
if (!has_model(name)) { if (!has_model(name)) {
throw std::runtime_error("model name=" + name + " is not found"); throw std::runtime_error("model name=" + name + " is not found");
} }
@ -369,6 +369,11 @@ void server_models::load(const std::string & name) {
child_args.push_back("--port"); child_args.push_back("--port");
child_args.push_back(std::to_string(inst.meta.port)); child_args.push_back(std::to_string(inst.meta.port));
// append extra args
for (const auto & arg : extra_args) {
child_args.push_back(arg);
}
std::vector<std::string> child_env = base_env; // copy std::vector<std::string> child_env = base_env; // copy
child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port)); child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
@ -465,6 +470,10 @@ void server_models::unload_all() {
} }
void server_models::update_status(const std::string & name, server_model_status status) { void server_models::update_status(const std::string & name, server_model_status status) {
// for now, we only allow updating to LOADED status
if (status != SERVER_MODEL_STATUS_LOADED) {
throw std::runtime_error("invalid status value");
}
auto meta = get_meta(name); auto meta = get_meta(name);
if (meta.has_value()) { if (meta.has_value()) {
meta->status = status; meta->status = status;
@ -493,7 +502,7 @@ bool server_models::ensure_model_loaded(const std::string & name) {
return false; // already loaded return false; // already loaded
} }
SRV_INF("model name=%s is not loaded, loading...\n", name.c_str()); SRV_INF("model name=%s is not loaded, loading...\n", name.c_str());
load(name); load(name, {});
wait_until_loaded(name); wait_until_loaded(name);
{ {
// check final status // check final status
@ -529,15 +538,18 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co
return proxy; return proxy;
} }
void server_models::setup_child_server(const std::string & host, int router_port, const std::string & name, std::function<void(int)> & shutdown_handler) { void server_models::setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function<void(int)> & shutdown_handler) {
// send a notification to the router server that a model instance is ready // send a notification to the router server that a model instance is ready
httplib::Client cli(host, router_port); httplib::Client cli(base_params.hostname, router_port);
cli.set_connection_timeout(0, 200000); // 200 milliseconds cli.set_connection_timeout(0, 200000); // 200 milliseconds
httplib::Request req; httplib::Request req;
req.method = "POST"; req.method = "POST";
req.path = "/models/status"; req.path = "/models/status";
req.set_header("Content-Type", "application/json"); req.set_header("Content-Type", "application/json");
if (!base_params.api_keys.empty()) {
req.set_header("Authorization", "Bearer " + base_params.api_keys[0]);
}
json body; json body;
body["model"] = name; body["model"] = name;

View File

@ -100,7 +100,7 @@ public:
// return a copy of all model metadata // return a copy of all model metadata
std::vector<server_model_meta> get_all_meta(); std::vector<server_model_meta> get_all_meta();
void load(const std::string & name); void load(const std::string & name, const std::vector<std::string> & extra_args);
void unload(const std::string & name); void unload(const std::string & name);
void unload_all(); void unload_all();
@ -119,7 +119,7 @@ public:
server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used); server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used);
// notify the router server that a model instance is ready // notify the router server that a model instance is ready
static void setup_child_server(const std::string & host, int router_port, const std::string & name, std::function<void(int)> & shutdown_handler); static void setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function<void(int)> & shutdown_handler);
}; };
/** /**

View File

@ -5158,6 +5158,7 @@ public:
auto res = std::make_unique<server_res_generator>(ctx_server); auto res = std::make_unique<server_res_generator>(ctx_server);
json body = json::parse(req.body); json body = json::parse(req.body);
std::string name = json_value(body, "model", std::string()); std::string name = json_value(body, "model", std::string());
std::vector<std::string> extra_args = json_value(body, "extra_args", std::vector<std::string>());
auto model = models->get_meta(name); auto model = models->get_meta(name);
if (!model.has_value()) { if (!model.has_value()) {
res->error(format_error_response("model is not found", ERROR_TYPE_NOT_FOUND)); res->error(format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
@ -5167,12 +5168,13 @@ public:
res->error(format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST)); res->error(format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST));
return res; return res;
} }
models->load(name); models->load(name, extra_args);
res->ok({{"success", true}}); res->ok({{"success", true}});
return res; return res;
}; };
// used by child process to notify the router about status change // used by child process to notify the router about status change
// TODO @ngxson : maybe implement authentication for this endpoint in the future
server_http_context::handler_t post_router_models_status = [this](const server_http_req & req) { server_http_context::handler_t post_router_models_status = [this](const server_http_req & req) {
auto res = std::make_unique<server_res_generator>(ctx_server); auto res = std::make_unique<server_res_generator>(ctx_server);
json body = json::parse(req.body); json body = json::parse(req.body);
@ -5836,7 +5838,7 @@ int main(int argc, char ** argv, char ** envp) {
// optionally, notify router server that this instance is ready // optionally, notify router server that this instance is ready
const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT"); const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT");
if (router_port != nullptr) { if (router_port != nullptr) {
server_models::setup_child_server(params.hostname, std::atoi(router_port), params.model_alias, shutdown_handler); server_models::setup_child_server(params, std::atoi(router_port), params.model_alias, shutdown_handler);
} }
// this call blocks the main thread until queue_tasks.terminate() is called // this call blocks the main thread until queue_tasks.terminate() is called