common/parser: handle reasoning budget (#20297)

* v1

* Finished!

* Handlie cli

* Reasoning sampler

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Less explosive terminology :)

* Add utf-8 case and tests

* common : migrate reasoning budget sampler to common

* cont : clean up

* cont : expose state and allow passing as initial state

* cont : remove unused imports

* cont : update state machine doc string

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Alde Rojas <hello@alde.dev>
This commit is contained in:
Piotr Wilkin (ilintar) 2026-03-11 10:26:12 +01:00 committed by GitHub
parent 5f91b1d5d5
commit acb7c79069
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 670 additions and 10 deletions

View File

@ -81,6 +81,8 @@ add_library(${TARGET} STATIC
preset.cpp preset.cpp
preset.h preset.h
regex-partial.cpp regex-partial.cpp
reasoning-budget.cpp
reasoning-budget.h
regex-partial.h regex-partial.h
sampling.cpp sampling.cpp
sampling.h sampling.h

View File

@ -2913,6 +2913,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
auto parsed = json::parse(value); auto parsed = json::parse(value);
for (const auto & item : parsed.items()) { for (const auto & item : parsed.items()) {
if (item.key() == "enable_thinking") {
LOG_WRN("Setting 'enable_thinking' via --chat-template-kwargs is deprecated. "
"Use --reasoning on / --reasoning off instead.\n");
}
params.default_template_kwargs[item.key()] = item.value().dump(); params.default_template_kwargs[item.key()] = item.value().dump();
} }
} }
@ -3048,14 +3052,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.reasoning_format = common_reasoning_format_from_name(value); params.reasoning_format = common_reasoning_format_from_name(value);
} }
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK")); ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK"));
add_opt(common_arg(
{"-rea", "--reasoning"}, "[on|off|auto]",
"Use reasoning/thinking in the chat ('on', 'off', or 'auto', default: 'auto' (detect from template))",
[](common_params & params, const std::string & value) {
if (is_truthy(value)) {
params.enable_reasoning = 1;
params.default_template_kwargs["enable_thinking"] = "true";
} else if (is_falsey(value)) {
params.enable_reasoning = 0;
params.default_template_kwargs["enable_thinking"] = "false";
} else if (is_autoy(value)) {
params.enable_reasoning = -1;
} else {
throw std::invalid_argument(
string_format("error: unknown value for --reasoning: '%s'\n", value.c_str()));
}
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING"));
add_opt(common_arg( add_opt(common_arg(
{"--reasoning-budget"}, "N", {"--reasoning-budget"}, "N",
"controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)", "token budget for thinking: -1 for unrestricted, 0 for immediate end, N>0 for token budget (default: -1)",
[](common_params & params, int value) { [](common_params & params, int value) {
if (value != 0 && value != -1) { throw std::invalid_argument("invalid value"); } if (value < -1) { throw std::invalid_argument("invalid value"); }
params.reasoning_budget = value; params.reasoning_budget = value;
} }
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET")); ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET"));
add_opt(common_arg(
{"--reasoning-budget-message"}, "MESSAGE",
"message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)",
[](common_params & params, const std::string & value) {
params.reasoning_budget_message = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
add_opt(common_arg( add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE", {"--chat-template"}, "JINJA_TEMPLATE",
string_format( string_format(

View File

@ -135,7 +135,9 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
if (thinking_forced_open || thinking_forced_closed) { if (thinking_forced_open || thinking_forced_closed) {
// Thinking is forced open OR forced closed with enable_thinking=true // Thinking is forced open OR forced closed with enable_thinking=true
// In both cases, expect only the closing tag (opening was in template) // In both cases, expect only the closing tag (opening was in template)
return p.reasoning(p.until(end)) + end; // However, since we might have incorrectly detected the open/close pattern,
// we admit an optional starting marker
return p.optional(p.literal(start)) + p.reasoning(p.until(end)) + end;
} }
if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) { if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) {
// Standard tag-based reasoning OR tools-only mode (reasoning appears with tools) // Standard tag-based reasoning OR tools-only mode (reasoning appears with tools)

View File

@ -857,7 +857,9 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = true; auto include_grammar = true;
data.supports_thinking = true; data.supports_thinking = true;
data.thinking_start_tag = "[THINK]";
data.thinking_end_tag = "[/THINK]";
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages); data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = { data.preserved_tokens = {
@ -1165,9 +1167,11 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
const autoparser::templates_params & inputs) { const autoparser::templates_params & inputs) {
common_chat_params data; common_chat_params data;
data.prompt = common_chat_template_direct_apply(tmpl, inputs); data.prompt = common_chat_template_direct_apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true; data.supports_thinking = true;
data.thinking_start_tag = "<think>";
data.thinking_end_tag = "</think>";
data.preserved_tokens = { data.preserved_tokens = {
"<|tool_calls_section_begin|>", "<|tool_calls_section_begin|>",
"<|tool_calls_section_end|>", "<|tool_calls_section_end|>",
@ -1527,6 +1531,16 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
autoparser.analyze_template(tmpl); autoparser.analyze_template(tmpl);
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser); auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE; auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
if (auto_params.supports_thinking) {
auto_params.thinking_start_tag = autoparser.reasoning.start;
auto_params.thinking_end_tag = autoparser.reasoning.end;
// FORCED_OPEN and FORCED_CLOSED both put <think> in the generation prompt
// (FORCED_CLOSED forces empty <think></think> when thinking is disabled,
// but forces <think> open when thinking is enabled)
auto_params.thinking_forced_open =
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN ||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED;
}
return auto_params; return auto_params;
} catch (const std::exception & e) { } catch (const std::exception & e) {
throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what()); throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what());

View File

@ -213,6 +213,8 @@ struct common_chat_params {
bool grammar_lazy = false; bool grammar_lazy = false;
bool thinking_forced_open = false; bool thinking_forced_open = false;
bool supports_thinking = false; bool supports_thinking = false;
std::string thinking_start_tag; // e.g., "<think>"
std::string thinking_end_tag; // e.g., "</think>"
std::vector<common_grammar_trigger> grammar_triggers; std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens; std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops; std::vector<std::string> additional_stops;

View File

@ -235,6 +235,14 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
// reasoning budget sampler parameters
// these are populated by the server/CLI based on chat template params
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
bool reasoning_budget_activate_immediately = false;
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)
bool backend_sampling = false; bool backend_sampling = false;
bool has_logit_bias() const { bool has_logit_bias() const {
@ -536,7 +544,9 @@ struct common_params {
bool use_jinja = true; // NOLINT bool use_jinja = true; // NOLINT
bool enable_chat_template = true; bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable
int reasoning_budget = -1; int reasoning_budget = -1;
std::string reasoning_budget_message; // message injected before end tag when budget exhausted
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time

219
common/reasoning-budget.cpp Normal file
View File

@ -0,0 +1,219 @@
#include "reasoning-budget.h"
#include "common.h"
#include "unicode.h"
#include "log.h"
#include <cmath>
#include <cstdint>
#include <string>
#include <vector>
struct token_matcher {
std::vector<llama_token> tokens;
size_t pos = 0;
bool advance(llama_token token) {
if (tokens.empty()) {
return false;
}
if (token == tokens[pos]) {
pos++;
if (pos >= tokens.size()) {
pos = 0;
return true;
}
} else {
pos = 0;
if (token == tokens[0]) {
pos = 1;
}
}
return false;
}
void reset() { pos = 0; }
};
struct common_reasoning_budget_ctx {
const llama_vocab * vocab;
token_matcher start_matcher;
token_matcher end_matcher;
std::vector<llama_token> forced_tokens;
int32_t budget; // maximum tokens in reasoning block
int32_t remaining; // tokens remaining in budget
common_reasoning_budget_state state;
// for forcing
size_t force_pos; // next position in forced_tokens to force
};
static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) {
return "reasoning-budget";
}
static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
switch (ctx->state) {
case REASONING_BUDGET_IDLE:
{
if (ctx->start_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_COUNTING;
ctx->remaining = ctx->budget;
LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
if (ctx->remaining <= 0) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
}
}
break;
}
case REASONING_BUDGET_COUNTING:
case REASONING_BUDGET_WAITING_UTF8:
{
if (ctx->end_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: deactivated (natural end)\n");
break;
}
bool utf8_complete = true;
if (ctx->vocab != nullptr) {
const std::string piece = common_token_to_piece(ctx->vocab, token, false);
utf8_complete = common_utf8_is_complete(piece);
}
if (ctx->state == REASONING_BUDGET_WAITING_UTF8) {
if (utf8_complete) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
}
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
ctx->remaining--;
if (ctx->remaining <= 0) {
if (utf8_complete) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
} else {
ctx->state = REASONING_BUDGET_WAITING_UTF8;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
}
}
}
break;
}
case REASONING_BUDGET_FORCING:
// force_pos is advanced in apply(), not here.
// This ensures the first forced token isn't skipped when the sampler
// is initialized directly in FORCING state (e.g. COUNTING + budget=0)
break;
case REASONING_BUDGET_DONE:
break;
}
}
static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
if (ctx->state != REASONING_BUDGET_FORCING) {
// passthrough — don't modify logits
return;
}
if (ctx->force_pos >= ctx->forced_tokens.size()) {
return;
}
const llama_token forced = ctx->forced_tokens[ctx->force_pos];
// set all logits to -inf except the forced token
for (size_t i = 0; i < cur_p->size; i++) {
if (cur_p->data[i].id != forced) {
cur_p->data[i].logit = -INFINITY;
}
}
// advance to next forced token (done here rather than in accept so that
// the first forced token isn't skipped when starting in FORCING state)
ctx->force_pos++;
if (ctx->force_pos >= ctx->forced_tokens.size()) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: forced sequence complete, done\n");
}
}
static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
ctx->state = REASONING_BUDGET_IDLE;
ctx->remaining = ctx->budget;
ctx->start_matcher.reset();
ctx->end_matcher.reset();
ctx->force_pos = 0;
}
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
return common_reasoning_budget_init(
ctx->vocab,
ctx->start_matcher.tokens,
ctx->end_matcher.tokens,
ctx->forced_tokens,
ctx->budget,
ctx->state);
}
static void common_reasoning_budget_free(struct llama_sampler * smpl) {
delete (common_reasoning_budget_ctx *) smpl->ctx;
}
static struct llama_sampler_i common_reasoning_budget_i = {
/* .name = */ common_reasoning_budget_name,
/* .accept = */ common_reasoning_budget_accept,
/* .apply = */ common_reasoning_budget_apply,
/* .reset = */ common_reasoning_budget_reset,
/* .clone = */ common_reasoning_budget_clone,
/* .free = */ common_reasoning_budget_free,
/* .backend_init = */ nullptr,
/* .backend_accept = */ nullptr,
/* .backend_apply = */ nullptr,
/* .backend_set_input = */ nullptr,
};
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state) {
// promote COUNTING with budget <= 0 to FORCING
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
initial_state = REASONING_BUDGET_FORCING;
}
return llama_sampler_init(
/* .iface = */ &common_reasoning_budget_i,
/* .ctx = */ new common_reasoning_budget_ctx {
/* .vocab = */ vocab,
/* .start_matcher = */ { start_tokens, 0 },
/* .end_matcher = */ { end_tokens, 0 },
/* .forced_tokens = */ forced_tokens,
/* .budget = */ budget,
/* .remaining = */ budget,
/* .state = */ initial_state,
/* .force_pos = */ 0,
}
);
}

41
common/reasoning-budget.h Normal file
View File

@ -0,0 +1,41 @@
#pragma once
#include "llama.h"
#include <cstdint>
#include <vector>
enum common_reasoning_budget_state {
REASONING_BUDGET_IDLE, // waiting for start sequence
REASONING_BUDGET_COUNTING, // counting down tokens
REASONING_BUDGET_FORCING, // forcing budget message + end sequence
REASONING_BUDGET_WAITING_UTF8, // budget exhausted, waiting for UTF-8 completion
REASONING_BUDGET_DONE, // passthrough forever
};
// Creates a reasoning budget sampler that limits token generation inside a
// reasoning block (e.g. between <think> and </think>).
//
// State machine: IDLE -> COUNTING -> WAITING_UTF8 -> FORCING -> DONE
// IDLE: passthrough, watching for start_tokens sequence
// COUNTING: counting down remaining tokens, watching for natural end_tokens
// WAITING_UTF8: budget exhausted, allowing tokens to complete a UTF-8 sequence
// FORCING: forces forced_tokens token-by-token (all other logits -> -inf)
// DONE: passthrough forever
//
// Parameters:
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
// start_tokens - token sequence that activates counting
// end_tokens - token sequence for natural deactivation
// forced_tokens - token sequence forced when budget expires
// budget - max tokens allowed in the reasoning block
// initial_state - initial state of the sampler (e.g. IDLE or COUNTING)
// note: COUNTING with budget <= 0 is promoted to FORCING
//
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state);

View File

@ -2,6 +2,7 @@
#include "common.h" #include "common.h"
#include "log.h" #include "log.h"
#include "reasoning-budget.h"
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
@ -250,6 +251,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
} }
} }
// reasoning budget sampler — added first so it can force tokens before other samplers
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
samplers.push_back(common_reasoning_budget_init(
vocab,
params.reasoning_budget_start,
params.reasoning_budget_end,
params.reasoning_budget_forced,
params.reasoning_budget_tokens,
params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE));
}
if (params.has_logit_bias()) { if (params.has_logit_bias()) {
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data())); samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
} }

