This is getting more and more complicated by the minute...

This commit is contained in:
Piotr Wilkin 2026-03-15 16:02:39 +01:00
parent a5b51007c1
commit a21d219a0d
8 changed files with 38 additions and 7 deletions

View File

@ -1833,6 +1833,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()),
[](common_params & params, const std::string & value) {
params.sampling.grammar = value;
params.sampling.grammar_external = true;
}
).set_sparam());
add_opt(common_arg(
@ -1840,6 +1841,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"file to read grammar from",
[](common_params & params, const std::string & value) {
params.sampling.grammar = read_file(value);
params.sampling.grammar_external = true;
}
).set_sparam());
add_opt(common_arg(

View File

@ -1629,7 +1629,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
}
try {
LOG_DBG("Using differential autoparser\n");
LOG_DBG("%s: using differential autoparser\n", __func__);
struct autoparser::autoparser autoparser;
autoparser.analyze_template(tmpl);
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
@ -1639,6 +1639,9 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
auto_params.thinking_end_tag = autoparser.reasoning.end;
}
auto_params.generation_prompt = params.generation_prompt;
common_peg_arena arena;
arena.load(auto_params.parser);
LOG_DBG("%s: generated parser:\n%s\n\nparser generation prompt: %s\n", __func__, arena.dump(arena.root()).c_str(), auto_params.generation_prompt.c_str());
return auto_params;
} catch (const std::exception & e) {
throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what());

View File

@ -231,12 +231,14 @@ struct common_params_sampling {
std::string grammar; // optional BNF-like grammar to constrain sampling
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
bool grammar_external = false; // is the grammar set by the user explicitly?
// if so, we must not pass extra grammar prefill to it
std::set<llama_token> preserved_tokens;
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
// Grammar prefill: reasoning markers already present in the prompt suffix.
// Grammar prefill: tokens already present in the prompt generation message.
// Fed to the grammar sampler (to advance past pre-existing tokens) and used
// to determine the reasoning budget sampler's initial state.
std::string grammar_prefill;

View File

@ -5,9 +5,11 @@
#include "reasoning-budget.h"
#include <algorithm>
#include <cctype>
#include <cmath>
#include <cstring>
#include <unordered_map>
#include <vector>
// the ring buffer works similarly to std::deque, but with a fixed capacity
// TODO: deduplicate with llama-impl.h
@ -254,11 +256,20 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
// Feed grammar prefill tokens to the grammar sampler so it advances past
// reasoning markers that the template already placed in the prompt.
std::vector<llama_token> prefill_tokens;
if (!params.grammar_prefill.empty() && vocab) {
if (!params.grammar_prefill.empty() && vocab && !params.grammar_external) {
prefill_tokens = common_tokenize(vocab, params.grammar_prefill, false, true);
if (!prefill_tokens.empty()) {
std::string first_token = common_token_to_piece(vocab, prefill_tokens[0], true);
if (std::isspace(first_token[0]) && !std::isspace(params.grammar_prefill[0])) {
// Some tokenizers will add a space before the first special token, need to remove
prefill_tokens = std::vector<llama_token>(prefill_tokens.begin() + 1, prefill_tokens.end());
}
}
if (grmr) {
for (const auto & token : prefill_tokens) {
llama_sampler_accept(grmr, token);
LOG_DBG("%s: accepted prefill token (%d)\n", __func__, token);
}
}
}

View File

@ -1085,7 +1085,8 @@ json oaicompat_chat_params_parse(
if (!chat_params.grammar.empty()) {
llama_params["grammar"] = chat_params.grammar;
}
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
llama_params["grammar_external"] = body.contains("grammar");
auto grammar_triggers = json::array();
for (const auto & trigger : chat_params.grammar_triggers) {
server_grammar_trigger ct(trigger);

View File

@ -15,6 +15,7 @@
#include <algorithm>
#include <cstddef>
#include <cinttypes>
#include <exception>
#include <memory>
#include <filesystem>
@ -1151,10 +1152,19 @@ private:
// initialize samplers
if (task.need_sampling()) {
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
try {
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
} catch (std::exception & e) {
LOG_ERR("%s: error initializing samplers. Grammar was:\n%s\n\nGrammar prefill:\n'%s'\n", __func__,
task.params.sampling.grammar.c_str(), task.params.sampling.grammar_prefill.c_str());
std::string err_msg = std::string("Failed to initialize samplers: ") + e.what();
send_error(task, err_msg, ERROR_TYPE_INVALID_REQUEST);
return false;
}
if (slot.smpl == nullptr) {
// for now, the only error that may happen here is invalid grammar
LOG_ERR("%s: error in parsing grammar. Grammar was:\n%s\n\nGrammar prefill:\n'%s'\n", __func__,
task.params.sampling.grammar.c_str(), task.params.sampling.grammar_prefill.c_str());
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
return false;
}

View File

@ -382,7 +382,8 @@ task_params server_task::params_from_json_cmpl(
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
}
} else {
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
params.sampling.grammar_external = json_value(data, "grammar_external", params.sampling.grammar_external);
SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");

View File

@ -210,6 +210,7 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
global server
server.jinja = jinja
server.debug = True
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,