diff --git a/common/arg.cpp b/common/arg.cpp index 0cc284e9c6..b864ca8c2b 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -212,7 +212,6 @@ struct handle_model_result { static handle_model_result common_params_handle_model( struct common_params_model & model, const std::string & bearer_token, - const std::string & model_path_default, bool offline) { handle_model_result result; // handle pre-fill default model path and url based on hf_repo and hf_file @@ -257,8 +256,6 @@ static handle_model_result common_params_handle_model( model.path = fs_get_cache_file(string_split(f, '/').back()); } - } else if (model.path.empty()) { - model.path = model_path_default; } } @@ -405,7 +402,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // handle model and download { - auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH, params.offline); + auto res = common_params_handle_model(params.model, params.hf_token, params.offline); if (params.no_mmproj) { params.mmproj = {}; } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { @@ -415,12 +412,18 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context // only download mmproj if the current example is using it for (auto & ex : mmproj_examples) { if (ctx_arg.ex == ex) { - common_params_handle_model(params.mmproj, params.hf_token, "", params.offline); + common_params_handle_model(params.mmproj, params.hf_token, params.offline); break; } } - common_params_handle_model(params.speculative.model, params.hf_token, "", params.offline); - common_params_handle_model(params.vocoder.model, params.hf_token, "", params.offline); + common_params_handle_model(params.speculative.model, params.hf_token, params.offline); + common_params_handle_model(params.vocoder.model, params.hf_token, params.offline); + } + + // model is required (except for server) + // TODO @ngxson : maybe show a list of available models in CLI in this case + if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER) { + throw std::invalid_argument("error: --model is required\n"); } if (params.escape) { @@ -2072,11 +2075,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex add_opt(common_arg( {"-m", "--model"}, "FNAME", ex == LLAMA_EXAMPLE_EXPORT_LORA - ? std::string("model path from which to load base model") - : string_format( - "model path (default: `models/$filename` with filename from `--hf-file` " - "or `--model-url` if set, otherwise %s)", DEFAULT_MODEL_PATH - ), + ? "model path from which to load base model" + : "model path to load", [](common_params & params, const std::string & value) { params.model.path = value; } diff --git a/common/common.h b/common/common.h index d7634589e7..197af5e6f2 100644 --- a/common/common.h +++ b/common/common.h @@ -26,8 +26,6 @@ fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ } while(0) -#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" - struct common_time_meas { common_time_meas(int64_t & t_acc, bool disable = false); ~common_time_meas(); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 0ce5c14265..e9388f208b 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -5662,8 +5662,7 @@ int main(int argc, char ** argv, char ** envp) { // register API routes server_routes routes(params, ctx_server, ctx_http); - // TODO: improve this by changing arg.cpp - bool is_router_server = params.model.path == DEFAULT_MODEL_PATH; + bool is_router_server = params.model.path.empty(); if (is_router_server) { // setup server instances manager routes.models.reset(new server_models(params, argc, argv, envp));