View File

@ -1,8 +1,10 @@
#include "unicode.h" #include "unicode.h"
#include <algorithm>
#include <cassert> #include <cassert>
#include <stdexcept> #include <stdexcept>
#include <vector>
#include <string> #include <string>
#include <vector>
// implementation adopted from src/unicode.cpp // implementation adopted from src/unicode.cpp
@ -67,6 +69,20 @@ utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t off
return utf8_parse_result(utf8_parse_result::INVALID); return utf8_parse_result(utf8_parse_result::INVALID);
} }
bool common_utf8_is_complete(const std::string & s) {
if (s.empty()) {
return true;
}
for (int i = 1; i <= std::min(4, (int)s.size()); i++) {
unsigned char c = s[s.size() - i];
if ((c & 0xC0) != 0x80) {
int expected = (c >= 0xF0) ? 4 : (c >= 0xE0) ? 3 : (c >= 0xC0) ? 2 : 1;
return i >= expected;
}
}
return false;
}
std::string common_unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) { std::string common_unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
std::string result; std::string result;
for (size_t i = 0; i < cps.size(); ++i) { for (size_t i = 0; i < cps.size(); ++i) {

View File

@ -20,6 +20,9 @@ struct utf8_parse_result {
// Returns 0 for invalid first bytes // Returns 0 for invalid first bytes
size_t common_utf8_sequence_length(unsigned char first_byte); size_t common_utf8_sequence_length(unsigned char first_byte);
// Check if a string ends with a complete UTF-8 sequence.
bool common_utf8_is_complete(const std::string & s);
// Parse a single UTF-8 codepoint from input // Parse a single UTF-8 codepoint from input
utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset); utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset);

View File

@ -149,6 +149,7 @@ endif ()
if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) if (NOT WIN32 OR NOT BUILD_SHARED_LIBS)
# these tests are disabled on Windows because they use internal functions not exported with LLAMA_API (when building with shared libraries) # these tests are disabled on Windows because they use internal functions not exported with LLAMA_API (when building with shared libraries)
llama_build_and_test(test-sampling.cpp) llama_build_and_test(test-sampling.cpp)
llama_build_and_test(test-reasoning-budget.cpp)
llama_build_and_test(test-grammar-parser.cpp) llama_build_and_test(test-grammar-parser.cpp)
llama_build_and_test(test-grammar-integration.cpp) llama_build_and_test(test-grammar-integration.cpp)
llama_build_and_test(test-llama-grammar.cpp) llama_build_and_test(test-llama-grammar.cpp)

View File

@ -0,0 +1,238 @@
#include "reasoning-budget.h"
#include "unicode.h"
#include "llama.h"
#include "ggml.h"
#ifdef NDEBUG
#undef NDEBUG
#endif
#include <cmath>
#include <cstddef>
#include <cstdio>
#include <string>
#include <vector>
// Reasoning budget sampler test helper
// These tests use nullptr vocab which safely falls back to treating all tokens as complete
// (The UTF-8 boundary detection logic is tested separately in test_utf8_boundary_detection)
static void test_reasoning_budget(
const char * test_name,
const std::vector<llama_token> & sequence,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state,
size_t expected_force_start, // token index where forcing should start (SIZE_MAX = never)
size_t expected_force_end // token index where forcing should end (after this, no more forcing)
) {
// Find the maximum token ID to ensure our vocab covers all tokens
llama_token max_token = 0;
for (auto t : sequence) max_token = std::max(max_token, t);
for (auto t : start_tokens) max_token = std::max(max_token, t);
for (auto t : end_tokens) max_token = std::max(max_token, t);
for (auto t : forced_tokens) max_token = std::max(max_token, t);
// Create a minimal sampler with mock vocabulary
// For this test, we use nullptr as vocab since we're testing state transitions
// The UTF-8 boundary check will treat all tokens as complete (safe fallback)
auto * sampler = common_reasoning_budget_init(
nullptr, // vocab - not used for basic state machine tests
start_tokens,
end_tokens,
forced_tokens,
budget,
initial_state
);
// Create a test token data array for checking forcing behavior
// Vocab size must be large enough to include all tokens (start, end, forced, sequence)
std::vector<llama_token_data> cur;
const size_t n_vocab = (size_t)max_token + 1;
for (size_t i = 0; i < n_vocab; i++) {
cur.emplace_back(llama_token_data{(llama_token)i, logf((float)(i+1)), 0.0f});
}
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
size_t actual_force_start = SIZE_MAX;
size_t actual_force_end = SIZE_MAX;
// Feed the sequence and track when forcing occurs
for (size_t i = 0; i < sequence.size(); i++) {
llama_sampler_accept(sampler, sequence[i]);
// Check if we're in forcing state by applying and seeing if logits are modified
cur_p.selected = -1;
for (size_t j = 0; j < cur.size(); j++) {
cur[j].logit = logf((float)(j+1)); // reset logits
}
llama_sampler_apply(sampler, &cur_p);
// Check if forcing is active (all logits except one should be -INFINITY)
size_t finite_count = 0;
llama_token finite_token = -1;
for (size_t j = 0; j < cur.size(); j++) {
if (std::isfinite(cur[j].logit)) {
finite_count++;
finite_token = cur[j].id;
}
}
fprintf(stderr, " i=%zu: token=%d, finite_count=%zu, finite_token=%d\n", i, (int)sequence[i], finite_count, (int)finite_token);
if (finite_count == 1) {
if (actual_force_start == SIZE_MAX) {
actual_force_start = i;
}
actual_force_end = i;
} else if (actual_force_start != SIZE_MAX && actual_force_end != SIZE_MAX) {
// Forcing stopped
break;
}
}
llama_sampler_free(sampler);
// Verify forcing occurred at expected positions
if (expected_force_start == SIZE_MAX) {
if (actual_force_start != SIZE_MAX) {
fprintf(stderr, "Test '%s' FAILED: Expected no forcing, but forcing occurred at %zu\n", test_name, actual_force_start);
GGML_ASSERT(false && "Expected no forcing, but forcing occurred");
}
} else {
if (actual_force_start == SIZE_MAX) {
fprintf(stderr, "Test '%s' FAILED: Expected forcing but none occurred\n", test_name);
GGML_ASSERT(false && "Expected forcing but none occurred");
}
if (actual_force_start != expected_force_start) {
fprintf(stderr, "Test '%s' FAILED: Forcing started at %zu, expected %zu\n", test_name, actual_force_start, expected_force_start);
GGML_ASSERT(false && "Forcing started at wrong position");
}
}
if (expected_force_end != SIZE_MAX) {
if (actual_force_end < expected_force_end) {
fprintf(stderr, "Test '%s' FAILED: Forcing ended at %zu, expected >= %zu\n", test_name, actual_force_end, expected_force_end);
GGML_ASSERT(false && "Forcing ended too early");
}
}
fprintf(stderr, " Test '%s' passed (force_start=%zu, force_end=%zu)\n", test_name, actual_force_start, actual_force_end);
(void)sequence;
}
// UTF-8 boundary detection unit test
// Tests common_utf8_is_complete() from reasoning-budget.h
static void test_utf8_boundary_detection() {
// Complete sequences
GGML_ASSERT(common_utf8_is_complete("hello"));
GGML_ASSERT(common_utf8_is_complete(""));
GGML_ASSERT(common_utf8_is_complete("\xC2\xA0")); // complete 2-byte UTF-8 (U+00A0)
GGML_ASSERT(common_utf8_is_complete("\xE2\x80\x9C")); // complete 3-byte UTF-8 (left double quote)
GGML_ASSERT(common_utf8_is_complete("\xF0\x9F\x98\x80")); // complete 4-byte UTF-8 (emoji)
GGML_ASSERT(common_utf8_is_complete("abc\xC3\xA9")); // ASCII + complete 2-byte
// Incomplete sequences
GGML_ASSERT(!common_utf8_is_complete(std::string("\xC2", 1))); // 2-byte start, missing continuation
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2\x80", 2))); // 3-byte start + 1 cont, missing 1
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2", 1))); // 3-byte start, missing 2
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F\x98", 3))); // 4-byte start + 2 cont, missing 1
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F", 2))); // 4-byte start + 1 cont, missing 2
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0", 1))); // 4-byte start, missing 3
GGML_ASSERT(!common_utf8_is_complete(std::string("\x80", 1))); // orphan continuation byte
// Mixed: ASCII followed by start of multi-byte
GGML_ASSERT(!common_utf8_is_complete(std::string("hello\xC3", 6))); // ASCII + incomplete 2-byte
GGML_ASSERT(common_utf8_is_complete(std::string("hello\xC3\xA9", 7))); // ASCII + complete 2-byte
}
int main(void) {
// Reasoning budget sampler tests
printf("Testing reasoning budget sampler... ");
// Test 1: Basic budget with start/end tokens - no forcing (natural end before budget exhausted)
{
const std::vector<llama_token> start = {100}; // start token
const std::vector<llama_token> end = {101}; // end token
const std::vector<llama_token> forced = {102}; // forced token (not used in this test)
const std::vector<llama_token> sequence = {100, 50, 51, 101, 52}; // start, two tokens, end, one more
test_reasoning_budget("natural end before budget exhausted", sequence, start, end, forced,
5, // budget of 5 tokens
REASONING_BUDGET_IDLE,
SIZE_MAX, SIZE_MAX); // no forcing expected (natural end)
}
// Test 2: Budget exhausted, forcing should occur
// Flow: i=0 accept(100)->COUNTING, i=1 accept(50)->remaining=1, i=2 accept(51)->remaining=0->FORCING
// Forcing is active at i=2 and i=3 (when apply() is called while in FORCING state)
// At i=4, force_pos becomes 2 which equals forced_tokens.size(), so state becomes DONE
{
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
const std::vector<llama_token> forced = {102, 101}; // forced message + end
const std::vector<llama_token> sequence = {100, 50, 51, 52, 53}; // start + 4 tokens (budget=2)
test_reasoning_budget("budget exhausted forcing", sequence, start, end, forced,
2, // budget of 2 tokens
REASONING_BUDGET_IDLE,
2, // forcing starts at i=2 (after accept(51) depletes budget, apply() forces)
3); // forcing continues through i=3 (at i=4 state becomes DONE)
}
// Test 3: Activate immediately with budget=0, forcing should start right away
// Flow: Since no start token in sequence, state stays IDLE (no start/end configured means passthrough)
// This test needs start token to be in the sequence or use activate_immediately with start token present
{
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
const std::vector<llama_token> forced = {102, 101};
const std::vector<llama_token> sequence = {100, 50, 51, 52}; // start token first, then 3 tokens
test_reasoning_budget("activate immediately budget=0", sequence, start, end, forced,
0, // budget of 0 tokens
REASONING_BUDGET_COUNTING, // starts counting, promoted to FORCING since budget=0
0, // forcing starts at i=0 (after accept(100), budget=0 goes straight to FORCING)
1); // forcing continues through i=1 (at i=2 state becomes DONE)
}
// Test 4: No start/end tokens configured - passthrough (no forcing)
{
const std::vector<llama_token> start = {};
const std::vector<llama_token> end = {};
const std::vector<llama_token> forced = {102};
const std::vector<llama_token> sequence = {50, 51, 52, 53};
test_reasoning_budget("no start/end configured", sequence, start, end, forced,
2, // budget
REASONING_BUDGET_IDLE,
SIZE_MAX, SIZE_MAX); // no forcing (no start/end configured)
}
// Test 5: Activate immediately with budget > 0, count down then force
// Flow: i=0 accept(50)->remaining=1, i=1 accept(51)->remaining=0->FORCING
// So forcing starts at i=1 (apply after accept sees FORCING with force_pos=0)
{
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
const std::vector<llama_token> forced = {102, 101};
const std::vector<llama_token> sequence = {50, 51, 52, 53};
test_reasoning_budget("activate immediately with budget", sequence, start, end, forced,
2, // budget of 2 tokens
REASONING_BUDGET_COUNTING,
1, // forcing starts at i=1 (after 2 accepts deplete budget)
2); // forcing continues through i=2
}
printf("OK (5 tests passed)\n");
printf("Testing UTF-8 boundary detection... ");
test_utf8_boundary_detection();
printf("OK\n");
return 0;
}

View File

@ -57,6 +57,8 @@ struct cli_context {
std::vector<raw_buffer> input_files; std::vector<raw_buffer> input_files;
task_params defaults; task_params defaults;
bool verbose_prompt; bool verbose_prompt;
int reasoning_budget = -1;
std::string reasoning_budget_message;
// thread for showing "loading" animation // thread for showing "loading" animation
std::atomic<bool> loading_show; std::atomic<bool> loading_show;
@ -73,6 +75,8 @@ struct cli_context {
// defaults.return_progress = true; // TODO: show progress // defaults.return_progress = true; // TODO: show progress
verbose_prompt = params.verbose_prompt; verbose_prompt = params.verbose_prompt;
reasoning_budget = params.reasoning_budget;
reasoning_budget_message = params.reasoning_budget_message;
} }
std::string generate_completion(result_timings & out_timings) { std::string generate_completion(result_timings & out_timings) {
@ -95,6 +99,24 @@ struct cli_context {
task.params.chat_parser_params.parser.load(chat_params.parser); task.params.chat_parser_params.parser.load(chat_params.parser);
} }
// reasoning budget sampler
if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) {
const llama_vocab * vocab = llama_model_get_vocab(
llama_get_model(ctx_server.get_llama_context()));
task.params.sampling.reasoning_budget_tokens = reasoning_budget;
task.params.sampling.reasoning_budget_activate_immediately = chat_params.thinking_forced_open;
if (!chat_params.thinking_start_tag.empty()) {
task.params.sampling.reasoning_budget_start =
common_tokenize(vocab, chat_params.thinking_start_tag, false, true);
}
task.params.sampling.reasoning_budget_end =
common_tokenize(vocab, chat_params.thinking_end_tag, false, true);
task.params.sampling.reasoning_budget_forced =
common_tokenize(vocab, reasoning_budget_message + chat_params.thinking_end_tag, false, true);
}
rd.post_task({std::move(task)}); rd.post_task({std::move(task)});
} }

