(wip) redirect minja calls
This commit is contained in:
parent
a66e4a4f5d
commit
4b71c285db
111
common/chat.cpp
111
common/chat.cpp
|
|
@ -7,8 +7,12 @@
|
|||
#include "log.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <minja/chat-template.hpp>
|
||||
#include <minja/minja.hpp>
|
||||
// #include <minja/chat-template.hpp>
|
||||
// #include <minja/minja.hpp>
|
||||
|
||||
#include "jinja/jinja-parser.h"
|
||||
#include "jinja/jinja-value.h"
|
||||
#include "jinja/jinja-vm.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
|
|
@ -135,7 +139,46 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
|
|||
return diffs;
|
||||
}
|
||||
|
||||
typedef minja::chat_template common_chat_template;
|
||||
struct common_chat_template {
|
||||
jinja::program prog;
|
||||
std::string bos_tok;
|
||||
std::string eos_tok;
|
||||
std::string src;
|
||||
common_chat_template(const std::string & src, const std::string & bos_token, const std::string & eos_token) {
|
||||
jinja::lexer lexer;
|
||||
jinja::preprocess_options options;
|
||||
options.trim_blocks = false;
|
||||
options.lstrip_blocks = false;
|
||||
auto lexer_res = lexer.tokenize(src, options);
|
||||
prog = jinja::parse_from_tokens(lexer_res);
|
||||
|
||||
this->src = lexer_res.preprocessed_source;
|
||||
this->bos_tok = bos_token;
|
||||
this->eos_tok = eos_token;
|
||||
}
|
||||
|
||||
const std::string & source() const { return src; }
|
||||
const std::string & bos_token() const { return bos_tok; }
|
||||
const std::string & eos_token() const { return eos_tok; }
|
||||
static json add_system(const json &, const std::string &) {
|
||||
throw std::runtime_error("common_chat_template::add_system not implemented");
|
||||
}
|
||||
|
||||
|
||||
// this is just for testing. it will be removed later
|
||||
struct chat_template_caps {
|
||||
bool supports_tools = true;
|
||||
bool supports_tool_calls = true;
|
||||
bool supports_tool_responses = true;
|
||||
bool supports_system_role = true;
|
||||
bool supports_parallel_tool_calls = true;
|
||||
bool requires_typed_content = true;
|
||||
};
|
||||
chat_template_caps original_caps() const {
|
||||
return chat_template_caps();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct common_chat_templates {
|
||||
bool add_bos;
|
||||
|
|
@ -627,14 +670,14 @@ common_chat_templates_ptr common_chat_templates_init(
|
|||
tmpls->add_bos = add_bos;
|
||||
tmpls->add_eos = add_eos;
|
||||
try {
|
||||
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
|
||||
tmpls->template_default = std::make_unique<common_chat_template>(default_template_src, token_bos, token_eos);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
|
||||
tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
|
||||
tmpls->template_default = std::make_unique<common_chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
|
||||
}
|
||||
if (!template_tool_use_src.empty()) {
|
||||
try {
|
||||
tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
|
||||
tmpls->template_tool_use = std::make_unique<common_chat_template>(template_tool_use_src, token_bos, token_eos);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
|
||||
}
|
||||
|
|
@ -737,34 +780,40 @@ static std::string apply(
|
|||
const std::optional<json> & tools_override = std::nullopt,
|
||||
const std::optional<json> & additional_context = std::nullopt)
|
||||
{
|
||||
minja::chat_template_inputs tmpl_inputs;
|
||||
tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages;
|
||||
if (tools_override) {
|
||||
tmpl_inputs.tools = *tools_override;
|
||||
} else {
|
||||
tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools;
|
||||
}
|
||||
tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt;
|
||||
tmpl_inputs.extra_context = inputs.extra_context;
|
||||
tmpl_inputs.extra_context["enable_thinking"] = inputs.enable_thinking;
|
||||
if (additional_context) {
|
||||
tmpl_inputs.extra_context.merge_patch(*additional_context);
|
||||
}
|
||||
// TODO: add flag to control date/time, if only for testing purposes.
|
||||
// tmpl_inputs.now = std::chrono::system_clock::now();
|
||||
// TODO IMPORTANT: IMPORVE THIS
|
||||
|
||||
minja::chat_template_options tmpl_opts;
|
||||
// To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
|
||||
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
|
||||
// may be needed inside the template / between messages too.
|
||||
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
|
||||
if (inputs.add_bos && string_starts_with(result, tmpl.bos_token())) {
|
||||
result = result.substr(tmpl.bos_token().size());
|
||||
jinja::context ctx;
|
||||
ctx.source = tmpl.source(); // for debugging
|
||||
|
||||
nlohmann::json inp = nlohmann::json{
|
||||
{"messages", messages_override.has_value() ? *messages_override : inputs.messages},
|
||||
{"tools", tools_override.has_value() ? *tools_override : inputs.tools},
|
||||
};
|
||||
if (additional_context.has_value()) {
|
||||
// TODO: merge properly instead of overwriting
|
||||
for (const auto & [k, v] : additional_context->items()) {
|
||||
inp[k] = v;
|
||||
}
|
||||
}
|
||||
if (inputs.add_eos && string_ends_with(result, tmpl.eos_token())) {
|
||||
result = result.substr(0, result.size() - tmpl.eos_token().size());
|
||||
if (inputs.add_generation_prompt) {
|
||||
inp["add_generation_prompt"] = true;
|
||||
}
|
||||
return result;
|
||||
if (inputs.add_bos) {
|
||||
inp["bos_token"] = tmpl.bos_token();
|
||||
}
|
||||
if (inputs.add_eos) {
|
||||
inp["eos_token"] = tmpl.eos_token();
|
||||
}
|
||||
// TODO: more inputs?
|
||||
|
||||
jinja::global_from_json(ctx, inp);
|
||||
|
||||
// render
|
||||
jinja::vm vm(ctx);
|
||||
const jinja::value results = vm.execute(tmpl.prog);
|
||||
auto parts = vm.gather_string_parts(results);
|
||||
|
||||
return parts->as_string().str();
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
|
|
|
|||
|
|
@ -125,6 +125,7 @@ struct expression : public statement {
|
|||
struct program : public statement {
|
||||
statements body;
|
||||
|
||||
program() = default;
|
||||
explicit program(statements && body) : body(std::move(body)) {}
|
||||
std::string type() const override { return "Program"; }
|
||||
value execute_impl(context &) override {
|
||||
|
|
@ -562,7 +563,7 @@ struct vm {
|
|||
context & ctx;
|
||||
explicit vm(context & ctx) : ctx(ctx) {}
|
||||
|
||||
value_array execute(program & prog) {
|
||||
value_array execute(const program & prog) {
|
||||
value_array results = mk_val<value_array>();
|
||||
for (auto & stmt : prog.body) {
|
||||
value res = stmt->execute(ctx);
|
||||
|
|
|
|||
Loading…
Reference in New Issue