diff --git a/common/arg.cpp b/common/arg.cpp index 10aa1b5e4f..e4bdc6aa3d 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2671,8 +2671,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.out_file = value; } + ).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_FINETUNE, - LLAMA_EXAMPLE_RESULTS, LLAMA_EXAMPLE_EXPORT_GRAPH_OPS})); + LLAMA_EXAMPLE_RESULTS, LLAMA_EXAMPLE_EXPORT_GRAPH_OPS, LLAMA_EXAMPLE_CLI})); add_opt(common_arg( {"-ofreq", "--output-frequency"}, "N", string_format("output the imatrix every N iterations (default: %d)", params.n_out_freq), diff --git a/tools/cli/README.md b/tools/cli/README.md index 22d3fc87e9..7681917bae 100644 --- a/tools/cli/README.md +++ b/tools/cli/README.md @@ -39,6 +39,7 @@ | `--perf, --no-perf` | whether to enable internal libllama performance timings (default: false)
(env: LLAMA_ARG_PERF) | | `-f, --file FNAME` | a file containing the prompt (default: none) | | `-bf, --binary-file FNAME` | binary file containing the prompt (default: none) | +| `-o, --output FNAME` | a file to which to save the output (default: none) | | `-e, --escape, --no-escape` | whether to process escapes sequences (\n, \r, \t, \', \", \\) (default: true) | | `--rope-scaling {none,linear,yarn}` | RoPE frequency scaling method, defaults to linear unless specified by the model
(env: LLAMA_ARG_ROPE_SCALING_TYPE) | | `--rope-scale N` | RoPE context scaling factor, expands context by a factor of N
(env: LLAMA_ARG_ROPE_SCALE) | diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index 7c4342d6bf..61c8aa59ef 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -59,6 +59,8 @@ struct cli_context { bool verbose_prompt; int reasoning_budget = -1; std::string reasoning_budget_message; + common_reasoning_format reasoning_format; + std::optional file_out = std::nullopt; // thread for showing "loading" animation std::atomic loading_show; @@ -69,6 +71,7 @@ struct cli_context { defaults.n_keep = params.n_keep; defaults.n_predict = params.n_predict; defaults.antiprompt = params.antiprompt; + defaults.special_characters = params.special; defaults.stream = true; // make sure we always use streaming mode defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way @@ -77,6 +80,7 @@ struct cli_context { verbose_prompt = params.verbose_prompt; reasoning_budget = params.reasoning_budget; reasoning_budget_message = params.reasoning_budget_message; + reasoning_format = params.reasoning_format; } std::string generate_completion(result_timings & out_timings) { @@ -94,7 +98,7 @@ struct cli_context { // chat template settings task.params.chat_parser_params = common_chat_parser_params(chat_params); - task.params.chat_parser_params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + task.params.chat_parser_params.reasoning_format = reasoning_format; if (!chat_params.parser.empty()) { task.params.chat_parser_params.parser.load(chat_params.parser); } @@ -126,6 +130,11 @@ struct cli_context { console::set_display(DISPLAY_TYPE_RESET); } + append_file_out( + "[Prompt]: " + messages.back()["content"].get() + "\n\n", + chat_params.prompt + ); + // wait for first result console::spinner::start(); server_task_result_ptr result = rd.next(should_stop); @@ -133,6 +142,7 @@ struct cli_context { console::spinner::stop(); std::string curr_content; bool is_thinking = false; + bool content_started = false; while (result) { if (should_stop()) { @@ -155,26 +165,40 @@ struct cli_context { if (is_thinking) { console::log("\n[End thinking]\n\n"); console::set_display(DISPLAY_TYPE_RESET); + append_file_out("\n\n", ""); + is_thinking = false; } curr_content += diff.content_delta; console::log("%s", diff.content_delta.c_str()); console::flush(); + if (!content_started) { + append_file_out("[Assistant]: ", ""); + content_started = true; + } + append_file_out(diff.content_delta); } if (!diff.reasoning_content_delta.empty()) { console::set_display(DISPLAY_TYPE_REASONING); + std::string reasoning_delta = diff.reasoning_content_delta; if (!is_thinking) { console::log("[Start thinking]\n"); + append_file_out("[Thinking]: ", ""); + if (reasoning_delta == "") { + reasoning_delta = ""; + } } is_thinking = true; - console::log("%s", diff.reasoning_content_delta.c_str()); + console::log("%s", reasoning_delta.c_str()); console::flush(); + append_file_out(reasoning_delta); } } } auto res_final = dynamic_cast(result.get()); if (res_final) { out_timings = std::move(res_final->timings); + append_file_out("\n\n",""); break; } result = rd.next(should_stop); @@ -201,6 +225,18 @@ struct cli_context { } } + void append_file_out(const std::string & content, const std::optional & special_characters_content = std::nullopt) { + if (!file_out.has_value()) { + return; + } + if (defaults.special_characters && special_characters_content.has_value()) { + *file_out.value() << special_characters_content.value(); + } else { + *file_out.value() << content; + } + file_out.value()->flush(); + } + common_chat_params format_chat() { auto meta = ctx_server.get_meta(); auto & chat_params = meta.chat_params; @@ -365,6 +401,17 @@ int main(int argc, char ** argv) { console::init(params.simple_io, params.use_color); atexit([]() { console::cleanup(); }); + // open output file early to fail fast + std::ofstream output_file; + if (!params.out_file.empty()) { + output_file.open(params.out_file, std::ios::binary); + if (!output_file || !output_file.is_open()) { + console::error("Failed to open output file '%s'\n", params.out_file.c_str()); + return 1; + } + ctx_cli.file_out = &output_file; + } + console::set_display(DISPLAY_TYPE_RESET); console::set_completion_callback(auto_completion_callback); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 1e342531d8..0f95b0b274 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -51,7 +51,8 @@ struct task_params { bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt bool return_tokens = false; bool return_progress = false; - + bool special_characters = false; // whether to include special tokens in the output (e.g. , , , etc.) + int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half int32_t n_predict = -1; // new tokens to predict