From acb7c790698fa28a0fbfc0468804926815b94de3 Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Wed, 11 Mar 2026 10:26:12 +0100 Subject: [PATCH] common/parser: handle reasoning budget (#20297) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * v1 * Finished! * Handlie cli * Reasoning sampler * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret * Less explosive terminology :) * Add utf-8 case and tests * common : migrate reasoning budget sampler to common * cont : clean up * cont : expose state and allow passing as initial state * cont : remove unused imports * cont : update state machine doc string --------- Co-authored-by: Sigbjørn Skjæret Co-authored-by: Alde Rojas --- common/CMakeLists.txt | 2 + common/arg.cpp | 33 +++- common/chat-auto-parser-generator.cpp | 4 +- common/chat.cpp | 22 ++- common/chat.h | 2 + common/common.h | 10 ++ common/reasoning-budget.cpp | 219 ++++++++++++++++++++++++ common/reasoning-budget.h | 41 +++++ common/sampling.cpp | 12 ++ common/unicode.cpp | 18 +- common/unicode.h | 3 + tests/CMakeLists.txt | 1 + tests/test-reasoning-budget.cpp | 238 ++++++++++++++++++++++++++ tools/cli/cli.cpp | 22 +++ tools/server/server-common.cpp | 16 ++ tools/server/server-common.h | 2 + tools/server/server-context.cpp | 7 +- tools/server/server-task.cpp | 28 +++ 18 files changed, 670 insertions(+), 10 deletions(-) create mode 100644 common/reasoning-budget.cpp create mode 100644 common/reasoning-budget.h create mode 100644 tests/test-reasoning-budget.cpp diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 51bff1c44b..75c6366c7f 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -81,6 +81,8 @@ add_library(${TARGET} STATIC preset.cpp preset.h regex-partial.cpp + reasoning-budget.cpp + reasoning-budget.h regex-partial.h sampling.cpp sampling.h diff --git a/common/arg.cpp b/common/arg.cpp index 0be6b28eb2..41da8563d6 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2913,6 +2913,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { auto parsed = json::parse(value); for (const auto & item : parsed.items()) { + if (item.key() == "enable_thinking") { + LOG_WRN("Setting 'enable_thinking' via --chat-template-kwargs is deprecated. " + "Use --reasoning on / --reasoning off instead.\n"); + } params.default_template_kwargs[item.key()] = item.value().dump(); } } @@ -3048,14 +3052,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.reasoning_format = common_reasoning_format_from_name(value); } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK")); + add_opt(common_arg( + {"-rea", "--reasoning"}, "[on|off|auto]", + "Use reasoning/thinking in the chat ('on', 'off', or 'auto', default: 'auto' (detect from template))", + [](common_params & params, const std::string & value) { + if (is_truthy(value)) { + params.enable_reasoning = 1; + params.default_template_kwargs["enable_thinking"] = "true"; + } else if (is_falsey(value)) { + params.enable_reasoning = 0; + params.default_template_kwargs["enable_thinking"] = "false"; + } else if (is_autoy(value)) { + params.enable_reasoning = -1; + } else { + throw std::invalid_argument( + string_format("error: unknown value for --reasoning: '%s'\n", value.c_str())); + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING")); add_opt(common_arg( {"--reasoning-budget"}, "N", - "controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)", + "token budget for thinking: -1 for unrestricted, 0 for immediate end, N>0 for token budget (default: -1)", [](common_params & params, int value) { - if (value != 0 && value != -1) { throw std::invalid_argument("invalid value"); } + if (value < -1) { throw std::invalid_argument("invalid value"); } params.reasoning_budget = value; } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET")); + add_opt(common_arg( + {"--reasoning-budget-message"}, "MESSAGE", + "message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)", + [](common_params & params, const std::string & value) { + params.reasoning_budget_message = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE")); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", string_format( diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp index 1c74ad30d9..b7cf513942 100644 --- a/common/chat-auto-parser-generator.cpp +++ b/common/chat-auto-parser-generator.cpp @@ -135,7 +135,9 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co if (thinking_forced_open || thinking_forced_closed) { // Thinking is forced open OR forced closed with enable_thinking=true // In both cases, expect only the closing tag (opening was in template) - return p.reasoning(p.until(end)) + end; + // However, since we might have incorrectly detected the open/close pattern, + // we admit an optional starting marker + return p.optional(p.literal(start)) + p.reasoning(p.until(end)) + end; } if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) { // Standard tag-based reasoning OR tools-only mode (reasoning appears with tools) diff --git a/common/chat.cpp b/common/chat.cpp index 60fd64ff91..b799912ae4 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -857,7 +857,9 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; auto include_grammar = true; - data.supports_thinking = true; + data.supports_thinking = true; + data.thinking_start_tag = "[THINK]"; + data.thinking_end_tag = "[/THINK]"; data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.preserved_tokens = { @@ -1165,9 +1167,11 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp const autoparser::templates_params & inputs) { common_chat_params data; - data.prompt = common_chat_template_direct_apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.supports_thinking = true; + data.prompt = common_chat_template_direct_apply(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.supports_thinking = true; + data.thinking_start_tag = ""; + data.thinking_end_tag = ""; data.preserved_tokens = { "<|tool_calls_section_begin|>", "<|tool_calls_section_end|>", @@ -1527,6 +1531,16 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ autoparser.analyze_template(tmpl); auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser); auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE; + if (auto_params.supports_thinking) { + auto_params.thinking_start_tag = autoparser.reasoning.start; + auto_params.thinking_end_tag = autoparser.reasoning.end; + // FORCED_OPEN and FORCED_CLOSED both put in the generation prompt + // (FORCED_CLOSED forces empty when thinking is disabled, + // but forces open when thinking is enabled) + auto_params.thinking_forced_open = + autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN || + autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED; + } return auto_params; } catch (const std::exception & e) { throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what()); diff --git a/common/chat.h b/common/chat.h index 005cc5c8b3..930987cf77 100644 --- a/common/chat.h +++ b/common/chat.h @@ -213,6 +213,8 @@ struct common_chat_params { bool grammar_lazy = false; bool thinking_forced_open = false; bool supports_thinking = false; + std::string thinking_start_tag; // e.g., "" + std::string thinking_end_tag; // e.g., "" std::vector grammar_triggers; std::vector preserved_tokens; std::vector additional_stops; diff --git a/common/common.h b/common/common.h index 440eb97200..ffaeefd7c9 100644 --- a/common/common.h +++ b/common/common.h @@ -235,6 +235,14 @@ struct common_params_sampling { std::vector logit_bias; // logit biases to apply std::vector logit_bias_eog; // pre-calculated logit biases for EOG tokens + // reasoning budget sampler parameters + // these are populated by the server/CLI based on chat template params + int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget + bool reasoning_budget_activate_immediately = false; + std::vector reasoning_budget_start; // start tag token sequence + std::vector reasoning_budget_end; // end tag token sequence + std::vector reasoning_budget_forced; // forced sequence (message + end tag) + bool backend_sampling = false; bool has_logit_bias() const { @@ -536,7 +544,9 @@ struct common_params { bool use_jinja = true; // NOLINT bool enable_chat_template = true; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable int reasoning_budget = -1; + std::string reasoning_budget_message; // message injected before end tag when budget exhausted bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time diff --git a/common/reasoning-budget.cpp b/common/reasoning-budget.cpp new file mode 100644 index 0000000000..a55e4f509d --- /dev/null +++ b/common/reasoning-budget.cpp @@ -0,0 +1,219 @@ +#include "reasoning-budget.h" +#include "common.h" +#include "unicode.h" + +#include "log.h" + +#include +#include +#include +#include + +struct token_matcher { + std::vector tokens; + size_t pos = 0; + + bool advance(llama_token token) { + if (tokens.empty()) { + return false; + } + + if (token == tokens[pos]) { + pos++; + if (pos >= tokens.size()) { + pos = 0; + return true; + } + } else { + pos = 0; + if (token == tokens[0]) { + pos = 1; + } + } + return false; + } + + void reset() { pos = 0; } +}; + +struct common_reasoning_budget_ctx { + const llama_vocab * vocab; + + token_matcher start_matcher; + token_matcher end_matcher; + std::vector forced_tokens; + + int32_t budget; // maximum tokens in reasoning block + int32_t remaining; // tokens remaining in budget + + common_reasoning_budget_state state; + + // for forcing + size_t force_pos; // next position in forced_tokens to force +}; + +static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) { + return "reasoning-budget"; +} + +static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; + + switch (ctx->state) { + case REASONING_BUDGET_IDLE: + { + if (ctx->start_matcher.advance(token)) { + ctx->state = REASONING_BUDGET_COUNTING; + ctx->remaining = ctx->budget; + LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget); + + if (ctx->remaining <= 0) { + ctx->state = REASONING_BUDGET_FORCING; + ctx->force_pos = 0; + LOG_INF("reasoning-budget: budget=0, forcing immediately\n"); + } + } + break; + } + case REASONING_BUDGET_COUNTING: + case REASONING_BUDGET_WAITING_UTF8: + { + if (ctx->end_matcher.advance(token)) { + ctx->state = REASONING_BUDGET_DONE; + LOG_INF("reasoning-budget: deactivated (natural end)\n"); + break; + } + + bool utf8_complete = true; + if (ctx->vocab != nullptr) { + const std::string piece = common_token_to_piece(ctx->vocab, token, false); + utf8_complete = common_utf8_is_complete(piece); + } + + if (ctx->state == REASONING_BUDGET_WAITING_UTF8) { + if (utf8_complete) { + ctx->state = REASONING_BUDGET_FORCING; + ctx->force_pos = 0; + ctx->end_matcher.reset(); + LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n"); + } + } else if (ctx->state == REASONING_BUDGET_COUNTING) { + ctx->remaining--; + if (ctx->remaining <= 0) { + if (utf8_complete) { + ctx->state = REASONING_BUDGET_FORCING; + ctx->force_pos = 0; + ctx->end_matcher.reset(); + LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n"); + } else { + ctx->state = REASONING_BUDGET_WAITING_UTF8; + ctx->end_matcher.reset(); + LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n"); + } + } + } + break; + } + case REASONING_BUDGET_FORCING: + // force_pos is advanced in apply(), not here. + // This ensures the first forced token isn't skipped when the sampler + // is initialized directly in FORCING state (e.g. COUNTING + budget=0) + break; + case REASONING_BUDGET_DONE: + break; + } +} + +static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; + + if (ctx->state != REASONING_BUDGET_FORCING) { + // passthrough — don't modify logits + return; + } + + if (ctx->force_pos >= ctx->forced_tokens.size()) { + return; + } + + const llama_token forced = ctx->forced_tokens[ctx->force_pos]; + + // set all logits to -inf except the forced token + for (size_t i = 0; i < cur_p->size; i++) { + if (cur_p->data[i].id != forced) { + cur_p->data[i].logit = -INFINITY; + } + } + + // advance to next forced token (done here rather than in accept so that + // the first forced token isn't skipped when starting in FORCING state) + ctx->force_pos++; + if (ctx->force_pos >= ctx->forced_tokens.size()) { + ctx->state = REASONING_BUDGET_DONE; + LOG_INF("reasoning-budget: forced sequence complete, done\n"); + } +} + +static void common_reasoning_budget_reset(struct llama_sampler * smpl) { + auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; + ctx->state = REASONING_BUDGET_IDLE; + ctx->remaining = ctx->budget; + ctx->start_matcher.reset(); + ctx->end_matcher.reset(); + ctx->force_pos = 0; +} + +static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx; + return common_reasoning_budget_init( + ctx->vocab, + ctx->start_matcher.tokens, + ctx->end_matcher.tokens, + ctx->forced_tokens, + ctx->budget, + ctx->state); +} + +static void common_reasoning_budget_free(struct llama_sampler * smpl) { + delete (common_reasoning_budget_ctx *) smpl->ctx; +} + +static struct llama_sampler_i common_reasoning_budget_i = { + /* .name = */ common_reasoning_budget_name, + /* .accept = */ common_reasoning_budget_accept, + /* .apply = */ common_reasoning_budget_apply, + /* .reset = */ common_reasoning_budget_reset, + /* .clone = */ common_reasoning_budget_clone, + /* .free = */ common_reasoning_budget_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, +}; + +struct llama_sampler * common_reasoning_budget_init( + const struct llama_vocab * vocab, + const std::vector & start_tokens, + const std::vector & end_tokens, + const std::vector & forced_tokens, + int32_t budget, + common_reasoning_budget_state initial_state) { + // promote COUNTING with budget <= 0 to FORCING + if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) { + initial_state = REASONING_BUDGET_FORCING; + } + + return llama_sampler_init( + /* .iface = */ &common_reasoning_budget_i, + /* .ctx = */ new common_reasoning_budget_ctx { + /* .vocab = */ vocab, + /* .start_matcher = */ { start_tokens, 0 }, + /* .end_matcher = */ { end_tokens, 0 }, + /* .forced_tokens = */ forced_tokens, + /* .budget = */ budget, + /* .remaining = */ budget, + /* .state = */ initial_state, + /* .force_pos = */ 0, + } + ); +} diff --git a/common/reasoning-budget.h b/common/reasoning-budget.h new file mode 100644 index 0000000000..08ad282481 --- /dev/null +++ b/common/reasoning-budget.h @@ -0,0 +1,41 @@ +#pragma once + +#include "llama.h" + +#include +#include + +enum common_reasoning_budget_state { + REASONING_BUDGET_IDLE, // waiting for start sequence + REASONING_BUDGET_COUNTING, // counting down tokens + REASONING_BUDGET_FORCING, // forcing budget message + end sequence + REASONING_BUDGET_WAITING_UTF8, // budget exhausted, waiting for UTF-8 completion + REASONING_BUDGET_DONE, // passthrough forever +}; + +// Creates a reasoning budget sampler that limits token generation inside a +// reasoning block (e.g. between and ). +// +// State machine: IDLE -> COUNTING -> WAITING_UTF8 -> FORCING -> DONE +// IDLE: passthrough, watching for start_tokens sequence +// COUNTING: counting down remaining tokens, watching for natural end_tokens +// WAITING_UTF8: budget exhausted, allowing tokens to complete a UTF-8 sequence +// FORCING: forces forced_tokens token-by-token (all other logits -> -inf) +// DONE: passthrough forever +// +// Parameters: +// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr) +// start_tokens - token sequence that activates counting +// end_tokens - token sequence for natural deactivation +// forced_tokens - token sequence forced when budget expires +// budget - max tokens allowed in the reasoning block +// initial_state - initial state of the sampler (e.g. IDLE or COUNTING) +// note: COUNTING with budget <= 0 is promoted to FORCING +// +struct llama_sampler * common_reasoning_budget_init( + const struct llama_vocab * vocab, + const std::vector & start_tokens, + const std::vector & end_tokens, + const std::vector & forced_tokens, + int32_t budget, + common_reasoning_budget_state initial_state); diff --git a/common/sampling.cpp b/common/sampling.cpp index 11a1d48398..f849d4f61a 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -2,6 +2,7 @@ #include "common.h" #include "log.h" +#include "reasoning-budget.h" #include #include @@ -250,6 +251,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st } } + // reasoning budget sampler — added first so it can force tokens before other samplers + if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) { + samplers.push_back(common_reasoning_budget_init( + vocab, + params.reasoning_budget_start, + params.reasoning_budget_end, + params.reasoning_budget_forced, + params.reasoning_budget_tokens, + params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE)); + } + if (params.has_logit_bias()) { samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data())); } diff --git a/common/unicode.cpp b/common/unicode.cpp index c0ef6d0292..f71fe56783 100644 --- a/common/unicode.cpp +++ b/common/unicode.cpp @@ -1,8 +1,10 @@ #include "unicode.h" + +#include #include #include -#include #include +#include // implementation adopted from src/unicode.cpp @@ -67,6 +69,20 @@ utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t off return utf8_parse_result(utf8_parse_result::INVALID); } +bool common_utf8_is_complete(const std::string & s) { + if (s.empty()) { + return true; + } + for (int i = 1; i <= std::min(4, (int)s.size()); i++) { + unsigned char c = s[s.size() - i]; + if ((c & 0xC0) != 0x80) { + int expected = (c >= 0xF0) ? 4 : (c >= 0xE0) ? 3 : (c >= 0xC0) ? 2 : 1; + return i >= expected; + } + } + return false; +} + std::string common_unicode_cpts_to_utf8(const std::vector & cps) { std::string result; for (size_t i = 0; i < cps.size(); ++i) { diff --git a/common/unicode.h b/common/unicode.h index 87bcc0ffca..9b32fa19d6 100644 --- a/common/unicode.h +++ b/common/unicode.h @@ -20,6 +20,9 @@ struct utf8_parse_result { // Returns 0 for invalid first bytes size_t common_utf8_sequence_length(unsigned char first_byte); +// Check if a string ends with a complete UTF-8 sequence. +bool common_utf8_is_complete(const std::string & s); + // Parse a single UTF-8 codepoint from input utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7fd895e2b6..bb0f0ef0ed 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -149,6 +149,7 @@ endif () if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) # these tests are disabled on Windows because they use internal functions not exported with LLAMA_API (when building with shared libraries) llama_build_and_test(test-sampling.cpp) + llama_build_and_test(test-reasoning-budget.cpp) llama_build_and_test(test-grammar-parser.cpp) llama_build_and_test(test-grammar-integration.cpp) llama_build_and_test(test-llama-grammar.cpp) diff --git a/tests/test-reasoning-budget.cpp b/tests/test-reasoning-budget.cpp new file mode 100644 index 0000000000..ab540a8463 --- /dev/null +++ b/tests/test-reasoning-budget.cpp @@ -0,0 +1,238 @@ +#include "reasoning-budget.h" +#include "unicode.h" + +#include "llama.h" +#include "ggml.h" + +#ifdef NDEBUG +#undef NDEBUG +#endif + +#include +#include +#include +#include +#include + +// Reasoning budget sampler test helper +// These tests use nullptr vocab which safely falls back to treating all tokens as complete +// (The UTF-8 boundary detection logic is tested separately in test_utf8_boundary_detection) +static void test_reasoning_budget( + const char * test_name, + const std::vector & sequence, + const std::vector & start_tokens, + const std::vector & end_tokens, + const std::vector & forced_tokens, + int32_t budget, + common_reasoning_budget_state initial_state, + size_t expected_force_start, // token index where forcing should start (SIZE_MAX = never) + size_t expected_force_end // token index where forcing should end (after this, no more forcing) +) { + // Find the maximum token ID to ensure our vocab covers all tokens + llama_token max_token = 0; + for (auto t : sequence) max_token = std::max(max_token, t); + for (auto t : start_tokens) max_token = std::max(max_token, t); + for (auto t : end_tokens) max_token = std::max(max_token, t); + for (auto t : forced_tokens) max_token = std::max(max_token, t); + + // Create a minimal sampler with mock vocabulary + // For this test, we use nullptr as vocab since we're testing state transitions + // The UTF-8 boundary check will treat all tokens as complete (safe fallback) + auto * sampler = common_reasoning_budget_init( + nullptr, // vocab - not used for basic state machine tests + start_tokens, + end_tokens, + forced_tokens, + budget, + initial_state + ); + + // Create a test token data array for checking forcing behavior + // Vocab size must be large enough to include all tokens (start, end, forced, sequence) + std::vector cur; + const size_t n_vocab = (size_t)max_token + 1; + for (size_t i = 0; i < n_vocab; i++) { + cur.emplace_back(llama_token_data{(llama_token)i, logf((float)(i+1)), 0.0f}); + } + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; + + size_t actual_force_start = SIZE_MAX; + size_t actual_force_end = SIZE_MAX; + + // Feed the sequence and track when forcing occurs + for (size_t i = 0; i < sequence.size(); i++) { + llama_sampler_accept(sampler, sequence[i]); + + // Check if we're in forcing state by applying and seeing if logits are modified + cur_p.selected = -1; + for (size_t j = 0; j < cur.size(); j++) { + cur[j].logit = logf((float)(j+1)); // reset logits + } + + llama_sampler_apply(sampler, &cur_p); + + // Check if forcing is active (all logits except one should be -INFINITY) + size_t finite_count = 0; + llama_token finite_token = -1; + for (size_t j = 0; j < cur.size(); j++) { + if (std::isfinite(cur[j].logit)) { + finite_count++; + finite_token = cur[j].id; + } + } + + fprintf(stderr, " i=%zu: token=%d, finite_count=%zu, finite_token=%d\n", i, (int)sequence[i], finite_count, (int)finite_token); + + if (finite_count == 1) { + if (actual_force_start == SIZE_MAX) { + actual_force_start = i; + } + actual_force_end = i; + } else if (actual_force_start != SIZE_MAX && actual_force_end != SIZE_MAX) { + // Forcing stopped + break; + } + } + + llama_sampler_free(sampler); + + // Verify forcing occurred at expected positions + if (expected_force_start == SIZE_MAX) { + if (actual_force_start != SIZE_MAX) { + fprintf(stderr, "Test '%s' FAILED: Expected no forcing, but forcing occurred at %zu\n", test_name, actual_force_start); + GGML_ASSERT(false && "Expected no forcing, but forcing occurred"); + } + } else { + if (actual_force_start == SIZE_MAX) { + fprintf(stderr, "Test '%s' FAILED: Expected forcing but none occurred\n", test_name); + GGML_ASSERT(false && "Expected forcing but none occurred"); + } + if (actual_force_start != expected_force_start) { + fprintf(stderr, "Test '%s' FAILED: Forcing started at %zu, expected %zu\n", test_name, actual_force_start, expected_force_start); + GGML_ASSERT(false && "Forcing started at wrong position"); + } + } + + if (expected_force_end != SIZE_MAX) { + if (actual_force_end < expected_force_end) { + fprintf(stderr, "Test '%s' FAILED: Forcing ended at %zu, expected >= %zu\n", test_name, actual_force_end, expected_force_end); + GGML_ASSERT(false && "Forcing ended too early"); + } + } + + fprintf(stderr, " Test '%s' passed (force_start=%zu, force_end=%zu)\n", test_name, actual_force_start, actual_force_end); + (void)sequence; +} + +// UTF-8 boundary detection unit test +// Tests common_utf8_is_complete() from reasoning-budget.h +static void test_utf8_boundary_detection() { + // Complete sequences + GGML_ASSERT(common_utf8_is_complete("hello")); + GGML_ASSERT(common_utf8_is_complete("")); + GGML_ASSERT(common_utf8_is_complete("\xC2\xA0")); // complete 2-byte UTF-8 (U+00A0) + GGML_ASSERT(common_utf8_is_complete("\xE2\x80\x9C")); // complete 3-byte UTF-8 (left double quote) + GGML_ASSERT(common_utf8_is_complete("\xF0\x9F\x98\x80")); // complete 4-byte UTF-8 (emoji) + GGML_ASSERT(common_utf8_is_complete("abc\xC3\xA9")); // ASCII + complete 2-byte + + // Incomplete sequences + GGML_ASSERT(!common_utf8_is_complete(std::string("\xC2", 1))); // 2-byte start, missing continuation + GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2\x80", 2))); // 3-byte start + 1 cont, missing 1 + GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2", 1))); // 3-byte start, missing 2 + GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F\x98", 3))); // 4-byte start + 2 cont, missing 1 + GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F", 2))); // 4-byte start + 1 cont, missing 2 + GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0", 1))); // 4-byte start, missing 3 + GGML_ASSERT(!common_utf8_is_complete(std::string("\x80", 1))); // orphan continuation byte + + // Mixed: ASCII followed by start of multi-byte + GGML_ASSERT(!common_utf8_is_complete(std::string("hello\xC3", 6))); // ASCII + incomplete 2-byte + GGML_ASSERT(common_utf8_is_complete(std::string("hello\xC3\xA9", 7))); // ASCII + complete 2-byte +} + +int main(void) { + // Reasoning budget sampler tests + printf("Testing reasoning budget sampler... "); + + // Test 1: Basic budget with start/end tokens - no forcing (natural end before budget exhausted) + { + const std::vector start = {100}; // start token + const std::vector end = {101}; // end token + const std::vector forced = {102}; // forced token (not used in this test) + const std::vector sequence = {100, 50, 51, 101, 52}; // start, two tokens, end, one more + + test_reasoning_budget("natural end before budget exhausted", sequence, start, end, forced, + 5, // budget of 5 tokens + REASONING_BUDGET_IDLE, + SIZE_MAX, SIZE_MAX); // no forcing expected (natural end) + } + + // Test 2: Budget exhausted, forcing should occur + // Flow: i=0 accept(100)->COUNTING, i=1 accept(50)->remaining=1, i=2 accept(51)->remaining=0->FORCING + // Forcing is active at i=2 and i=3 (when apply() is called while in FORCING state) + // At i=4, force_pos becomes 2 which equals forced_tokens.size(), so state becomes DONE + { + const std::vector start = {100}; + const std::vector end = {101}; + const std::vector forced = {102, 101}; // forced message + end + const std::vector sequence = {100, 50, 51, 52, 53}; // start + 4 tokens (budget=2) + + test_reasoning_budget("budget exhausted forcing", sequence, start, end, forced, + 2, // budget of 2 tokens + REASONING_BUDGET_IDLE, + 2, // forcing starts at i=2 (after accept(51) depletes budget, apply() forces) + 3); // forcing continues through i=3 (at i=4 state becomes DONE) + } + + // Test 3: Activate immediately with budget=0, forcing should start right away + // Flow: Since no start token in sequence, state stays IDLE (no start/end configured means passthrough) + // This test needs start token to be in the sequence or use activate_immediately with start token present + { + const std::vector start = {100}; + const std::vector end = {101}; + const std::vector forced = {102, 101}; + const std::vector sequence = {100, 50, 51, 52}; // start token first, then 3 tokens + + test_reasoning_budget("activate immediately budget=0", sequence, start, end, forced, + 0, // budget of 0 tokens + REASONING_BUDGET_COUNTING, // starts counting, promoted to FORCING since budget=0 + 0, // forcing starts at i=0 (after accept(100), budget=0 goes straight to FORCING) + 1); // forcing continues through i=1 (at i=2 state becomes DONE) + } + + // Test 4: No start/end tokens configured - passthrough (no forcing) + { + const std::vector start = {}; + const std::vector end = {}; + const std::vector forced = {102}; + const std::vector sequence = {50, 51, 52, 53}; + + test_reasoning_budget("no start/end configured", sequence, start, end, forced, + 2, // budget + REASONING_BUDGET_IDLE, + SIZE_MAX, SIZE_MAX); // no forcing (no start/end configured) + } + + // Test 5: Activate immediately with budget > 0, count down then force + // Flow: i=0 accept(50)->remaining=1, i=1 accept(51)->remaining=0->FORCING + // So forcing starts at i=1 (apply after accept sees FORCING with force_pos=0) + { + const std::vector start = {100}; + const std::vector end = {101}; + const std::vector forced = {102, 101}; + const std::vector sequence = {50, 51, 52, 53}; + + test_reasoning_budget("activate immediately with budget", sequence, start, end, forced, + 2, // budget of 2 tokens + REASONING_BUDGET_COUNTING, + 1, // forcing starts at i=1 (after 2 accepts deplete budget) + 2); // forcing continues through i=2 + } + + printf("OK (5 tests passed)\n"); + + printf("Testing UTF-8 boundary detection... "); + test_utf8_boundary_detection(); + printf("OK\n"); + + return 0; +} diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index d43d105490..4c2ae7a033 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -57,6 +57,8 @@ struct cli_context { std::vector input_files; task_params defaults; bool verbose_prompt; + int reasoning_budget = -1; + std::string reasoning_budget_message; // thread for showing "loading" animation std::atomic loading_show; @@ -73,6 +75,8 @@ struct cli_context { // defaults.return_progress = true; // TODO: show progress verbose_prompt = params.verbose_prompt; + reasoning_budget = params.reasoning_budget; + reasoning_budget_message = params.reasoning_budget_message; } std::string generate_completion(result_timings & out_timings) { @@ -95,6 +99,24 @@ struct cli_context { task.params.chat_parser_params.parser.load(chat_params.parser); } + // reasoning budget sampler + if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) { + const llama_vocab * vocab = llama_model_get_vocab( + llama_get_model(ctx_server.get_llama_context())); + + task.params.sampling.reasoning_budget_tokens = reasoning_budget; + task.params.sampling.reasoning_budget_activate_immediately = chat_params.thinking_forced_open; + + if (!chat_params.thinking_start_tag.empty()) { + task.params.sampling.reasoning_budget_start = + common_tokenize(vocab, chat_params.thinking_start_tag, false, true); + } + task.params.sampling.reasoning_budget_end = + common_tokenize(vocab, chat_params.thinking_end_tag, false, true); + task.params.sampling.reasoning_budget_forced = + common_tokenize(vocab, reasoning_budget_message + chat_params.thinking_end_tag, false, true); + } + rd.post_task({std::move(task)}); } diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 5b8895b341..bd203228cc 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1101,6 +1101,22 @@ json oaicompat_chat_params_parse( llama_params["chat_parser"] = chat_params.parser; } + // Reasoning budget: pass parameters through to sampling layer + { + int reasoning_budget = opt.reasoning_budget; + if (reasoning_budget == -1 && body.contains("thinking_budget_tokens")) { + reasoning_budget = json_value(body, "thinking_budget_tokens", -1); + } + + if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) { + llama_params["reasoning_budget_tokens"] = reasoning_budget; + llama_params["reasoning_budget_start_tag"] = chat_params.thinking_start_tag; + llama_params["reasoning_budget_end_tag"] = chat_params.thinking_end_tag; + llama_params["reasoning_budget_message"] = opt.reasoning_budget_message; + llama_params["reasoning_budget_activate_immediately"] = chat_params.thinking_forced_open; + } + } + // Handle "logprobs" field // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future if (json_value(body, "logprobs", false)) { diff --git a/tools/server/server-common.h b/tools/server/server-common.h index a234541e19..3e56b3d856 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -287,6 +287,8 @@ struct server_chat_params { bool allow_image; bool allow_audio; bool enable_thinking = true; + int reasoning_budget = -1; + std::string reasoning_budget_message; std::string media_path; }; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index b86e7e608e..b4373c101b 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -893,9 +893,10 @@ private: } // thinking is enabled if: - // 1. It's not explicitly disabled (reasoning_budget == 0) + // 1. It's not explicitly disabled via --reasoning off // 2. The chat template supports it - const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); + const bool template_supports_thinking = params_base.use_jinja && common_chat_templates_support_enable_thinking(chat_templates.get()); + const bool enable_thinking = params_base.enable_reasoning != 0 && template_supports_thinking; SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking); chat_params = { @@ -907,6 +908,8 @@ private: /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, /* enable_thinking */ enable_thinking, + /* reasoning_budget */ params_base.reasoning_budget, + /* reasoning_budget_msg */ params_base.reasoning_budget_message, /* media_path */ params_base.media_path, }; } diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 9d6e422d62..b3d510977b 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -462,6 +462,34 @@ task_params server_task::params_from_json_cmpl( } } + // Parse reasoning budget sampler parameters + { + const int32_t budget = json_value(data, "reasoning_budget_tokens", (int32_t) -1); + if (budget >= 0) { + const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string()); + const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string()); + const auto message = json_value(data, "reasoning_budget_message", std::string()); + const bool activate_imm = json_value(data, "reasoning_budget_activate_immediately", false); + + params.sampling.reasoning_budget_tokens = budget; + params.sampling.reasoning_budget_activate_immediately = activate_imm; + + if (!start_tag.empty()) { + params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true); + } + if (!end_tag.empty()) { + params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true); + params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true); + } + + SRV_DBG("reasoning budget: tokens=%d, activate_immediately=%s, start=%zu toks, end=%zu toks, forced=%zu toks\n", + budget, activate_imm ? "true" : "false", + params.sampling.reasoning_budget_start.size(), + params.sampling.reasoning_budget_end.size(), + params.sampling.reasoning_budget_forced.size()); + } + } + { params.sampling.logit_bias.clear();