diff --git a/common/arg.cpp b/common/arg.cpp index 430ab45dfe..0cc284e9c6 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2474,6 +2474,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--models-dir"}, "PATH", + "directory containing models for the router server (default: disabled)", + [](common_params & params, const std::string & value) { + params.models_dir = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--max-models"}, "N", + string_format("for router server, maximum number of models to load simultaneously (default: %d, 0 = unlimited)", params.max_models), + [](common_params & params, int value) { + params.max_models = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--jinja"}, "use jinja template for chat (default: disabled)", diff --git a/common/common.h b/common/common.h index de5b404dd8..d7634589e7 100644 --- a/common/common.h +++ b/common/common.h @@ -460,6 +460,10 @@ struct common_params { bool endpoint_props = false; // only control POST requests, not GET bool endpoint_metrics = false; + // router server configs + std::string models_dir = ""; // directory containing models for the router server + int max_models = 4; // maximum number of models to load simultaneously + bool log_json = false; std::string slot_save_path; diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index d615adfc85..16f91e65c0 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #ifdef _WIN32 #include @@ -69,6 +70,46 @@ static std::filesystem::path get_server_exec_path() { #endif } +struct local_model { + std::string name; + std::string path; + std::string path_mmproj; +}; + +static std::vector list_local_models(const std::string & dir) { + if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) { + throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", dir.c_str())); + } + auto files = fs_list_files(dir); + std::unordered_set files_model; + std::unordered_set files_mmproj; + for (const auto & file : files) { + // TODO: also handle multiple shards + if (string_ends_with(file.name, ".gguf")) { + if (string_starts_with(file.name, "mmproj-")) { + files_mmproj.insert(file.name); + } else { + files_model.insert(file.name); + } + } + } + std::vector models; + for (const auto & model_file : files_model) { + bool has_mmproj = false; + std::string mmproj_file = "mmproj-" + model_file; + if (files_mmproj.find(mmproj_file) != files_mmproj.end()) { + has_mmproj = true; + } + local_model model{ + /* name */ model_file, + /* path */ dir + DIRECTORY_SEPARATOR + model_file, + /* path_mmproj */ has_mmproj ? (dir + DIRECTORY_SEPARATOR + mmproj_file) : "" + }; + models.push_back(model); + } + return models; +} + // // server_models // @@ -85,12 +126,13 @@ server_models::server_models( base_env.push_back(std::string(*env)); } // TODO: allow refreshing cached model list + // add cached models auto cached_models = common_list_cached_models(); for (const auto & model : cached_models) { server_model_meta meta{ /* name */ model.to_string(), /* path */ model.manifest_path, - /* path_mmproj */ "", + /* path_mmproj */ "", // auto-detected when loading /* in_cache */ true, /* port */ 0, /* status */ SERVER_MODEL_STATUS_UNLOADED @@ -101,6 +143,29 @@ server_models::server_models( /* meta */ meta }; } + // add local models specificed via --models-dir + if (!params.models_dir.empty()) { + auto local_models = list_local_models(params.models_dir); + for (const auto & model : local_models) { + if (mapping.find(model.name) != mapping.end()) { + // already exists in cached models, skip + continue; + } + server_model_meta meta{ + /* name */ model.name, + /* path */ model.path, + /* path_mmproj */ model.path_mmproj, + /* in_cache */ false, + /* port */ 0, + /* status */ SERVER_MODEL_STATUS_UNLOADED + }; + mapping[meta.name] = instance_t{ + /* subproc */ std::make_shared(), + /* th */ std::thread(), + /* meta */ meta + }; + } + } } void server_models::update_meta(const std::string & name, const server_model_meta & meta) {