This commit is contained in:
Piotr Wilkin (ilintar) 2026-03-23 00:43:04 +00:00 committed by GitHub
commit 8e57e66a4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 258 additions and 5 deletions

View File

@ -483,6 +483,13 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
}
}
void common_sampler_set_grammar_trigger_suppressed(struct common_sampler * gsmpl, bool suppressed) {
if (!gsmpl || !gsmpl->grmr) {
return;
}
llama_sampler_grammar_set_trigger_suppressed(gsmpl->grmr, suppressed);
}
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
if (!gsmpl) {
return nullptr;

View File

@ -87,6 +87,10 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
// suppress or un-suppress grammar trigger detection (e.g. during reasoning/thinking blocks)
// when suppressed, the grammar still buffers tokens but does not check for triggers
void common_sampler_set_grammar_trigger_suppressed(struct common_sampler * gsmpl, bool suppressed);
// helpers
// access the internal list of current candidate tokens

View File

@ -1380,6 +1380,13 @@ extern "C" {
const llama_token * trigger_tokens,
size_t num_trigger_tokens);
/// @details Suppress or un-suppress trigger detection on a grammar sampler.
/// When suppressed, the grammar still buffers tokens but does not check for triggers.
/// Useful for suppressing grammar activation during reasoning/thinking blocks.
/// No-op if the sampler is not a grammar sampler.
LLAMA_API void llama_sampler_grammar_set_trigger_suppressed(
struct llama_sampler * smpl,
bool suppressed);
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(

View File

@ -1185,6 +1185,7 @@ struct llama_grammar * llama_grammar_init_impl(
/* .partial_utf8 = */ {},
/* .lazy = */ false,
/* .awaiting_trigger = */ false,
/* .trigger_suppressed = */ false,
/* .trigger_buffer = */ "",
/* .trigger_buffer_positions = */ {},
/* .trigger_tokens = */ {},
@ -1291,6 +1292,7 @@ struct llama_grammar * llama_grammar_init_impl(
/* .partial_utf8 = */ {},
/* .lazy = */ lazy,
/* .awaiting_trigger = */ lazy,
/* .trigger_suppressed = */ false,
/* .trigger_buffer = */ "",
/* .trigger_buffer_positions = */ {},
std::move(vec_trigger_tokens),
@ -1314,6 +1316,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
grammar.partial_utf8,
grammar.lazy,
grammar.awaiting_trigger,
grammar.trigger_suppressed,
grammar.trigger_buffer,
grammar.trigger_buffer_positions,
grammar.trigger_tokens,
@ -1385,6 +1388,15 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
const auto & piece = grammar.vocab->token_to_piece(token);
if (grammar.awaiting_trigger) {
// When trigger is suppressed (e.g. during reasoning), still buffer tokens but skip trigger detection
if (grammar.trigger_suppressed) {
auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size());
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
grammar.trigger_buffer += piece;
LLAMA_LOG_DEBUG("Grammar trigger suppressed, buffering token %d (`%s`)\n", token, piece.c_str());
return;
}
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
grammar.awaiting_trigger = false;
grammar.trigger_buffer.clear();

View File

@ -141,6 +141,7 @@ struct llama_grammar {
// (useful e.g. for tool_choice=required)
bool lazy = false;
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
bool trigger_suppressed = false; // When true, trigger detection is suppressed (e.g. during reasoning)
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
std::vector<token_pos> trigger_buffer_positions; // Tokens buffered by lazy grammar. Used to replay when a trigger is found.
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).

View File

@ -2529,6 +2529,16 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
/* .backend_set_input = */ nullptr,
};
void llama_sampler_grammar_set_trigger_suppressed(struct llama_sampler * smpl, bool suppressed) {
if (!smpl || smpl->iface != &llama_sampler_grammar_i) {
return;
}
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
if (ctx->grammar) {
ctx->grammar->trigger_suppressed = suppressed;
}
}
static struct llama_sampler * llama_sampler_init_grammar_impl(
const struct llama_vocab * vocab,
const char * grammar_str,

View File

@ -936,7 +936,71 @@ static void test_peg_parser(common_chat_templates * tmpls,
throw std::runtime_error("Failed to build grammar: " + parser.params_.grammar);
}
// Find the earliest trigger position to determine the constrained portion
// Determine reasoning regions in tc.input so we can suppress grammar triggers inside them.
// A reasoning region spans from thinking_start_tag to thinking_end_tag.
// If generation_prompt contains the start tag (without a matching end), reasoning starts
// before tc.input, so position 0 is already inside reasoning.
std::vector<std::pair<size_t, size_t>> reasoning_regions; // [start, end) in tc.input
{
const auto & start_tag = parser.params_.thinking_start_tag;
const auto & end_tag = parser.params_.thinking_end_tag;
if (!end_tag.empty()) {
// check if generation_prompt puts us inside reasoning at the start of tc.input
bool in_reasoning = false;
if (!start_tag.empty()) {
const auto & gen_prompt = parser.params_.generation_prompt;
auto last_start = gen_prompt.rfind(start_tag);
if (last_start != std::string::npos) {
auto last_end = gen_prompt.rfind(end_tag);
if (last_end == std::string::npos || last_end < last_start) {
in_reasoning = true;
}
}
}
size_t search_from = 0;
size_t region_start = in_reasoning ? 0 : std::string::npos;
while (search_from < tc.input.size()) {
if (in_reasoning) {
auto end_pos = tc.input.find(end_tag, search_from);
if (end_pos != std::string::npos) {
reasoning_regions.push_back({region_start, end_pos + end_tag.size()});
search_from = end_pos + end_tag.size();
in_reasoning = false;
} else {
// reasoning extends to end of input
reasoning_regions.push_back({region_start, tc.input.size()});
break;
}
} else if (!start_tag.empty()) {
auto start_pos = tc.input.find(start_tag, search_from);
if (start_pos != std::string::npos) {
region_start = start_pos;
search_from = start_pos + start_tag.size();
in_reasoning = true;
} else {
break;
}
} else {
break;
}
}
}
}
// Helper: check if a position falls inside any reasoning region
auto is_in_reasoning = [&reasoning_regions](size_t pos) -> bool {
for (const auto & [start, end] : reasoning_regions) {
if (pos >= start && pos < end) {
return true;
}
}
return false;
};
// Find the earliest trigger position to determine the constrained portion,
// skipping triggers that fall inside reasoning regions.
auto earliest_trigger_pos = std::string::npos;
for (const auto & trigger : parser.params_.grammar_triggers) {
size_t pos = std::string::npos;
@ -945,14 +1009,34 @@ static void test_peg_parser(common_chat_templates * tmpls,
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{
const auto & word = trigger.value;
pos = tc.input.find(word);
// find first occurrence outside reasoning
size_t search_from = 0;
while (search_from < tc.input.size()) {
auto found = tc.input.find(word, search_from);
if (found == std::string::npos) {
break;
}
if (!is_in_reasoning(found)) {
pos = found;
break;
}
search_from = found + 1;
}
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
{
const auto & pattern = std::regex(trigger.value);
if (std::regex_search(tc.input, match, pattern)) {
pos = match.position(pattern.mark_count());
auto search_str = tc.input;
size_t offset = 0;
while (std::regex_search(search_str, match, pattern)) {
auto found = offset + match.position(pattern.mark_count());
if (!is_in_reasoning(found)) {
pos = found;
break;
}
offset += match.position(0) + match.length(0);
search_str = tc.input.substr(offset);
}
break;
}
@ -970,7 +1054,11 @@ static void test_peg_parser(common_chat_templates * tmpls,
if (mpos == std::string::npos) {
mpos = match.position(0);
}
pos = mpos;
// PATTERN_FULL matches the entire input, so if the match position
// is in reasoning, skip it entirely
if (!is_in_reasoning(mpos)) {
pos = mpos;
}
}
break;
}
@ -1425,6 +1513,50 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect_reasoning("I need to output the invoice details in JSON")
.expect_content(R"({"amount": 123.45, "date": "2025-12-03"})")
.run();
// tool call segment in reasoning
tst.test(
"Let's call a tool: <tool_call>\n"
"<function=python>\n"
"<parameter=code>\n"
"def hello():\n"
" print(\"Hello, world!\")\n"
"\n"
"hello()\n"
"</parameter>\n"
"</function>\n"
"</tool_call></think>\n"
"<tool_call>\n"
"<function=python>\n"
"<parameter=code>\n"
"def hello():\n"
" print(\"Hello, world!\")\n"
"\n"
"hello()\n"
"</parameter>\n"
"</function>\n"
"</tool_call>"
)
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({
python_tool
})
.expect_reasoning("Let's call a tool: <tool_call>\n"
"<function=python>\n"
"<parameter=code>\n"
"def hello():\n"
" print(\"Hello, world!\")\n"
"\n"
"hello()\n"
"</parameter>\n"
"</function>\n"
"</tool_call>")
.expect_tool_calls({
{ "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} },
})
.run();
}
{

View File

@ -1103,6 +1103,13 @@ json oaicompat_chat_params_parse(
llama_params["chat_parser"] = chat_params.parser;
}
// Always pass thinking tags so the slot can track reasoning state
// (used to suppress grammar triggers during reasoning blocks)
if (!chat_params.thinking_end_tag.empty()) {
llama_params["thinking_start_tag"] = chat_params.thinking_start_tag;
llama_params["thinking_end_tag"] = chat_params.thinking_end_tag;
}
// Reasoning budget: pass parameters through to sampling layer
{
int reasoning_budget = opt.reasoning_budget;

View File

@ -92,6 +92,8 @@ struct server_slot {
bool has_next_token = true;
bool has_new_line = false;
bool truncated = false;
bool in_reasoning = false; // true when inside a thinking/reasoning block
llama_token thinking_end_first_token = LLAMA_TOKEN_NULL; // first token of thinking end tag (for EOG interception)
stop_type stop;
@ -173,6 +175,8 @@ struct server_slot {
generated_text = "";
has_new_line = false;
truncated = false;
in_reasoning = false;
thinking_end_first_token = LLAMA_TOKEN_NULL;
stop = STOP_TYPE_NONE;
stopping_word = "";
n_sent_text = 0;
@ -1181,6 +1185,29 @@ private:
}
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
// determine initial reasoning state from generation prompt
// if the generation prompt ends inside a thinking block, suppress grammar triggers initially
if (!task.params.thinking_end_tag.empty()) {
const auto & gen_prompt = task.params.sampling.generation_prompt;
const auto & start_tag = task.params.thinking_start_tag;
const auto & end_tag = task.params.thinking_end_tag;
// tokenize the thinking end tag so we can intercept EOG during reasoning
auto end_tag_tokens = common_tokenize(ctx, end_tag, false, true);
if (!end_tag_tokens.empty()) {
slot.thinking_end_first_token = end_tag_tokens[0];
}
auto last_start = start_tag.empty() ? std::string::npos : gen_prompt.rfind(start_tag);
auto last_end = gen_prompt.rfind(end_tag);
if (last_start != std::string::npos
&& (last_end == std::string::npos || last_end < last_start)) {
slot.in_reasoning = true;
common_sampler_set_grammar_trigger_suppressed(slot.smpl.get(), true);
SLT_DBG(slot, "starting in reasoning state, grammar triggers suppressed\n%s", "");
}
}
} else {
slot.smpl.reset();
}
@ -1209,6 +1236,34 @@ private:
}
slot.has_next_token = true;
// update reasoning state and propagate to grammar trigger suppression
if (!slot.task->params.thinking_end_tag.empty() && slot.smpl) {
const auto & end_tag = slot.task->params.thinking_end_tag;
const auto & start_tag = slot.task->params.thinking_start_tag;
if (slot.in_reasoning) {
// check if the end tag just appeared at the end of generated_text
if (slot.generated_text.size() >= end_tag.size()
&& slot.generated_text.compare(
slot.generated_text.size() - end_tag.size(),
end_tag.size(), end_tag) == 0) {
slot.in_reasoning = false;
common_sampler_set_grammar_trigger_suppressed(slot.smpl.get(), false);
SLT_DBG(slot, "reasoning ended, grammar triggers un-suppressed\n%s", "");
}
} else {
// check if the start tag just appeared at the end of generated_text
if (!start_tag.empty()
&& slot.generated_text.size() >= start_tag.size()
&& slot.generated_text.compare(
slot.generated_text.size() - start_tag.size(),
start_tag.size(), start_tag) == 0) {
slot.in_reasoning = true;
common_sampler_set_grammar_trigger_suppressed(slot.smpl.get(), true);
SLT_DBG(slot, "reasoning started, grammar triggers suppressed\n%s", "");
}
}
}
// check if there is incomplete UTF-8 character at the end
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
@ -2835,6 +2890,16 @@ private:
llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
// if the model emits EOG while still inside a reasoning block,
// force the first token of the thinking end tag instead
if (slot.in_reasoning
&& slot.thinking_end_first_token != LLAMA_TOKEN_NULL
&& llama_vocab_is_eog(vocab, id)) {
SLT_DBG(slot, "intercepted EOG during reasoning, forcing thinking end token %d\n",
slot.thinking_end_first_token);
id = slot.thinking_end_first_token;
}
slot.i_batch = -1;
common_sampler_accept(slot.smpl.get(), id, true);

View File

@ -500,6 +500,10 @@ task_params server_task::params_from_json_cmpl(
}
}
// Parse thinking tags for reasoning state tracking (used to suppress grammar triggers during reasoning)
params.thinking_start_tag = json_value(data, "thinking_start_tag", std::string());
params.thinking_end_tag = json_value(data, "thinking_end_tag", std::string());
{
params.sampling.logit_bias.clear();

View File

@ -83,6 +83,10 @@ struct task_params {
// per-request parameters for chat parsing
common_chat_parser_params chat_parser_params;
// thinking/reasoning tags for tracking reasoning state in the slot
std::string thinking_start_tag;
std::string thinking_end_tag;
// Embeddings
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)