common: add two-phase graceful reasoning budget termination ...

This commit is contained in:
Zeel 2026-03-28 23:00:38 -04:00
parent afe65aa282
commit be5cd55750
5 changed files with 136 additions and 37 deletions

View File

@ -3100,8 +3100,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)",
[](common_params & params, const std::string & value) {
params.reasoning_budget_message = value;
params.sampling.reasoning_budget_message = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
add_opt(common_arg(
{"--reasoning-budget-conclusion"}, "N",
"token budget for conclusion phase after message injection (0 = disabled, default: 0).\n"
"When set, the model is given N tokens to conclude naturally after the budget message\n"
"is injected, instead of immediately forcing the end-of-thinking tag.",
[](common_params & params, int value) {
if (value < 0) { throw std::invalid_argument("invalid value for --reasoning-budget-conclusion: must be >= 0"); }
params.reasoning_budget_conclusion = value;
params.sampling.reasoning_budget_conclusion = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_CONCLUSION"));
add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE",
string_format(
@ -3894,4 +3906,4 @@ void common_params_add_preset_options(std::vector<common_arg> & args) {
// "in server router mode, do not unload this model if models_max is exceeded",
// [](common_params &) { /* unused */ }
// ).set_preset_only());
}
}

View File

@ -283,10 +283,12 @@ struct common_params_sampling {
// reasoning budget sampler parameters
// these are populated by the server/CLI based on chat template params
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
int32_t reasoning_budget_conclusion = 0; // tokens reserved for conclusion phase (0 = disabled)
std::string reasoning_budget_message; // message injected at start of conclusion phase
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
std::vector<llama_token> reasoning_budget_forced; // forced sequence (end tag, hard-cutoff safety net)
bool backend_sampling = false;
@ -593,6 +595,7 @@ struct common_params {
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable
int reasoning_budget = -1;
int reasoning_budget_conclusion = 0; // tokens reserved for conclusion phase (0 = disabled)
std::string reasoning_budget_message; // message injected before end tag when budget exhausted
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time
@ -995,4 +998,4 @@ inline llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
// "adamw" or "sgd" (case insensitive)
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);

View File

@ -41,15 +41,18 @@ struct common_reasoning_budget_ctx {
token_matcher start_matcher;
token_matcher end_matcher;
std::vector<llama_token> forced_tokens;
std::vector<llama_token> forced_tokens; // end-of-thinking sequence (hard-cutoff safety net)
std::vector<llama_token> message_tokens; // message injected at conclusion phase start
int32_t budget; // maximum tokens in reasoning block
int32_t remaining; // tokens remaining in budget
int32_t budget; // maximum tokens in thinking phase
int32_t conclusion_budget; // tokens reserved for conclusion phase (0 = disabled)
int32_t remaining; // tokens remaining in current phase
common_reasoning_budget_state state;
// for forcing
size_t force_pos; // next position in forced_tokens to force
size_t force_pos; // next position in forced_tokens to force
size_t message_pos; // next position in message_tokens to force (during message injection)
};
static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) {
@ -76,6 +79,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
break;
}
case REASONING_BUDGET_COUNTING:
case REASONING_BUDGET_CONCLUDING:
case REASONING_BUDGET_WAITING_UTF8:
{
if (ctx->end_matcher.advance(token)) {
@ -97,10 +101,36 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
}
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
} else if (ctx->state == REASONING_BUDGET_CONCLUDING) {
ctx->remaining--;
if (ctx->remaining <= 0) {
if (utf8_complete) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: conclusion budget exhausted, forcing end sequence\n");
} else {
ctx->state = REASONING_BUDGET_WAITING_UTF8;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: conclusion budget exhausted, waiting for UTF-8 completion\n");
}
}
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
ctx->remaining--;
if (ctx->remaining <= 0) {
if (ctx->conclusion_budget > 0 && !ctx->message_tokens.empty()) {
// Two-phase: force message tokens first, then let model conclude
ctx->state = REASONING_BUDGET_INJECTING;
ctx->message_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: thinking budget exhausted, injecting conclusion message\n");
} else if (ctx->conclusion_budget > 0) {
// No message, but conclusion budget set — go straight to free conclusion
ctx->state = REASONING_BUDGET_CONCLUDING;
ctx->remaining = ctx->conclusion_budget;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: thinking budget exhausted, entering conclusion phase (%d tokens)\n", ctx->conclusion_budget);
} else if (utf8_complete) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
@ -114,6 +144,16 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
}
break;
}
case REASONING_BUDGET_INJECTING:
ctx->message_pos++;
if (ctx->message_pos >= ctx->message_tokens.size()) {
// Message fully injected — enter free conclusion phase
ctx->state = REASONING_BUDGET_CONCLUDING;
ctx->remaining = ctx->conclusion_budget;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: message injected, entering conclusion phase (%d tokens)\n", ctx->conclusion_budget);
}
break;
case REASONING_BUDGET_FORCING:
ctx->force_pos++;
if (ctx->force_pos >= ctx->forced_tokens.size()) {
@ -129,6 +169,19 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
if (ctx->state == REASONING_BUDGET_INJECTING) {
if (ctx->message_pos >= ctx->message_tokens.size()) {
return;
}
const llama_token forced = ctx->message_tokens[ctx->message_pos];
for (size_t i = 0; i < cur_p->size; i++) {
if (cur_p->data[i].id != forced) {
cur_p->data[i].logit = -INFINITY;
}
}
return;
}
if (ctx->state != REASONING_BUDGET_FORCING) {
// passthrough — don't modify logits
return;
@ -155,13 +208,15 @@ static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
ctx->start_matcher.reset();
ctx->end_matcher.reset();
ctx->force_pos = 0;
ctx->message_pos = 0;
}
// forward declaration for use in clone
static struct llama_sampler * common_reasoning_budget_init_state(
const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens,
int32_t budget, common_reasoning_budget_state initial_state);
const std::vector<llama_token> & message_tokens,
int32_t budget, int32_t conclusion_budget, common_reasoning_budget_state initial_state);
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
@ -170,7 +225,9 @@ static struct llama_sampler * common_reasoning_budget_clone(const struct llama_s
ctx->start_matcher.tokens,
ctx->end_matcher.tokens,
ctx->forced_tokens,
ctx->message_tokens,
ctx->budget,
ctx->conclusion_budget,
ctx->state);
}
@ -196,7 +253,9 @@ static struct llama_sampler * common_reasoning_budget_init_state(
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
const std::vector<llama_token> & message_tokens,
int32_t budget,
int32_t conclusion_budget,
common_reasoning_budget_state initial_state) {
// promote COUNTING with budget <= 0 to FORCING
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
@ -206,14 +265,17 @@ static struct llama_sampler * common_reasoning_budget_init_state(
return llama_sampler_init(
/* .iface = */ &common_reasoning_budget_i,
/* .ctx = */ new common_reasoning_budget_ctx {
/* .vocab = */ vocab,
/* .start_matcher = */ { start_tokens, 0 },
/* .end_matcher = */ { end_tokens, 0 },
/* .forced_tokens = */ forced_tokens,
/* .budget = */ budget,
/* .remaining = */ budget,
/* .state = */ initial_state,
/* .force_pos = */ 0,
/* .vocab = */ vocab,
/* .start_matcher = */ { start_tokens, 0 },
/* .end_matcher = */ { end_tokens, 0 },
/* .forced_tokens = */ forced_tokens,
/* .message_tokens = */ message_tokens,
/* .budget = */ budget,
/* .conclusion_budget = */ conclusion_budget,
/* .remaining = */ budget,
/* .state = */ initial_state,
/* .force_pos = */ 0,
/* .message_pos = */ 0,
}
);
}
@ -223,7 +285,9 @@ struct llama_sampler * common_reasoning_budget_init(
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
const std::vector<llama_token> & message_tokens,
int32_t budget,
int32_t conclusion_budget,
const std::vector<llama_token> & prefill_tokens) {
// Determine initial state from prefill: COUNTING if the prefill begins with
// the start sequence but does not also contain the end sequence after it.
@ -243,7 +307,7 @@ struct llama_sampler * common_reasoning_budget_init(
}
}
}
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, message_tokens, budget, conclusion_budget, initial_state);
}
struct llama_sampler * common_reasoning_budget_init(
@ -251,9 +315,11 @@ struct llama_sampler * common_reasoning_budget_init(
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
const std::vector<llama_token> & message_tokens,
int32_t budget,
int32_t conclusion_budget,
common_reasoning_budget_state initial_state) {
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, message_tokens, budget, conclusion_budget, initial_state);
}
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl) {
@ -261,4 +327,4 @@ common_reasoning_budget_state common_reasoning_budget_get_state(const struct lla
return REASONING_BUDGET_IDLE;
}
return ((const common_reasoning_budget_ctx *)smpl->ctx)->state;
}
}

