From 4b71c285dbfefb22f1e2a0b86609351f3bfa2333 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 1 Jan 2026 23:33:23 +0100 Subject: [PATCH] (wip) redirect minja calls --- common/chat.cpp | 111 +++++++++++++++++++++++++++++----------- common/jinja/jinja-vm.h | 3 +- 2 files changed, 82 insertions(+), 32 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 0a426f4478..82c742ee18 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -7,8 +7,12 @@ #include "log.h" #include "regex-partial.h" -#include -#include +// #include +// #include + +#include "jinja/jinja-parser.h" +#include "jinja/jinja-value.h" +#include "jinja/jinja-vm.h" #include #include @@ -135,7 +139,46 @@ std::vector common_chat_msg_diff::compute_diffs(const comm return diffs; } -typedef minja::chat_template common_chat_template; +struct common_chat_template { + jinja::program prog; + std::string bos_tok; + std::string eos_tok; + std::string src; + common_chat_template(const std::string & src, const std::string & bos_token, const std::string & eos_token) { + jinja::lexer lexer; + jinja::preprocess_options options; + options.trim_blocks = false; + options.lstrip_blocks = false; + auto lexer_res = lexer.tokenize(src, options); + prog = jinja::parse_from_tokens(lexer_res); + + this->src = lexer_res.preprocessed_source; + this->bos_tok = bos_token; + this->eos_tok = eos_token; + } + + const std::string & source() const { return src; } + const std::string & bos_token() const { return bos_tok; } + const std::string & eos_token() const { return eos_tok; } + static json add_system(const json &, const std::string &) { + throw std::runtime_error("common_chat_template::add_system not implemented"); + } + + + // this is just for testing. it will be removed later + struct chat_template_caps { + bool supports_tools = true; + bool supports_tool_calls = true; + bool supports_tool_responses = true; + bool supports_system_role = true; + bool supports_parallel_tool_calls = true; + bool requires_typed_content = true; + }; + chat_template_caps original_caps() const { + return chat_template_caps(); + } + +}; struct common_chat_templates { bool add_bos; @@ -627,14 +670,14 @@ common_chat_templates_ptr common_chat_templates_init( tmpls->add_bos = add_bos; tmpls->add_eos = add_eos; try { - tmpls->template_default = std::make_unique(default_template_src, token_bos, token_eos); + tmpls->template_default = std::make_unique(default_template_src, token_bos, token_eos); } catch (const std::exception & e) { LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what()); - tmpls->template_default = std::make_unique(CHATML_TEMPLATE_SRC, token_bos, token_eos); + tmpls->template_default = std::make_unique(CHATML_TEMPLATE_SRC, token_bos, token_eos); } if (!template_tool_use_src.empty()) { try { - tmpls->template_tool_use = std::make_unique(template_tool_use_src, token_bos, token_eos); + tmpls->template_tool_use = std::make_unique(template_tool_use_src, token_bos, token_eos); } catch (const std::exception & e) { LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what()); } @@ -737,34 +780,40 @@ static std::string apply( const std::optional & tools_override = std::nullopt, const std::optional & additional_context = std::nullopt) { - minja::chat_template_inputs tmpl_inputs; - tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages; - if (tools_override) { - tmpl_inputs.tools = *tools_override; - } else { - tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools; - } - tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt; - tmpl_inputs.extra_context = inputs.extra_context; - tmpl_inputs.extra_context["enable_thinking"] = inputs.enable_thinking; - if (additional_context) { - tmpl_inputs.extra_context.merge_patch(*additional_context); - } - // TODO: add flag to control date/time, if only for testing purposes. - // tmpl_inputs.now = std::chrono::system_clock::now(); + // TODO IMPORTANT: IMPORVE THIS - minja::chat_template_options tmpl_opts; - // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens - // instead of using `chat_template_options.use_bos_token = false`, since these tokens - // may be needed inside the template / between messages too. - auto result = tmpl.apply(tmpl_inputs, tmpl_opts); - if (inputs.add_bos && string_starts_with(result, tmpl.bos_token())) { - result = result.substr(tmpl.bos_token().size()); + jinja::context ctx; + ctx.source = tmpl.source(); // for debugging + + nlohmann::json inp = nlohmann::json{ + {"messages", messages_override.has_value() ? *messages_override : inputs.messages}, + {"tools", tools_override.has_value() ? *tools_override : inputs.tools}, + }; + if (additional_context.has_value()) { + // TODO: merge properly instead of overwriting + for (const auto & [k, v] : additional_context->items()) { + inp[k] = v; + } } - if (inputs.add_eos && string_ends_with(result, tmpl.eos_token())) { - result = result.substr(0, result.size() - tmpl.eos_token().size()); + if (inputs.add_generation_prompt) { + inp["add_generation_prompt"] = true; } - return result; + if (inputs.add_bos) { + inp["bos_token"] = tmpl.bos_token(); + } + if (inputs.add_eos) { + inp["eos_token"] = tmpl.eos_token(); + } + // TODO: more inputs? + + jinja::global_from_json(ctx, inp); + + // render + jinja::vm vm(ctx); + const jinja::value results = vm.execute(tmpl.prog); + auto parts = vm.gather_string_parts(results); + + return parts->as_string().str(); } static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) { diff --git a/common/jinja/jinja-vm.h b/common/jinja/jinja-vm.h index faee1559cf..c1f91dd81f 100644 --- a/common/jinja/jinja-vm.h +++ b/common/jinja/jinja-vm.h @@ -125,6 +125,7 @@ struct expression : public statement { struct program : public statement { statements body; + program() = default; explicit program(statements && body) : body(std::move(body)) {} std::string type() const override { return "Program"; } value execute_impl(context &) override { @@ -562,7 +563,7 @@ struct vm { context & ctx; explicit vm(context & ctx) : ctx(ctx) {} - value_array execute(program & prog) { + value_array execute(const program & prog) { value_array results = mk_val(); for (auto & stmt : prog.body) { value res = stmt->execute(ctx);