diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 7f62dc4edb..3a5bd2b215 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -324,6 +324,18 @@ void server_models::unload_lru() { } } +static void add_or_replace_arg(std::vector & args, const std::string & key, const std::string & value) { + for (size_t i = 0; i < args.size(); i++) { + if (args[i] == key && i + 1 < args.size()) { + args[i + 1] = value; + return; + } + } + // not found, append + args.push_back(key); + args.push_back(value); +} + void server_models::load(const std::string & name, bool auto_load) { if (!has_model(name)) { throw std::runtime_error("model name=" + name + " is not found"); @@ -356,33 +368,23 @@ void server_models::load(const std::string & name, bool auto_load) { std::vector child_args; if (auto_load && !meta.args.empty()) { - child_args = meta.args; // reuse previous args - // update port arg - for (size_t i = 0; i < child_args.size(); i++) { - if (child_args[i] == "--port" && i + 1 < child_args.size()) { - child_args[i + 1] = std::to_string(inst.meta.port); - break; - } - } + child_args = meta.args; // copy previous args } else { child_args = base_args; // copy if (inst.meta.in_cache) { - child_args.push_back("-hf"); - child_args.push_back(inst.meta.name); + add_or_replace_arg(child_args, "-hf", inst.meta.name); } else { - child_args.push_back("-m"); - child_args.push_back(inst.meta.path); + add_or_replace_arg(child_args, "-m", inst.meta.path); if (!inst.meta.path_mmproj.empty()) { - child_args.push_back("--mmproj"); - child_args.push_back(inst.meta.path_mmproj); + add_or_replace_arg(child_args, "--mmproj", inst.meta.path_mmproj); } } - child_args.push_back("--alias"); - child_args.push_back(inst.meta.name); - child_args.push_back("--port"); - child_args.push_back(std::to_string(inst.meta.port)); } + // set model args + add_or_replace_arg(child_args, "--port", std::to_string(inst.meta.port)); + add_or_replace_arg(child_args, "--alias", inst.meta.name); + std::vector child_env = base_env; // copy child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));