1299 lines
54 KiB
C++
1299 lines
54 KiB
C++
#include "chat.h"
|
|
|
|
#include "chat-auto-parser-helpers.h"
|
|
#include "chat-auto-parser.h"
|
|
#include "chat-peg-parser.h"
|
|
#include "common.h"
|
|
#include "ggml.h"
|
|
#include "json-schema-to-grammar.h"
|
|
#include "log.h"
|
|
#include "regex-partial.h"
|
|
|
|
#include "jinja/parser.h"
|
|
#include "jinja/value.h"
|
|
#include "jinja/runtime.h"
|
|
#include "jinja/caps.h"
|
|
|
|
#include <cstdio>
|
|
#include <cstdlib>
|
|
#include <exception>
|
|
#include <functional>
|
|
|
|
#include <optional>
|
|
#include <sstream>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
using json = nlohmann::ordered_json;
|
|
|
|
static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
|
|
auto time = std::chrono::system_clock::to_time_t(now);
|
|
auto local_time = *std::localtime(&time);
|
|
std::ostringstream ss;
|
|
ss << std::put_time(&local_time, format.c_str());
|
|
auto res = ss.str();
|
|
return res;
|
|
}
|
|
|
|
static json safe_args_parse(const std::string & to_parse) {
|
|
std::string stripped = to_parse;
|
|
if (to_parse.at(0) == '"' && to_parse.at(to_parse.length() - 1) == '"') {
|
|
stripped = to_parse.substr(1, to_parse.length() - 1);
|
|
}
|
|
try {
|
|
return json::parse(stripped);
|
|
} catch (json::exception & e) {
|
|
return stripped;
|
|
}
|
|
}
|
|
|
|
static std::string string_diff(const std::string & last, const std::string & current) {
|
|
if (last.empty()) {
|
|
return current;
|
|
}
|
|
if (!string_starts_with(current, last)) {
|
|
if (string_starts_with(last, current)) {
|
|
// This happens if the last generation ended on a partial stop word (not erased),
|
|
// and the current ended on a stop word (erased).
|
|
return "";
|
|
}
|
|
throw std::runtime_error("Invalid diff: '" + last + "' not found at start of '" + current + "'");
|
|
}
|
|
return current.substr(last.size());
|
|
}
|
|
|
|
static bool has_content_or_tool_calls(const common_chat_msg & msg) {
|
|
return !msg.content.empty() || !msg.tool_calls.empty();
|
|
}
|
|
|
|
json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
|
|
if (!content.empty() && !content_parts.empty()) {
|
|
throw std::runtime_error("Cannot specify both content and content_parts");
|
|
}
|
|
json jmsg {
|
|
{"role", role},
|
|
};
|
|
if (!content.empty()) {
|
|
jmsg["content"] = content;
|
|
} else if (!content_parts.empty()) {
|
|
if (concat_typed_text) {
|
|
std::string text;
|
|
for (const auto & part : content_parts) {
|
|
if (part.type != "text") {
|
|
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
|
continue;
|
|
}
|
|
if (!text.empty()) {
|
|
text += '\n';
|
|
}
|
|
text += part.text;
|
|
}
|
|
jmsg["content"] = text;
|
|
} else {
|
|
auto & parts = jmsg["content"] = json::array();
|
|
for (const auto & part : content_parts) {
|
|
parts.push_back({
|
|
{"type", part.type},
|
|
{"text", part.text},
|
|
});
|
|
}
|
|
}
|
|
} else {
|
|
jmsg["content"] = "";
|
|
}
|
|
if (!reasoning_content.empty()) {
|
|
jmsg["reasoning_content"] = reasoning_content;
|
|
}
|
|
if (!tool_name.empty()) {
|
|
jmsg["name"] = tool_name;
|
|
}
|
|
if (!tool_call_id.empty()) {
|
|
jmsg["tool_call_id"] = tool_call_id;
|
|
}
|
|
if (!tool_calls.empty()) {
|
|
jmsg["tool_calls"] = json::array();
|
|
auto & jtool_calls = jmsg["tool_calls"];
|
|
for (const auto & tool_call : tool_calls) {
|
|
json tc {
|
|
{"type", "function"},
|
|
{"function", {
|
|
{"name", tool_call.name},
|
|
{"arguments", tool_call.arguments},
|
|
}},
|
|
};
|
|
if (!tool_call.id.empty()) {
|
|
tc["id"] = tool_call.id;
|
|
}
|
|
// Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
|
|
// We only generate a random id for the ones that don't generate one by themselves
|
|
// (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
|
|
// {"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
|
|
jtool_calls.push_back(tc);
|
|
}
|
|
}
|
|
|
|
return jmsg;
|
|
}
|
|
|
|
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv,
|
|
const common_chat_msg & msg_new) {
|
|
std::vector<common_chat_msg_diff> diffs;
|
|
if (msg_new.tool_calls.size() > msg_prv.tool_calls.size()) {
|
|
diffs.reserve(msg_new.tool_calls.size() - msg_prv.tool_calls.size() + 3);
|
|
} else {
|
|
diffs.reserve(3);
|
|
}
|
|
|
|
// TODO: these can become expensive for long messages - how to optimize?
|
|
if (msg_prv.reasoning_content != msg_new.reasoning_content) {
|
|
auto & diff = diffs.emplace_back();
|
|
diff.reasoning_content_delta = string_diff(msg_prv.reasoning_content, msg_new.reasoning_content);
|
|
}
|
|
if (msg_prv.content != msg_new.content) {
|
|
auto & diff = diffs.emplace_back();
|
|
diff.content_delta = string_diff(msg_prv.content, msg_new.content);
|
|
}
|
|
|
|
if (msg_new.tool_calls.size() < msg_prv.tool_calls.size()) {
|
|
std::string err = "Invalid diff: now finding less tool calls!\n";
|
|
err += " Previous (" + std::to_string(msg_prv.tool_calls.size()) + "):\n";
|
|
for (const auto & tc : msg_prv.tool_calls) {
|
|
err += " - name: '" + tc.name + "', args: '" + tc.arguments + "'\n";
|
|
}
|
|
err += " Current (" + std::to_string(msg_new.tool_calls.size()) + "):\n";
|
|
for (const auto & tc : msg_new.tool_calls) {
|
|
err += " - name: '" + tc.name + "', args: '" + tc.arguments + "'\n";
|
|
}
|
|
err += " Current msg text content:\n" + msg_new.content + "\n";
|
|
throw std::runtime_error(err);
|
|
}
|
|
|
|
if (!msg_prv.tool_calls.empty()) {
|
|
const auto idx = msg_prv.tool_calls.size() - 1;
|
|
const auto & pref = msg_prv.tool_calls[idx];
|
|
const auto & newf = msg_new.tool_calls[idx];
|
|
// Allow tool name to change during incremental parsing:
|
|
// - empty -> non-empty (initial discovery)
|
|
// - prefix -> longer string (name grows as more input is parsed)
|
|
if (pref.name != newf.name && !pref.name.empty() && !newf.name.empty()) {
|
|
// Check if one is a prefix of the other (for incremental parsing where names grow or shrink)
|
|
bool is_prefix = (newf.name.rfind(pref.name, 0) == 0);
|
|
if (!is_prefix) {
|
|
LOG_ERR("Tool call mismatch: prev='%s' new='%s'\n", pref.name.c_str(), newf.name.c_str());
|
|
throw std::runtime_error("Invalid diff: tool call mismatch!");
|
|
}
|
|
}
|
|
const auto args_diff = string_diff(pref.arguments, newf.arguments);
|
|
if (!args_diff.empty() || pref.id != newf.id || pref.name != newf.name) {
|
|
auto & diff = diffs.emplace_back();
|
|
diff.tool_call_index = idx;
|
|
if (pref.id != newf.id || pref.name != newf.name) {
|
|
diff.tool_call_delta.id = newf.id;
|
|
diff.tool_call_delta.name = newf.name;
|
|
}
|
|
diff.tool_call_delta.arguments = args_diff;
|
|
}
|
|
}
|
|
for (size_t idx = msg_prv.tool_calls.size(); idx < msg_new.tool_calls.size(); ++idx) {
|
|
auto & diff = diffs.emplace_back();
|
|
diff.tool_call_index = idx;
|
|
diff.tool_call_delta = msg_new.tool_calls[idx];
|
|
}
|
|
|
|
return diffs;
|
|
}
|
|
|
|
using chat_template_caps = jinja::caps;
|
|
|
|
struct common_chat_templates {
|
|
bool add_bos;
|
|
bool add_eos;
|
|
bool has_explicit_template; // Model had builtin template or template overridde was specified.
|
|
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
|
|
std::unique_ptr<common_chat_template> template_tool_use;
|
|
};
|
|
|
|
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
|
if (tool_choice == "auto") {
|
|
return COMMON_CHAT_TOOL_CHOICE_AUTO;
|
|
}
|
|
if (tool_choice == "none") {
|
|
return COMMON_CHAT_TOOL_CHOICE_NONE;
|
|
}
|
|
if (tool_choice == "required") {
|
|
return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
|
}
|
|
throw std::invalid_argument("Invalid tool_choice: " + tool_choice);
|
|
}
|
|
|
|
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates) {
|
|
common_chat_templates_inputs dummy_inputs;
|
|
common_chat_msg msg;
|
|
msg.role = "user";
|
|
msg.content = "test";
|
|
dummy_inputs.messages = { msg };
|
|
dummy_inputs.enable_thinking = false;
|
|
const auto rendered_no_thinking = common_chat_templates_apply(chat_templates, dummy_inputs);
|
|
dummy_inputs.enable_thinking = true;
|
|
const auto rendered_with_thinking = common_chat_templates_apply(chat_templates, dummy_inputs);
|
|
return rendered_no_thinking.prompt != rendered_with_thinking.prompt;
|
|
}
|
|
|
|
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
|
|
std::vector<common_chat_msg> msgs;
|
|
|
|
try {
|
|
if (!messages.is_array()) {
|
|
throw std::invalid_argument("Expected 'messages' to be an array, got " + messages.dump());
|
|
}
|
|
|
|
for (const auto & message : messages) {
|
|
if (!message.is_object()) {
|
|
throw std::invalid_argument("Expected 'message' to be an object, got " + message.dump());
|
|
}
|
|
|
|
common_chat_msg msg;
|
|
if (!message.contains("role")) {
|
|
throw std::invalid_argument("Missing 'role' in message: " + message.dump());
|
|
}
|
|
msg.role = message.at("role");
|
|
|
|
auto has_content = message.contains("content");
|
|
auto has_tool_calls = message.contains("tool_calls");
|
|
if (has_content) {
|
|
const auto & content = message.at("content");
|
|
if (content.is_string()) {
|
|
msg.content = content;
|
|
} else if (content.is_array()) {
|
|
for (const auto & part : content) {
|
|
if (!part.contains("type")) {
|
|
throw std::invalid_argument("Missing content part type: " + part.dump());
|
|
}
|
|
const auto & type = part.at("type");
|
|
if (type != "text") {
|
|
throw std::invalid_argument("Unsupported content part type: " + type.dump());
|
|
}
|
|
common_chat_msg_content_part msg_part;
|
|
msg_part.type = type;
|
|
msg_part.text = part.at("text");
|
|
msg.content_parts.push_back(msg_part);
|
|
}
|
|
} else if (!content.is_null()) {
|
|
throw std::invalid_argument("Invalid 'content' type: expected string or array, got " +
|
|
content.dump() +
|
|
" (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
|
|
}
|
|
}
|
|
if (has_tool_calls) {
|
|
for (const auto & tool_call : message.at("tool_calls")) {
|
|
common_chat_tool_call tc;
|
|
if (!tool_call.contains("type")) {
|
|
throw std::invalid_argument("Missing tool call type: " + tool_call.dump());
|
|
}
|
|
const auto & type = tool_call.at("type");
|
|
if (type != "function") {
|
|
throw std::invalid_argument("Unsupported tool call type: " + tool_call.dump());
|
|
}
|
|
if (!tool_call.contains("function")) {
|
|
throw std::invalid_argument("Missing tool call function: " + tool_call.dump());
|
|
}
|
|
const auto & fc = tool_call.at("function");
|
|
if (!fc.contains("name")) {
|
|
throw std::invalid_argument("Missing tool call name: " + tool_call.dump());
|
|
}
|
|
tc.name = fc.at("name");
|
|
const auto & args = fc.at("arguments");
|
|
if (args.is_string()) {
|
|
tc.arguments = args;
|
|
} else {
|
|
tc.arguments = args.dump();
|
|
}
|
|
if (tool_call.contains("id")) {
|
|
tc.id = tool_call.at("id");
|
|
}
|
|
msg.tool_calls.push_back(tc);
|
|
}
|
|
}
|
|
if (!has_content && !has_tool_calls) {
|
|
throw std::invalid_argument(
|
|
"Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & "
|
|
"https://github.com/ggml-org/llama.cpp/issues/12279)");
|
|
}
|
|
if (message.contains("reasoning_content")) {
|
|
msg.reasoning_content = message.at("reasoning_content");
|
|
}
|
|
if (message.contains("name")) {
|
|
msg.tool_name = message.at("name");
|
|
}
|
|
if (message.contains("tool_call_id")) {
|
|
msg.tool_call_id = message.at("tool_call_id");
|
|
}
|
|
|
|
msgs.push_back(msg);
|
|
}
|
|
} catch (const std::exception & e) {
|
|
// @ngxson : disable otherwise it's bloating the API response
|
|
// printf("%s\n", std::string("; messages = ") + messages.dump(2));
|
|
throw std::runtime_error("Failed to parse messages: " + std::string(e.what()));
|
|
}
|
|
|
|
return msgs;
|
|
}
|
|
|
|
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
|
|
json messages = json::array();
|
|
for (const auto & msg : msgs) {
|
|
json jmsg = msg.to_json_oaicompat(concat_typed_text);
|
|
messages.push_back(jmsg);
|
|
}
|
|
return messages;
|
|
}
|
|
|
|
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
|
|
std::vector<common_chat_tool> result;
|
|
|
|
try {
|
|
if (!tools.is_null()) {
|
|
if (!tools.is_array()) {
|
|
throw std::invalid_argument("Expected 'tools' to be an array, got " + tools.dump());
|
|
}
|
|
for (const auto & tool : tools) {
|
|
if (!tool.contains("type")) {
|
|
throw std::invalid_argument("Missing tool type: " + tool.dump());
|
|
}
|
|
const auto & type = tool.at("type");
|
|
if (!type.is_string() || type != "function") {
|
|
throw std::invalid_argument("Unsupported tool type: " + tool.dump());
|
|
}
|
|
if (!tool.contains("function")) {
|
|
throw std::invalid_argument("Missing tool function: " + tool.dump());
|
|
}
|
|
|
|
const auto & function = tool.at("function");
|
|
result.push_back({
|
|
/* .name = */ function.at("name"),
|
|
/* .description = */ function.value("description", ""),
|
|
/* .parameters = */ function.value("parameters", json::object()).dump(),
|
|
});
|
|
}
|
|
}
|
|
} catch (const std::exception & e) {
|
|
throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2));
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
|
|
if (tools.empty()) {
|
|
return json();
|
|
}
|
|
|
|
auto result = json::array();
|
|
for (const auto & tool : tools) {
|
|
result.push_back({
|
|
{ "type", "function" },
|
|
{ "function",
|
|
{
|
|
{ "name", tool.name },
|
|
{ "description", tool.description },
|
|
{ "parameters", json::parse(tool.parameters) },
|
|
} },
|
|
});
|
|
}
|
|
return result;
|
|
}
|
|
|
|
json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
|
|
json delta = json::object();
|
|
if (!diff.reasoning_content_delta.empty()) {
|
|
delta["reasoning_content"] = diff.reasoning_content_delta;
|
|
}
|
|
if (!diff.content_delta.empty()) {
|
|
delta["content"] = diff.content_delta;
|
|
}
|
|
if (diff.tool_call_index != std::string::npos) {
|
|
json tool_call;
|
|
tool_call["index"] = diff.tool_call_index;
|
|
if (!diff.tool_call_delta.id.empty()) {
|
|
tool_call["id"] = diff.tool_call_delta.id;
|
|
tool_call["type"] = "function";
|
|
}
|
|
if (!diff.tool_call_delta.name.empty() || !diff.tool_call_delta.arguments.empty()) {
|
|
json function = json::object();
|
|
if (!diff.tool_call_delta.name.empty()) {
|
|
function["name"] = diff.tool_call_delta.name;
|
|
}
|
|
if (!diff.tool_call_delta.arguments.empty()) {
|
|
function["arguments"] = diff.tool_call_delta.arguments;
|
|
}
|
|
tool_call["function"] = function;
|
|
}
|
|
delta["tool_calls"] = json::array({ tool_call });
|
|
}
|
|
return delta;
|
|
}
|
|
|
|
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
|
if (use_jinja) {
|
|
try {
|
|
common_chat_msg msg;
|
|
msg.role = "user";
|
|
msg.content = "test";
|
|
|
|
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
|
|
|
|
common_chat_templates_inputs inputs;
|
|
inputs.messages = { msg };
|
|
|
|
common_chat_templates_apply(tmpls.get(), inputs);
|
|
return true;
|
|
} catch (const std::exception & e) {
|
|
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
|
return false;
|
|
}
|
|
}
|
|
llama_chat_message chat[] = {
|
|
{ "user", "test" }
|
|
};
|
|
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
|
|
return res >= 0;
|
|
}
|
|
|
|
std::string common_chat_format_single(const struct common_chat_templates * tmpls,
|
|
const std::vector<common_chat_msg> & past_msg,
|
|
const common_chat_msg & new_msg,
|
|
bool add_ass,
|
|
bool use_jinja) {
|
|
common_chat_templates_inputs inputs;
|
|
inputs.use_jinja = use_jinja;
|
|
inputs.add_bos = tmpls->add_bos;
|
|
inputs.add_eos = tmpls->add_eos;
|
|
|
|
std::string fmt_past_msg;
|
|
if (!past_msg.empty()) {
|
|
inputs.messages = past_msg;
|
|
inputs.add_generation_prompt = false;
|
|
fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt;
|
|
}
|
|
std::ostringstream ss;
|
|
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
|
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
|
ss << "\n";
|
|
};
|
|
// format chat with new_msg
|
|
inputs.messages.push_back(new_msg);
|
|
inputs.add_generation_prompt = add_ass;
|
|
auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
|
|
// get the diff part
|
|
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
|
return ss.str();
|
|
}
|
|
|
|
std::string common_chat_format_example(const struct common_chat_templates * tmpls,
|
|
bool use_jinja,
|
|
const std::map<std::string, std::string> & chat_template_kwargs) {
|
|
common_chat_templates_inputs inputs;
|
|
inputs.use_jinja = use_jinja;
|
|
inputs.add_bos = tmpls->add_bos;
|
|
inputs.add_eos = tmpls->add_eos;
|
|
inputs.chat_template_kwargs = chat_template_kwargs;
|
|
auto add_simple_msg = [&](auto role, auto content) {
|
|
common_chat_msg msg;
|
|
msg.role = role;
|
|
msg.content = content;
|
|
inputs.messages.push_back(msg);
|
|
};
|
|
add_simple_msg("system", "You are a helpful assistant");
|
|
add_simple_msg("user", "Hello");
|
|
add_simple_msg("assistant", "Hi there");
|
|
add_simple_msg("user", "How are you?");
|
|
return common_chat_templates_apply(tmpls, inputs).prompt;
|
|
}
|
|
|
|
#define CHATML_TEMPLATE_SRC \
|
|
"{%- for message in messages -%}\n" \
|
|
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
|
|
"{%- endfor -%}\n" \
|
|
"{%- if add_generation_prompt -%}\n" \
|
|
" {{- '<|im_start|>assistant\n' -}}\n" \
|
|
"{%- endif -%}"
|
|
|
|
void common_chat_templates_free(struct common_chat_templates * tmpls) {
|
|
delete tmpls;
|
|
}
|
|
|
|
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) {
|
|
return tmpls->has_explicit_template;
|
|
}
|
|
|
|
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant) {
|
|
if (!variant.empty()) {
|
|
if (variant == "tool_use") {
|
|
if (tmpls->template_tool_use) {
|
|
return tmpls->template_tool_use->source();
|
|
}
|
|
return "";
|
|
}
|
|
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant.c_str());
|
|
}
|
|
return tmpls->template_default->source();
|
|
}
|
|
|
|
common_chat_templates_ptr common_chat_templates_init(const struct llama_model * model,
|
|
const std::string & chat_template_override,
|
|
const std::string & bos_token_override,
|
|
const std::string & eos_token_override) {
|
|
std::string default_template_src;
|
|
std::string template_tool_use_src;
|
|
|
|
bool has_explicit_template = !chat_template_override.empty();
|
|
if (chat_template_override.empty()) {
|
|
GGML_ASSERT(model != nullptr);
|
|
const auto * str = llama_model_chat_template(model, /* name */ nullptr);
|
|
if (str) {
|
|
default_template_src = str;
|
|
has_explicit_template = true;
|
|
}
|
|
str = llama_model_chat_template(model, /* name */ "tool_use");
|
|
if (str) {
|
|
template_tool_use_src = str;
|
|
has_explicit_template = true;
|
|
}
|
|
} else {
|
|
default_template_src = chat_template_override;
|
|
}
|
|
if (default_template_src.empty() || default_template_src == "chatml") {
|
|
if (!template_tool_use_src.empty()) {
|
|
default_template_src = template_tool_use_src;
|
|
} else {
|
|
default_template_src = CHATML_TEMPLATE_SRC;
|
|
}
|
|
}
|
|
|
|
// TODO @ngxson : this is a temporary hack to prevent chat template from throwing an error
|
|
// Ref: https://github.com/ggml-org/llama.cpp/pull/15230#issuecomment-3173959633
|
|
if (default_template_src.find("<|channel|>") != std::string::npos
|
|
// search for the error message and patch it
|
|
&& default_template_src.find("in message.content or") != std::string::npos) {
|
|
string_replace_all(default_template_src,
|
|
"{%- if \"<|channel|>analysis<|message|>\" in message.content or "
|
|
"\"<|channel|>final<|message|>\" in message.content %}",
|
|
"{%- if false %}");
|
|
}
|
|
|
|
// TODO @aldehir : this is a temporary fix, pending Minja changes
|
|
// Ref: https://github.com/ggml-org/llama.cpp/pull/17713#issuecomment-3631342664
|
|
if (default_template_src.find("[TOOL_CALLS]") != std::string::npos
|
|
// search for the error message and patch it
|
|
&& default_template_src.find("if (message['content'] is none or") != std::string::npos) {
|
|
string_replace_all(default_template_src,
|
|
"{%- if (message['content'] is none or message['content'] == '' or "
|
|
"message['content']|length == 0) and (message['tool_calls'] is not defined or "
|
|
"message['tool_calls'] is none or message['tool_calls']|length == 0) %}",
|
|
"{%- if false %}");
|
|
}
|
|
|
|
std::string token_bos = bos_token_override;
|
|
std::string token_eos = eos_token_override;
|
|
bool add_bos = false;
|
|
bool add_eos = false;
|
|
if (model) {
|
|
const auto * vocab = llama_model_get_vocab(model);
|
|
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
|
|
if (token == LLAMA_TOKEN_NULL) {
|
|
if (default_template_src.find(jinja_variable_name) != std::string::npos ||
|
|
template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
|
|
LOG_WRN(
|
|
"common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't "
|
|
"work as intended.\n",
|
|
name);
|
|
}
|
|
return std::string();
|
|
}
|
|
return common_token_to_piece(vocab, token, true);
|
|
};
|
|
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
|
|
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
|
|
add_bos = llama_vocab_get_add_bos(vocab);
|
|
add_eos = llama_vocab_get_add_eos(vocab);
|
|
}
|
|
common_chat_templates_ptr tmpls(new common_chat_templates());
|
|
tmpls->has_explicit_template = has_explicit_template;
|
|
tmpls->add_bos = add_bos;
|
|
tmpls->add_eos = add_eos;
|
|
try {
|
|
tmpls->template_default = std::make_unique<common_chat_template>(default_template_src, token_bos, token_eos);
|
|
} catch (const std::exception & e) {
|
|
LOG_ERR("%s: error: %s\n", __func__, e.what());
|
|
LOG_ERR("%s: failed to initialize chat template\n", __func__);
|
|
LOG_ERR("%s: please consider disabling jinja via --no-jinja, or using another chat template\n", __func__);
|
|
throw e;
|
|
}
|
|
if (!template_tool_use_src.empty()) {
|
|
try {
|
|
tmpls->template_tool_use = std::make_unique<common_chat_template>(template_tool_use_src, token_bos, token_eos);
|
|
} catch (const std::exception & e) {
|
|
LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
|
|
}
|
|
}
|
|
return tmpls;
|
|
}
|
|
|
|
const char * common_chat_format_name(common_chat_format format) {
|
|
switch (format) {
|
|
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
|
|
return "Content-only";
|
|
case COMMON_CHAT_FORMAT_PEG_SIMPLE:
|
|
return "peg-simple";
|
|
case COMMON_CHAT_FORMAT_PEG_NATIVE:
|
|
return "peg-native";
|
|
default:
|
|
throw std::runtime_error("Unknown chat format");
|
|
}
|
|
}
|
|
|
|
const char * common_reasoning_format_name(common_reasoning_format format) {
|
|
switch (format) {
|
|
case COMMON_REASONING_FORMAT_NONE:
|
|
return "none";
|
|
case COMMON_REASONING_FORMAT_AUTO:
|
|
return "auto";
|
|
case COMMON_REASONING_FORMAT_DEEPSEEK:
|
|
return "deepseek";
|
|
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY:
|
|
return "deepseek-legacy";
|
|
default:
|
|
throw std::runtime_error("Unknown reasoning format");
|
|
}
|
|
}
|
|
|
|
common_reasoning_format common_reasoning_format_from_name(const std::string & format) {
|
|
if (format == "none") {
|
|
return COMMON_REASONING_FORMAT_NONE;
|
|
}
|
|
if (format == "auto") {
|
|
return COMMON_REASONING_FORMAT_AUTO;
|
|
}
|
|
if (format == "deepseek") {
|
|
return COMMON_REASONING_FORMAT_DEEPSEEK;
|
|
}
|
|
if (format == "deepseek-legacy") {
|
|
return COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY;
|
|
}
|
|
throw std::runtime_error("Unknown reasoning format: " + format);
|
|
}
|
|
|
|
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
|
|
for (const auto & tool : tools) {
|
|
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
|
|
LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
|
|
continue;
|
|
}
|
|
fn(tool);
|
|
}
|
|
}
|
|
|
|
static void foreach_parameter(const json & function,
|
|
const std::function<void(const std::string &, const json &, bool)> & fn) {
|
|
if (!function.contains("parameters") || !function.at("parameters").is_object()) {
|
|
return;
|
|
}
|
|
const auto & params = function.at("parameters");
|
|
if (!params.contains("properties") || !params.at("properties").is_object()) {
|
|
return;
|
|
}
|
|
const auto & props = params.at("properties");
|
|
std::set<std::string> required;
|
|
if (params.contains("required") && params.at("required").is_array()) {
|
|
params.at("required").get_to(required);
|
|
}
|
|
for (const auto & [name, prop] : props.items()) {
|
|
bool is_required = (required.find(name) != required.end());
|
|
fn(name, prop, is_required);
|
|
}
|
|
}
|
|
|
|
std::string common_chat_template_direct_apply(
|
|
const common_chat_template & tmpl,
|
|
const struct templates_params & inputs,
|
|
const std::optional<json> & messages_override,
|
|
const std::optional<json> & tools_override,
|
|
const std::optional<json> & additional_context) {
|
|
jinja::context ctx(tmpl.source());
|
|
|
|
nlohmann::ordered_json inp = nlohmann::ordered_json{
|
|
{"messages", messages_override.has_value() ? *messages_override : inputs.messages},
|
|
{"bos_token", tmpl.bos_token()},
|
|
{"eos_token", tmpl.eos_token()},
|
|
{"enable_thinking", inputs.enable_thinking},
|
|
};
|
|
if (tools_override.has_value() || !inputs.tools.empty()) {
|
|
inp["tools"] = tools_override.has_value() ? *tools_override : inputs.tools;
|
|
}
|
|
if (inputs.extra_context.is_object()) {
|
|
// TODO: do we need to merge, or replacing is fine?
|
|
for (const auto & [k, v] : inputs.extra_context.items()) {
|
|
inp[k] = v;
|
|
}
|
|
}
|
|
if (additional_context.has_value()) {
|
|
// TODO: merge properly instead of overwriting (matching old behavior)
|
|
for (const auto & [k, v] : additional_context->items()) {
|
|
inp[k] = v;
|
|
}
|
|
}
|
|
if (inputs.add_generation_prompt) {
|
|
inp["add_generation_prompt"] = true;
|
|
}
|
|
|
|
jinja::global_from_json(ctx, inp, inputs.mark_input);
|
|
|
|
// render
|
|
jinja::runtime runtime(ctx);
|
|
const jinja::value results = runtime.execute(tmpl.prog);
|
|
auto parts = jinja::runtime::gather_string_parts(results);
|
|
|
|
std::string result = parts->as_string().str();
|
|
|
|
// TODO: improve this later
|
|
if (inputs.add_bos && string_starts_with(result, tmpl.bos_token())) {
|
|
result = result.substr(tmpl.bos_token().size());
|
|
}
|
|
if (inputs.add_eos && string_ends_with(result, tmpl.eos_token())) {
|
|
result = result.substr(0, result.size() - tmpl.eos_token().size());
|
|
}
|
|
return result;
|
|
}
|
|
|
|
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
|
|
const struct templates_params & inputs) {
|
|
common_chat_params data;
|
|
|
|
// Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja
|
|
auto adjusted_messages = json::array();
|
|
for (const auto & msg : inputs.messages) {
|
|
auto role = msg.value("role", "");
|
|
if (role != "system" && role != "assistant") {
|
|
// Only adjust system and assistant messages. Interestingly, the system message may contain thinking.
|
|
adjusted_messages.push_back(msg);
|
|
continue;
|
|
}
|
|
|
|
auto content = json::array();
|
|
|
|
// If message contains `reasoning_content`, add it as a block of type `thinking`
|
|
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
|
|
content.push_back({
|
|
{ "type", "thinking" },
|
|
{ "thinking", msg.at("reasoning_content").get<std::string>() },
|
|
});
|
|
}
|
|
|
|
// If message contains `content`, add it as a block of type `text`
|
|
if (msg.contains("content")) {
|
|
if (msg.at("content").is_string()) {
|
|
content.push_back({
|
|
{ "type", "text" },
|
|
{ "text", msg.at("content").get<std::string>() },
|
|
});
|
|
} else if (msg.at("content").is_array()) {
|
|
auto blocks = msg.at("content");
|
|
content.insert(content.end(), blocks.begin(), blocks.end());
|
|
}
|
|
}
|
|
|
|
auto adjusted = msg;
|
|
adjusted["content"] = content;
|
|
adjusted.erase("reasoning_content");
|
|
adjusted_messages.push_back(adjusted);
|
|
}
|
|
|
|
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
|
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
|
auto include_grammar = true;
|
|
|
|
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
|
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
|
data.preserved_tokens = {
|
|
"[THINK]",
|
|
"[/THINK]",
|
|
"[TOOL_CALLS]",
|
|
"[ARGS]",
|
|
};
|
|
|
|
auto parser = build_chat_peg_unified_parser([&](common_chat_peg_unified_builder & p) {
|
|
auto reasoning =
|
|
extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps();
|
|
|
|
// Response format parser
|
|
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
|
|
// Ministral wants to emit json surrounded by code fences
|
|
return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema))
|
|
<< "```";
|
|
}
|
|
|
|
// Tool call parser
|
|
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
|
auto tool_choice = p.choice();
|
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
const auto & function = tool.at("function");
|
|
std::string name = function.at("name");
|
|
const auto & schema = function.at("parameters");
|
|
|
|
tool_choice |=
|
|
p.rule("tool-" + name, p.tool_open(p.tool_name(p.literal(name)) + "[ARGS]") +
|
|
p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)));
|
|
});
|
|
|
|
auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0;
|
|
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
|
|
auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls));
|
|
|
|
return reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls;
|
|
}
|
|
|
|
// Content only parser
|
|
include_grammar = false;
|
|
return reasoning << p.content(p.rest());
|
|
});
|
|
|
|
data.parser = parser.save();
|
|
|
|
if (include_grammar) {
|
|
data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
|
|
|
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
const auto & function = tool.at("function");
|
|
auto schema = function.at("parameters");
|
|
builder.resolve_refs(schema);
|
|
});
|
|
parser.build_grammar(builder, data.grammar_lazy);
|
|
});
|
|
|
|
data.grammar_triggers = {
|
|
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]" }
|
|
};
|
|
}
|
|
|
|
return data;
|
|
}
|
|
|
|
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl,
|
|
const struct templates_params & inputs) {
|
|
common_chat_params data;
|
|
|
|
// Copy reasoning to the "thinking" field as expected by the gpt-oss template
|
|
auto adjusted_messages = json::array();
|
|
for (const auto & msg : inputs.messages) {
|
|
auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
|
|
auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
|
|
|
|
if (has_reasoning_content && has_tool_calls) {
|
|
auto adjusted_message = msg;
|
|
adjusted_message["thinking"] = msg.at("reasoning_content");
|
|
adjusted_messages.push_back(adjusted_message);
|
|
} else {
|
|
adjusted_messages.push_back(msg);
|
|
}
|
|
}
|
|
|
|
auto prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
|
|
|
// Check if we need to replace the return token with end token during
|
|
// inference and without generation prompt. For more details see:
|
|
// https://github.com/ggml-org/llama.cpp/issues/15417
|
|
if (inputs.is_inference && !inputs.add_generation_prompt) {
|
|
static constexpr std::string_view return_token = "<|return|>";
|
|
static constexpr std::string_view end_token = "<|end|>";
|
|
if (size_t pos = prompt.rfind(return_token); pos != std::string::npos) {
|
|
prompt.replace(pos, return_token.length(), end_token);
|
|
}
|
|
}
|
|
|
|
data.prompt = prompt;
|
|
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
|
|
|
// These special tokens are required to parse properly, so we include them
|
|
// even if parse_tool_calls is false.
|
|
data.preserved_tokens = {
|
|
"<|channel|>", "<|constrain|>", "<|message|>", "<|start|>", "<|end|>",
|
|
};
|
|
|
|
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
|
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
|
auto include_grammar = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && has_tools;
|
|
|
|
auto parser = build_chat_peg_unified_parser([&](common_chat_peg_unified_builder & p) {
|
|
const std::string END = "<|end|>";
|
|
const std::string START = "<|start|>";
|
|
const std::string MESSAGE = "<|message|>";
|
|
const std::string CHANNEL = "<|channel|>";
|
|
const std::string CONSTRAIN = "<|constrain|>";
|
|
const std::string START_ASSISTANT = START + "assistant";
|
|
const std::string CHANNEL_ANALYSIS = CHANNEL + "analysis";
|
|
const std::string CHANNEL_COMMENTARY = CHANNEL + "commentary";
|
|
const std::string CHANNEL_FINAL = CHANNEL + "final";
|
|
|
|
auto the_end = END | p.end();
|
|
|
|
const std::string analysis_header = CHANNEL_ANALYSIS + MESSAGE;
|
|
auto segment_content = p.until(END);
|
|
auto analysis_segment = extract_reasoning ?
|
|
p.literal(analysis_header) + p.reasoning(segment_content) + p.until(END) + the_end :
|
|
p.content(analysis_header + p.until(END) + the_end);
|
|
|
|
auto channel_header_content = p.until_one_of({ " to=functions.", MESSAGE });
|
|
auto content_header = p.choice({ p.literal(CHANNEL_COMMENTARY), p.literal(CHANNEL_FINAL) });
|
|
auto content_segment = p.rule("content-segment", content_header + channel_header_content + MESSAGE +
|
|
p.content(segment_content) + the_end);
|
|
|
|
if (!inputs.json_schema.is_null()) {
|
|
auto final_header = p.literal(CHANNEL_FINAL);
|
|
auto constraint = p.optional(p.space() + p.literal(CONSTRAIN) + channel_header_content);
|
|
return p.optional(analysis_segment) + final_header + constraint + MESSAGE +
|
|
p.content(p.schema(p.json(), "response-format", inputs.json_schema));
|
|
}
|
|
|
|
auto segment = p.optional(START_ASSISTANT + p.space()) + p.choice({ content_segment, analysis_segment });
|
|
auto contents = p.optional(segment + p.repeat(p.optional(p.space()) + segment, 0, -1)) + p.end();
|
|
|
|
// Tool call parser
|
|
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
|
auto tool_choice = p.choice();
|
|
|
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
const auto & function = tool.at("function");
|
|
std::string name = function.at("name");
|
|
const auto & params = function.at("parameters");
|
|
|
|
// Tool call can appear as:
|
|
// 1. In role header: " to=functions.NAME<|channel|>..."
|
|
// 2. In channel: "<|channel|>(analysis|commentary) to=functions.NAME..."
|
|
auto func_name = p.literal(" to=functions.") + p.tool_name(p.literal(name));
|
|
|
|
auto channel = p.literal(CHANNEL_COMMENTARY) | p.literal(CHANNEL_ANALYSIS);
|
|
auto constraint = p.space() + p.optional(p.literal(CONSTRAIN) + channel_header_content);
|
|
auto args = p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", params));
|
|
|
|
// Pattern 1: recipient in role header
|
|
// " to=functions.NAME<|channel|>(analysis|commentary)[constraint]<|message|>ARGS"
|
|
auto tool_in_role = p.tool(p.tool_open(func_name + channel) + constraint + MESSAGE + args);
|
|
|
|
// Pattern 2: recipient in channel header
|
|
// "<|channel|>(analysis|commentary) to=functions.NAME[constraint]<|message|>ARGS"
|
|
|
|
auto tool_in_channel = p.tool(channel + p.tool_open(func_name + constraint + MESSAGE) + args);
|
|
|
|
tool_choice |= p.trigger_rule("tool-" + name, tool_in_role | tool_in_channel);
|
|
});
|
|
|
|
auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0;
|
|
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
|
|
|
|
auto role_start = p.optional(p.space() + p.literal(START_ASSISTANT));
|
|
auto tool_call = p.rule("tool-call", p.repeat(role_start + tool_choice, min_calls, max_calls) + p.end());
|
|
|
|
return p.choice({ tool_call, p.one_or_more(segment) + tool_call });
|
|
}
|
|
|
|
return contents;
|
|
});
|
|
|
|
data.parser = parser.save();
|
|
|
|
if (include_grammar) {
|
|
data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
|
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
|
foreach_function(inputs.tools, [&](const json & tool) {
|
|
const auto & function = tool.at("function");
|
|
auto schema = function.at("parameters");
|
|
builder.resolve_refs(schema);
|
|
});
|
|
parser.build_grammar(builder, data.grammar_lazy);
|
|
});
|
|
|
|
data.grammar_triggers = {
|
|
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "^(?:<\\|start\\|>assistant\\s*)?(\\s+to=functions)" },
|
|
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "(?:<\\|end\\|>)(?:<\\|start\\|>assistant\\s*)?(\\s+to=functions)" },
|
|
{ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
|
"(?:<\\|start\\|>assistant\\s*)?(<\\|channel\\|>(?:commentary|analysis)\\s+to=functions)" }
|
|
};
|
|
}
|
|
|
|
return data;
|
|
}
|
|
|
|
namespace workaround {
|
|
|
|
// if first message is system and template does not support it, merge it with next message
|
|
static void system_message_not_supported(json & messages) {
|
|
if (!messages.empty() && messages.front().at("role") == "system") {
|
|
if (messages.size() > 1) {
|
|
LOG_DBG("Merging system prompt into next message\n");
|
|
auto & first_msg = messages.front();
|
|
auto & second_msg = messages[1];
|
|
second_msg["content"] = first_msg.at("content").get<std::string>()
|
|
+ "\n" + second_msg.at("content").get<std::string>();
|
|
messages.erase(messages.begin());
|
|
} else {
|
|
LOG_WRN("Removing system prompt due to template not supporting system role\n");
|
|
messages.erase(messages.begin());
|
|
}
|
|
}
|
|
}
|
|
|
|
static void requires_non_null_content(json & messages) {
|
|
GGML_ASSERT(messages.is_array());
|
|
for (auto & message : messages) {
|
|
if (message.contains("tool_calls") && !message.contains("content")) {
|
|
message["content"] = "";
|
|
}
|
|
}
|
|
}
|
|
|
|
static void func_args_not_string(json & messages) {
|
|
GGML_ASSERT(messages.is_array());
|
|
for (auto & message : messages) {
|
|
if (message.contains("tool_calls")) {
|
|
for (auto & tool_call : message["tool_calls"]) {
|
|
if (tool_call.contains("function") && tool_call["function"].contains("arguments")) {
|
|
auto & args = tool_call["function"]["arguments"];
|
|
if (args.is_string()) {
|
|
try {
|
|
args = json::parse(args.get<std::string>());
|
|
} catch (const std::exception & e) {
|
|
throw std::runtime_error("Failed to parse tool call arguments as JSON: " + std::string(e.what()));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls,
|
|
const struct common_chat_templates_inputs & inputs) {
|
|
templates_params params;
|
|
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
|
|
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
|
|
? *tmpls->template_tool_use
|
|
: *tmpls->template_default;
|
|
const auto & src = tmpl.source();
|
|
const auto & caps = tmpl.original_caps();
|
|
params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
|
|
params.add_generation_prompt = inputs.add_generation_prompt;
|
|
params.tool_choice = inputs.tool_choice;
|
|
params.reasoning_format = inputs.reasoning_format;
|
|
params.enable_thinking = inputs.enable_thinking;
|
|
params.grammar = inputs.grammar;
|
|
params.now = inputs.now;
|
|
params.add_bos = tmpls->add_bos;
|
|
params.add_eos = tmpls->add_eos;
|
|
|
|
if (!tmpl.original_caps().supports_system_role) {
|
|
workaround::system_message_not_supported(params.messages);
|
|
}
|
|
|
|
if (!tmpl.original_caps().requires_non_null_content) {
|
|
workaround::requires_non_null_content(params.messages);
|
|
}
|
|
|
|
params.extra_context = json::object();
|
|
for (auto el : inputs.chat_template_kwargs) {
|
|
params.extra_context[el.first] = json::parse(el.second);
|
|
}
|
|
|
|
if (!inputs.json_schema.empty()) {
|
|
params.json_schema = json::parse(inputs.json_schema);
|
|
}
|
|
|
|
// if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
|
|
// LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
|
|
// params.parallel_tool_calls = false;
|
|
// } else {
|
|
params.parallel_tool_calls = inputs.parallel_tool_calls;
|
|
//}
|
|
|
|
if (params.tools.is_array()) {
|
|
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
|
|
throw std::runtime_error("Cannot specify grammar with tools");
|
|
}
|
|
if (caps.supports_tool_calls && !caps.supports_tools) {
|
|
LOG_WRN(
|
|
"Template supports tool calls but does not natively describe tools. The fallback behaviour used may "
|
|
"produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
|
|
}
|
|
}
|
|
|
|
// Ministral/Mistral Large 3 - uses special reasoning structure fixes, can't use autoparser
|
|
// Note: Mistral Small 3.2 uses [CALL_ID] which Ministral doesn't have, so we can distinguish them
|
|
if (src.find("[SYSTEM_PROMPT]") != std::string::npos && src.find("[TOOL_CALLS]") != std::string::npos &&
|
|
src.find("[ARGS]") != std::string::npos && src.find("[CALL_ID]") == std::string::npos) {
|
|
LOG_INF("Using specialized template: Ministral/Magistral Large 3\n");
|
|
return common_chat_params_init_ministral_3(tmpl, params);
|
|
}
|
|
|
|
// GPT-OSS - has unique channel-based structure that needs dedicated handler
|
|
if (src.find("<|channel|>") != std::string::npos) {
|
|
LOG_INF("Using specialized template: GPT-OSS\n");
|
|
return common_chat_params_init_gpt_oss(tmpl, params);
|
|
}
|
|
|
|
try {
|
|
LOG_INF("Using autoparser for template analysis\n");
|
|
template_analysis_result analysis = template_analyzer::analyze_template(tmpl);
|
|
auto auto_params = universal_peg_generator::generate_parser(analysis, tmpl, params);
|
|
return auto_params;
|
|
} catch (const std::exception & e) {
|
|
LOG_WRN("Automatic parser generation failed: %s\n", e.what());
|
|
}
|
|
|
|
GGML_ABORT("Unable to generate parser for this template.");
|
|
}
|
|
|
|
// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template.
|
|
static common_chat_params common_chat_templates_apply_legacy(const struct common_chat_templates * tmpls,
|
|
const struct common_chat_templates_inputs & inputs) {
|
|
size_t alloc_size = 0;
|
|
std::vector<llama_chat_message> chat;
|
|
std::vector<std::string> contents;
|
|
|
|
for (const auto & msg : inputs.messages) {
|
|
auto content = msg.content;
|
|
for (const auto & part : msg.content_parts) {
|
|
if (part.type != "text") {
|
|
LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
|
|
continue;
|
|
}
|
|
if (!content.empty()) {
|
|
content += "\n";
|
|
;
|
|
}
|
|
content += part.text;
|
|
}
|
|
contents.emplace_back(std::move(content));
|
|
}
|
|
for (size_t i = 0; i < contents.size(); ++i) {
|
|
const auto & msg = inputs.messages[i];
|
|
const auto & content = contents[i];
|
|
chat.push_back({ msg.role.c_str(), content.c_str() });
|
|
size_t msg_size = msg.role.size() + content.size();
|
|
alloc_size += msg_size + (msg_size / 4); // == msg_size * 1.25 but avoiding float ops
|
|
}
|
|
|
|
std::vector<char> buf(alloc_size);
|
|
|
|
// run the first time to get the total output length
|
|
const auto & src = tmpls->template_default->source();
|
|
int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt,
|
|
buf.data(), buf.size());
|
|
|
|
// error: chat template is not supported
|
|
if (res < 0) {
|
|
// if the custom "tmpl" is not supported, we throw an error
|
|
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
|
throw std::runtime_error("this custom template is not supported, try using --jinja");
|
|
}
|
|
|
|
// if it turns out that our buffer is too small, we resize it
|
|
if ((size_t) res > buf.size()) {
|
|
buf.resize(res);
|
|
res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(),
|
|
buf.size());
|
|
}
|
|
|
|
// for safety, we check the result again
|
|
if (res < 0 || (size_t) res > buf.size()) {
|
|
throw std::runtime_error("failed to apply chat template, try using --jinja");
|
|
}
|
|
|
|
common_chat_params params;
|
|
params.prompt = std::string(buf.data(), res);
|
|
if (!inputs.json_schema.empty()) {
|
|
params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema));
|
|
} else {
|
|
params.grammar = inputs.grammar;
|
|
}
|
|
return params;
|
|
}
|
|
|
|
common_chat_params common_chat_templates_apply(const struct common_chat_templates * tmpls,
|
|
const struct common_chat_templates_inputs & inputs) {
|
|
GGML_ASSERT(tmpls != nullptr);
|
|
return inputs.use_jinja ? common_chat_templates_apply_jinja(tmpls, inputs) :
|
|
common_chat_templates_apply_legacy(tmpls, inputs);
|
|
}
|
|
|
|
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
|
|
return common_chat_peg_parse(syntax.parser, input, is_partial, syntax);
|
|
}
|
|
|
|
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser,
|
|
const std::string & input,
|
|
bool is_partial,
|
|
const common_chat_syntax & syntax) {
|
|
if (parser.empty()) {
|
|
throw std::runtime_error("Failed to parse due to missing parser definition.");
|
|
}
|
|
|
|
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(syntax.format), input.c_str());
|
|
|
|
common_peg_parse_context ctx(input, is_partial);
|
|
ctx.debug = syntax.debug;
|
|
auto result = parser.parse(ctx);
|
|
|
|
if (result.fail()) {
|
|
// During partial parsing, return partial results if any AST nodes were captured
|
|
// This allows streaming to work correctly for formats like FUNC_MARKDOWN_CODE_BLOCK
|
|
if (is_partial && result.end > 0) {
|
|
// Try to extract any partial results from what was successfully parsed
|
|
common_chat_msg msg;
|
|
msg.role = "assistant";
|
|
if (syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE) {
|
|
auto mapper = common_chat_peg_unified_mapper(msg);
|
|
mapper.from_ast(ctx.ast, result);
|
|
} else {
|
|
auto mapper = common_chat_peg_mapper(msg);
|
|
mapper.from_ast(ctx.ast, result);
|
|
}
|
|
if (ctx.debug) {
|
|
fprintf(stderr, "\nAST for partial parse (fail):\n%s\n", ctx.ast.dump().c_str());
|
|
fflush(stderr);
|
|
}
|
|
return msg;
|
|
}
|
|
throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end) + ": " +
|
|
input.substr(result.end));
|
|
}
|
|
|
|
common_chat_msg msg;
|
|
msg.role = "assistant";
|
|
|
|
if (syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE) {
|
|
auto mapper = common_chat_peg_unified_mapper(msg);
|
|
mapper.from_ast(ctx.ast, result);
|
|
} else {
|
|
// Generic mapper
|
|
auto mapper = common_chat_peg_mapper(msg);
|
|
mapper.from_ast(ctx.ast, result);
|
|
}
|
|
if (ctx.debug) {
|
|
fprintf(stderr, "\nAST for %s parse:\n%s\n", is_partial ? "partial" : "full", ctx.ast.dump().c_str());
|
|
fflush(stderr);
|
|
}
|
|
|
|
if (!is_partial) {
|
|
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({ msg }).at(0).dump().c_str());
|
|
}
|
|
return msg;
|
|
}
|
|
|
|
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates) {
|
|
GGML_ASSERT(chat_templates != nullptr);
|
|
GGML_ASSERT(chat_templates->template_default != nullptr);
|
|
return chat_templates->template_default->caps.to_map();
|
|
}
|