chat : avoid including json in chat.h (#21306)
This commit is contained in:
parent
39b27f0da0
commit
57ace0d612
|
|
@ -1,7 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "chat-auto-parser.h"
|
#include "chat-auto-parser.h"
|
||||||
#include "peg-parser.h"
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "jinja/caps.h"
|
#include "jinja/caps.h"
|
||||||
#include "peg-parser.h"
|
#include "peg-parser.h"
|
||||||
|
#include "nlohmann/json.hpp"
|
||||||
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,8 @@
|
||||||
#include "jinja/caps.h"
|
#include "jinja/caps.h"
|
||||||
#include "peg-parser.h"
|
#include "peg-parser.h"
|
||||||
|
|
||||||
|
#include "nlohmann/json.hpp"
|
||||||
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <ctime>
|
#include <ctime>
|
||||||
|
|
@ -762,12 +764,12 @@ static void foreach_parameter(const json &
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string common_chat_template_direct_apply(
|
static std::string common_chat_template_direct_apply_impl(
|
||||||
const common_chat_template & tmpl,
|
const common_chat_template & tmpl,
|
||||||
const autoparser::generation_params & inputs,
|
const autoparser::generation_params & inputs,
|
||||||
const std::optional<json> & messages_override,
|
const std::optional<json> & messages_override = std::nullopt,
|
||||||
const std::optional<json> & tools_override,
|
const std::optional<json> & tools_override = std::nullopt,
|
||||||
const std::optional<json> & additional_context) {
|
const std::optional<json> & additional_context = std::nullopt) {
|
||||||
jinja::context ctx(tmpl.source());
|
jinja::context ctx(tmpl.source());
|
||||||
|
|
||||||
nlohmann::ordered_json inp = nlohmann::ordered_json{
|
nlohmann::ordered_json inp = nlohmann::ordered_json{
|
||||||
|
|
@ -814,6 +816,12 @@ std::string common_chat_template_direct_apply(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string common_chat_template_direct_apply(
|
||||||
|
const common_chat_template & tmpl,
|
||||||
|
const autoparser::generation_params & inputs) {
|
||||||
|
return common_chat_template_direct_apply_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt);
|
||||||
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
|
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
|
||||||
const autoparser::generation_params & inputs) {
|
const autoparser::generation_params & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
|
|
@ -864,7 +872,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
||||||
data.supports_thinking = true;
|
data.supports_thinking = true;
|
||||||
data.thinking_start_tag = "[THINK]";
|
data.thinking_start_tag = "[THINK]";
|
||||||
data.thinking_end_tag = "[/THINK]";
|
data.thinking_end_tag = "[/THINK]";
|
||||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
||||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||||
data.preserved_tokens = {
|
data.preserved_tokens = {
|
||||||
"[THINK]",
|
"[THINK]",
|
||||||
|
|
@ -947,7 +955,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||||
adjusted_messages.push_back(msg);
|
adjusted_messages.push_back(msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
auto prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||||
|
|
||||||
// Check if we need to replace the return token with end token during
|
// Check if we need to replace the return token with end token during
|
||||||
// inference and without generation prompt. For more details see:
|
// inference and without generation prompt. For more details see:
|
||||||
|
|
@ -1074,7 +1082,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
||||||
const autoparser::generation_params & inputs) {
|
const autoparser::generation_params & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
|
|
||||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||||
data.preserved_tokens = {
|
data.preserved_tokens = {
|
||||||
">>>all",
|
">>>all",
|
||||||
|
|
@ -1168,7 +1176,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
||||||
const autoparser::generation_params & inputs) {
|
const autoparser::generation_params & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
|
|
||||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||||
data.supports_thinking = true;
|
data.supports_thinking = true;
|
||||||
data.preserved_tokens = {
|
data.preserved_tokens = {
|
||||||
|
|
@ -1291,7 +1299,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
||||||
const autoparser::generation_params & inputs) {
|
const autoparser::generation_params & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
|
|
||||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||||
data.supports_thinking = true;
|
data.supports_thinking = true;
|
||||||
data.preserved_tokens = {
|
data.preserved_tokens = {
|
||||||
|
|
@ -1370,7 +1378,7 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ
|
||||||
const autoparser::generation_params & inputs) {
|
const autoparser::generation_params & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
|
|
||||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||||
data.supports_thinking = true;
|
data.supports_thinking = true;
|
||||||
data.preserved_tokens = {
|
data.preserved_tokens = {
|
||||||
|
|
@ -1441,7 +1449,7 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
||||||
|
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
|
|
||||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||||
data.supports_thinking = false;
|
data.supports_thinking = false;
|
||||||
data.preserved_tokens = {
|
data.preserved_tokens = {
|
||||||
|
|
@ -1724,9 +1732,9 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||||
}
|
}
|
||||||
|
|
||||||
params.add_generation_prompt = false;
|
params.add_generation_prompt = false;
|
||||||
std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params);
|
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||||
params.add_generation_prompt = true;
|
params.add_generation_prompt = true;
|
||||||
std::string gen_prompt = common_chat_template_direct_apply(tmpl, params);
|
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||||
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
|
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
|
||||||
params.generation_prompt = diff.right;
|
params.generation_prompt = diff.right;
|
||||||
|
|
||||||
|
|
@ -1760,7 +1768,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
auto params_copy = params;
|
auto params_copy = params;
|
||||||
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||||
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
|
data.prompt = common_chat_template_direct_apply_impl(tmpl, params_copy);
|
||||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||||
data.generation_prompt = params.generation_prompt;
|
data.generation_prompt = params.generation_prompt;
|
||||||
auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) {
|
auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) {
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,12 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "jinja/parser.h"
|
|
||||||
#include "nlohmann/json_fwd.hpp"
|
|
||||||
#include "peg-parser.h"
|
#include "peg-parser.h"
|
||||||
|
#include "jinja/parser.h"
|
||||||
#include "jinja/runtime.h"
|
#include "jinja/runtime.h"
|
||||||
#include "jinja/caps.h"
|
#include "jinja/caps.h"
|
||||||
#include "nlohmann/json.hpp"
|
|
||||||
|
#include "nlohmann/json_fwd.hpp"
|
||||||
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
@ -19,8 +19,6 @@
|
||||||
using chat_template_caps = jinja::caps;
|
using chat_template_caps = jinja::caps;
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
#include <nlohmann/json_fwd.hpp>
|
|
||||||
|
|
||||||
struct common_chat_templates;
|
struct common_chat_templates;
|
||||||
|
|
||||||
namespace autoparser {
|
namespace autoparser {
|
||||||
|
|
@ -75,41 +73,9 @@ struct common_chat_template {
|
||||||
const std::string & bos_token() const { return bos_tok; }
|
const std::string & bos_token() const { return bos_tok; }
|
||||||
const std::string & eos_token() const { return eos_tok; }
|
const std::string & eos_token() const { return eos_tok; }
|
||||||
|
|
||||||
// TODO: this is ugly, refactor it somehow
|
|
||||||
json add_system(const json & messages, const std::string & system_prompt) const {
|
|
||||||
GGML_ASSERT(messages.is_array());
|
|
||||||
auto msgs_copy = messages;
|
|
||||||
if (!caps.supports_system_role) {
|
|
||||||
if (msgs_copy.empty()) {
|
|
||||||
msgs_copy.insert(msgs_copy.begin(), json{
|
|
||||||
{"role", "user"},
|
|
||||||
{"content", system_prompt}
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
auto & first_msg = msgs_copy[0];
|
|
||||||
if (!first_msg.contains("content")) {
|
|
||||||
first_msg["content"] = "";
|
|
||||||
}
|
|
||||||
first_msg["content"] = system_prompt + "\n\n"
|
|
||||||
+ first_msg["content"].get<std::string>();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (msgs_copy.empty() || msgs_copy[0].at("role") != "system") {
|
|
||||||
msgs_copy.insert(msgs_copy.begin(), json{
|
|
||||||
{"role", "system"},
|
|
||||||
{"content", system_prompt}
|
|
||||||
});
|
|
||||||
} else if (msgs_copy[0].at("role") == "system") {
|
|
||||||
msgs_copy[0]["content"] = system_prompt;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return msgs_copy;
|
|
||||||
}
|
|
||||||
|
|
||||||
chat_template_caps original_caps() const {
|
chat_template_caps original_caps() const {
|
||||||
return caps;
|
return caps;
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_chat_msg {
|
struct common_chat_msg {
|
||||||
|
|
@ -257,8 +223,8 @@ common_chat_templates_ptr common_chat_templates_init(const struct llama_model *
|
||||||
const std::string & bos_token_override = "",
|
const std::string & bos_token_override = "",
|
||||||
const std::string & eos_token_override = "");
|
const std::string & eos_token_override = "");
|
||||||
|
|
||||||
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||||
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
|
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
|
||||||
|
|
||||||
struct common_chat_params common_chat_templates_apply(const struct common_chat_templates * tmpls,
|
struct common_chat_params common_chat_templates_apply(const struct common_chat_templates * tmpls,
|
||||||
const struct common_chat_templates_inputs & inputs);
|
const struct common_chat_templates_inputs & inputs);
|
||||||
|
|
@ -275,9 +241,9 @@ std::string common_chat_format_example(const struct common_chat_templates *
|
||||||
bool use_jinja,
|
bool use_jinja,
|
||||||
const std::map<std::string, std::string> & chat_template_kwargs);
|
const std::map<std::string, std::string> & chat_template_kwargs);
|
||||||
|
|
||||||
const char * common_chat_format_name(common_chat_format format);
|
const char * common_chat_format_name(common_chat_format format);
|
||||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||||
|
|
||||||
// used by arg and server
|
// used by arg and server
|
||||||
const char * common_reasoning_format_name(common_reasoning_format format);
|
const char * common_reasoning_format_name(common_reasoning_format format);
|
||||||
|
|
@ -303,7 +269,4 @@ std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_tem
|
||||||
|
|
||||||
std::string common_chat_template_direct_apply(
|
std::string common_chat_template_direct_apply(
|
||||||
const common_chat_template & tmpl,
|
const common_chat_template & tmpl,
|
||||||
const autoparser::generation_params & inputs,
|
const autoparser::generation_params & inputs);
|
||||||
const std::optional<json> & messages_override = std::nullopt,
|
|
||||||
const std::optional<json> & tools_override = std::nullopt,
|
|
||||||
const std::optional<json> & additional_context = std::nullopt);
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue