diff --git a/common/arg.cpp b/common/arg.cpp index 649216b7f0..cac8819956 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3044,6 +3044,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.models_max = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_MAX")); + add_opt(common_arg( + {"--models-memory-max"}, "N", + string_format("for router server, maximum memory usage in MB (default: %d, 0 = unlimited)", params.models_memory_max), + [](common_params & params, int value) { + params.models_memory_max = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_MEMORY_MAX")); add_opt(common_arg( {"--models-autoload"}, {"--no-models-autoload"}, diff --git a/common/common.h b/common/common.h index 31a337daa6..573a9bf4ef 100644 --- a/common/common.h +++ b/common/common.h @@ -621,6 +621,7 @@ struct common_params { std::string models_dir = ""; // directory containing models for the router server std::string models_preset = ""; // directory containing model presets for the router server int models_max = 4; // maximum number of models to load simultaneously + int models_memory_max = 0; // maximum memory usage in MB (0 = unlimited, estimated from model files) bool models_autoload = true; // automatically load models when requested via the router server bool log_json = false; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 6f737d94d0..bfa032a814 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3495,6 +3495,7 @@ void server_routes::init_routes() { { "total_slots", params.n_parallel }, { "model_alias", meta->model_name }, { "model_path", meta->model_path }, + { "memory_mb", meta->model_size / (1024 * 1024) }, { "modalities", json { {"vision", meta->has_inp_image}, {"audio", meta->has_inp_audio}, diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index c83709272f..f86e267919 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -300,6 +300,7 @@ void server_models::load_models() { /* port */ 0, /* status */ SERVER_MODEL_STATUS_UNLOADED, /* last_used */ 0, + /* memory_mb */ 0, /* args */ std::vector(), /* exit_code */ 0, /* stop_timeout */ DEFAULT_STOP_TIMEOUT, @@ -494,34 +495,45 @@ std::vector server_models::get_all_meta() { } void server_models::unload_lru() { - if (base_params.models_max <= 0) { + if (base_params.models_max <= 0 && base_params.models_memory_max <= 0) { return; // no limit } - // remove one of the servers if we passed the models_max (least recently used - LRU) - std::string lru_model_name = ""; - int64_t lru_last_used = ggml_time_ms(); - size_t count_active = 0; - { - std::unique_lock lk(mutex); - for (const auto & m : mapping) { - if (m.second.meta.is_running()) { - count_active++; - if (m.second.meta.last_used < lru_last_used) { - lru_model_name = m.first; - lru_last_used = m.second.meta.last_used; + // Keep unloading LRU models until limits are satisfied + while (true) { + std::string lru_model_name = ""; + int64_t lru_last_used = ggml_time_ms(); + size_t count_active = 0; + uint64_t total_memory_mb = 0; + { + std::unique_lock lk(mutex); + for (const auto & m : mapping) { + if (m.second.meta.is_running()) { + count_active++; + total_memory_mb += m.second.meta.memory_mb; + if (m.second.meta.last_used < lru_last_used) { + lru_model_name = m.first; + lru_last_used = m.second.meta.last_used; + } } } } - } - if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) { - SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str()); - unload(lru_model_name); - // wait for unload to complete - { - std::unique_lock lk(mutex); - cv.wait(lk, [this, &lru_model_name]() { - return mapping[lru_model_name].meta.status == SERVER_MODEL_STATUS_UNLOADED; - }); + // Check if limits exceeded + bool count_exceeded = base_params.models_max > 0 && count_active >= (size_t)base_params.models_max; + bool memory_exceeded = base_params.models_memory_max > 0 && total_memory_mb >= (uint64_t)base_params.models_memory_max; + if (!lru_model_name.empty() && (count_exceeded || memory_exceeded)) { + SRV_INF("limits reached (count=%zu, memory=%lu MB), removing LRU name=%s\n", + count_active, (unsigned long)total_memory_mb, lru_model_name.c_str()); + unload(lru_model_name); + // wait for unload to complete + { + std::unique_lock lk(mutex); + cv.wait(lk, [this, &lru_model_name]() { + return mapping[lru_model_name].meta.status == SERVER_MODEL_STATUS_UNLOADED; + }); + } + // Loop continues to check if more unloading is needed + } else { + break; // limits satisfied } } } @@ -544,14 +556,18 @@ void server_models::load(const std::string & name) { // exceeding models_max. Without this, the window between unload_lru() // releasing its lock and this lock_guard acquiring allows multiple // threads to each observe capacity and all proceed to load. - if (base_params.models_max > 0) { + if (base_params.models_max > 0 || base_params.models_memory_max > 0) { size_t count_active = 0; + uint64_t total_memory_mb = 0; for (const auto & m : mapping) { if (m.second.meta.is_running()) { count_active++; + total_memory_mb += m.second.meta.memory_mb; } } - if (count_active >= (size_t)base_params.models_max) { + bool count_exceeded = base_params.models_max > 0 && count_active >= (size_t)base_params.models_max; + bool memory_exceeded = base_params.models_memory_max > 0 && total_memory_mb >= (uint64_t)base_params.models_memory_max; + if (count_exceeded || memory_exceeded) { throw std::runtime_error("model limit reached, try again later"); } } @@ -608,10 +624,35 @@ void server_models::load(const std::string & name) { // also handle status report from child process if (stdout_file) { char buffer[4096]; + bool ready_received = false; while (fgets(buffer, sizeof(buffer), stdout_file) != nullptr) { LOG("[%5d] %s", port, buffer); std::string str(buffer); if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_READY)) { + // Query memory usage from the child's /props endpoint + if (!ready_received) { + ready_received = true; + try { + httplib::Client cli("http://CHILD_ADDR"); + cli.set_connection_timeout(5, 0); + if (auto res = cli.Get("/props")) { + if (res->status == 200) { + json props = json::parse(res->body); + if (props.contains("memory_mb")) { + uint64_t memory_mb = props["memory_mb"].get(); + SRV_INF("model %s loaded, memory usage: %lu MB\n", name.c_str(), (unsigned long)memory_mb); + // Update memory_mb in meta + std::lock_guard lk(this->mutex); + if (mapping.find(name) != mapping.end()) { + mapping[name].meta.memory_mb = memory_mb; + } + } + } + } + } catch (const std::exception & e) { + SRV_WRN("failed to query memory for model %s: %s\n", name.c_str(), e.what()); + } + } this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0); } else if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_SLEEP)) { this->update_status(name, SERVER_MODEL_STATUS_SLEEPING, 0); diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 1db34b6c4d..c195dbeb26 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -62,6 +62,7 @@ struct server_model_meta { int port = 0; server_model_status status = SERVER_MODEL_STATUS_UNLOADED; int64_t last_used = 0; // for LRU unloading + uint64_t memory_mb = 0; // estimated memory usage in MB std::vector args; // args passed to the model instance, will be populated by render_args() int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown