Merge c25aed1f5c into ec2b787ebe
This commit is contained in:
commit
8e57e66a4d
|
|
@ -483,6 +483,13 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
|
|||
}
|
||||
}
|
||||
|
||||
void common_sampler_set_grammar_trigger_suppressed(struct common_sampler * gsmpl, bool suppressed) {
|
||||
if (!gsmpl || !gsmpl->grmr) {
|
||||
return;
|
||||
}
|
||||
llama_sampler_grammar_set_trigger_suppressed(gsmpl->grmr, suppressed);
|
||||
}
|
||||
|
||||
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
|
||||
if (!gsmpl) {
|
||||
return nullptr;
|
||||
|
|
|
|||
|
|
@ -87,6 +87,10 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
|||
|
||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
||||
|
||||
// suppress or un-suppress grammar trigger detection (e.g. during reasoning/thinking blocks)
|
||||
// when suppressed, the grammar still buffers tokens but does not check for triggers
|
||||
void common_sampler_set_grammar_trigger_suppressed(struct common_sampler * gsmpl, bool suppressed);
|
||||
|
||||
// helpers
|
||||
|
||||
// access the internal list of current candidate tokens
|
||||
|
|
|
|||
|
|
@ -1380,6 +1380,13 @@ extern "C" {
|
|||
const llama_token * trigger_tokens,
|
||||
size_t num_trigger_tokens);
|
||||
|
||||
/// @details Suppress or un-suppress trigger detection on a grammar sampler.
|
||||
/// When suppressed, the grammar still buffers tokens but does not check for triggers.
|
||||
/// Useful for suppressing grammar activation during reasoning/thinking blocks.
|
||||
/// No-op if the sampler is not a grammar sampler.
|
||||
LLAMA_API void llama_sampler_grammar_set_trigger_suppressed(
|
||||
struct llama_sampler * smpl,
|
||||
bool suppressed);
|
||||
|
||||
/// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
|
||||
|
|
|
|||
|
|
@ -1185,6 +1185,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||
/* .partial_utf8 = */ {},
|
||||
/* .lazy = */ false,
|
||||
/* .awaiting_trigger = */ false,
|
||||
/* .trigger_suppressed = */ false,
|
||||
/* .trigger_buffer = */ "",
|
||||
/* .trigger_buffer_positions = */ {},
|
||||
/* .trigger_tokens = */ {},
|
||||
|
|
@ -1291,6 +1292,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||
/* .partial_utf8 = */ {},
|
||||
/* .lazy = */ lazy,
|
||||
/* .awaiting_trigger = */ lazy,
|
||||
/* .trigger_suppressed = */ false,
|
||||
/* .trigger_buffer = */ "",
|
||||
/* .trigger_buffer_positions = */ {},
|
||||
std::move(vec_trigger_tokens),
|
||||
|
|
@ -1314,6 +1316,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
|
|||
grammar.partial_utf8,
|
||||
grammar.lazy,
|
||||
grammar.awaiting_trigger,
|
||||
grammar.trigger_suppressed,
|
||||
grammar.trigger_buffer,
|
||||
grammar.trigger_buffer_positions,
|
||||
grammar.trigger_tokens,
|
||||
|
|
@ -1385,6 +1388,15 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|||
const auto & piece = grammar.vocab->token_to_piece(token);
|
||||
|
||||
if (grammar.awaiting_trigger) {
|
||||
// When trigger is suppressed (e.g. during reasoning), still buffer tokens but skip trigger detection
|
||||
if (grammar.trigger_suppressed) {
|
||||
auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size());
|
||||
grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
|
||||
grammar.trigger_buffer += piece;
|
||||
LLAMA_LOG_DEBUG("Grammar trigger suppressed, buffering token %d (`%s`)\n", token, piece.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
|
||||
grammar.awaiting_trigger = false;
|
||||
grammar.trigger_buffer.clear();
|
||||
|
|
|
|||
|
|
@ -141,6 +141,7 @@ struct llama_grammar {
|
|||
// (useful e.g. for tool_choice=required)
|
||||
bool lazy = false;
|
||||
bool awaiting_trigger = false; // Initialized to true for lazy grammars only
|
||||
bool trigger_suppressed = false; // When true, trigger detection is suppressed (e.g. during reasoning)
|
||||
std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found.
|
||||
std::vector<token_pos> trigger_buffer_positions; // Tokens buffered by lazy grammar. Used to replay when a trigger is found.
|
||||
std::vector<llama_token> trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
|
||||
|
|
|
|||
|
|
@ -2529,6 +2529,16 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
|
|||
/* .backend_set_input = */ nullptr,
|
||||
};
|
||||
|
||||
void llama_sampler_grammar_set_trigger_suppressed(struct llama_sampler * smpl, bool suppressed) {
|
||||
if (!smpl || smpl->iface != &llama_sampler_grammar_i) {
|
||||
return;
|
||||
}
|
||||
auto * ctx = (llama_sampler_grammar *) smpl->ctx;
|
||||
if (ctx->grammar) {
|
||||
ctx->grammar->trigger_suppressed = suppressed;
|
||||
}
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
|
|
|
|||
|
|
@ -936,7 +936,71 @@ static void test_peg_parser(common_chat_templates * tmpls,
|
|||
throw std::runtime_error("Failed to build grammar: " + parser.params_.grammar);
|
||||
}
|
||||
|
||||
// Find the earliest trigger position to determine the constrained portion
|
||||
// Determine reasoning regions in tc.input so we can suppress grammar triggers inside them.
|
||||
// A reasoning region spans from thinking_start_tag to thinking_end_tag.
|
||||
// If generation_prompt contains the start tag (without a matching end), reasoning starts
|
||||
// before tc.input, so position 0 is already inside reasoning.
|
||||
std::vector<std::pair<size_t, size_t>> reasoning_regions; // [start, end) in tc.input
|
||||
{
|
||||
const auto & start_tag = parser.params_.thinking_start_tag;
|
||||
const auto & end_tag = parser.params_.thinking_end_tag;
|
||||
if (!end_tag.empty()) {
|
||||
// check if generation_prompt puts us inside reasoning at the start of tc.input
|
||||
bool in_reasoning = false;
|
||||
if (!start_tag.empty()) {
|
||||
const auto & gen_prompt = parser.params_.generation_prompt;
|
||||
auto last_start = gen_prompt.rfind(start_tag);
|
||||
if (last_start != std::string::npos) {
|
||||
auto last_end = gen_prompt.rfind(end_tag);
|
||||
if (last_end == std::string::npos || last_end < last_start) {
|
||||
in_reasoning = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t search_from = 0;
|
||||
size_t region_start = in_reasoning ? 0 : std::string::npos;
|
||||
|
||||
while (search_from < tc.input.size()) {
|
||||
if (in_reasoning) {
|
||||
auto end_pos = tc.input.find(end_tag, search_from);
|
||||
if (end_pos != std::string::npos) {
|
||||
reasoning_regions.push_back({region_start, end_pos + end_tag.size()});
|
||||
search_from = end_pos + end_tag.size();
|
||||
in_reasoning = false;
|
||||
} else {
|
||||
// reasoning extends to end of input
|
||||
reasoning_regions.push_back({region_start, tc.input.size()});
|
||||
break;
|
||||
}
|
||||
} else if (!start_tag.empty()) {
|
||||
auto start_pos = tc.input.find(start_tag, search_from);
|
||||
if (start_pos != std::string::npos) {
|
||||
region_start = start_pos;
|
||||
search_from = start_pos + start_tag.size();
|
||||
in_reasoning = true;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper: check if a position falls inside any reasoning region
|
||||
auto is_in_reasoning = [&reasoning_regions](size_t pos) -> bool {
|
||||
for (const auto & [start, end] : reasoning_regions) {
|
||||
if (pos >= start && pos < end) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// Find the earliest trigger position to determine the constrained portion,
|
||||
// skipping triggers that fall inside reasoning regions.
|
||||
auto earliest_trigger_pos = std::string::npos;
|
||||
for (const auto & trigger : parser.params_.grammar_triggers) {
|
||||
size_t pos = std::string::npos;
|
||||
|
|
@ -945,14 +1009,34 @@ static void test_peg_parser(common_chat_templates * tmpls,
|
|||
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
||||
{
|
||||
const auto & word = trigger.value;
|
||||
pos = tc.input.find(word);
|
||||
// find first occurrence outside reasoning
|
||||
size_t search_from = 0;
|
||||
while (search_from < tc.input.size()) {
|
||||
auto found = tc.input.find(word, search_from);
|
||||
if (found == std::string::npos) {
|
||||
break;
|
||||
}
|
||||
if (!is_in_reasoning(found)) {
|
||||
pos = found;
|
||||
break;
|
||||
}
|
||||
search_from = found + 1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
{
|
||||
const auto & pattern = std::regex(trigger.value);
|
||||
if (std::regex_search(tc.input, match, pattern)) {
|
||||
pos = match.position(pattern.mark_count());
|
||||
auto search_str = tc.input;
|
||||
size_t offset = 0;
|
||||
while (std::regex_search(search_str, match, pattern)) {
|
||||
auto found = offset + match.position(pattern.mark_count());
|
||||
if (!is_in_reasoning(found)) {
|
||||
pos = found;
|
||||
break;
|
||||
}
|
||||
offset += match.position(0) + match.length(0);
|
||||
search_str = tc.input.substr(offset);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
@ -970,7 +1054,11 @@ static void test_peg_parser(common_chat_templates * tmpls,
|
|||
if (mpos == std::string::npos) {
|
||||
mpos = match.position(0);
|
||||
}
|
||||
pos = mpos;
|
||||
// PATTERN_FULL matches the entire input, so if the match position
|
||||
// is in reasoning, skip it entirely
|
||||
if (!is_in_reasoning(mpos)) {
|
||||
pos = mpos;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
@ -1425,6 +1513,50 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
|||
.expect_reasoning("I need to output the invoice details in JSON")
|
||||
.expect_content(R"({"amount": 123.45, "date": "2025-12-03"})")
|
||||
.run();
|
||||
|
||||
// tool call segment in reasoning
|
||||
tst.test(
|
||||
"Let's call a tool: <tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Hello, world!\")\n"
|
||||
"\n"
|
||||
"hello()\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call></think>\n"
|
||||
"<tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Hello, world!\")\n"
|
||||
"\n"
|
||||
"hello()\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>"
|
||||
)
|
||||
.enable_thinking(true)
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.tools({
|
||||
python_tool
|
||||
})
|
||||
.expect_reasoning("Let's call a tool: <tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Hello, world!\")\n"
|
||||
"\n"
|
||||
"hello()\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>")
|
||||
.expect_tool_calls({
|
||||
{ "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
}
|
||||
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1103,6 +1103,13 @@ json oaicompat_chat_params_parse(
|
|||
llama_params["chat_parser"] = chat_params.parser;
|
||||
}
|
||||
|
||||
// Always pass thinking tags so the slot can track reasoning state
|
||||
// (used to suppress grammar triggers during reasoning blocks)
|
||||
if (!chat_params.thinking_end_tag.empty()) {
|
||||
llama_params["thinking_start_tag"] = chat_params.thinking_start_tag;
|
||||
llama_params["thinking_end_tag"] = chat_params.thinking_end_tag;
|
||||
}
|
||||
|
||||
// Reasoning budget: pass parameters through to sampling layer
|
||||
{
|
||||
int reasoning_budget = opt.reasoning_budget;
|
||||
|
|
|
|||
|
|
@ -92,6 +92,8 @@ struct server_slot {
|
|||
bool has_next_token = true;
|
||||
bool has_new_line = false;
|
||||
bool truncated = false;
|
||||
bool in_reasoning = false; // true when inside a thinking/reasoning block
|
||||
llama_token thinking_end_first_token = LLAMA_TOKEN_NULL; // first token of thinking end tag (for EOG interception)
|
||||
|
||||
stop_type stop;
|
||||
|
||||
|
|
@ -173,6 +175,8 @@ struct server_slot {
|
|||
generated_text = "";
|
||||
has_new_line = false;
|
||||
truncated = false;
|
||||
in_reasoning = false;
|
||||
thinking_end_first_token = LLAMA_TOKEN_NULL;
|
||||
stop = STOP_TYPE_NONE;
|
||||
stopping_word = "";
|
||||
n_sent_text = 0;
|
||||
|
|
@ -1181,6 +1185,29 @@ private:
|
|||
}
|
||||
|
||||
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
|
||||
|
||||
// determine initial reasoning state from generation prompt
|
||||
// if the generation prompt ends inside a thinking block, suppress grammar triggers initially
|
||||
if (!task.params.thinking_end_tag.empty()) {
|
||||
const auto & gen_prompt = task.params.sampling.generation_prompt;
|
||||
const auto & start_tag = task.params.thinking_start_tag;
|
||||
const auto & end_tag = task.params.thinking_end_tag;
|
||||
|
||||
// tokenize the thinking end tag so we can intercept EOG during reasoning
|
||||
auto end_tag_tokens = common_tokenize(ctx, end_tag, false, true);
|
||||
if (!end_tag_tokens.empty()) {
|
||||
slot.thinking_end_first_token = end_tag_tokens[0];
|
||||
}
|
||||
|
||||
auto last_start = start_tag.empty() ? std::string::npos : gen_prompt.rfind(start_tag);
|
||||
auto last_end = gen_prompt.rfind(end_tag);
|
||||
if (last_start != std::string::npos
|
||||
&& (last_end == std::string::npos || last_end < last_start)) {
|
||||
slot.in_reasoning = true;
|
||||
common_sampler_set_grammar_trigger_suppressed(slot.smpl.get(), true);
|
||||
SLT_DBG(slot, "starting in reasoning state, grammar triggers suppressed\n%s", "");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
slot.smpl.reset();
|
||||
}
|
||||
|
|
@ -1209,6 +1236,34 @@ private:
|
|||
}
|
||||
slot.has_next_token = true;
|
||||
|
||||
// update reasoning state and propagate to grammar trigger suppression
|
||||
if (!slot.task->params.thinking_end_tag.empty() && slot.smpl) {
|
||||
const auto & end_tag = slot.task->params.thinking_end_tag;
|
||||
const auto & start_tag = slot.task->params.thinking_start_tag;
|
||||
if (slot.in_reasoning) {
|
||||
// check if the end tag just appeared at the end of generated_text
|
||||
if (slot.generated_text.size() >= end_tag.size()
|
||||
&& slot.generated_text.compare(
|
||||
slot.generated_text.size() - end_tag.size(),
|
||||
end_tag.size(), end_tag) == 0) {
|
||||
slot.in_reasoning = false;
|
||||
common_sampler_set_grammar_trigger_suppressed(slot.smpl.get(), false);
|
||||
SLT_DBG(slot, "reasoning ended, grammar triggers un-suppressed\n%s", "");
|
||||
}
|
||||
} else {
|
||||
// check if the start tag just appeared at the end of generated_text
|
||||
if (!start_tag.empty()
|
||||
&& slot.generated_text.size() >= start_tag.size()
|
||||
&& slot.generated_text.compare(
|
||||
slot.generated_text.size() - start_tag.size(),
|
||||
start_tag.size(), start_tag) == 0) {
|
||||
slot.in_reasoning = true;
|
||||
common_sampler_set_grammar_trigger_suppressed(slot.smpl.get(), true);
|
||||
SLT_DBG(slot, "reasoning started, grammar triggers suppressed\n%s", "");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check if there is incomplete UTF-8 character at the end
|
||||
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
|
||||
|
||||
|
|
@ -2835,6 +2890,16 @@ private:
|
|||
|
||||
llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
|
||||
|
||||
// if the model emits EOG while still inside a reasoning block,
|
||||
// force the first token of the thinking end tag instead
|
||||
if (slot.in_reasoning
|
||||
&& slot.thinking_end_first_token != LLAMA_TOKEN_NULL
|
||||
&& llama_vocab_is_eog(vocab, id)) {
|
||||
SLT_DBG(slot, "intercepted EOG during reasoning, forcing thinking end token %d\n",
|
||||
slot.thinking_end_first_token);
|
||||
id = slot.thinking_end_first_token;
|
||||
}
|
||||
|
||||
slot.i_batch = -1;
|
||||
|
||||
common_sampler_accept(slot.smpl.get(), id, true);
|
||||
|
|
|
|||
|
|
@ -500,6 +500,10 @@ task_params server_task::params_from_json_cmpl(
|
|||
}
|
||||
}
|
||||
|
||||
// Parse thinking tags for reasoning state tracking (used to suppress grammar triggers during reasoning)
|
||||
params.thinking_start_tag = json_value(data, "thinking_start_tag", std::string());
|
||||
params.thinking_end_tag = json_value(data, "thinking_end_tag", std::string());
|
||||
|
||||
{
|
||||
params.sampling.logit_bias.clear();
|
||||
|
||||
|
|
|
|||
|
|
@ -83,6 +83,10 @@ struct task_params {
|
|||
// per-request parameters for chat parsing
|
||||
common_chat_parser_params chat_parser_params;
|
||||
|
||||
// thinking/reasoning tags for tracking reasoning state in the slot
|
||||
std::string thinking_start_tag;
|
||||
std::string thinking_end_tag;
|
||||
|
||||
// Embeddings
|
||||
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue