From be5cd55750b28cfd637c14cc5257db72166c8c49 Mon Sep 17 00:00:00 2001 From: Zeel Date: Sat, 28 Mar 2026 23:00:38 -0400 Subject: [PATCH 1/2] common: add two-phase graceful reasoning budget termination ... --- common/arg.cpp | 14 ++++- common/common.h | 13 +++-- common/reasoning-budget.cpp | 100 ++++++++++++++++++++++++++++++------ common/reasoning-budget.h | 35 ++++++++----- common/sampling.cpp | 11 +++- 5 files changed, 136 insertions(+), 37 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 538d2a4b0a..7268947fc3 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3100,8 +3100,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)", [](common_params & params, const std::string & value) { params.reasoning_budget_message = value; + params.sampling.reasoning_budget_message = value; } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE")); + add_opt(common_arg( + {"--reasoning-budget-conclusion"}, "N", + "token budget for conclusion phase after message injection (0 = disabled, default: 0).\n" + "When set, the model is given N tokens to conclude naturally after the budget message\n" + "is injected, instead of immediately forcing the end-of-thinking tag.", + [](common_params & params, int value) { + if (value < 0) { throw std::invalid_argument("invalid value for --reasoning-budget-conclusion: must be >= 0"); } + params.reasoning_budget_conclusion = value; + params.sampling.reasoning_budget_conclusion = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_CONCLUSION")); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", string_format( @@ -3894,4 +3906,4 @@ void common_params_add_preset_options(std::vector & args) { // "in server router mode, do not unload this model if models_max is exceeded", // [](common_params &) { /* unused */ } // ).set_preset_only()); -} +} \ No newline at end of file diff --git a/common/common.h b/common/common.h index 17dc3fb232..4e6cd926b3 100644 --- a/common/common.h +++ b/common/common.h @@ -283,10 +283,12 @@ struct common_params_sampling { // reasoning budget sampler parameters // these are populated by the server/CLI based on chat template params - int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget - std::vector reasoning_budget_start; // start tag token sequence - std::vector reasoning_budget_end; // end tag token sequence - std::vector reasoning_budget_forced; // forced sequence (message + end tag) + int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget + int32_t reasoning_budget_conclusion = 0; // tokens reserved for conclusion phase (0 = disabled) + std::string reasoning_budget_message; // message injected at start of conclusion phase + std::vector reasoning_budget_start; // start tag token sequence + std::vector reasoning_budget_end; // end tag token sequence + std::vector reasoning_budget_forced; // forced sequence (end tag, hard-cutoff safety net) bool backend_sampling = false; @@ -593,6 +595,7 @@ struct common_params { common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable int reasoning_budget = -1; + int reasoning_budget_conclusion = 0; // tokens reserved for conclusion phase (0 = disabled) std::string reasoning_budget_message; // message injected before end tag when budget exhausted bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time @@ -995,4 +998,4 @@ inline llama_model_tensor_buft_override llm_ffn_exps_cpu_override() { ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector & tokens, int64_t stride); // "adamw" or "sgd" (case insensitive) -enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *); +enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *); \ No newline at end of file diff --git a/common/reasoning-budget.cpp b/common/reasoning-budget.cpp index cc408a6869..5bf6bbd689 100644 --- a/common/reasoning-budget.cpp +++ b/common/reasoning-budget.cpp @@ -41,15 +41,18 @@ struct common_reasoning_budget_ctx { token_matcher start_matcher; token_matcher end_matcher; - std::vector forced_tokens; + std::vector forced_tokens; // end-of-thinking sequence (hard-cutoff safety net) + std::vector message_tokens; // message injected at conclusion phase start - int32_t budget; // maximum tokens in reasoning block - int32_t remaining; // tokens remaining in budget + int32_t budget; // maximum tokens in thinking phase + int32_t conclusion_budget; // tokens reserved for conclusion phase (0 = disabled) + int32_t remaining; // tokens remaining in current phase common_reasoning_budget_state state; // for forcing - size_t force_pos; // next position in forced_tokens to force + size_t force_pos; // next position in forced_tokens to force + size_t message_pos; // next position in message_tokens to force (during message injection) }; static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) { @@ -76,6 +79,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to break; } case REASONING_BUDGET_COUNTING: + case REASONING_BUDGET_CONCLUDING: case REASONING_BUDGET_WAITING_UTF8: { if (ctx->end_matcher.advance(token)) { @@ -97,10 +101,36 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to ctx->end_matcher.reset(); LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n"); } - } else if (ctx->state == REASONING_BUDGET_COUNTING) { + } else if (ctx->state == REASONING_BUDGET_CONCLUDING) { ctx->remaining--; if (ctx->remaining <= 0) { if (utf8_complete) { + ctx->state = REASONING_BUDGET_FORCING; + ctx->force_pos = 0; + ctx->end_matcher.reset(); + LOG_INF("reasoning-budget: conclusion budget exhausted, forcing end sequence\n"); + } else { + ctx->state = REASONING_BUDGET_WAITING_UTF8; + ctx->end_matcher.reset(); + LOG_INF("reasoning-budget: conclusion budget exhausted, waiting for UTF-8 completion\n"); + } + } + } else if (ctx->state == REASONING_BUDGET_COUNTING) { + ctx->remaining--; + if (ctx->remaining <= 0) { + if (ctx->conclusion_budget > 0 && !ctx->message_tokens.empty()) { + // Two-phase: force message tokens first, then let model conclude + ctx->state = REASONING_BUDGET_INJECTING; + ctx->message_pos = 0; + ctx->end_matcher.reset(); + LOG_INF("reasoning-budget: thinking budget exhausted, injecting conclusion message\n"); + } else if (ctx->conclusion_budget > 0) { + // No message, but conclusion budget set — go straight to free conclusion + ctx->state = REASONING_BUDGET_CONCLUDING; + ctx->remaining = ctx->conclusion_budget; + ctx->end_matcher.reset(); + LOG_INF("reasoning-budget: thinking budget exhausted, entering conclusion phase (%d tokens)\n", ctx->conclusion_budget); + } else if (utf8_complete) { ctx->state = REASONING_BUDGET_FORCING; ctx->force_pos = 0; ctx->end_matcher.reset(); @@ -114,6 +144,16 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to } break; } + case REASONING_BUDGET_INJECTING: + ctx->message_pos++; + if (ctx->message_pos >= ctx->message_tokens.size()) { + // Message fully injected — enter free conclusion phase + ctx->state = REASONING_BUDGET_CONCLUDING; + ctx->remaining = ctx->conclusion_budget; + ctx->end_matcher.reset(); + LOG_INF("reasoning-budget: message injected, entering conclusion phase (%d tokens)\n", ctx->conclusion_budget); + } + break; case REASONING_BUDGET_FORCING: ctx->force_pos++; if (ctx->force_pos >= ctx->forced_tokens.size()) { @@ -129,6 +169,19 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; + if (ctx->state == REASONING_BUDGET_INJECTING) { + if (ctx->message_pos >= ctx->message_tokens.size()) { + return; + } + const llama_token forced = ctx->message_tokens[ctx->message_pos]; + for (size_t i = 0; i < cur_p->size; i++) { + if (cur_p->data[i].id != forced) { + cur_p->data[i].logit = -INFINITY; + } + } + return; + } + if (ctx->state != REASONING_BUDGET_FORCING) { // passthrough — don't modify logits return; @@ -155,13 +208,15 @@ static void common_reasoning_budget_reset(struct llama_sampler * smpl) { ctx->start_matcher.reset(); ctx->end_matcher.reset(); ctx->force_pos = 0; + ctx->message_pos = 0; } // forward declaration for use in clone static struct llama_sampler * common_reasoning_budget_init_state( const struct llama_vocab * vocab, const std::vector & start_tokens, const std::vector & end_tokens, const std::vector & forced_tokens, - int32_t budget, common_reasoning_budget_state initial_state); + const std::vector & message_tokens, + int32_t budget, int32_t conclusion_budget, common_reasoning_budget_state initial_state); static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) { const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx; @@ -170,7 +225,9 @@ static struct llama_sampler * common_reasoning_budget_clone(const struct llama_s ctx->start_matcher.tokens, ctx->end_matcher.tokens, ctx->forced_tokens, + ctx->message_tokens, ctx->budget, + ctx->conclusion_budget, ctx->state); } @@ -196,7 +253,9 @@ static struct llama_sampler * common_reasoning_budget_init_state( const std::vector & start_tokens, const std::vector & end_tokens, const std::vector & forced_tokens, + const std::vector & message_tokens, int32_t budget, + int32_t conclusion_budget, common_reasoning_budget_state initial_state) { // promote COUNTING with budget <= 0 to FORCING if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) { @@ -206,14 +265,17 @@ static struct llama_sampler * common_reasoning_budget_init_state( return llama_sampler_init( /* .iface = */ &common_reasoning_budget_i, /* .ctx = */ new common_reasoning_budget_ctx { - /* .vocab = */ vocab, - /* .start_matcher = */ { start_tokens, 0 }, - /* .end_matcher = */ { end_tokens, 0 }, - /* .forced_tokens = */ forced_tokens, - /* .budget = */ budget, - /* .remaining = */ budget, - /* .state = */ initial_state, - /* .force_pos = */ 0, + /* .vocab = */ vocab, + /* .start_matcher = */ { start_tokens, 0 }, + /* .end_matcher = */ { end_tokens, 0 }, + /* .forced_tokens = */ forced_tokens, + /* .message_tokens = */ message_tokens, + /* .budget = */ budget, + /* .conclusion_budget = */ conclusion_budget, + /* .remaining = */ budget, + /* .state = */ initial_state, + /* .force_pos = */ 0, + /* .message_pos = */ 0, } ); } @@ -223,7 +285,9 @@ struct llama_sampler * common_reasoning_budget_init( const std::vector & start_tokens, const std::vector & end_tokens, const std::vector & forced_tokens, + const std::vector & message_tokens, int32_t budget, + int32_t conclusion_budget, const std::vector & prefill_tokens) { // Determine initial state from prefill: COUNTING if the prefill begins with // the start sequence but does not also contain the end sequence after it. @@ -243,7 +307,7 @@ struct llama_sampler * common_reasoning_budget_init( } } } - return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state); + return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, message_tokens, budget, conclusion_budget, initial_state); } struct llama_sampler * common_reasoning_budget_init( @@ -251,9 +315,11 @@ struct llama_sampler * common_reasoning_budget_init( const std::vector & start_tokens, const std::vector & end_tokens, const std::vector & forced_tokens, + const std::vector & message_tokens, int32_t budget, + int32_t conclusion_budget, common_reasoning_budget_state initial_state) { - return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state); + return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, message_tokens, budget, conclusion_budget, initial_state); } common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl) { @@ -261,4 +327,4 @@ common_reasoning_budget_state common_reasoning_budget_get_state(const struct lla return REASONING_BUDGET_IDLE; } return ((const common_reasoning_budget_ctx *)smpl->ctx)->state; -} +} \ No newline at end of file diff --git a/common/reasoning-budget.h b/common/reasoning-budget.h index ee1a30ed3c..0d88518154 100644 --- a/common/reasoning-budget.h +++ b/common/reasoning-budget.h @@ -8,7 +8,9 @@ enum common_reasoning_budget_state { REASONING_BUDGET_IDLE, // waiting for start sequence REASONING_BUDGET_COUNTING, // counting down tokens - REASONING_BUDGET_FORCING, // forcing budget message + end sequence + REASONING_BUDGET_INJECTING, // forcing message tokens before conclusion phase + REASONING_BUDGET_CONCLUDING, // conclusion phase: model free to conclude naturally + REASONING_BUDGET_FORCING, // forcing end sequence (hard cutoff safety net) REASONING_BUDGET_WAITING_UTF8, // budget exhausted, waiting for UTF-8 completion REASONING_BUDGET_DONE, // passthrough forever }; @@ -16,31 +18,36 @@ enum common_reasoning_budget_state { // Creates a reasoning budget sampler that limits token generation inside a // reasoning block (e.g. between and ). // -// State machine: IDLE -> COUNTING -> WAITING_UTF8 -> FORCING -> DONE +// State machine: IDLE -> COUNTING -> CONCLUDING -> WAITING_UTF8 -> FORCING -> DONE // IDLE: passthrough, watching for start_tokens sequence // COUNTING: counting down remaining tokens, watching for natural end_tokens +// CONCLUDING: conclusion phase after message injection; model generates freely +// until it produces end_tokens naturally or conclusion_budget runs out // WAITING_UTF8: budget exhausted, allowing tokens to complete a UTF-8 sequence // FORCING: forces forced_tokens token-by-token (all other logits -> -inf) // DONE: passthrough forever // // Parameters: -// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr) -// start_tokens - token sequence that activates counting -// end_tokens - token sequence for natural deactivation -// forced_tokens - token sequence forced when budget expires -// budget - max tokens allowed in the reasoning block -// prefill_tokens - tokens already present in the prompt (generation prompt); -// used to determine the initial state: COUNTING if they begin -// with start_tokens (but don't also end with end_tokens), -// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING. +// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr) +// start_tokens - token sequence that activates counting +// end_tokens - token sequence for natural deactivation +// forced_tokens - token sequence forced when budget expires (hard-cutoff safety net) +// budget - max tokens allowed in the thinking phase +// conclusion_budget - tokens reserved for conclusion phase (0 = disabled, original behavior) +// prefill_tokens - tokens already present in the prompt (generation prompt); +// used to determine the initial state: COUNTING if they begin +// with start_tokens (but don't also end with end_tokens), +// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING. // struct llama_sampler * common_reasoning_budget_init( const struct llama_vocab * vocab, const std::vector & start_tokens, const std::vector & end_tokens, const std::vector & forced_tokens, + const std::vector & message_tokens, int32_t budget, - const std::vector & prefill_tokens = {}); + int32_t conclusion_budget = 0, + const std::vector & prefill_tokens = {}); // Variant that takes an explicit initial state (used by tests and clone). // COUNTING with budget <= 0 is promoted to FORCING. @@ -49,7 +56,9 @@ struct llama_sampler * common_reasoning_budget_init( const std::vector & start_tokens, const std::vector & end_tokens, const std::vector & forced_tokens, + const std::vector & message_tokens, int32_t budget, + int32_t conclusion_budget, common_reasoning_budget_state initial_state); -common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl); +common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl); \ No newline at end of file diff --git a/common/sampling.cpp b/common/sampling.cpp index 5259c5f3c6..fca826dbfc 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -289,12 +289,21 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st // reasoning budget sampler if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty()) { + // Tokenize the budget message separately so it can be injected before the + // conclusion phase (two-phase graceful termination). The forced_tokens + // sequence (message + end tag) is retained as the hard-cutoff safety net. + std::vector message_tokens; + if (!params.reasoning_budget_message.empty() && params.reasoning_budget_conclusion > 0) { + message_tokens = common_tokenize(vocab, params.reasoning_budget_message, false, true); + } rbudget = common_reasoning_budget_init( vocab, params.reasoning_budget_start, params.reasoning_budget_end, params.reasoning_budget_forced, + message_tokens, params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens, + params.reasoning_budget_conclusion, prefill_tokens); } @@ -823,4 +832,4 @@ std::vector common_sampler_types_from_chars(const std::stri } return samplers; -} +} \ No newline at end of file From 02d4c32517dd58c266477567938476653e4b746d Mon Sep 17 00:00:00 2001 From: Zeel Date: Sun, 29 Mar 2026 21:42:31 -0400 Subject: [PATCH 2/2] common: add two-phase graceful reasoning budget termination MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add --reasoning-budget-conclusion N flag that splits the reasoning budget into a thinking phase and a conclusion phase: - At end of thinking budget, inject --reasoning-budget-message and enter INJECTING state (forces message tokens token-by-token) - After message is injected, enter CONCLUDING state giving the model N free tokens to terminate naturally - If model does not self-terminate, fall through to FORCING (hard cutoff) as a safety net New states added to the sampler state machine: IDLE -> COUNTING -> INJECTING -> CONCLUDING -> FORCING -> DONE Setting --reasoning-budget-conclusion 0 (the default) preserves existing behavior exactly — fully backward compatible. Add 5 new tests to test-reasoning-budget.cpp covering: - natural end in conclusion window (no FORCING) - conclusion budget exhausted, safety net fires - no message tokens, conclusion budget only - backward compat with conclusion_budget=0 - multi-token message injection Implements Option B from issue #20632. --- common/arg.cpp | 1 - common/common.h | 1 - tests/test-reasoning-budget.cpp | 226 ++++++++++++++------------------ 3 files changed, 95 insertions(+), 133 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 7268947fc3..513a9d51be 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3110,7 +3110,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "is injected, instead of immediately forcing the end-of-thinking tag.", [](common_params & params, int value) { if (value < 0) { throw std::invalid_argument("invalid value for --reasoning-budget-conclusion: must be >= 0"); } - params.reasoning_budget_conclusion = value; params.sampling.reasoning_budget_conclusion = value; } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_CONCLUSION")); diff --git a/common/common.h b/common/common.h index 4e6cd926b3..ffad248611 100644 --- a/common/common.h +++ b/common/common.h @@ -595,7 +595,6 @@ struct common_params { common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable int reasoning_budget = -1; - int reasoning_budget_conclusion = 0; // tokens reserved for conclusion phase (0 = disabled) std::string reasoning_budget_message; // message injected before end tag when budget exhausted bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time diff --git a/tests/test-reasoning-budget.cpp b/tests/test-reasoning-budget.cpp index 3028fb4d8f..d6a9a684f6 100644 --- a/tests/test-reasoning-budget.cpp +++ b/tests/test-reasoning-budget.cpp @@ -14,89 +14,69 @@ #include #include -// Reasoning budget sampler test helper -// These tests use nullptr vocab which safely falls back to treating all tokens as complete -// (The UTF-8 boundary detection logic is tested separately in test_utf8_boundary_detection) static void test_reasoning_budget( const char * test_name, const std::vector & sequence, const std::vector & start_tokens, const std::vector & end_tokens, const std::vector & forced_tokens, + const std::vector & message_tokens, int32_t budget, + int32_t conclusion_budget, common_reasoning_budget_state initial_state, - size_t expected_force_start, // token index where forcing should start (SIZE_MAX = never) - size_t expected_force_end // token index where forcing should end (after this, no more forcing) + size_t expected_force_start, + size_t expected_force_end ) { - // Find the maximum token ID to ensure our vocab covers all tokens llama_token max_token = 0; - for (auto t : sequence) max_token = std::max(max_token, t); - for (auto t : start_tokens) max_token = std::max(max_token, t); - for (auto t : end_tokens) max_token = std::max(max_token, t); - for (auto t : forced_tokens) max_token = std::max(max_token, t); + for (size_t k = 0; k < sequence.size(); k++) { if (sequence[k] > max_token) max_token = sequence[k]; } + for (size_t k = 0; k < start_tokens.size(); k++) { if (start_tokens[k] > max_token) max_token = start_tokens[k]; } + for (size_t k = 0; k < end_tokens.size(); k++) { if (end_tokens[k] > max_token) max_token = end_tokens[k]; } + for (size_t k = 0; k < forced_tokens.size(); k++) { if (forced_tokens[k] > max_token) max_token = forced_tokens[k]; } + for (size_t k = 0; k < message_tokens.size();k++) { if (message_tokens[k]> max_token) max_token = message_tokens[k];} - // Create a minimal sampler with mock vocabulary - // For this test, we use nullptr as vocab since we're testing state transitions - // The UTF-8 boundary check will treat all tokens as complete (safe fallback) auto * sampler = common_reasoning_budget_init( - nullptr, // vocab - not used for basic state machine tests - start_tokens, - end_tokens, - forced_tokens, - budget, - initial_state + nullptr, + start_tokens, end_tokens, forced_tokens, message_tokens, + budget, conclusion_budget, initial_state ); - // Create a test token data array for checking forcing behavior - // Vocab size must be large enough to include all tokens (start, end, forced, sequence) std::vector cur; - const size_t n_vocab = (size_t)max_token + 1; + const size_t n_vocab = (size_t)(max_token + 1); for (size_t i = 0; i < n_vocab; i++) { - cur.emplace_back(llama_token_data{(llama_token)i, logf((float)(i+1)), 0.0f}); + llama_token_data d; + d.id = (llama_token)i; d.logit = logf((float)(i+1)); d.p = 0.0f; + cur.push_back(d); } llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; size_t actual_force_start = SIZE_MAX; - size_t actual_force_end = SIZE_MAX; + size_t actual_force_end = SIZE_MAX; - // Feed the sequence and track when forcing occurs for (size_t i = 0; i < sequence.size(); i++) { - // Check if we're in forcing state by applying and seeing if logits are modified cur_p.selected = -1; - for (size_t j = 0; j < cur.size(); j++) { - cur[j].logit = logf((float)(j+1)); // reset logits - } + for (size_t j = 0; j < cur.size(); j++) { cur[j].logit = logf((float)(j+1)); } llama_sampler_apply(sampler, &cur_p); - // Check if forcing is active (all logits except one should be -INFINITY) size_t finite_count = 0; llama_token finite_token = -1; for (size_t j = 0; j < cur.size(); j++) { - if (std::isfinite(cur[j].logit)) { - finite_count++; - finite_token = cur[j].id; - } + if (std::isfinite(cur[j].logit)) { finite_count++; finite_token = cur[j].id; } } llama_sampler_accept(sampler, sequence[i]); - fprintf(stderr, " i=%zu: token=%d, finite_count=%zu, finite_token=%d\n", i, (int)sequence[i], finite_count, (int)finite_token); + fprintf(stderr, " i=%zu: token=%d, finite_count=%zu, finite_token=%d\n", + i, (int)sequence[i], finite_count, (int)finite_token); if (finite_count == 1) { - if (actual_force_start == SIZE_MAX) { - actual_force_start = i; - } + if (actual_force_start == SIZE_MAX) { actual_force_start = i; } actual_force_end = i; - } else if (actual_force_start != SIZE_MAX && actual_force_end != SIZE_MAX) { - // Forcing stopped - break; } } llama_sampler_free(sampler); - // Verify forcing occurred at expected positions if (expected_force_start == SIZE_MAX) { if (actual_force_start != SIZE_MAX) { fprintf(stderr, "Test '%s' FAILED: Expected no forcing, but forcing occurred at %zu\n", test_name, actual_force_start); @@ -112,126 +92,110 @@ static void test_reasoning_budget( GGML_ASSERT(false && "Forcing started at wrong position"); } } - - if (expected_force_end != SIZE_MAX) { - if (actual_force_end < expected_force_end) { - fprintf(stderr, "Test '%s' FAILED: Forcing ended at %zu, expected >= %zu\n", test_name, actual_force_end, expected_force_end); - GGML_ASSERT(false && "Forcing ended too early"); - } + if (expected_force_end != SIZE_MAX && actual_force_end < expected_force_end) { + fprintf(stderr, "Test '%s' FAILED: Forcing ended at %zu, expected >= %zu\n", test_name, actual_force_end, expected_force_end); + GGML_ASSERT(false && "Forcing ended too early"); } fprintf(stderr, " Test '%s' passed (force_start=%zu, force_end=%zu)\n", test_name, actual_force_start, actual_force_end); - (void)sequence; } -// UTF-8 boundary detection unit test -// Tests common_utf8_is_complete() from reasoning-budget.h static void test_utf8_boundary_detection() { - // Complete sequences GGML_ASSERT(common_utf8_is_complete("hello")); GGML_ASSERT(common_utf8_is_complete("")); - GGML_ASSERT(common_utf8_is_complete("\xC2\xA0")); // complete 2-byte UTF-8 (U+00A0) - GGML_ASSERT(common_utf8_is_complete("\xE2\x80\x9C")); // complete 3-byte UTF-8 (left double quote) - GGML_ASSERT(common_utf8_is_complete("\xF0\x9F\x98\x80")); // complete 4-byte UTF-8 (emoji) - GGML_ASSERT(common_utf8_is_complete("abc\xC3\xA9")); // ASCII + complete 2-byte - - // Incomplete sequences - GGML_ASSERT(!common_utf8_is_complete(std::string("\xC2", 1))); // 2-byte start, missing continuation - GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2\x80", 2))); // 3-byte start + 1 cont, missing 1 - GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2", 1))); // 3-byte start, missing 2 - GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F\x98", 3))); // 4-byte start + 2 cont, missing 1 - GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F", 2))); // 4-byte start + 1 cont, missing 2 - GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0", 1))); // 4-byte start, missing 3 - GGML_ASSERT(!common_utf8_is_complete(std::string("\x80", 1))); // orphan continuation byte - - // Mixed: ASCII followed by start of multi-byte - GGML_ASSERT(!common_utf8_is_complete(std::string("hello\xC3", 6))); // ASCII + incomplete 2-byte - GGML_ASSERT(common_utf8_is_complete(std::string("hello\xC3\xA9", 7))); // ASCII + complete 2-byte + GGML_ASSERT(common_utf8_is_complete("\xC2\xA0")); + GGML_ASSERT(common_utf8_is_complete("\xE2\x80\x9C")); + GGML_ASSERT(common_utf8_is_complete("\xF0\x9F\x98\x80")); + GGML_ASSERT(common_utf8_is_complete("abc\xC3\xA9")); + GGML_ASSERT(!common_utf8_is_complete(std::string("\xC2", 1))); + GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2\x80", 2))); + GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2", 1))); + GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F\x98", 3))); + GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F", 2))); + GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0", 1))); + GGML_ASSERT(!common_utf8_is_complete(std::string("\x80", 1))); + GGML_ASSERT(!common_utf8_is_complete(std::string("hello\xC3", 6))); + GGML_ASSERT(common_utf8_is_complete(std::string("hello\xC3\xA9", 7))); } int main(void) { - // Reasoning budget sampler tests printf("Testing reasoning budget sampler... "); - // Test 1: Basic budget with start/end tokens - no forcing (natural end before budget exhausted) + // Test 1: Natural end before budget exhausted { - const std::vector start = {100}; // start token - const std::vector end = {101}; // end token - const std::vector forced = {102}; // forced token (not used in this test) - const std::vector sequence = {100, 50, 51, 101, 52}; // start, two tokens, end, one more - - test_reasoning_budget("natural end before budget exhausted", sequence, start, end, forced, - 5, // budget of 5 tokens - REASONING_BUDGET_IDLE, - SIZE_MAX, SIZE_MAX); // no forcing expected (natural end) + std::vector start = {100}, end = {101}, forced = {102}, msg = {}; + std::vector seq = {100, 50, 51, 101, 52}; + test_reasoning_budget("natural end before budget exhausted", seq, start, end, forced, msg, 5, 0, REASONING_BUDGET_IDLE, SIZE_MAX, SIZE_MAX); } - // Test 2: Budget exhausted, forcing should occur - // Flow: i=0 apply()->passthrough, accept(100)->COUNTING; i=1 accept(50)->remaining=1 - // i=2 accept(51)->remaining=0->FORCING; i=3 apply() forces token[0]; i=4 apply() forces token[1] - // At i=4, accept() advances force_pos to 2 which equals forced_tokens.size(), so state becomes DONE + // Test 2: Budget exhausted, forcing occurs { - const std::vector start = {100}; - const std::vector end = {101}; - const std::vector forced = {102, 101}; // forced message + end - const std::vector sequence = {100, 50, 51, 52, 53}; // start + 4 tokens (budget=2) - - test_reasoning_budget("budget exhausted forcing", sequence, start, end, forced, - 2, // budget of 2 tokens - REASONING_BUDGET_IDLE, - 3, // forcing starts at i=3 (accept at i=2 depletes budget, apply at i=3 forces) - 4); // forcing continues through i=4 (accept at i=4 transitions to DONE) + std::vector start = {100}, end = {101}, forced = {102, 101}, msg = {}; + std::vector seq = {100, 50, 51, 52, 53}; + test_reasoning_budget("budget exhausted forcing", seq, start, end, forced, msg, 2, 0, REASONING_BUDGET_IDLE, 3, 4); } - // Test 3: Activate immediately with budget=0, forcing should start right away - // Flow: init promotes COUNTING+budget=0 to FORCING, so apply() sees FORCING at i=0 + // Test 3: Budget=0 forces immediately { - const std::vector start = {100}; - const std::vector end = {101}; - const std::vector forced = {102, 101}; - const std::vector sequence = {100, 50, 51, 52}; // start token first, then 3 tokens - - test_reasoning_budget("activate immediately budget=0", sequence, start, end, forced, - 0, // budget of 0 tokens - REASONING_BUDGET_COUNTING, // starts counting, promoted to FORCING since budget=0 - 0, // forcing starts at i=0 (initialized in FORCING, apply forces immediately) - 1); // forcing continues through i=1 (accept at i=1 transitions to DONE) + std::vector start = {100}, end = {101}, forced = {102, 101}, msg = {}; + std::vector seq = {100, 50, 51, 52}; + test_reasoning_budget("activate immediately budget=0", seq, start, end, forced, msg, 0, 0, REASONING_BUDGET_COUNTING, 0, 1); } - // Test 4: No start/end tokens configured - passthrough (no forcing) + // Test 4: No start/end — passthrough { - const std::vector start = {}; - const std::vector end = {}; - const std::vector forced = {102}; - const std::vector sequence = {50, 51, 52, 53}; - - test_reasoning_budget("no start/end configured", sequence, start, end, forced, - 2, // budget - REASONING_BUDGET_IDLE, - SIZE_MAX, SIZE_MAX); // no forcing (no start/end configured) + std::vector start = {}, end = {}, forced = {102}, msg = {}; + std::vector seq = {50, 51, 52, 53}; + test_reasoning_budget("no start/end configured", seq, start, end, forced, msg, 2, 0, REASONING_BUDGET_IDLE, SIZE_MAX, SIZE_MAX); } - // Test 5: Activate immediately with budget > 0, count down then force - // Flow: i=0 accept(50)->remaining=1, i=1 accept(51)->remaining=0->FORCING - // Forcing starts at i=2 (apply sees FORCING after accept at i=1 transitioned) + // Test 5: Start in COUNTING state, count down then force { - const std::vector start = {100}; - const std::vector end = {101}; - const std::vector forced = {102, 101}; - const std::vector sequence = {50, 51, 52, 53}; - - test_reasoning_budget("activate immediately with budget", sequence, start, end, forced, - 2, // budget of 2 tokens - REASONING_BUDGET_COUNTING, - 2, // forcing starts at i=2 (after 2 accepts deplete budget, apply at i=2 forces) - 3); // forcing continues through i=3 + std::vector start = {100}, end = {101}, forced = {102, 101}, msg = {}; + std::vector seq = {50, 51, 52, 53}; + test_reasoning_budget("activate immediately with budget", seq, start, end, forced, msg, 2, 0, REASONING_BUDGET_COUNTING, 2, 3); } - printf("OK (5 tests passed)\n"); + // Test 6: Two-phase — model concludes naturally in conclusion window + { + std::vector start = {100}, end = {101}, forced = {101}, msg = {200}; + std::vector seq = {100, 50, 51, 200, 101, 52}; + test_reasoning_budget("two-phase natural end in conclusion window", seq, start, end, forced, msg, 2, 3, REASONING_BUDGET_IDLE, 3, 3); + } + + // Test 7: Two-phase — conclusion budget exhausted, safety net fires + { + std::vector start = {100}, end = {101}, forced = {101}, msg = {200}; + std::vector seq = {100, 50, 51, 200, 52, 101}; + test_reasoning_budget("two-phase conclusion budget exhausted safety net fires", seq, start, end, forced, msg, 2, 1, REASONING_BUDGET_IDLE, 3, 5); + } + + // Test 8: Two-phase — no message tokens, conclusion only (skips INJECTING) + { + std::vector start = {100}, end = {101}, forced = {101}, msg = {}; + std::vector seq = {100, 50, 51, 101, 52}; + test_reasoning_budget("two-phase no message tokens conclusion only", seq, start, end, forced, msg, 2, 5, REASONING_BUDGET_IDLE, SIZE_MAX, SIZE_MAX); + } + + // Test 9: Backward compat — conclusion_budget=0 identical to original + { + std::vector start = {100}, end = {101}, forced = {102, 101}, msg = {}; + std::vector seq = {100, 50, 51, 52, 53}; + test_reasoning_budget("backward compat conclusion_budget=0", seq, start, end, forced, msg, 2, 0, REASONING_BUDGET_IDLE, 3, 4); + } + + // Test 10: Two-phase — multi-token message (3 tokens all forced before CONCLUDING) + { + std::vector start = {100}, end = {101}, forced = {101}, msg = {200, 201, 202}; + std::vector seq = {100, 50, 51, 200, 201, 202, 101, 52}; + test_reasoning_budget("two-phase multi-token message injection", seq, start, end, forced, msg, 2, 5, REASONING_BUDGET_IDLE, 3, 5); + } + + printf("OK (10 tests passed)\n"); printf("Testing UTF-8 boundary detection... "); test_utf8_boundary_detection(); printf("OK\n"); return 0; -} +} \ No newline at end of file