diff --git a/common/arg.cpp b/common/arg.cpp index ec0a2f015e..f2675f842a 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1295,7 +1295,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.kv_unified = true; } - ).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY})); + ).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED})); add_opt(common_arg( {"--context-shift"}, {"--no-context-shift"}, diff --git a/common/common.h b/common/common.h index 7794c0268b..b3ac04c4ae 100644 --- a/common/common.h +++ b/common/common.h @@ -80,6 +80,7 @@ int32_t cpu_get_num_math(); // enum llama_example { + LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_DEBUG, LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_SPECULATIVE, diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 6b134b4f6f..6875038770 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -21,7 +21,7 @@ int main(int argc, char ** argv) { params.prompt = "Hello my name is"; params.n_predict = 32; - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_BATCHED, print_usage)) { return 1; }