View File

@ -1101,6 +1101,22 @@ json oaicompat_chat_params_parse(
llama_params["chat_parser"] = chat_params.parser; llama_params["chat_parser"] = chat_params.parser;
} }
// Reasoning budget: pass parameters through to sampling layer
{
int reasoning_budget = opt.reasoning_budget;
if (reasoning_budget == -1 && body.contains("thinking_budget_tokens")) {
reasoning_budget = json_value(body, "thinking_budget_tokens", -1);
}
if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) {
llama_params["reasoning_budget_tokens"] = reasoning_budget;
llama_params["reasoning_budget_start_tag"] = chat_params.thinking_start_tag;
llama_params["reasoning_budget_end_tag"] = chat_params.thinking_end_tag;
llama_params["reasoning_budget_message"] = opt.reasoning_budget_message;
llama_params["reasoning_budget_activate_immediately"] = chat_params.thinking_forced_open;
}
}
// Handle "logprobs" field // Handle "logprobs" field
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
if (json_value(body, "logprobs", false)) { if (json_value(body, "logprobs", false)) {

View File

@ -287,6 +287,8 @@ struct server_chat_params {
bool allow_image; bool allow_image;
bool allow_audio; bool allow_audio;
bool enable_thinking = true; bool enable_thinking = true;
int reasoning_budget = -1;
std::string reasoning_budget_message;
std::string media_path; std::string media_path;
}; };

View File

@ -893,9 +893,10 @@ private:
} }
// thinking is enabled if: // thinking is enabled if:
// 1. It's not explicitly disabled (reasoning_budget == 0) // 1. It's not explicitly disabled via --reasoning off
// 2. The chat template supports it // 2. The chat template supports it
const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); const bool template_supports_thinking = params_base.use_jinja && common_chat_templates_support_enable_thinking(chat_templates.get());
const bool enable_thinking = params_base.enable_reasoning != 0 && template_supports_thinking;
SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking); SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking);
chat_params = { chat_params = {
@ -907,6 +908,8 @@ private:
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false, /* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
/* enable_thinking */ enable_thinking, /* enable_thinking */ enable_thinking,
/* reasoning_budget */ params_base.reasoning_budget,
/* reasoning_budget_msg */ params_base.reasoning_budget_message,
/* media_path */ params_base.media_path, /* media_path */ params_base.media_path,
}; };
} }

View File

@ -462,6 +462,34 @@ task_params server_task::params_from_json_cmpl(
} }
} }
// Parse reasoning budget sampler parameters
{
const int32_t budget = json_value(data, "reasoning_budget_tokens", (int32_t) -1);
if (budget >= 0) {
const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string());
const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string());
const auto message = json_value(data, "reasoning_budget_message", std::string());
const bool activate_imm = json_value(data, "reasoning_budget_activate_immediately", false);
params.sampling.reasoning_budget_tokens = budget;
params.sampling.reasoning_budget_activate_immediately = activate_imm;
if (!start_tag.empty()) {
params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true);
}
if (!end_tag.empty()) {
params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true);
params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true);
}
SRV_DBG("reasoning budget: tokens=%d, activate_immediately=%s, start=%zu toks, end=%zu toks, forced=%zu toks\n",
budget, activate_imm ? "true" : "false",
params.sampling.reasoning_budget_start.size(),
params.sampling.reasoning_budget_end.size(),
params.sampling.reasoning_budget_forced.size());
}
}
{ {
params.sampling.logit_bias.clear(); params.sampling.logit_bias.clear();