diff --git a/common/chat-auto-parser-helpers.h b/common/chat-auto-parser-helpers.h index 7cd031c4d6..b8804ac191 100644 --- a/common/chat-auto-parser-helpers.h +++ b/common/chat-auto-parser-helpers.h @@ -1,7 +1,7 @@ #pragma once #include "chat-auto-parser.h" -#include "peg-parser.h" + #include #include #include diff --git a/common/chat-auto-parser.h b/common/chat-auto-parser.h index 9d7d4e69e6..8886c330dd 100644 --- a/common/chat-auto-parser.h +++ b/common/chat-auto-parser.h @@ -4,6 +4,7 @@ #include "common.h" #include "jinja/caps.h" #include "peg-parser.h" +#include "nlohmann/json.hpp" #include #include diff --git a/common/chat.cpp b/common/chat.cpp index 9cd2dd7076..c0670496b3 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -13,6 +13,8 @@ #include "jinja/caps.h" #include "peg-parser.h" +#include "nlohmann/json.hpp" + #include #include #include @@ -762,12 +764,12 @@ static void foreach_parameter(const json & } } -std::string common_chat_template_direct_apply( +static std::string common_chat_template_direct_apply_impl( const common_chat_template & tmpl, const autoparser::generation_params & inputs, - const std::optional & messages_override, - const std::optional & tools_override, - const std::optional & additional_context) { + const std::optional & messages_override = std::nullopt, + const std::optional & tools_override = std::nullopt, + const std::optional & additional_context = std::nullopt) { jinja::context ctx(tmpl.source()); nlohmann::ordered_json inp = nlohmann::ordered_json{ @@ -814,6 +816,12 @@ std::string common_chat_template_direct_apply( return result; } +std::string common_chat_template_direct_apply( + const common_chat_template & tmpl, + const autoparser::generation_params & inputs) { + return common_chat_template_direct_apply_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt); +} + static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl, const autoparser::generation_params & inputs) { common_chat_params data; @@ -864,7 +872,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ data.supports_thinking = true; data.thinking_start_tag = "[THINK]"; data.thinking_end_tag = "[/THINK]"; - data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages); + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override = */ adjusted_messages); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.preserved_tokens = { "[THINK]", @@ -947,7 +955,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp adjusted_messages.push_back(msg); } - auto prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override= */ adjusted_messages); + auto prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override= */ adjusted_messages); // Check if we need to replace the return token with end token during // inference and without generation prompt. For more details see: @@ -1074,7 +1082,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ const autoparser::generation_params & inputs) { common_chat_params data; - data.prompt = common_chat_template_direct_apply(tmpl, inputs); + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.preserved_tokens = { ">>>all", @@ -1168,7 +1176,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp const autoparser::generation_params & inputs) { common_chat_params data; - data.prompt = common_chat_template_direct_apply(tmpl, inputs); + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; data.preserved_tokens = { @@ -1291,7 +1299,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat const autoparser::generation_params & inputs) { common_chat_params data; - data.prompt = common_chat_template_direct_apply(tmpl, inputs); + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; data.preserved_tokens = { @@ -1370,7 +1378,7 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ const autoparser::generation_params & inputs) { common_chat_params data; - data.prompt = common_chat_template_direct_apply(tmpl, inputs); + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; data.preserved_tokens = { @@ -1441,7 +1449,7 @@ static common_chat_params common_chat_params_init_gigachat_v3( common_chat_params data; - data.prompt = common_chat_template_direct_apply(tmpl, inputs); + data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = false; data.preserved_tokens = { @@ -1724,9 +1732,9 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ } params.add_generation_prompt = false; - std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params); + std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params); params.add_generation_prompt = true; - std::string gen_prompt = common_chat_template_direct_apply(tmpl, params); + std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params); auto diff = calculate_diff_split(no_gen_prompt, gen_prompt); params.generation_prompt = diff.right; @@ -1760,7 +1768,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ common_chat_params data; auto params_copy = params; params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE; - data.prompt = common_chat_template_direct_apply(tmpl, params_copy); + data.prompt = common_chat_template_direct_apply_impl(tmpl, params_copy); data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.generation_prompt = params.generation_prompt; auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) { diff --git a/common/chat.h b/common/chat.h index 50c73d4817..a60a9228bd 100644 --- a/common/chat.h +++ b/common/chat.h @@ -3,12 +3,12 @@ #pragma once #include "common.h" -#include "jinja/parser.h" -#include "nlohmann/json_fwd.hpp" #include "peg-parser.h" +#include "jinja/parser.h" #include "jinja/runtime.h" #include "jinja/caps.h" -#include "nlohmann/json.hpp" + +#include "nlohmann/json_fwd.hpp" #include #include @@ -19,8 +19,6 @@ using chat_template_caps = jinja::caps; using json = nlohmann::ordered_json; -#include - struct common_chat_templates; namespace autoparser { @@ -75,41 +73,9 @@ struct common_chat_template { const std::string & bos_token() const { return bos_tok; } const std::string & eos_token() const { return eos_tok; } - // TODO: this is ugly, refactor it somehow - json add_system(const json & messages, const std::string & system_prompt) const { - GGML_ASSERT(messages.is_array()); - auto msgs_copy = messages; - if (!caps.supports_system_role) { - if (msgs_copy.empty()) { - msgs_copy.insert(msgs_copy.begin(), json{ - {"role", "user"}, - {"content", system_prompt} - }); - } else { - auto & first_msg = msgs_copy[0]; - if (!first_msg.contains("content")) { - first_msg["content"] = ""; - } - first_msg["content"] = system_prompt + "\n\n" - + first_msg["content"].get(); - } - } else { - if (msgs_copy.empty() || msgs_copy[0].at("role") != "system") { - msgs_copy.insert(msgs_copy.begin(), json{ - {"role", "system"}, - {"content", system_prompt} - }); - } else if (msgs_copy[0].at("role") == "system") { - msgs_copy[0]["content"] = system_prompt; - } - } - return msgs_copy; - } - chat_template_caps original_caps() const { return caps; } - }; struct common_chat_msg { @@ -257,8 +223,8 @@ common_chat_templates_ptr common_chat_templates_init(const struct llama_model * const std::string & bos_token_override = "", const std::string & eos_token_override = ""); -bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls); -std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = ""); +bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls); +std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = ""); struct common_chat_params common_chat_templates_apply(const struct common_chat_templates * tmpls, const struct common_chat_templates_inputs & inputs); @@ -275,9 +241,9 @@ std::string common_chat_format_example(const struct common_chat_templates * bool use_jinja, const std::map & chat_template_kwargs); -const char * common_chat_format_name(common_chat_format format); -common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params); -common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params); +const char * common_chat_format_name(common_chat_format format); +common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params); +common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params); // used by arg and server const char * common_reasoning_format_name(common_reasoning_format format); @@ -303,7 +269,4 @@ std::map common_chat_templates_get_caps(const common_chat_tem std::string common_chat_template_direct_apply( const common_chat_template & tmpl, - const autoparser::generation_params & inputs, - const std::optional & messages_override = std::nullopt, - const std::optional & tools_override = std::nullopt, - const std::optional & additional_context = std::nullopt); + const autoparser::generation_params & inputs);