View File

@ -8,7 +8,9 @@
enum common_reasoning_budget_state {
REASONING_BUDGET_IDLE, // waiting for start sequence
REASONING_BUDGET_COUNTING, // counting down tokens
REASONING_BUDGET_FORCING, // forcing budget message + end sequence
REASONING_BUDGET_INJECTING, // forcing message tokens before conclusion phase
REASONING_BUDGET_CONCLUDING, // conclusion phase: model free to conclude naturally
REASONING_BUDGET_FORCING, // forcing end sequence (hard cutoff safety net)
REASONING_BUDGET_WAITING_UTF8, // budget exhausted, waiting for UTF-8 completion
REASONING_BUDGET_DONE, // passthrough forever
};
@ -16,31 +18,36 @@ enum common_reasoning_budget_state {
// Creates a reasoning budget sampler that limits token generation inside a
// reasoning block (e.g. between <think> and </think>).
//
// State machine: IDLE -> COUNTING -> WAITING_UTF8 -> FORCING -> DONE
// State machine: IDLE -> COUNTING -> CONCLUDING -> WAITING_UTF8 -> FORCING -> DONE
// IDLE: passthrough, watching for start_tokens sequence
// COUNTING: counting down remaining tokens, watching for natural end_tokens
// CONCLUDING: conclusion phase after message injection; model generates freely
// until it produces end_tokens naturally or conclusion_budget runs out
// WAITING_UTF8: budget exhausted, allowing tokens to complete a UTF-8 sequence
// FORCING: forces forced_tokens token-by-token (all other logits -> -inf)
// DONE: passthrough forever
//
// Parameters:
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
// start_tokens - token sequence that activates counting
// end_tokens - token sequence for natural deactivation
// forced_tokens - token sequence forced when budget expires
// budget - max tokens allowed in the reasoning block
// prefill_tokens - tokens already present in the prompt (generation prompt);
// used to determine the initial state: COUNTING if they begin
// with start_tokens (but don't also end with end_tokens),
// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING.
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
// start_tokens - token sequence that activates counting
// end_tokens - token sequence for natural deactivation
// forced_tokens - token sequence forced when budget expires (hard-cutoff safety net)
// budget - max tokens allowed in the thinking phase
// conclusion_budget - tokens reserved for conclusion phase (0 = disabled, original behavior)
// prefill_tokens - tokens already present in the prompt (generation prompt);
// used to determine the initial state: COUNTING if they begin
// with start_tokens (but don't also end with end_tokens),
// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING.
//
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
const std::vector<llama_token> & message_tokens,
int32_t budget,
const std::vector<llama_token> & prefill_tokens = {});
int32_t conclusion_budget = 0,
const std::vector<llama_token> & prefill_tokens = {});
// Variant that takes an explicit initial state (used by tests and clone).
// COUNTING with budget <= 0 is promoted to FORCING.
@ -49,7 +56,9 @@ struct llama_sampler * common_reasoning_budget_init(
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
const std::vector<llama_token> & message_tokens,
int32_t budget,
int32_t conclusion_budget,
common_reasoning_budget_state initial_state);
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl);
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl);

View File

@ -289,12 +289,21 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
// reasoning budget sampler
if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty()) {
// Tokenize the budget message separately so it can be injected before the
// conclusion phase (two-phase graceful termination). The forced_tokens
// sequence (message + end tag) is retained as the hard-cutoff safety net.
std::vector<llama_token> message_tokens;
if (!params.reasoning_budget_message.empty() && params.reasoning_budget_conclusion > 0) {
message_tokens = common_tokenize(vocab, params.reasoning_budget_message, false, true);
}
rbudget = common_reasoning_budget_init(
vocab,
params.reasoning_budget_start,
params.reasoning_budget_end,
params.reasoning_budget_forced,
message_tokens,
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens,
params.reasoning_budget_conclusion,
prefill_tokens);
}
@ -823,4 +832,4 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
}
return samplers;
}
}