common: add two-phase graceful reasoning budget termination

Add --reasoning-budget-conclusion N flag that splits the reasoning budget
into a thinking phase and a conclusion phase:

- At end of thinking budget, inject --reasoning-budget-message and enter
  INJECTING state (forces message tokens token-by-token)
- After message is injected, enter CONCLUDING state giving the model N
  free tokens to terminate naturally
- If model does not self-terminate, fall through to FORCING (hard cutoff)
  as a safety net

New states added to the sampler state machine:
  IDLE -> COUNTING -> INJECTING -> CONCLUDING -> FORCING -> DONE

Setting --reasoning-budget-conclusion 0 (the default) preserves existing
behavior exactly — fully backward compatible.

Add 5 new tests to test-reasoning-budget.cpp covering:
- natural end in conclusion window (no FORCING)
- conclusion budget exhausted, safety net fires
- no message tokens, conclusion budget only
- backward compat with conclusion_budget=0
- multi-token message injection

Implements Option B from issue #20632.
This commit is contained in:
Zeel 2026-03-29 21:42:31 -04:00
parent be5cd55750
commit 02d4c32517
3 changed files with 95 additions and 133 deletions

View File

@ -3110,7 +3110,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"is injected, instead of immediately forcing the end-of-thinking tag.",
[](common_params & params, int value) {
if (value < 0) { throw std::invalid_argument("invalid value for --reasoning-budget-conclusion: must be >= 0"); }
params.reasoning_budget_conclusion = value;
params.sampling.reasoning_budget_conclusion = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_CONCLUSION"));

View File

@ -595,7 +595,6 @@ struct common_params {
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable
int reasoning_budget = -1;
int reasoning_budget_conclusion = 0; // tokens reserved for conclusion phase (0 = disabled)
std::string reasoning_budget_message; // message injected before end tag when budget exhausted
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time

View File

@ -14,89 +14,69 @@
#include <string>
#include <vector>
// Reasoning budget sampler test helper
// These tests use nullptr vocab which safely falls back to treating all tokens as complete
// (The UTF-8 boundary detection logic is tested separately in test_utf8_boundary_detection)
static void test_reasoning_budget(
const char * test_name,
const std::vector<llama_token> & sequence,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
const std::vector<llama_token> & message_tokens,
int32_t budget,
int32_t conclusion_budget,
common_reasoning_budget_state initial_state,
size_t expected_force_start, // token index where forcing should start (SIZE_MAX = never)
size_t expected_force_end // token index where forcing should end (after this, no more forcing)
size_t expected_force_start,
size_t expected_force_end
) {
// Find the maximum token ID to ensure our vocab covers all tokens
llama_token max_token = 0;
for (auto t : sequence) max_token = std::max(max_token, t);
for (auto t : start_tokens) max_token = std::max(max_token, t);
for (auto t : end_tokens) max_token = std::max(max_token, t);
for (auto t : forced_tokens) max_token = std::max(max_token, t);
for (size_t k = 0; k < sequence.size(); k++) { if (sequence[k] > max_token) max_token = sequence[k]; }
for (size_t k = 0; k < start_tokens.size(); k++) { if (start_tokens[k] > max_token) max_token = start_tokens[k]; }
for (size_t k = 0; k < end_tokens.size(); k++) { if (end_tokens[k] > max_token) max_token = end_tokens[k]; }
for (size_t k = 0; k < forced_tokens.size(); k++) { if (forced_tokens[k] > max_token) max_token = forced_tokens[k]; }
for (size_t k = 0; k < message_tokens.size();k++) { if (message_tokens[k]> max_token) max_token = message_tokens[k];}
// Create a minimal sampler with mock vocabulary
// For this test, we use nullptr as vocab since we're testing state transitions
// The UTF-8 boundary check will treat all tokens as complete (safe fallback)
auto * sampler = common_reasoning_budget_init(
nullptr, // vocab - not used for basic state machine tests
start_tokens,
end_tokens,
forced_tokens,
budget,
initial_state
nullptr,
start_tokens, end_tokens, forced_tokens, message_tokens,
budget, conclusion_budget, initial_state
);
// Create a test token data array for checking forcing behavior
// Vocab size must be large enough to include all tokens (start, end, forced, sequence)
std::vector<llama_token_data> cur;
const size_t n_vocab = (size_t)max_token + 1;
const size_t n_vocab = (size_t)(max_token + 1);
for (size_t i = 0; i < n_vocab; i++) {
cur.emplace_back(llama_token_data{(llama_token)i, logf((float)(i+1)), 0.0f});
llama_token_data d;
d.id = (llama_token)i; d.logit = logf((float)(i+1)); d.p = 0.0f;
cur.push_back(d);
}
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
size_t actual_force_start = SIZE_MAX;
size_t actual_force_end = SIZE_MAX;
size_t actual_force_end = SIZE_MAX;
// Feed the sequence and track when forcing occurs
for (size_t i = 0; i < sequence.size(); i++) {
// Check if we're in forcing state by applying and seeing if logits are modified
cur_p.selected = -1;
for (size_t j = 0; j < cur.size(); j++) {
cur[j].logit = logf((float)(j+1)); // reset logits
}
for (size_t j = 0; j < cur.size(); j++) { cur[j].logit = logf((float)(j+1)); }
llama_sampler_apply(sampler, &cur_p);
// Check if forcing is active (all logits except one should be -INFINITY)
size_t finite_count = 0;
llama_token finite_token = -1;
for (size_t j = 0; j < cur.size(); j++) {
if (std::isfinite(cur[j].logit)) {
finite_count++;
finite_token = cur[j].id;
}
if (std::isfinite(cur[j].logit)) { finite_count++; finite_token = cur[j].id; }
}
llama_sampler_accept(sampler, sequence[i]);
fprintf(stderr, " i=%zu: token=%d, finite_count=%zu, finite_token=%d\n", i, (int)sequence[i], finite_count, (int)finite_token);
fprintf(stderr, " i=%zu: token=%d, finite_count=%zu, finite_token=%d\n",
i, (int)sequence[i], finite_count, (int)finite_token);
if (finite_count == 1) {
if (actual_force_start == SIZE_MAX) {
actual_force_start = i;
}
if (actual_force_start == SIZE_MAX) { actual_force_start = i; }
actual_force_end = i;
} else if (actual_force_start != SIZE_MAX && actual_force_end != SIZE_MAX) {
// Forcing stopped
break;
}
}
llama_sampler_free(sampler);
// Verify forcing occurred at expected positions
if (expected_force_start == SIZE_MAX) {
if (actual_force_start != SIZE_MAX) {
fprintf(stderr, "Test '%s' FAILED: Expected no forcing, but forcing occurred at %zu\n", test_name, actual_force_start);
@ -112,126 +92,110 @@ static void test_reasoning_budget(
GGML_ASSERT(false && "Forcing started at wrong position");
}
}
if (expected_force_end != SIZE_MAX) {
if (actual_force_end < expected_force_end) {
fprintf(stderr, "Test '%s' FAILED: Forcing ended at %zu, expected >= %zu\n", test_name, actual_force_end, expected_force_end);
GGML_ASSERT(false && "Forcing ended too early");
}
if (expected_force_end != SIZE_MAX && actual_force_end < expected_force_end) {
fprintf(stderr, "Test '%s' FAILED: Forcing ended at %zu, expected >= %zu\n", test_name, actual_force_end, expected_force_end);
GGML_ASSERT(false && "Forcing ended too early");
}
fprintf(stderr, " Test '%s' passed (force_start=%zu, force_end=%zu)\n", test_name, actual_force_start, actual_force_end);
(void)sequence;
}
// UTF-8 boundary detection unit test
// Tests common_utf8_is_complete() from reasoning-budget.h
static void test_utf8_boundary_detection() {
// Complete sequences
GGML_ASSERT(common_utf8_is_complete("hello"));
GGML_ASSERT(common_utf8_is_complete(""));
GGML_ASSERT(common_utf8_is_complete("\xC2\xA0")); // complete 2-byte UTF-8 (U+00A0)
GGML_ASSERT(common_utf8_is_complete("\xE2\x80\x9C")); // complete 3-byte UTF-8 (left double quote)
GGML_ASSERT(common_utf8_is_complete("\xF0\x9F\x98\x80")); // complete 4-byte UTF-8 (emoji)
GGML_ASSERT(common_utf8_is_complete("abc\xC3\xA9")); // ASCII + complete 2-byte
// Incomplete sequences
GGML_ASSERT(!common_utf8_is_complete(std::string("\xC2", 1))); // 2-byte start, missing continuation
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2\x80", 2))); // 3-byte start + 1 cont, missing 1
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2", 1))); // 3-byte start, missing 2
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F\x98", 3))); // 4-byte start + 2 cont, missing 1
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F", 2))); // 4-byte start + 1 cont, missing 2
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0", 1))); // 4-byte start, missing 3
GGML_ASSERT(!common_utf8_is_complete(std::string("\x80", 1))); // orphan continuation byte
// Mixed: ASCII followed by start of multi-byte
GGML_ASSERT(!common_utf8_is_complete(std::string("hello\xC3", 6))); // ASCII + incomplete 2-byte
GGML_ASSERT(common_utf8_is_complete(std::string("hello\xC3\xA9", 7))); // ASCII + complete 2-byte
GGML_ASSERT(common_utf8_is_complete("\xC2\xA0"));
GGML_ASSERT(common_utf8_is_complete("\xE2\x80\x9C"));
GGML_ASSERT(common_utf8_is_complete("\xF0\x9F\x98\x80"));
GGML_ASSERT(common_utf8_is_complete("abc\xC3\xA9"));
GGML_ASSERT(!common_utf8_is_complete(std::string("\xC2", 1)));
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2\x80", 2)));
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2", 1)));
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F\x98", 3)));
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F", 2)));
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0", 1)));
GGML_ASSERT(!common_utf8_is_complete(std::string("\x80", 1)));
GGML_ASSERT(!common_utf8_is_complete(std::string("hello\xC3", 6)));
GGML_ASSERT(common_utf8_is_complete(std::string("hello\xC3\xA9", 7)));
}
int main(void) {
// Reasoning budget sampler tests
printf("Testing reasoning budget sampler... ");
// Test 1: Basic budget with start/end tokens - no forcing (natural end before budget exhausted)
// Test 1: Natural end before budget exhausted
{
const std::vector<llama_token> start = {100}; // start token
const std::vector<llama_token> end = {101}; // end token
const std::vector<llama_token> forced = {102}; // forced token (not used in this test)
const std::vector<llama_token> sequence = {100, 50, 51, 101, 52}; // start, two tokens, end, one more
test_reasoning_budget("natural end before budget exhausted", sequence, start, end, forced,
5, // budget of 5 tokens
REASONING_BUDGET_IDLE,
SIZE_MAX, SIZE_MAX); // no forcing expected (natural end)
std::vector<llama_token> start = {100}, end = {101}, forced = {102}, msg = {};
std::vector<llama_token> seq = {100, 50, 51, 101, 52};
test_reasoning_budget("natural end before budget exhausted", seq, start, end, forced, msg, 5, 0, REASONING_BUDGET_IDLE, SIZE_MAX, SIZE_MAX);
}
// Test 2: Budget exhausted, forcing should occur
// Flow: i=0 apply()->passthrough, accept(100)->COUNTING; i=1 accept(50)->remaining=1
// i=2 accept(51)->remaining=0->FORCING; i=3 apply() forces token[0]; i=4 apply() forces token[1]
// At i=4, accept() advances force_pos to 2 which equals forced_tokens.size(), so state becomes DONE
// Test 2: Budget exhausted, forcing occurs
{
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
const std::vector<llama_token> forced = {102, 101}; // forced message + end
const std::vector<llama_token> sequence = {100, 50, 51, 52, 53}; // start + 4 tokens (budget=2)
test_reasoning_budget("budget exhausted forcing", sequence, start, end, forced,
2, // budget of 2 tokens
REASONING_BUDGET_IDLE,
3, // forcing starts at i=3 (accept at i=2 depletes budget, apply at i=3 forces)
4); // forcing continues through i=4 (accept at i=4 transitions to DONE)
std::vector<llama_token> start = {100}, end = {101}, forced = {102, 101}, msg = {};
std::vector<llama_token> seq = {100, 50, 51, 52, 53};
test_reasoning_budget("budget exhausted forcing", seq, start, end, forced, msg, 2, 0, REASONING_BUDGET_IDLE, 3, 4);
}
// Test 3: Activate immediately with budget=0, forcing should start right away
// Flow: init promotes COUNTING+budget=0 to FORCING, so apply() sees FORCING at i=0
// Test 3: Budget=0 forces immediately
{
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
const std::vector<llama_token> forced = {102, 101};
const std::vector<llama_token> sequence = {100, 50, 51, 52}; // start token first, then 3 tokens
test_reasoning_budget("activate immediately budget=0", sequence, start, end, forced,
0, // budget of 0 tokens
REASONING_BUDGET_COUNTING, // starts counting, promoted to FORCING since budget=0
0, // forcing starts at i=0 (initialized in FORCING, apply forces immediately)
1); // forcing continues through i=1 (accept at i=1 transitions to DONE)
std::vector<llama_token> start = {100}, end = {101}, forced = {102, 101}, msg = {};
std::vector<llama_token> seq = {100, 50, 51, 52};
test_reasoning_budget("activate immediately budget=0", seq, start, end, forced, msg, 0, 0, REASONING_BUDGET_COUNTING, 0, 1);
}
// Test 4: No start/end tokens configured - passthrough (no forcing)
// Test 4: No start/end — passthrough
{
const std::vector<llama_token> start = {};
const std::vector<llama_token> end = {};
const std::vector<llama_token> forced = {102};
const std::vector<llama_token> sequence = {50, 51, 52, 53};
test_reasoning_budget("no start/end configured", sequence, start, end, forced,
2, // budget
REASONING_BUDGET_IDLE,
SIZE_MAX, SIZE_MAX); // no forcing (no start/end configured)
std::vector<llama_token> start = {}, end = {}, forced = {102}, msg = {};
std::vector<llama_token> seq = {50, 51, 52, 53};
test_reasoning_budget("no start/end configured", seq, start, end, forced, msg, 2, 0, REASONING_BUDGET_IDLE, SIZE_MAX, SIZE_MAX);
}
// Test 5: Activate immediately with budget > 0, count down then force
// Flow: i=0 accept(50)->remaining=1, i=1 accept(51)->remaining=0->FORCING
// Forcing starts at i=2 (apply sees FORCING after accept at i=1 transitioned)
// Test 5: Start in COUNTING state, count down then force
{
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
const std::vector<llama_token> forced = {102, 101};
const std::vector<llama_token> sequence = {50, 51, 52, 53};
test_reasoning_budget("activate immediately with budget", sequence, start, end, forced,
2, // budget of 2 tokens
REASONING_BUDGET_COUNTING,
2, // forcing starts at i=2 (after 2 accepts deplete budget, apply at i=2 forces)
3); // forcing continues through i=3
std::vector<llama_token> start = {100}, end = {101}, forced = {102, 101}, msg = {};
std::vector<llama_token> seq = {50, 51, 52, 53};
test_reasoning_budget("activate immediately with budget", seq, start, end, forced, msg, 2, 0, REASONING_BUDGET_COUNTING, 2, 3);
}
printf("OK (5 tests passed)\n");
// Test 6: Two-phase — model concludes naturally in conclusion window
{
std::vector<llama_token> start = {100}, end = {101}, forced = {101}, msg = {200};
std::vector<llama_token> seq = {100, 50, 51, 200, 101, 52};
test_reasoning_budget("two-phase natural end in conclusion window", seq, start, end, forced, msg, 2, 3, REASONING_BUDGET_IDLE, 3, 3);
}
// Test 7: Two-phase — conclusion budget exhausted, safety net fires
{
std::vector<llama_token> start = {100}, end = {101}, forced = {101}, msg = {200};
std::vector<llama_token> seq = {100, 50, 51, 200, 52, 101};
test_reasoning_budget("two-phase conclusion budget exhausted safety net fires", seq, start, end, forced, msg, 2, 1, REASONING_BUDGET_IDLE, 3, 5);
}
// Test 8: Two-phase — no message tokens, conclusion only (skips INJECTING)
{
std::vector<llama_token> start = {100}, end = {101}, forced = {101}, msg = {};
std::vector<llama_token> seq = {100, 50, 51, 101, 52};
test_reasoning_budget("two-phase no message tokens conclusion only", seq, start, end, forced, msg, 2, 5, REASONING_BUDGET_IDLE, SIZE_MAX, SIZE_MAX);
}
// Test 9: Backward compat — conclusion_budget=0 identical to original
{
std::vector<llama_token> start = {100}, end = {101}, forced = {102, 101}, msg = {};
std::vector<llama_token> seq = {100, 50, 51, 52, 53};
test_reasoning_budget("backward compat conclusion_budget=0", seq, start, end, forced, msg, 2, 0, REASONING_BUDGET_IDLE, 3, 4);
}
// Test 10: Two-phase — multi-token message (3 tokens all forced before CONCLUDING)
{
std::vector<llama_token> start = {100}, end = {101}, forced = {101}, msg = {200, 201, 202};
std::vector<llama_token> seq = {100, 50, 51, 200, 201, 202, 101, 52};
test_reasoning_budget("two-phase multi-token message injection", seq, start, end, forced, msg, 2, 5, REASONING_BUDGET_IDLE, 3, 5);
}
printf("OK (10 tests passed)\n");
printf("Testing UTF-8 boundary detection... ");
test_utf8_boundary_detection();
printf("OK\n");
return 0;
}
}