diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 16f91e65c0..1142cff217 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -135,7 +135,8 @@ server_models::server_models( /* path_mmproj */ "", // auto-detected when loading /* in_cache */ true, /* port */ 0, - /* status */ SERVER_MODEL_STATUS_UNLOADED + /* status */ SERVER_MODEL_STATUS_UNLOADED, + /* last_used */ 0 }; mapping[meta.name] = instance_t{ /* subproc */ std::make_shared(), @@ -157,7 +158,8 @@ server_models::server_models( /* path_mmproj */ model.path_mmproj, /* in_cache */ false, /* port */ 0, - /* status */ SERVER_MODEL_STATUS_UNLOADED + /* status */ SERVER_MODEL_STATUS_UNLOADED, + /* last_used */ 0 }; mapping[meta.name] = instance_t{ /* subproc */ std::make_shared(), @@ -272,11 +274,39 @@ std::vector server_models::get_all_meta() { return result; } +void server_models::unload_lru() { + if (base_params.max_models <= 0) { + return; // no limit + } + // remove one of the servers if we passed the max_models (least recently used - LRU) + std::string lru_model_name = ""; + int64_t lru_last_used = ggml_time_ms(); + size_t count_active = 0; + { + std::lock_guard lk(mutex); + for (const auto & m : mapping) { + if (m.second.meta.is_active()) { + count_active++; + if (m.second.meta.last_used < lru_last_used) { + lru_model_name = m.first; + lru_last_used = m.second.meta.last_used; + } + } + } + } + if (!lru_model_name.empty() && count_active >= (size_t)base_params.max_models) { + SRV_INF("max_models limit reached, removing LRU name=%s\n", lru_model_name.c_str()); + unload(lru_model_name); + } +} + void server_models::load(const std::string & name) { - std::lock_guard lk(mutex); - if (mapping.find(name) == mapping.end()) { + if (!has_model(name)) { throw std::runtime_error("model name=" + name + " is not found"); } + unload_lru(); + + std::lock_guard lk(mutex); auto meta = mapping[name].meta; if (meta.status != SERVER_MODEL_STATUS_FAILED && meta.status != SERVER_MODEL_STATUS_UNLOADED) { @@ -286,9 +316,10 @@ void server_models::load(const std::string & name) { // prepare new instance info instance_t inst; - inst.meta = meta; - inst.meta.port = get_free_port(); - inst.meta.status = SERVER_MODEL_STATUS_LOADING; + inst.meta = meta; + inst.meta.port = get_free_port(); + inst.meta.status = SERVER_MODEL_STATUS_LOADING; + inst.meta.last_used = ggml_time_ms(); if (inst.meta.port <= 0) { throw std::runtime_error("failed to get a port number"); @@ -450,7 +481,7 @@ bool server_models::ensure_model_loaded(const std::string & name) { return true; } -server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name) { +server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used) { auto meta = get_meta(name); if (!meta.has_value()) { throw std::runtime_error("model name=" + name + " is not found"); @@ -458,6 +489,10 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co if (ensure_model_loaded(name)) { meta = get_meta(name); // refresh meta } + if (update_last_used) { + std::unique_lock lk(mutex); + mapping[name].meta.last_used = ggml_time_ms(); + } SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port); auto proxy = std::make_unique( method, diff --git a/tools/server/server-models.h b/tools/server/server-models.h index f8ae757fa4..3cb3b39fe7 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -58,6 +58,8 @@ struct server_model_meta { bool in_cache = false; // if true, use -hf; use -m otherwise int port = 0; server_model_status status = SERVER_MODEL_STATUS_UNLOADED; + int64_t last_used = 0; + bool is_active() const { return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING; } @@ -81,6 +83,9 @@ private: void update_meta(const std::string & name, const server_model_meta & meta); + // unload least recently used models if the limit is reached + void unload_lru(); + public: server_models(const common_params & params, int argc, char ** argv, char ** envp); @@ -109,7 +114,7 @@ public: bool ensure_model_loaded(const std::string & name); // proxy an HTTP request to the model instance - server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name); + server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used); // notify the router server that a model instance is ready static void setup_child_server(const std::string & host, int router_port, const std::string & name, std::function & shutdown_handler); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 37ef9d96ea..0ce5c14265 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -5143,7 +5143,7 @@ public: std::string method = "GET"; std::string name = req.get_param("model"); models->ensure_model_loaded(name); - return models->proxy_request(req, method, name); + return models->proxy_request(req, method, name, false); }; server_http_context::handler_t proxy_post = [this](const server_http_req & req) { @@ -5151,7 +5151,7 @@ public: json body = json::parse(req.body); std::string name = json_value(body, "model", std::string()); models->ensure_model_loaded(name); - return models->proxy_request(req, method, name); + return models->proxy_request(req, method, name, true); // update last usage for POST request only }; server_http_context::handler_t post_router_models_load = [this](const server_http_req & req) {