From a21d219a0d8d11bbc51fd98475a812e980b667d2 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Sun, 15 Mar 2026 16:02:39 +0100 Subject: [PATCH] This is getting more and more complicated by the minute... --- common/arg.cpp | 2 ++ common/chat.cpp | 5 ++++- common/common.h | 4 +++- common/sampling.cpp | 13 ++++++++++++- tools/server/server-common.cpp | 3 ++- tools/server/server-context.cpp | 14 ++++++++++++-- tools/server/server-task.cpp | 3 ++- tools/server/tests/unit/test_chat_completion.py | 1 + 8 files changed, 38 insertions(+), 7 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 10aa1b5e4f..9999c59969 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1833,6 +1833,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()), [](common_params & params, const std::string & value) { params.sampling.grammar = value; + params.sampling.grammar_external = true; } ).set_sparam()); add_opt(common_arg( @@ -1840,6 +1841,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "file to read grammar from", [](common_params & params, const std::string & value) { params.sampling.grammar = read_file(value); + params.sampling.grammar_external = true; } ).set_sparam()); add_opt(common_arg( diff --git a/common/chat.cpp b/common/chat.cpp index 947f8bf41c..82071096ca 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1629,7 +1629,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ } try { - LOG_DBG("Using differential autoparser\n"); + LOG_DBG("%s: using differential autoparser\n", __func__); struct autoparser::autoparser autoparser; autoparser.analyze_template(tmpl); auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser); @@ -1639,6 +1639,9 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ auto_params.thinking_end_tag = autoparser.reasoning.end; } auto_params.generation_prompt = params.generation_prompt; + common_peg_arena arena; + arena.load(auto_params.parser); + LOG_DBG("%s: generated parser:\n%s\n\nparser generation prompt: %s\n", __func__, arena.dump(arena.root()).c_str(), auto_params.generation_prompt.c_str()); return auto_params; } catch (const std::exception & e) { throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what()); diff --git a/common/common.h b/common/common.h index c9ab673d66..27d8bc342a 100644 --- a/common/common.h +++ b/common/common.h @@ -231,12 +231,14 @@ struct common_params_sampling { std::string grammar; // optional BNF-like grammar to constrain sampling bool grammar_lazy = false; std::vector grammar_triggers; // optional triggers (for lazy grammars) + bool grammar_external = false; // is the grammar set by the user explicitly? + // if so, we must not pass extra grammar prefill to it std::set preserved_tokens; std::vector logit_bias; // logit biases to apply std::vector logit_bias_eog; // pre-calculated logit biases for EOG tokens - // Grammar prefill: reasoning markers already present in the prompt suffix. + // Grammar prefill: tokens already present in the prompt generation message. // Fed to the grammar sampler (to advance past pre-existing tokens) and used // to determine the reasoning budget sampler's initial state. std::string grammar_prefill; diff --git a/common/sampling.cpp b/common/sampling.cpp index ebed0f19e6..6237e40a4c 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -5,9 +5,11 @@ #include "reasoning-budget.h" #include +#include #include #include #include +#include // the ring buffer works similarly to std::deque, but with a fixed capacity // TODO: deduplicate with llama-impl.h @@ -254,11 +256,20 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st // Feed grammar prefill tokens to the grammar sampler so it advances past // reasoning markers that the template already placed in the prompt. std::vector prefill_tokens; - if (!params.grammar_prefill.empty() && vocab) { + if (!params.grammar_prefill.empty() && vocab && !params.grammar_external) { prefill_tokens = common_tokenize(vocab, params.grammar_prefill, false, true); + if (!prefill_tokens.empty()) { + std::string first_token = common_token_to_piece(vocab, prefill_tokens[0], true); + if (std::isspace(first_token[0]) && !std::isspace(params.grammar_prefill[0])) { + // Some tokenizers will add a space before the first special token, need to remove + prefill_tokens = std::vector(prefill_tokens.begin() + 1, prefill_tokens.end()); + } + } + if (grmr) { for (const auto & token : prefill_tokens) { llama_sampler_accept(grmr, token); + LOG_DBG("%s: accepted prefill token (%d)\n", __func__, token); } } } diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index e0a6e86e7f..2f80bac1c1 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1085,7 +1085,8 @@ json oaicompat_chat_params_parse( if (!chat_params.grammar.empty()) { llama_params["grammar"] = chat_params.grammar; } - llama_params["grammar_lazy"] = chat_params.grammar_lazy; + llama_params["grammar_lazy"] = chat_params.grammar_lazy; + llama_params["grammar_external"] = body.contains("grammar"); auto grammar_triggers = json::array(); for (const auto & trigger : chat_params.grammar_triggers) { server_grammar_trigger ct(trigger); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index c47ad876cb..569a0b02f5 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -1151,10 +1152,19 @@ private: // initialize samplers if (task.need_sampling()) { - slot.smpl.reset(common_sampler_init(model, task.params.sampling)); + try { + slot.smpl.reset(common_sampler_init(model, task.params.sampling)); + } catch (std::exception & e) { + LOG_ERR("%s: error initializing samplers. Grammar was:\n%s\n\nGrammar prefill:\n'%s'\n", __func__, + task.params.sampling.grammar.c_str(), task.params.sampling.grammar_prefill.c_str()); + std::string err_msg = std::string("Failed to initialize samplers: ") + e.what(); + send_error(task, err_msg, ERROR_TYPE_INVALID_REQUEST); + return false; + } if (slot.smpl == nullptr) { - // for now, the only error that may happen here is invalid grammar + LOG_ERR("%s: error in parsing grammar. Grammar was:\n%s\n\nGrammar prefill:\n'%s'\n", __func__, + task.params.sampling.grammar.c_str(), task.params.sampling.grammar_prefill.c_str()); send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); return false; } diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index a60ade594a..50b184e127 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -382,7 +382,8 @@ task_params server_task::params_from_json_cmpl( throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); } } else { - params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); + params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); + params.sampling.grammar_external = json_value(data, "grammar_external", params.sampling.grammar_external); SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index d56a930f7c..1bcffd91b6 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -210,6 +210,7 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str): global server server.jinja = jinja + server.debug = True server.start() res = server.make_request("POST", "/chat/completions", data={ "max_tokens": n_predicted,