From 982cf3b6a28644e8c53976930a2987d4639a235f Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Sat, 21 Mar 2026 20:16:27 +0100 Subject: [PATCH 1/2] Do not trigger grammar inside tool calling section + force thinking close on premature EOG inside thinking. --- common/sampling.cpp | 10 +++ common/sampling.h | 4 + src/llama-grammar.cpp | 12 +++ src/llama-grammar.h | 1 + src/llama-sampler.cpp | 10 +++ src/llama-sampler.h | 4 + tests/test-chat.cpp | 142 ++++++++++++++++++++++++++++++-- tools/server/server-common.cpp | 7 ++ tools/server/server-context.cpp | 66 +++++++++++++++ tools/server/server-task.cpp | 4 + tools/server/server-task.h | 4 + 11 files changed, 259 insertions(+), 5 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 012e212660..c3194ba3f0 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -483,6 +483,16 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam } } +// forward declaration of internal function (defined in llama-sampler.cpp) +void llama_sampler_grammar_set_trigger_suppressed(struct llama_sampler * smpl, bool suppressed); + +void common_sampler_set_grammar_trigger_suppressed(struct common_sampler * gsmpl, bool suppressed) { + if (!gsmpl || !gsmpl->grmr) { + return; + } + llama_sampler_grammar_set_trigger_suppressed(gsmpl->grmr, suppressed); +} + struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) { if (!gsmpl) { return nullptr; diff --git a/common/sampling.h b/common/sampling.h index 5b57ad6581..543e47f5cc 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -87,6 +87,10 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); +// suppress or un-suppress grammar trigger detection (e.g. during reasoning/thinking blocks) +// when suppressed, the grammar still buffers tokens but does not check for triggers +void common_sampler_set_grammar_trigger_suppressed(struct common_sampler * gsmpl, bool suppressed); + // helpers // access the internal list of current candidate tokens diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index badcbfd0fb..a1a8b2db54 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1185,6 +1185,7 @@ struct llama_grammar * llama_grammar_init_impl( /* .partial_utf8 = */ {}, /* .lazy = */ false, /* .awaiting_trigger = */ false, + /* .trigger_suppressed = */ false, /* .trigger_buffer = */ "", /* .trigger_buffer_positions = */ {}, /* .trigger_tokens = */ {}, @@ -1291,6 +1292,7 @@ struct llama_grammar * llama_grammar_init_impl( /* .partial_utf8 = */ {}, /* .lazy = */ lazy, /* .awaiting_trigger = */ lazy, + /* .trigger_suppressed = */ false, /* .trigger_buffer = */ "", /* .trigger_buffer_positions = */ {}, std::move(vec_trigger_tokens), @@ -1314,6 +1316,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra grammar.partial_utf8, grammar.lazy, grammar.awaiting_trigger, + grammar.trigger_suppressed, grammar.trigger_buffer, grammar.trigger_buffer_positions, grammar.trigger_tokens, @@ -1385,6 +1388,15 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token const auto & piece = grammar.vocab->token_to_piece(token); if (grammar.awaiting_trigger) { + // When trigger is suppressed (e.g. during reasoning), still buffer tokens but skip trigger detection + if (grammar.trigger_suppressed) { + auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size()); + grammar.trigger_buffer_positions.push_back(std::make_pair(token, position)); + grammar.trigger_buffer += piece; + LLAMA_LOG_DEBUG("Grammar trigger suppressed, buffering token %d (`%s`)\n", token, piece.c_str()); + return; + } + if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { grammar.awaiting_trigger = false; grammar.trigger_buffer.clear(); diff --git a/src/llama-grammar.h b/src/llama-grammar.h index b5a0e588e9..c8179a203b 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -141,6 +141,7 @@ struct llama_grammar { // (useful e.g. for tool_choice=required) bool lazy = false; bool awaiting_trigger = false; // Initialized to true for lazy grammars only + bool trigger_suppressed = false; // When true, trigger detection is suppressed (e.g. during reasoning) std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. std::vector trigger_buffer_positions; // Tokens buffered by lazy grammar. Used to replay when a trigger is found. std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). diff --git a/src/llama-sampler.cpp b/src/llama-sampler.cpp index 9bbc5dbde2..4bad507754 100644 --- a/src/llama-sampler.cpp +++ b/src/llama-sampler.cpp @@ -2529,6 +2529,16 @@ static struct llama_sampler_i llama_sampler_grammar_i = { /* .backend_set_input = */ nullptr, }; +void llama_sampler_grammar_set_trigger_suppressed(struct llama_sampler * smpl, bool suppressed) { + if (!smpl || smpl->iface != &llama_sampler_grammar_i) { + return; + } + auto * ctx = (llama_sampler_grammar *) smpl->ctx; + if (ctx->grammar) { + ctx->grammar->trigger_suppressed = suppressed; + } +} + static struct llama_sampler * llama_sampler_init_grammar_impl( const struct llama_vocab * vocab, const char * grammar_str, diff --git a/src/llama-sampler.h b/src/llama-sampler.h index b9bfc20d25..11ad399fb8 100644 --- a/src/llama-sampler.h +++ b/src/llama-sampler.h @@ -33,6 +33,10 @@ struct llama_sampler_chain { mutable int32_t n_sample; }; +// set trigger_suppressed on a grammar sampler (e.g. to suppress triggers during reasoning) +// the sampler must have been created by llama_sampler_init_grammar* or this is a no-op +void llama_sampler_grammar_set_trigger_suppressed(struct llama_sampler * smpl, bool suppressed); + struct llama_sampler * llama_sampler_init_dry_testing( int32_t context_size, float dry_multiplier, diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 575d240791..bdfc326590 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -936,7 +936,71 @@ static void test_peg_parser(common_chat_templates * tmpls, throw std::runtime_error("Failed to build grammar: " + parser.params_.grammar); } - // Find the earliest trigger position to determine the constrained portion + // Determine reasoning regions in tc.input so we can suppress grammar triggers inside them. + // A reasoning region spans from thinking_start_tag to thinking_end_tag. + // If generation_prompt contains the start tag (without a matching end), reasoning starts + // before tc.input, so position 0 is already inside reasoning. + std::vector> reasoning_regions; // [start, end) in tc.input + { + const auto & start_tag = parser.params_.thinking_start_tag; + const auto & end_tag = parser.params_.thinking_end_tag; + if (!end_tag.empty()) { + // check if generation_prompt puts us inside reasoning at the start of tc.input + bool in_reasoning = false; + if (!start_tag.empty()) { + const auto & gen_prompt = parser.params_.generation_prompt; + auto last_start = gen_prompt.rfind(start_tag); + if (last_start != std::string::npos) { + auto last_end = gen_prompt.rfind(end_tag); + if (last_end == std::string::npos || last_end < last_start) { + in_reasoning = true; + } + } + } + + size_t search_from = 0; + size_t region_start = in_reasoning ? 0 : std::string::npos; + + while (search_from < tc.input.size()) { + if (in_reasoning) { + auto end_pos = tc.input.find(end_tag, search_from); + if (end_pos != std::string::npos) { + reasoning_regions.push_back({region_start, end_pos + end_tag.size()}); + search_from = end_pos + end_tag.size(); + in_reasoning = false; + } else { + // reasoning extends to end of input + reasoning_regions.push_back({region_start, tc.input.size()}); + break; + } + } else if (!start_tag.empty()) { + auto start_pos = tc.input.find(start_tag, search_from); + if (start_pos != std::string::npos) { + region_start = start_pos; + search_from = start_pos + start_tag.size(); + in_reasoning = true; + } else { + break; + } + } else { + break; + } + } + } + } + + // Helper: check if a position falls inside any reasoning region + auto is_in_reasoning = [&reasoning_regions](size_t pos) -> bool { + for (const auto & [start, end] : reasoning_regions) { + if (pos >= start && pos < end) { + return true; + } + } + return false; + }; + + // Find the earliest trigger position to determine the constrained portion, + // skipping triggers that fall inside reasoning regions. auto earliest_trigger_pos = std::string::npos; for (const auto & trigger : parser.params_.grammar_triggers) { size_t pos = std::string::npos; @@ -945,14 +1009,34 @@ static void test_peg_parser(common_chat_templates * tmpls, case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: { const auto & word = trigger.value; - pos = tc.input.find(word); + // find first occurrence outside reasoning + size_t search_from = 0; + while (search_from < tc.input.size()) { + auto found = tc.input.find(word, search_from); + if (found == std::string::npos) { + break; + } + if (!is_in_reasoning(found)) { + pos = found; + break; + } + search_from = found + 1; + } break; } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: { const auto & pattern = std::regex(trigger.value); - if (std::regex_search(tc.input, match, pattern)) { - pos = match.position(pattern.mark_count()); + auto search_str = tc.input; + size_t offset = 0; + while (std::regex_search(search_str, match, pattern)) { + auto found = offset + match.position(pattern.mark_count()); + if (!is_in_reasoning(found)) { + pos = found; + break; + } + offset += match.position(0) + match.length(0); + search_str = tc.input.substr(offset); } break; } @@ -970,7 +1054,11 @@ static void test_peg_parser(common_chat_templates * tmpls, if (mpos == std::string::npos) { mpos = match.position(0); } - pos = mpos; + // PATTERN_FULL matches the entire input, so if the match position + // is in reasoning, skip it entirely + if (!is_in_reasoning(mpos)) { + pos = mpos; + } } break; } @@ -1425,6 +1513,50 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_reasoning("I need to output the invoice details in JSON") .expect_content(R"({"amount": 123.45, "date": "2025-12-03"})") .run(); + + // tool call segment in reasoning + tst.test( + "Let's call a tool: \n" + "\n" + "\n" + "def hello():\n" + " print(\"Hello, world!\")\n" + "\n" + "hello()\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "def hello():\n" + " print(\"Hello, world!\")\n" + "\n" + "hello()\n" + "\n" + "\n" + "" + ) + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ + python_tool + }) + .expect_reasoning("Let's call a tool: \n" + "\n" + "\n" + "def hello():\n" + " print(\"Hello, world!\")\n" + "\n" + "hello()\n" + "\n" + "\n" + "") + .expect_tool_calls({ + { "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} }, + }) + .run(); + } { diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index e01c8c53df..8ed1f0287e 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1103,6 +1103,13 @@ json oaicompat_chat_params_parse( llama_params["chat_parser"] = chat_params.parser; } + // Always pass thinking tags so the slot can track reasoning state + // (used to suppress grammar triggers during reasoning blocks) + if (!chat_params.thinking_end_tag.empty()) { + llama_params["thinking_start_tag"] = chat_params.thinking_start_tag; + llama_params["thinking_end_tag"] = chat_params.thinking_end_tag; + } + // Reasoning budget: pass parameters through to sampling layer { int reasoning_budget = opt.reasoning_budget; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 9de554e900..ca4d15ec88 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -92,6 +92,8 @@ struct server_slot { bool has_next_token = true; bool has_new_line = false; bool truncated = false; + bool in_reasoning = false; // true when inside a thinking/reasoning block + llama_token thinking_end_first_token = LLAMA_TOKEN_NULL; // first token of thinking end tag (for EOG interception) stop_type stop; @@ -173,6 +175,8 @@ struct server_slot { generated_text = ""; has_new_line = false; truncated = false; + in_reasoning = false; + thinking_end_first_token = LLAMA_TOKEN_NULL; stop = STOP_TYPE_NONE; stopping_word = ""; n_sent_text = 0; @@ -1181,6 +1185,29 @@ private: } SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); + + // determine initial reasoning state from generation prompt + // if the generation prompt ends inside a thinking block, suppress grammar triggers initially + if (!task.params.thinking_end_tag.empty()) { + const auto & gen_prompt = task.params.sampling.generation_prompt; + const auto & start_tag = task.params.thinking_start_tag; + const auto & end_tag = task.params.thinking_end_tag; + + // tokenize the thinking end tag so we can intercept EOG during reasoning + auto end_tag_tokens = common_tokenize(ctx, end_tag, false, true); + if (!end_tag_tokens.empty()) { + slot.thinking_end_first_token = end_tag_tokens[0]; + } + + auto last_start = start_tag.empty() ? std::string::npos : gen_prompt.rfind(start_tag); + auto last_end = gen_prompt.rfind(end_tag); + if (last_start != std::string::npos + && (last_end == std::string::npos || last_end < last_start)) { + slot.in_reasoning = true; + common_sampler_set_grammar_trigger_suppressed(slot.smpl.get(), true); + SLT_DBG(slot, "starting in reasoning state, grammar triggers suppressed\n%s", ""); + } + } } else { slot.smpl.reset(); } @@ -1209,6 +1236,35 @@ private: } slot.has_next_token = true; + // update reasoning state and propagate to grammar trigger suppression + if (!slot.task->params.thinking_end_tag.empty() && slot.smpl) { + const auto & end_tag = slot.task->params.thinking_end_tag; + const auto & start_tag = slot.task->params.thinking_start_tag; + if (slot.in_reasoning) { + // check if the end tag just appeared + if (slot.generated_text.size() >= end_tag.size()) { + auto tail = std::string_view(slot.generated_text).substr( + slot.generated_text.size() - end_tag.size()); + if (tail == end_tag) { + slot.in_reasoning = false; + common_sampler_set_grammar_trigger_suppressed(slot.smpl.get(), false); + SLT_DBG(slot, "reasoning ended, grammar triggers un-suppressed\n%s", ""); + } + } + } else { + // check if the start tag just appeared + if (!start_tag.empty() && slot.generated_text.size() >= start_tag.size()) { + auto tail = std::string_view(slot.generated_text).substr( + slot.generated_text.size() - start_tag.size()); + if (tail == start_tag) { + slot.in_reasoning = true; + common_sampler_set_grammar_trigger_suppressed(slot.smpl.get(), true); + SLT_DBG(slot, "reasoning started, grammar triggers suppressed\n%s", ""); + } + } + } + } + // check if there is incomplete UTF-8 character at the end bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); @@ -2835,6 +2891,16 @@ private: llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx); + // if the model emits EOG while still inside a reasoning block, + // force the first token of the thinking end tag instead + if (slot.in_reasoning + && slot.thinking_end_first_token != LLAMA_TOKEN_NULL + && llama_vocab_is_eog(vocab, id)) { + SLT_DBG(slot, "intercepted EOG during reasoning, forcing thinking end token %d\n", + slot.thinking_end_first_token); + id = slot.thinking_end_first_token; + } + slot.i_batch = -1; common_sampler_accept(slot.smpl.get(), id, true); diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 7d543b9292..2573e66eb0 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -500,6 +500,10 @@ task_params server_task::params_from_json_cmpl( } } + // Parse thinking tags for reasoning state tracking (used to suppress grammar triggers during reasoning) + params.thinking_start_tag = json_value(data, "thinking_start_tag", std::string()); + params.thinking_end_tag = json_value(data, "thinking_end_tag", std::string()); + { params.sampling.logit_bias.clear(); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index a49ddb594b..d51d113c21 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -83,6 +83,10 @@ struct task_params { // per-request parameters for chat parsing common_chat_parser_params chat_parser_params; + // thinking/reasoning tags for tracking reasoning state in the slot + std::string thinking_start_tag; + std::string thinking_end_tag; + // Embeddings int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) From c25aed1f5c6ccb030cf761634847e612f5a7b6a2 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Sat, 21 Mar 2026 22:00:33 +0100 Subject: [PATCH 2/2] Compilation fixes --- common/sampling.cpp | 3 --- include/llama.h | 7 +++++++ src/llama-sampler.h | 4 ---- tools/server/server-context.cpp | 35 ++++++++++++++++----------------- 4 files changed, 24 insertions(+), 25 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index c3194ba3f0..db7d2912be 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -483,9 +483,6 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam } } -// forward declaration of internal function (defined in llama-sampler.cpp) -void llama_sampler_grammar_set_trigger_suppressed(struct llama_sampler * smpl, bool suppressed); - void common_sampler_set_grammar_trigger_suppressed(struct common_sampler * gsmpl, bool suppressed) { if (!gsmpl || !gsmpl->grmr) { return; diff --git a/include/llama.h b/include/llama.h index 6e72db7e3c..34ceeefdf1 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1380,6 +1380,13 @@ extern "C" { const llama_token * trigger_tokens, size_t num_trigger_tokens); + /// @details Suppress or un-suppress trigger detection on a grammar sampler. + /// When suppressed, the grammar still buffers tokens but does not check for triggers. + /// Useful for suppressing grammar activation during reasoning/thinking blocks. + /// No-op if the sampler is not a grammar sampler. + LLAMA_API void llama_sampler_grammar_set_trigger_suppressed( + struct llama_sampler * smpl, + bool suppressed); /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first. LLAMA_API struct llama_sampler * llama_sampler_init_penalties( diff --git a/src/llama-sampler.h b/src/llama-sampler.h index 11ad399fb8..b9bfc20d25 100644 --- a/src/llama-sampler.h +++ b/src/llama-sampler.h @@ -33,10 +33,6 @@ struct llama_sampler_chain { mutable int32_t n_sample; }; -// set trigger_suppressed on a grammar sampler (e.g. to suppress triggers during reasoning) -// the sampler must have been created by llama_sampler_init_grammar* or this is a no-op -void llama_sampler_grammar_set_trigger_suppressed(struct llama_sampler * smpl, bool suppressed); - struct llama_sampler * llama_sampler_init_dry_testing( int32_t context_size, float dry_multiplier, diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index ca4d15ec88..0e137447b7 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1241,26 +1241,25 @@ private: const auto & end_tag = slot.task->params.thinking_end_tag; const auto & start_tag = slot.task->params.thinking_start_tag; if (slot.in_reasoning) { - // check if the end tag just appeared - if (slot.generated_text.size() >= end_tag.size()) { - auto tail = std::string_view(slot.generated_text).substr( - slot.generated_text.size() - end_tag.size()); - if (tail == end_tag) { - slot.in_reasoning = false; - common_sampler_set_grammar_trigger_suppressed(slot.smpl.get(), false); - SLT_DBG(slot, "reasoning ended, grammar triggers un-suppressed\n%s", ""); - } + // check if the end tag just appeared at the end of generated_text + if (slot.generated_text.size() >= end_tag.size() + && slot.generated_text.compare( + slot.generated_text.size() - end_tag.size(), + end_tag.size(), end_tag) == 0) { + slot.in_reasoning = false; + common_sampler_set_grammar_trigger_suppressed(slot.smpl.get(), false); + SLT_DBG(slot, "reasoning ended, grammar triggers un-suppressed\n%s", ""); } } else { - // check if the start tag just appeared - if (!start_tag.empty() && slot.generated_text.size() >= start_tag.size()) { - auto tail = std::string_view(slot.generated_text).substr( - slot.generated_text.size() - start_tag.size()); - if (tail == start_tag) { - slot.in_reasoning = true; - common_sampler_set_grammar_trigger_suppressed(slot.smpl.get(), true); - SLT_DBG(slot, "reasoning started, grammar triggers suppressed\n%s", ""); - } + // check if the start tag just appeared at the end of generated_text + if (!start_tag.empty() + && slot.generated_text.size() >= start_tag.size() + && slot.generated_text.compare( + slot.generated_text.size() - start_tag.size(), + start_tag.size(), start_tag) == 0) { + slot.in_reasoning = true; + common_sampler_set_grammar_trigger_suppressed(slot.smpl.get(), true); + SLT_DBG(slot, "reasoning started, grammar triggers suppressed\n%s", ""); } } }