common/parser: handle reasoning budget (#20297)
* v1 * Finished! * Handlie cli * Reasoning sampler * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Less explosive terminology :) * Add utf-8 case and tests * common : migrate reasoning budget sampler to common * cont : clean up * cont : expose state and allow passing as initial state * cont : remove unused imports * cont : update state machine doc string --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: Alde Rojas <hello@alde.dev>
This commit is contained in:
parent
5f91b1d5d5
commit
acb7c79069
|
|
@ -81,6 +81,8 @@ add_library(${TARGET} STATIC
|
|||
preset.cpp
|
||||
preset.h
|
||||
regex-partial.cpp
|
||||
reasoning-budget.cpp
|
||||
reasoning-budget.h
|
||||
regex-partial.h
|
||||
sampling.cpp
|
||||
sampling.h
|
||||
|
|
|
|||
|
|
@ -2913,6 +2913,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params, const std::string & value) {
|
||||
auto parsed = json::parse(value);
|
||||
for (const auto & item : parsed.items()) {
|
||||
if (item.key() == "enable_thinking") {
|
||||
LOG_WRN("Setting 'enable_thinking' via --chat-template-kwargs is deprecated. "
|
||||
"Use --reasoning on / --reasoning off instead.\n");
|
||||
}
|
||||
params.default_template_kwargs[item.key()] = item.value().dump();
|
||||
}
|
||||
}
|
||||
|
|
@ -3048,14 +3052,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.reasoning_format = common_reasoning_format_from_name(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK"));
|
||||
add_opt(common_arg(
|
||||
{"-rea", "--reasoning"}, "[on|off|auto]",
|
||||
"Use reasoning/thinking in the chat ('on', 'off', or 'auto', default: 'auto' (detect from template))",
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (is_truthy(value)) {
|
||||
params.enable_reasoning = 1;
|
||||
params.default_template_kwargs["enable_thinking"] = "true";
|
||||
} else if (is_falsey(value)) {
|
||||
params.enable_reasoning = 0;
|
||||
params.default_template_kwargs["enable_thinking"] = "false";
|
||||
} else if (is_autoy(value)) {
|
||||
params.enable_reasoning = -1;
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
string_format("error: unknown value for --reasoning: '%s'\n", value.c_str()));
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-budget"}, "N",
|
||||
"controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)",
|
||||
"token budget for thinking: -1 for unrestricted, 0 for immediate end, N>0 for token budget (default: -1)",
|
||||
[](common_params & params, int value) {
|
||||
if (value != 0 && value != -1) { throw std::invalid_argument("invalid value"); }
|
||||
if (value < -1) { throw std::invalid_argument("invalid value"); }
|
||||
params.reasoning_budget = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-budget-message"}, "MESSAGE",
|
||||
"message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.reasoning_budget_message = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template"}, "JINJA_TEMPLATE",
|
||||
string_format(
|
||||
|
|
|
|||
|
|
@ -135,7 +135,9 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
|
|||
if (thinking_forced_open || thinking_forced_closed) {
|
||||
// Thinking is forced open OR forced closed with enable_thinking=true
|
||||
// In both cases, expect only the closing tag (opening was in template)
|
||||
return p.reasoning(p.until(end)) + end;
|
||||
// However, since we might have incorrectly detected the open/close pattern,
|
||||
// we admit an optional starting marker
|
||||
return p.optional(p.literal(start)) + p.reasoning(p.until(end)) + end;
|
||||
}
|
||||
if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) {
|
||||
// Standard tag-based reasoning OR tools-only mode (reasoning appears with tools)
|
||||
|
|
|
|||
|
|
@ -857,7 +857,9 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
|||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = true;
|
||||
|
||||
data.supports_thinking = true;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "[THINK]";
|
||||
data.thinking_end_tag = "[/THINK]";
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
|
|
@ -1165,9 +1167,11 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
|||
const autoparser::templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "<think>";
|
||||
data.thinking_end_tag = "</think>";
|
||||
data.preserved_tokens = {
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_calls_section_end|>",
|
||||
|
|
@ -1527,6 +1531,16 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
|||
autoparser.analyze_template(tmpl);
|
||||
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
|
||||
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
|
||||
if (auto_params.supports_thinking) {
|
||||
auto_params.thinking_start_tag = autoparser.reasoning.start;
|
||||
auto_params.thinking_end_tag = autoparser.reasoning.end;
|
||||
// FORCED_OPEN and FORCED_CLOSED both put <think> in the generation prompt
|
||||
// (FORCED_CLOSED forces empty <think></think> when thinking is disabled,
|
||||
// but forces <think> open when thinking is enabled)
|
||||
auto_params.thinking_forced_open =
|
||||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN ||
|
||||
autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED;
|
||||
}
|
||||
return auto_params;
|
||||
} catch (const std::exception & e) {
|
||||
throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what());
|
||||
|
|
|
|||
|
|
@ -213,6 +213,8 @@ struct common_chat_params {
|
|||
bool grammar_lazy = false;
|
||||
bool thinking_forced_open = false;
|
||||
bool supports_thinking = false;
|
||||
std::string thinking_start_tag; // e.g., "<think>"
|
||||
std::string thinking_end_tag; // e.g., "</think>"
|
||||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
|
|
|
|||
|
|
@ -235,6 +235,14 @@ struct common_params_sampling {
|
|||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||
|
||||
// reasoning budget sampler parameters
|
||||
// these are populated by the server/CLI based on chat template params
|
||||
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
|
||||
bool reasoning_budget_activate_immediately = false;
|
||||
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
|
||||
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
|
||||
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)
|
||||
|
||||
bool backend_sampling = false;
|
||||
|
||||
bool has_logit_bias() const {
|
||||
|
|
@ -536,7 +544,9 @@ struct common_params {
|
|||
bool use_jinja = true; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable
|
||||
int reasoning_budget = -1;
|
||||
std::string reasoning_budget_message; // message injected before end tag when budget exhausted
|
||||
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
||||
int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,219 @@
|
|||
#include "reasoning-budget.h"
|
||||
#include "common.h"
|
||||
#include "unicode.h"
|
||||
|
||||
#include "log.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
struct token_matcher {
|
||||
std::vector<llama_token> tokens;
|
||||
size_t pos = 0;
|
||||
|
||||
bool advance(llama_token token) {
|
||||
if (tokens.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (token == tokens[pos]) {
|
||||
pos++;
|
||||
if (pos >= tokens.size()) {
|
||||
pos = 0;
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
pos = 0;
|
||||
if (token == tokens[0]) {
|
||||
pos = 1;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void reset() { pos = 0; }
|
||||
};
|
||||
|
||||
struct common_reasoning_budget_ctx {
|
||||
const llama_vocab * vocab;
|
||||
|
||||
token_matcher start_matcher;
|
||||
token_matcher end_matcher;
|
||||
std::vector<llama_token> forced_tokens;
|
||||
|
||||
int32_t budget; // maximum tokens in reasoning block
|
||||
int32_t remaining; // tokens remaining in budget
|
||||
|
||||
common_reasoning_budget_state state;
|
||||
|
||||
// for forcing
|
||||
size_t force_pos; // next position in forced_tokens to force
|
||||
};
|
||||
|
||||
static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) {
|
||||
return "reasoning-budget";
|
||||
}
|
||||
|
||||
static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) {
|
||||
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
|
||||
|
||||
switch (ctx->state) {
|
||||
case REASONING_BUDGET_IDLE:
|
||||
{
|
||||
if (ctx->start_matcher.advance(token)) {
|
||||
ctx->state = REASONING_BUDGET_COUNTING;
|
||||
ctx->remaining = ctx->budget;
|
||||
LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
|
||||
|
||||
if (ctx->remaining <= 0) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case REASONING_BUDGET_COUNTING:
|
||||
case REASONING_BUDGET_WAITING_UTF8:
|
||||
{
|
||||
if (ctx->end_matcher.advance(token)) {
|
||||
ctx->state = REASONING_BUDGET_DONE;
|
||||
LOG_INF("reasoning-budget: deactivated (natural end)\n");
|
||||
break;
|
||||
}
|
||||
|
||||
bool utf8_complete = true;
|
||||
if (ctx->vocab != nullptr) {
|
||||
const std::string piece = common_token_to_piece(ctx->vocab, token, false);
|
||||
utf8_complete = common_utf8_is_complete(piece);
|
||||
}
|
||||
|
||||
if (ctx->state == REASONING_BUDGET_WAITING_UTF8) {
|
||||
if (utf8_complete) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
|
||||
}
|
||||
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
|
||||
ctx->remaining--;
|
||||
if (ctx->remaining <= 0) {
|
||||
if (utf8_complete) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
|
||||
} else {
|
||||
ctx->state = REASONING_BUDGET_WAITING_UTF8;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case REASONING_BUDGET_FORCING:
|
||||
// force_pos is advanced in apply(), not here.
|
||||
// This ensures the first forced token isn't skipped when the sampler
|
||||
// is initialized directly in FORCING state (e.g. COUNTING + budget=0)
|
||||
break;
|
||||
case REASONING_BUDGET_DONE:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
|
||||
|
||||
if (ctx->state != REASONING_BUDGET_FORCING) {
|
||||
// passthrough — don't modify logits
|
||||
return;
|
||||
}
|
||||
|
||||
if (ctx->force_pos >= ctx->forced_tokens.size()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_token forced = ctx->forced_tokens[ctx->force_pos];
|
||||
|
||||
// set all logits to -inf except the forced token
|
||||
for (size_t i = 0; i < cur_p->size; i++) {
|
||||
if (cur_p->data[i].id != forced) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
// advance to next forced token (done here rather than in accept so that
|
||||
// the first forced token isn't skipped when starting in FORCING state)
|
||||
ctx->force_pos++;
|
||||
if (ctx->force_pos >= ctx->forced_tokens.size()) {
|
||||
ctx->state = REASONING_BUDGET_DONE;
|
||||
LOG_INF("reasoning-budget: forced sequence complete, done\n");
|
||||
}
|
||||
}
|
||||
|
||||
static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
|
||||
ctx->state = REASONING_BUDGET_IDLE;
|
||||
ctx->remaining = ctx->budget;
|
||||
ctx->start_matcher.reset();
|
||||
ctx->end_matcher.reset();
|
||||
ctx->force_pos = 0;
|
||||
}
|
||||
|
||||
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
|
||||
return common_reasoning_budget_init(
|
||||
ctx->vocab,
|
||||
ctx->start_matcher.tokens,
|
||||
ctx->end_matcher.tokens,
|
||||
ctx->forced_tokens,
|
||||
ctx->budget,
|
||||
ctx->state);
|
||||
}
|
||||
|
||||
static void common_reasoning_budget_free(struct llama_sampler * smpl) {
|
||||
delete (common_reasoning_budget_ctx *) smpl->ctx;
|
||||
}
|
||||
|
||||
static struct llama_sampler_i common_reasoning_budget_i = {
|
||||
/* .name = */ common_reasoning_budget_name,
|
||||
/* .accept = */ common_reasoning_budget_accept,
|
||||
/* .apply = */ common_reasoning_budget_apply,
|
||||
/* .reset = */ common_reasoning_budget_reset,
|
||||
/* .clone = */ common_reasoning_budget_clone,
|
||||
/* .free = */ common_reasoning_budget_free,
|
||||
/* .backend_init = */ nullptr,
|
||||
/* .backend_accept = */ nullptr,
|
||||
/* .backend_apply = */ nullptr,
|
||||
/* .backend_set_input = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state) {
|
||||
// promote COUNTING with budget <= 0 to FORCING
|
||||
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
|
||||
initial_state = REASONING_BUDGET_FORCING;
|
||||
}
|
||||
|
||||
return llama_sampler_init(
|
||||
/* .iface = */ &common_reasoning_budget_i,
|
||||
/* .ctx = */ new common_reasoning_budget_ctx {
|
||||
/* .vocab = */ vocab,
|
||||
/* .start_matcher = */ { start_tokens, 0 },
|
||||
/* .end_matcher = */ { end_tokens, 0 },
|
||||
/* .forced_tokens = */ forced_tokens,
|
||||
/* .budget = */ budget,
|
||||
/* .remaining = */ budget,
|
||||
/* .state = */ initial_state,
|
||||
/* .force_pos = */ 0,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
enum common_reasoning_budget_state {
|
||||
REASONING_BUDGET_IDLE, // waiting for start sequence
|
||||
REASONING_BUDGET_COUNTING, // counting down tokens
|
||||
REASONING_BUDGET_FORCING, // forcing budget message + end sequence
|
||||
REASONING_BUDGET_WAITING_UTF8, // budget exhausted, waiting for UTF-8 completion
|
||||
REASONING_BUDGET_DONE, // passthrough forever
|
||||
};
|
||||
|
||||
// Creates a reasoning budget sampler that limits token generation inside a
|
||||
// reasoning block (e.g. between <think> and </think>).
|
||||
//
|
||||
// State machine: IDLE -> COUNTING -> WAITING_UTF8 -> FORCING -> DONE
|
||||
// IDLE: passthrough, watching for start_tokens sequence
|
||||
// COUNTING: counting down remaining tokens, watching for natural end_tokens
|
||||
// WAITING_UTF8: budget exhausted, allowing tokens to complete a UTF-8 sequence
|
||||
// FORCING: forces forced_tokens token-by-token (all other logits -> -inf)
|
||||
// DONE: passthrough forever
|
||||
//
|
||||
// Parameters:
|
||||
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
|
||||
// start_tokens - token sequence that activates counting
|
||||
// end_tokens - token sequence for natural deactivation
|
||||
// forced_tokens - token sequence forced when budget expires
|
||||
// budget - max tokens allowed in the reasoning block
|
||||
// initial_state - initial state of the sampler (e.g. IDLE or COUNTING)
|
||||
// note: COUNTING with budget <= 0 is promoted to FORCING
|
||||
//
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state);
|
||||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "reasoning-budget.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
|
@ -250,6 +251,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
|||
}
|
||||
}
|
||||
|
||||
// reasoning budget sampler — added first so it can force tokens before other samplers
|
||||
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
|
||||
samplers.push_back(common_reasoning_budget_init(
|
||||
vocab,
|
||||
params.reasoning_budget_start,
|
||||
params.reasoning_budget_end,
|
||||
params.reasoning_budget_forced,
|
||||
params.reasoning_budget_tokens,
|
||||
params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE));
|
||||
}
|
||||
|
||||
if (params.has_logit_bias()) {
|
||||
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
#include "unicode.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// implementation adopted from src/unicode.cpp
|
||||
|
||||
|
|
@ -67,6 +69,20 @@ utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t off
|
|||
return utf8_parse_result(utf8_parse_result::INVALID);
|
||||
}
|
||||
|
||||
bool common_utf8_is_complete(const std::string & s) {
|
||||
if (s.empty()) {
|
||||
return true;
|
||||
}
|
||||
for (int i = 1; i <= std::min(4, (int)s.size()); i++) {
|
||||
unsigned char c = s[s.size() - i];
|
||||
if ((c & 0xC0) != 0x80) {
|
||||
int expected = (c >= 0xF0) ? 4 : (c >= 0xE0) ? 3 : (c >= 0xC0) ? 2 : 1;
|
||||
return i >= expected;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string common_unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
|
||||
std::string result;
|
||||
for (size_t i = 0; i < cps.size(); ++i) {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,9 @@ struct utf8_parse_result {
|
|||
// Returns 0 for invalid first bytes
|
||||
size_t common_utf8_sequence_length(unsigned char first_byte);
|
||||
|
||||
// Check if a string ends with a complete UTF-8 sequence.
|
||||
bool common_utf8_is_complete(const std::string & s);
|
||||
|
||||
// Parse a single UTF-8 codepoint from input
|
||||
utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset);
|
||||
|
||||
|
|
|
|||
|
|
@ -149,6 +149,7 @@ endif ()
|
|||
if (NOT WIN32 OR NOT BUILD_SHARED_LIBS)
|
||||
# these tests are disabled on Windows because they use internal functions not exported with LLAMA_API (when building with shared libraries)
|
||||
llama_build_and_test(test-sampling.cpp)
|
||||
llama_build_and_test(test-reasoning-budget.cpp)
|
||||
llama_build_and_test(test-grammar-parser.cpp)
|
||||
llama_build_and_test(test-grammar-integration.cpp)
|
||||
llama_build_and_test(test-llama-grammar.cpp)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,238 @@
|
|||
#include "reasoning-budget.h"
|
||||
#include "unicode.h"
|
||||
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef NDEBUG
|
||||
#undef NDEBUG
|
||||
#endif
|
||||
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// Reasoning budget sampler test helper
|
||||
// These tests use nullptr vocab which safely falls back to treating all tokens as complete
|
||||
// (The UTF-8 boundary detection logic is tested separately in test_utf8_boundary_detection)
|
||||
static void test_reasoning_budget(
|
||||
const char * test_name,
|
||||
const std::vector<llama_token> & sequence,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget,
|
||||
common_reasoning_budget_state initial_state,
|
||||
size_t expected_force_start, // token index where forcing should start (SIZE_MAX = never)
|
||||
size_t expected_force_end // token index where forcing should end (after this, no more forcing)
|
||||
) {
|
||||
// Find the maximum token ID to ensure our vocab covers all tokens
|
||||
llama_token max_token = 0;
|
||||
for (auto t : sequence) max_token = std::max(max_token, t);
|
||||
for (auto t : start_tokens) max_token = std::max(max_token, t);
|
||||
for (auto t : end_tokens) max_token = std::max(max_token, t);
|
||||
for (auto t : forced_tokens) max_token = std::max(max_token, t);
|
||||
|
||||
// Create a minimal sampler with mock vocabulary
|
||||
// For this test, we use nullptr as vocab since we're testing state transitions
|
||||
// The UTF-8 boundary check will treat all tokens as complete (safe fallback)
|
||||
auto * sampler = common_reasoning_budget_init(
|
||||
nullptr, // vocab - not used for basic state machine tests
|
||||
start_tokens,
|
||||
end_tokens,
|
||||
forced_tokens,
|
||||
budget,
|
||||
initial_state
|
||||
);
|
||||
|
||||
// Create a test token data array for checking forcing behavior
|
||||
// Vocab size must be large enough to include all tokens (start, end, forced, sequence)
|
||||
std::vector<llama_token_data> cur;
|
||||
const size_t n_vocab = (size_t)max_token + 1;
|
||||
for (size_t i = 0; i < n_vocab; i++) {
|
||||
cur.emplace_back(llama_token_data{(llama_token)i, logf((float)(i+1)), 0.0f});
|
||||
}
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||
|
||||
size_t actual_force_start = SIZE_MAX;
|
||||
size_t actual_force_end = SIZE_MAX;
|
||||
|
||||
// Feed the sequence and track when forcing occurs
|
||||
for (size_t i = 0; i < sequence.size(); i++) {
|
||||
llama_sampler_accept(sampler, sequence[i]);
|
||||
|
||||
// Check if we're in forcing state by applying and seeing if logits are modified
|
||||
cur_p.selected = -1;
|
||||
for (size_t j = 0; j < cur.size(); j++) {
|
||||
cur[j].logit = logf((float)(j+1)); // reset logits
|
||||
}
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
|
||||
// Check if forcing is active (all logits except one should be -INFINITY)
|
||||
size_t finite_count = 0;
|
||||
llama_token finite_token = -1;
|
||||
for (size_t j = 0; j < cur.size(); j++) {
|
||||
if (std::isfinite(cur[j].logit)) {
|
||||
finite_count++;
|
||||
finite_token = cur[j].id;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, " i=%zu: token=%d, finite_count=%zu, finite_token=%d\n", i, (int)sequence[i], finite_count, (int)finite_token);
|
||||
|
||||
if (finite_count == 1) {
|
||||
if (actual_force_start == SIZE_MAX) {
|
||||
actual_force_start = i;
|
||||
}
|
||||
actual_force_end = i;
|
||||
} else if (actual_force_start != SIZE_MAX && actual_force_end != SIZE_MAX) {
|
||||
// Forcing stopped
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
llama_sampler_free(sampler);
|
||||
|
||||
// Verify forcing occurred at expected positions
|
||||
if (expected_force_start == SIZE_MAX) {
|
||||
if (actual_force_start != SIZE_MAX) {
|
||||
fprintf(stderr, "Test '%s' FAILED: Expected no forcing, but forcing occurred at %zu\n", test_name, actual_force_start);
|
||||
GGML_ASSERT(false && "Expected no forcing, but forcing occurred");
|
||||
}
|
||||
} else {
|
||||
if (actual_force_start == SIZE_MAX) {
|
||||
fprintf(stderr, "Test '%s' FAILED: Expected forcing but none occurred\n", test_name);
|
||||
GGML_ASSERT(false && "Expected forcing but none occurred");
|
||||
}
|
||||
if (actual_force_start != expected_force_start) {
|
||||
fprintf(stderr, "Test '%s' FAILED: Forcing started at %zu, expected %zu\n", test_name, actual_force_start, expected_force_start);
|
||||
GGML_ASSERT(false && "Forcing started at wrong position");
|
||||
}
|
||||
}
|
||||
|
||||
if (expected_force_end != SIZE_MAX) {
|
||||
if (actual_force_end < expected_force_end) {
|
||||
fprintf(stderr, "Test '%s' FAILED: Forcing ended at %zu, expected >= %zu\n", test_name, actual_force_end, expected_force_end);
|
||||
GGML_ASSERT(false && "Forcing ended too early");
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, " Test '%s' passed (force_start=%zu, force_end=%zu)\n", test_name, actual_force_start, actual_force_end);
|
||||
(void)sequence;
|
||||
}
|
||||
|
||||
// UTF-8 boundary detection unit test
|
||||
// Tests common_utf8_is_complete() from reasoning-budget.h
|
||||
static void test_utf8_boundary_detection() {
|
||||
// Complete sequences
|
||||
GGML_ASSERT(common_utf8_is_complete("hello"));
|
||||
GGML_ASSERT(common_utf8_is_complete(""));
|
||||
GGML_ASSERT(common_utf8_is_complete("\xC2\xA0")); // complete 2-byte UTF-8 (U+00A0)
|
||||
GGML_ASSERT(common_utf8_is_complete("\xE2\x80\x9C")); // complete 3-byte UTF-8 (left double quote)
|
||||
GGML_ASSERT(common_utf8_is_complete("\xF0\x9F\x98\x80")); // complete 4-byte UTF-8 (emoji)
|
||||
GGML_ASSERT(common_utf8_is_complete("abc\xC3\xA9")); // ASCII + complete 2-byte
|
||||
|
||||
// Incomplete sequences
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xC2", 1))); // 2-byte start, missing continuation
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2\x80", 2))); // 3-byte start + 1 cont, missing 1
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2", 1))); // 3-byte start, missing 2
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F\x98", 3))); // 4-byte start + 2 cont, missing 1
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F", 2))); // 4-byte start + 1 cont, missing 2
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0", 1))); // 4-byte start, missing 3
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\x80", 1))); // orphan continuation byte
|
||||
|
||||
// Mixed: ASCII followed by start of multi-byte
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("hello\xC3", 6))); // ASCII + incomplete 2-byte
|
||||
GGML_ASSERT(common_utf8_is_complete(std::string("hello\xC3\xA9", 7))); // ASCII + complete 2-byte
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
// Reasoning budget sampler tests
|
||||
printf("Testing reasoning budget sampler... ");
|
||||
|
||||
// Test 1: Basic budget with start/end tokens - no forcing (natural end before budget exhausted)
|
||||
{
|
||||
const std::vector<llama_token> start = {100}; // start token
|
||||
const std::vector<llama_token> end = {101}; // end token
|
||||
const std::vector<llama_token> forced = {102}; // forced token (not used in this test)
|
||||
const std::vector<llama_token> sequence = {100, 50, 51, 101, 52}; // start, two tokens, end, one more
|
||||
|
||||
test_reasoning_budget("natural end before budget exhausted", sequence, start, end, forced,
|
||||
5, // budget of 5 tokens
|
||||
REASONING_BUDGET_IDLE,
|
||||
SIZE_MAX, SIZE_MAX); // no forcing expected (natural end)
|
||||
}
|
||||
|
||||
// Test 2: Budget exhausted, forcing should occur
|
||||
// Flow: i=0 accept(100)->COUNTING, i=1 accept(50)->remaining=1, i=2 accept(51)->remaining=0->FORCING
|
||||
// Forcing is active at i=2 and i=3 (when apply() is called while in FORCING state)
|
||||
// At i=4, force_pos becomes 2 which equals forced_tokens.size(), so state becomes DONE
|
||||
{
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
const std::vector<llama_token> forced = {102, 101}; // forced message + end
|
||||
const std::vector<llama_token> sequence = {100, 50, 51, 52, 53}; // start + 4 tokens (budget=2)
|
||||
|
||||
test_reasoning_budget("budget exhausted forcing", sequence, start, end, forced,
|
||||
2, // budget of 2 tokens
|
||||
REASONING_BUDGET_IDLE,
|
||||
2, // forcing starts at i=2 (after accept(51) depletes budget, apply() forces)
|
||||
3); // forcing continues through i=3 (at i=4 state becomes DONE)
|
||||
}
|
||||
|
||||
// Test 3: Activate immediately with budget=0, forcing should start right away
|
||||
// Flow: Since no start token in sequence, state stays IDLE (no start/end configured means passthrough)
|
||||
// This test needs start token to be in the sequence or use activate_immediately with start token present
|
||||
{
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
const std::vector<llama_token> forced = {102, 101};
|
||||
const std::vector<llama_token> sequence = {100, 50, 51, 52}; // start token first, then 3 tokens
|
||||
|
||||
test_reasoning_budget("activate immediately budget=0", sequence, start, end, forced,
|
||||
0, // budget of 0 tokens
|
||||
REASONING_BUDGET_COUNTING, // starts counting, promoted to FORCING since budget=0
|
||||
0, // forcing starts at i=0 (after accept(100), budget=0 goes straight to FORCING)
|
||||
1); // forcing continues through i=1 (at i=2 state becomes DONE)
|
||||
}
|
||||
|
||||
// Test 4: No start/end tokens configured - passthrough (no forcing)
|
||||
{
|
||||
const std::vector<llama_token> start = {};
|
||||
const std::vector<llama_token> end = {};
|
||||
const std::vector<llama_token> forced = {102};
|
||||
const std::vector<llama_token> sequence = {50, 51, 52, 53};
|
||||
|
||||
test_reasoning_budget("no start/end configured", sequence, start, end, forced,
|
||||
2, // budget
|
||||
REASONING_BUDGET_IDLE,
|
||||
SIZE_MAX, SIZE_MAX); // no forcing (no start/end configured)
|
||||
}
|
||||
|
||||
// Test 5: Activate immediately with budget > 0, count down then force
|
||||
// Flow: i=0 accept(50)->remaining=1, i=1 accept(51)->remaining=0->FORCING
|
||||
// So forcing starts at i=1 (apply after accept sees FORCING with force_pos=0)
|
||||
{
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
const std::vector<llama_token> forced = {102, 101};
|
||||
const std::vector<llama_token> sequence = {50, 51, 52, 53};
|
||||
|
||||
test_reasoning_budget("activate immediately with budget", sequence, start, end, forced,
|
||||
2, // budget of 2 tokens
|
||||
REASONING_BUDGET_COUNTING,
|
||||
1, // forcing starts at i=1 (after 2 accepts deplete budget)
|
||||
2); // forcing continues through i=2
|
||||
}
|
||||
|
||||
printf("OK (5 tests passed)\n");
|
||||
|
||||
printf("Testing UTF-8 boundary detection... ");
|
||||
test_utf8_boundary_detection();
|
||||
printf("OK\n");
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -57,6 +57,8 @@ struct cli_context {
|
|||
std::vector<raw_buffer> input_files;
|
||||
task_params defaults;
|
||||
bool verbose_prompt;
|
||||
int reasoning_budget = -1;
|
||||
std::string reasoning_budget_message;
|
||||
|
||||
// thread for showing "loading" animation
|
||||
std::atomic<bool> loading_show;
|
||||
|
|
@ -73,6 +75,8 @@ struct cli_context {
|
|||
// defaults.return_progress = true; // TODO: show progress
|
||||
|
||||
verbose_prompt = params.verbose_prompt;
|
||||
reasoning_budget = params.reasoning_budget;
|
||||
reasoning_budget_message = params.reasoning_budget_message;
|
||||
}
|
||||
|
||||
std::string generate_completion(result_timings & out_timings) {
|
||||
|
|
@ -95,6 +99,24 @@ struct cli_context {
|
|||
task.params.chat_parser_params.parser.load(chat_params.parser);
|
||||
}
|
||||
|
||||
// reasoning budget sampler
|
||||
if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(
|
||||
llama_get_model(ctx_server.get_llama_context()));
|
||||
|
||||
task.params.sampling.reasoning_budget_tokens = reasoning_budget;
|
||||
task.params.sampling.reasoning_budget_activate_immediately = chat_params.thinking_forced_open;
|
||||
|
||||
if (!chat_params.thinking_start_tag.empty()) {
|
||||
task.params.sampling.reasoning_budget_start =
|
||||
common_tokenize(vocab, chat_params.thinking_start_tag, false, true);
|
||||
}
|
||||
task.params.sampling.reasoning_budget_end =
|
||||
common_tokenize(vocab, chat_params.thinking_end_tag, false, true);
|
||||
task.params.sampling.reasoning_budget_forced =
|
||||
common_tokenize(vocab, reasoning_budget_message + chat_params.thinking_end_tag, false, true);
|
||||
}
|
||||
|
||||
rd.post_task({std::move(task)});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1101,6 +1101,22 @@ json oaicompat_chat_params_parse(
|
|||
llama_params["chat_parser"] = chat_params.parser;
|
||||
}
|
||||
|
||||
// Reasoning budget: pass parameters through to sampling layer
|
||||
{
|
||||
int reasoning_budget = opt.reasoning_budget;
|
||||
if (reasoning_budget == -1 && body.contains("thinking_budget_tokens")) {
|
||||
reasoning_budget = json_value(body, "thinking_budget_tokens", -1);
|
||||
}
|
||||
|
||||
if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) {
|
||||
llama_params["reasoning_budget_tokens"] = reasoning_budget;
|
||||
llama_params["reasoning_budget_start_tag"] = chat_params.thinking_start_tag;
|
||||
llama_params["reasoning_budget_end_tag"] = chat_params.thinking_end_tag;
|
||||
llama_params["reasoning_budget_message"] = opt.reasoning_budget_message;
|
||||
llama_params["reasoning_budget_activate_immediately"] = chat_params.thinking_forced_open;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle "logprobs" field
|
||||
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
|
||||
if (json_value(body, "logprobs", false)) {
|
||||
|
|
|
|||
|
|
@ -287,6 +287,8 @@ struct server_chat_params {
|
|||
bool allow_image;
|
||||
bool allow_audio;
|
||||
bool enable_thinking = true;
|
||||
int reasoning_budget = -1;
|
||||
std::string reasoning_budget_message;
|
||||
std::string media_path;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -893,9 +893,10 @@ private:
|
|||
}
|
||||
|
||||
// thinking is enabled if:
|
||||
// 1. It's not explicitly disabled (reasoning_budget == 0)
|
||||
// 1. It's not explicitly disabled via --reasoning off
|
||||
// 2. The chat template supports it
|
||||
const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
|
||||
const bool template_supports_thinking = params_base.use_jinja && common_chat_templates_support_enable_thinking(chat_templates.get());
|
||||
const bool enable_thinking = params_base.enable_reasoning != 0 && template_supports_thinking;
|
||||
SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking);
|
||||
|
||||
chat_params = {
|
||||
|
|
@ -907,6 +908,8 @@ private:
|
|||
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
|
||||
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
|
||||
/* enable_thinking */ enable_thinking,
|
||||
/* reasoning_budget */ params_base.reasoning_budget,
|
||||
/* reasoning_budget_msg */ params_base.reasoning_budget_message,
|
||||
/* media_path */ params_base.media_path,
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -462,6 +462,34 @@ task_params server_task::params_from_json_cmpl(
|
|||
}
|
||||
}
|
||||
|
||||
// Parse reasoning budget sampler parameters
|
||||
{
|
||||
const int32_t budget = json_value(data, "reasoning_budget_tokens", (int32_t) -1);
|
||||
if (budget >= 0) {
|
||||
const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string());
|
||||
const auto end_tag = json_value(data, "reasoning_budget_end_tag", std::string());
|
||||
const auto message = json_value(data, "reasoning_budget_message", std::string());
|
||||
const bool activate_imm = json_value(data, "reasoning_budget_activate_immediately", false);
|
||||
|
||||
params.sampling.reasoning_budget_tokens = budget;
|
||||
params.sampling.reasoning_budget_activate_immediately = activate_imm;
|
||||
|
||||
if (!start_tag.empty()) {
|
||||
params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true);
|
||||
}
|
||||
if (!end_tag.empty()) {
|
||||
params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true);
|
||||
params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true);
|
||||
}
|
||||
|
||||
SRV_DBG("reasoning budget: tokens=%d, activate_immediately=%s, start=%zu toks, end=%zu toks, forced=%zu toks\n",
|
||||
budget, activate_imm ? "true" : "false",
|
||||
params.sampling.reasoning_budget_start.size(),
|
||||
params.sampling.reasoning_budget_end.size(),
|
||||
params.sampling.reasoning_budget_forced.size());
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
params.sampling.logit_bias.clear();
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue