(wip) redirect minja calls

This commit is contained in:
Xuan Son Nguyen 2026-01-01 23:33:23 +01:00
parent a66e4a4f5d
commit 4b71c285db
2 changed files with 82 additions and 32 deletions

View File

@ -7,8 +7,12 @@
#include "log.h"
#include "regex-partial.h"
#include <minja/chat-template.hpp>
#include <minja/minja.hpp>
// #include <minja/chat-template.hpp>
// #include <minja/minja.hpp>
#include "jinja/jinja-parser.h"
#include "jinja/jinja-value.h"
#include "jinja/jinja-vm.h"
#include <algorithm>
#include <cstdio>
@ -135,7 +139,46 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
return diffs;
}
typedef minja::chat_template common_chat_template;
struct common_chat_template {
jinja::program prog;
std::string bos_tok;
std::string eos_tok;
std::string src;
common_chat_template(const std::string & src, const std::string & bos_token, const std::string & eos_token) {
jinja::lexer lexer;
jinja::preprocess_options options;
options.trim_blocks = false;
options.lstrip_blocks = false;
auto lexer_res = lexer.tokenize(src, options);
prog = jinja::parse_from_tokens(lexer_res);
this->src = lexer_res.preprocessed_source;
this->bos_tok = bos_token;
this->eos_tok = eos_token;
}
const std::string & source() const { return src; }
const std::string & bos_token() const { return bos_tok; }
const std::string & eos_token() const { return eos_tok; }
static json add_system(const json &, const std::string &) {
throw std::runtime_error("common_chat_template::add_system not implemented");
}
// this is just for testing. it will be removed later
struct chat_template_caps {
bool supports_tools = true;
bool supports_tool_calls = true;
bool supports_tool_responses = true;
bool supports_system_role = true;
bool supports_parallel_tool_calls = true;
bool requires_typed_content = true;
};
chat_template_caps original_caps() const {
return chat_template_caps();
}
};
struct common_chat_templates {
bool add_bos;
@ -627,14 +670,14 @@ common_chat_templates_ptr common_chat_templates_init(
tmpls->add_bos = add_bos;
tmpls->add_eos = add_eos;
try {
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
tmpls->template_default = std::make_unique<common_chat_template>(default_template_src, token_bos, token_eos);
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
tmpls->template_default = std::make_unique<common_chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
}
if (!template_tool_use_src.empty()) {
try {
tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
tmpls->template_tool_use = std::make_unique<common_chat_template>(template_tool_use_src, token_bos, token_eos);
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
}
@ -737,34 +780,40 @@ static std::string apply(
const std::optional<json> & tools_override = std::nullopt,
const std::optional<json> & additional_context = std::nullopt)
{
minja::chat_template_inputs tmpl_inputs;
tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages;
if (tools_override) {
tmpl_inputs.tools = *tools_override;
} else {
tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools;
}
tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt;
tmpl_inputs.extra_context = inputs.extra_context;
tmpl_inputs.extra_context["enable_thinking"] = inputs.enable_thinking;
if (additional_context) {
tmpl_inputs.extra_context.merge_patch(*additional_context);
}
// TODO: add flag to control date/time, if only for testing purposes.
// tmpl_inputs.now = std::chrono::system_clock::now();
// TODO IMPORTANT: IMPORVE THIS
minja::chat_template_options tmpl_opts;
// To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
// may be needed inside the template / between messages too.
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
if (inputs.add_bos && string_starts_with(result, tmpl.bos_token())) {
result = result.substr(tmpl.bos_token().size());
jinja::context ctx;
ctx.source = tmpl.source(); // for debugging
nlohmann::json inp = nlohmann::json{
{"messages", messages_override.has_value() ? *messages_override : inputs.messages},
{"tools", tools_override.has_value() ? *tools_override : inputs.tools},
};
if (additional_context.has_value()) {
// TODO: merge properly instead of overwriting
for (const auto & [k, v] : additional_context->items()) {
inp[k] = v;
}
}
if (inputs.add_eos && string_ends_with(result, tmpl.eos_token())) {
result = result.substr(0, result.size() - tmpl.eos_token().size());
if (inputs.add_generation_prompt) {
inp["add_generation_prompt"] = true;
}
return result;
if (inputs.add_bos) {
inp["bos_token"] = tmpl.bos_token();
}
if (inputs.add_eos) {
inp["eos_token"] = tmpl.eos_token();
}
// TODO: more inputs?
jinja::global_from_json(ctx, inp);
// render
jinja::vm vm(ctx);
const jinja::value results = vm.execute(tmpl.prog);
auto parts = vm.gather_string_parts(results);
return parts->as_string().str();
}
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {

View File

@ -125,6 +125,7 @@ struct expression : public statement {
struct program : public statement {
statements body;
program() = default;
explicit program(statements && body) : body(std::move(body)) {}
std::string type() const override { return "Program"; }
value execute_impl(context &) override {
@ -562,7 +563,7 @@ struct vm {
context & ctx;
explicit vm(context & ctx) : ctx(ctx) {}
value_array execute(program & prog) {
value_array execute(const program & prog) {
value_array results = mk_val<value_array>();
for (auto & stmt : prog.body) {
value res = stmt->execute(ctx);