Merge branch 'ggml-org:master' into i8mm-ci
This commit is contained in:
commit
65fbb935d7
|
|
@ -1390,14 +1390,10 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
arch: [x86, aarch64]
|
||||
cann:
|
||||
- '8.3.rc1.alpha001-910b-openeuler22.03-py3.11'
|
||||
chip_type:
|
||||
- '910b'
|
||||
build:
|
||||
- 'Release'
|
||||
chip_type: ['910b', '310p']
|
||||
build: ['Release']
|
||||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
container: ascendai/cann:${{ matrix.cann }}
|
||||
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
|
|
|||
|
|
@ -698,10 +698,9 @@ jobs:
|
|||
matrix:
|
||||
arch: [x86, aarch64]
|
||||
chip_type: ['910b', '310p']
|
||||
build:
|
||||
- 'Release'
|
||||
build: ['Release']
|
||||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.3.rc1.alpha001-310p-openeuler22.03-py3.11' }}
|
||||
container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
|
@ -737,7 +736,7 @@ jobs:
|
|||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
|
||||
name: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}
|
||||
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip
|
||||
|
||||
release:
|
||||
if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }}
|
||||
|
|
|
|||
|
|
@ -50,6 +50,8 @@ add_library(${TARGET} STATIC
|
|||
base64.hpp
|
||||
chat-parser.cpp
|
||||
chat-parser.h
|
||||
chat-parser-xml-toolcall.h
|
||||
chat-parser-xml-toolcall.cpp
|
||||
chat.cpp
|
||||
chat.h
|
||||
common.cpp
|
||||
|
|
|
|||
|
|
@ -0,0 +1,861 @@
|
|||
#include "chat.h"
|
||||
#include "chat-parser.h"
|
||||
#include "common.h"
|
||||
#include "json-partial.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
class xml_toolcall_syntax_exception : public std::runtime_error {
|
||||
public:
|
||||
xml_toolcall_syntax_exception(const std::string & message) : std::runtime_error(message) {}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
inline void sort_uniq(std::vector<T> &vec) {
|
||||
std::sort(vec.begin(), vec.end());
|
||||
vec.erase(std::unique(vec.begin(), vec.end()), vec.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline bool all_space(const T &str) {
|
||||
return std::all_of(str.begin(), str.end(), [](unsigned char ch) { return std::isspace(ch); });
|
||||
}
|
||||
|
||||
static size_t utf8_truncate_safe(const std::string_view s) {
|
||||
size_t len = s.size();
|
||||
if (len == 0) return 0;
|
||||
size_t i = len;
|
||||
for (size_t back = 0; back < 4 && i > 0; ++back) {
|
||||
--i;
|
||||
unsigned char c = s[i];
|
||||
if ((c & 0x80) == 0) {
|
||||
return len;
|
||||
} else if ((c & 0xC0) == 0xC0) {
|
||||
size_t expected_len = 0;
|
||||
if ((c & 0xE0) == 0xC0) expected_len = 2;
|
||||
else if ((c & 0xF0) == 0xE0) expected_len = 3;
|
||||
else if ((c & 0xF8) == 0xF0) expected_len = 4;
|
||||
else return i;
|
||||
if (len - i >= expected_len) {
|
||||
return len;
|
||||
} else {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
return len - std::min(len, size_t(3));
|
||||
}
|
||||
|
||||
inline void utf8_truncate_safe_resize(std::string &s) {
|
||||
s.resize(utf8_truncate_safe(s));
|
||||
}
|
||||
|
||||
inline std::string_view utf8_truncate_safe_view(const std::string_view s) {
|
||||
return s.substr(0, utf8_truncate_safe(s));
|
||||
}
|
||||
|
||||
static std::optional<common_chat_msg_parser::find_regex_result> try_find_2_literal_splited_by_spaces(common_chat_msg_parser & builder, const std::string & literal1, const std::string & literal2) {
|
||||
if (literal1.size() == 0) return builder.try_find_literal(literal2);
|
||||
const auto saved_pos = builder.pos();
|
||||
while (auto res = builder.try_find_literal(literal1)) {
|
||||
builder.consume_spaces();
|
||||
const auto match_len = std::min(literal2.size(), builder.input().size() - builder.pos());
|
||||
if (builder.input().compare(builder.pos(), match_len, literal2, 0, match_len) == 0) {
|
||||
if (res->prelude.size() != res->groups[0].begin - saved_pos) {
|
||||
res->prelude = builder.str({saved_pos, res->groups[0].begin});
|
||||
}
|
||||
builder.move_to(builder.pos() + match_len);
|
||||
res->groups[0].end = builder.pos();
|
||||
GGML_ASSERT(res->groups[0].begin != res->groups[0].end);
|
||||
return res;
|
||||
}
|
||||
builder.move_to(res->groups[0].begin + 1);
|
||||
}
|
||||
builder.move_to(saved_pos);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
/**
|
||||
* make a GBNF that accept any strings except those containing any of the forbidden strings.
|
||||
*/
|
||||
std::string make_gbnf_excluding(std::vector<std::string> forbids) {
|
||||
constexpr auto charclass_escape = [](unsigned char c) -> std::string {
|
||||
if (c == '\\' || c == ']' || c == '^' || c == '-') {
|
||||
std::string s = "\\";
|
||||
s.push_back((char)c);
|
||||
return s;
|
||||
}
|
||||
if (isprint(c)) {
|
||||
return std::string(1, (char)c);
|
||||
}
|
||||
char buf[16];
|
||||
snprintf(buf, 15, "\\x%02X", c);
|
||||
return std::string(buf);
|
||||
};
|
||||
constexpr auto build_expr = [charclass_escape](auto self, const std::vector<std::string>& forbids, int l, int r, int depth) -> std::string {
|
||||
std::vector<std::pair<unsigned char, std::pair<int,int>>> children;
|
||||
int i = l;
|
||||
while (i < r) {
|
||||
const std::string &s = forbids[i];
|
||||
if ((int)s.size() == depth) {
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
unsigned char c = (unsigned char)s[depth];
|
||||
int j = i;
|
||||
while (j < r && (int)forbids[j].size() > depth &&
|
||||
(unsigned char)forbids[j][depth] == c) {
|
||||
++j;
|
||||
}
|
||||
children.push_back({c, {i, j}});
|
||||
i = j;
|
||||
}
|
||||
std::vector<std::string> alts;
|
||||
if (!children.empty()) {
|
||||
std::string cls;
|
||||
for (auto &ch : children) cls += charclass_escape(ch.first);
|
||||
alts.push_back(std::string("[^") + cls + "]");
|
||||
}
|
||||
for (auto &ch : children) {
|
||||
std::string childExpr = self(self, forbids, ch.second.first, ch.second.second, depth+1);
|
||||
if (!childExpr.empty()) {
|
||||
std::string quoted_ch = "\"";
|
||||
if (ch.first == '\\') quoted_ch += "\\\\";
|
||||
else if (ch.first == '"') quoted_ch += "\\\"";
|
||||
else if (isprint(ch.first)) quoted_ch.push_back(ch.first);
|
||||
else {
|
||||
char buf[16];
|
||||
snprintf(buf, 15, "\\x%02X", ch.first);
|
||||
quoted_ch += buf;
|
||||
}
|
||||
quoted_ch += "\"";
|
||||
std::string branch = quoted_ch + std::string(" ") + childExpr;
|
||||
alts.push_back(branch);
|
||||
}
|
||||
}
|
||||
if (alts.empty()) return "";
|
||||
std::ostringstream oss;
|
||||
oss << "( ";
|
||||
for (size_t k = 0; k < alts.size(); ++k) {
|
||||
if (k) oss << " | ";
|
||||
oss << alts[k];
|
||||
}
|
||||
oss << " )";
|
||||
return oss.str();
|
||||
};
|
||||
if (forbids.empty()) return "( . )*";
|
||||
sort(forbids.begin(), forbids.end());
|
||||
std::string expr = build_expr(build_expr, forbids, 0, forbids.size(), 0);
|
||||
if (expr.empty()) {
|
||||
std::string cls;
|
||||
for (auto &s : forbids) if (!s.empty()) cls += charclass_escape((unsigned char)s[0]);
|
||||
expr = std::string("( [^") + cls + "] )";
|
||||
}
|
||||
if (forbids.size() == 1)
|
||||
return expr + "*";
|
||||
else
|
||||
return std::string("( ") + expr + " )*";
|
||||
}
|
||||
|
||||
/**
|
||||
* Build grammar for xml-style tool call
|
||||
* form.scope_start and form.scope_end can be empty.
|
||||
* Requires data.format for model-specific hacks.
|
||||
*/
|
||||
void build_grammar_xml_tool_call(common_chat_params & data, const json & tools, const struct xml_tool_call_format & form) {
|
||||
GGML_ASSERT(!form.tool_start.empty());
|
||||
GGML_ASSERT(!form.tool_sep.empty());
|
||||
GGML_ASSERT(!form.key_start.empty());
|
||||
GGML_ASSERT(!form.val_end.empty());
|
||||
GGML_ASSERT(!form.tool_end.empty());
|
||||
|
||||
std::string key_val_sep = form.key_val_sep;
|
||||
if (form.key_val_sep2) {
|
||||
key_val_sep += "\n";
|
||||
key_val_sep += *form.key_val_sep2;
|
||||
}
|
||||
GGML_ASSERT(!key_val_sep.empty());
|
||||
|
||||
if (tools.is_array() && !tools.empty()) {
|
||||
data.grammar = build_grammar([&](const common_grammar_builder &builder) {
|
||||
auto string_arg_val = form.last_val_end ?
|
||||
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end, *form.last_val_end})) :
|
||||
builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end}));
|
||||
|
||||
std::vector<std::string> tool_rules;
|
||||
for (const auto & tool : tools) {
|
||||
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
|
||||
LOG_WRN("Skipping tool without function: %s", tool.dump(2).c_str());
|
||||
continue;
|
||||
}
|
||||
const auto & function = tool.at("function");
|
||||
if (!function.contains("name") || !function.at("name").is_string()) {
|
||||
LOG_WRN("Skipping invalid function (invalid name): %s", function.dump(2).c_str());
|
||||
continue;
|
||||
}
|
||||
if (!function.contains("parameters") || !function.at("parameters").is_object()) {
|
||||
LOG_WRN("Skipping invalid function (invalid parameters): %s", function.dump(2).c_str());
|
||||
continue;
|
||||
}
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
|
||||
struct parameter_rule {
|
||||
std::string symbol_name;
|
||||
bool is_required;
|
||||
};
|
||||
std::vector<parameter_rule> arg_rules;
|
||||
if (!parameters.contains("properties") || !parameters.at("properties").is_object()) {
|
||||
LOG_WRN("Skipping invalid function (invalid properties): %s", function.dump(2).c_str());
|
||||
continue;
|
||||
} else {
|
||||
std::vector<std::string> requiredParameters;
|
||||
if (parameters.contains("required")) {
|
||||
try { parameters.at("required").get_to(requiredParameters); }
|
||||
catch (const std::runtime_error&) {
|
||||
LOG_WRN("Invalid function required parameters, ignoring: %s", function.at("required").dump(2).c_str());
|
||||
}
|
||||
}
|
||||
sort_uniq(requiredParameters);
|
||||
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||
std::string quoted_key = key;
|
||||
bool required = std::binary_search(requiredParameters.begin(), requiredParameters.end(), key);
|
||||
if (form.key_start.back() == '"' && key_val_sep[0] == '"') {
|
||||
quoted_key = gbnf_format_literal(key);
|
||||
quoted_key = quoted_key.substr(1, quoted_key.size() - 2);
|
||||
}
|
||||
arg_rules.push_back(parameter_rule {builder.add_rule("func-" + name + "-kv-" + key,
|
||||
gbnf_format_literal(form.key_start) + " " +
|
||||
gbnf_format_literal(quoted_key) + " " +
|
||||
gbnf_format_literal(key_val_sep) + " " +
|
||||
((value.contains("type") && value["type"].is_string() && value["type"] == "string" && (!form.raw_argval || *form.raw_argval)) ?
|
||||
(form.raw_argval ?
|
||||
string_arg_val :
|
||||
"( " + string_arg_val + " | " + builder.add_schema(name + "-arg-" + key, value) + " )"
|
||||
) :
|
||||
builder.add_schema(name + "-arg-" + key, value)
|
||||
)
|
||||
), required});
|
||||
}
|
||||
}
|
||||
|
||||
auto next_arg_with_sep = builder.add_rule(name + "-last-arg-end", form.last_val_end ? gbnf_format_literal(*form.last_val_end) : gbnf_format_literal(form.val_end));
|
||||
decltype(next_arg_with_sep) next_arg = "\"\"";
|
||||
for (auto i = arg_rules.size() - 1; /* i >= 0 && */ i < arg_rules.size(); --i) {
|
||||
std::string include_this_arg = arg_rules[i].symbol_name + " " + next_arg_with_sep;
|
||||
next_arg = builder.add_rule(name + "-arg-after-" + std::to_string(i), arg_rules[i].is_required ?
|
||||
include_this_arg : "( " + include_this_arg + " ) | " + next_arg
|
||||
);
|
||||
include_this_arg = gbnf_format_literal(form.val_end) + " " + include_this_arg;
|
||||
next_arg_with_sep = builder.add_rule(name + "-arg-after-" + std::to_string(i) + "-with-sep", arg_rules[i].is_required ?
|
||||
include_this_arg : "( " + include_this_arg + " ) | " + next_arg_with_sep
|
||||
);
|
||||
}
|
||||
|
||||
std::string quoted_name = name;
|
||||
if (form.tool_start.back() == '"' && form.tool_sep[0] == '"') {
|
||||
quoted_name = gbnf_format_literal(name);
|
||||
quoted_name = quoted_name.substr(1, quoted_name.size() - 2);
|
||||
}
|
||||
quoted_name = gbnf_format_literal(quoted_name);
|
||||
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
|
||||
if (data.format == COMMON_CHAT_FORMAT_KIMI_K2) {
|
||||
quoted_name = "\"functions.\" " + quoted_name + " \":\" [0-9]+";
|
||||
}
|
||||
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||
gbnf_format_literal(form.tool_start) + " " +
|
||||
quoted_name + " " +
|
||||
gbnf_format_literal(form.tool_sep) + " " +
|
||||
next_arg
|
||||
));
|
||||
}
|
||||
|
||||
auto tool_call_once = builder.add_rule("root-tool-call-once", string_join(tool_rules, " | "));
|
||||
auto tool_call_more = builder.add_rule("root-tool-call-more", gbnf_format_literal(form.tool_end) + " " + tool_call_once);
|
||||
auto call_end = builder.add_rule("root-call-end", form.last_tool_end ? gbnf_format_literal(*form.last_tool_end) : gbnf_format_literal(form.tool_end));
|
||||
auto tool_call_multiple_with_end = builder.add_rule("root-tool-call-multiple-with-end", tool_call_once + " " + tool_call_more + "* " + call_end);
|
||||
builder.add_rule("root",
|
||||
(form.scope_start.empty() ? "" : gbnf_format_literal(form.scope_start) + " ") +
|
||||
tool_call_multiple_with_end + "?" +
|
||||
(form.scope_end.empty() ? "" : " " + gbnf_format_literal(form.scope_end))
|
||||
);
|
||||
});
|
||||
|
||||
// grammar trigger for tool call
|
||||
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, form.scope_start + form.tool_start });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||
* Throws xml_toolcall_syntax_exception if there is invalid syntax and cannot recover the original status for common_chat_msg_parser.
|
||||
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||
*/
|
||||
inline bool parse_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form) {
|
||||
GGML_ASSERT(!form.tool_start.empty());
|
||||
GGML_ASSERT(!form.key_start.empty());
|
||||
GGML_ASSERT(!form.key_val_sep.empty());
|
||||
GGML_ASSERT(!form.val_end.empty());
|
||||
GGML_ASSERT(!form.tool_end.empty());
|
||||
|
||||
// Helper to choose return false or throw error
|
||||
constexpr auto return_error = [](common_chat_msg_parser & builder, auto &start_pos, const bool &recovery) {
|
||||
LOG_DBG("Failed to parse XML-Style tool call at position: %s\n", gbnf_format_literal(builder.consume_rest().substr(0, 20)).c_str());
|
||||
if (recovery) {
|
||||
builder.move_to(start_pos);
|
||||
return false;
|
||||
} else throw xml_toolcall_syntax_exception("Tool call parsing failed with unrecoverable errors. Try using a grammar to constrain the model’s output.");
|
||||
};
|
||||
// Drop substring from needle to end from a JSON
|
||||
constexpr auto partial_json = [](std::string &json_str, std::string_view needle = "XML_TOOL_CALL_PARTIAL_FLAG") {
|
||||
auto pos = json_str.rfind(needle);
|
||||
if (pos == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
for (auto i = pos + needle.size(); i < json_str.size(); ++i) {
|
||||
unsigned char ch = static_cast<unsigned char>(json_str[i]);
|
||||
if (ch != '\'' && ch != '"' && ch != '}' && ch != ':' && !std::isspace(ch)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (pos != 0 && json_str[pos - 1] == '"') {
|
||||
--pos;
|
||||
}
|
||||
json_str.resize(pos);
|
||||
return true;
|
||||
};
|
||||
// Helper to generate a partial argument JSON
|
||||
constexpr auto gen_partial_json = [partial_json](auto set_partial_arg, auto &arguments, auto &builder, auto &function_name) {
|
||||
auto rest = builder.consume_rest();
|
||||
utf8_truncate_safe_resize(rest);
|
||||
set_partial_arg(rest, "XML_TOOL_CALL_PARTIAL_FLAG");
|
||||
auto tool_str = arguments.dump();
|
||||
if (partial_json(tool_str)) {
|
||||
if (builder.add_tool_call(function_name, "", tool_str)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
LOG_DBG("Failed to parse partial XML-Style tool call, fallback to non-partial: %s\n", tool_str.c_str());
|
||||
};
|
||||
// Helper to find a close (because there may be form.last_val_end or form.last_tool_end)
|
||||
constexpr auto try_find_close = [](
|
||||
common_chat_msg_parser & builder,
|
||||
const std::string & end,
|
||||
const std::optional<std::string> & alt_end,
|
||||
const std::string & end_next,
|
||||
const std::optional<std::string> & alt_end_next
|
||||
) {
|
||||
auto saved_pos = builder.pos();
|
||||
auto tc = builder.try_find_literal(end);
|
||||
auto val_end_size = end.size();
|
||||
if (alt_end) {
|
||||
auto pos_1 = builder.pos();
|
||||
builder.move_to(saved_pos);
|
||||
auto tc2 = try_find_2_literal_splited_by_spaces(builder, *alt_end, end_next);
|
||||
if (alt_end_next) {
|
||||
builder.move_to(saved_pos);
|
||||
auto tc3 = try_find_2_literal_splited_by_spaces(builder, *alt_end, *alt_end_next);
|
||||
if (tc3 && (!tc2 || tc2->prelude.size() > tc3->prelude.size())) {
|
||||
tc2 = tc3;
|
||||
}
|
||||
}
|
||||
if (tc2 && (!tc || tc->prelude.size() > tc2->prelude.size())) {
|
||||
tc = tc2;
|
||||
tc->groups[0].end = std::min(builder.input().size(), tc->groups[0].begin + alt_end->size());
|
||||
builder.move_to(tc->groups[0].end);
|
||||
val_end_size = alt_end->size();
|
||||
} else {
|
||||
builder.move_to(pos_1);
|
||||
}
|
||||
}
|
||||
return std::make_pair(val_end_size, tc);
|
||||
};
|
||||
// Helper to find a val_end or last_val_end, returns matched pattern size
|
||||
const auto try_find_val_end = [try_find_close, &builder, &form]() {
|
||||
return try_find_close(builder, form.val_end, form.last_val_end, form.tool_end, form.last_tool_end);
|
||||
};
|
||||
// Helper to find a tool_end or last_tool_end, returns matched pattern size
|
||||
const auto try_find_tool_end = [try_find_close, &builder, &form]() {
|
||||
return try_find_close(builder, form.tool_end, form.last_tool_end, form.scope_end, std::nullopt);
|
||||
};
|
||||
|
||||
bool recovery = true;
|
||||
const auto start_pos = builder.pos();
|
||||
if (!all_space(form.scope_start)) {
|
||||
if (auto tc = builder.try_find_literal(form.scope_start)) {
|
||||
if (all_space(tc->prelude)) {
|
||||
if (form.scope_start.size() != tc->groups[0].end - tc->groups[0].begin)
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.scope_start));
|
||||
} else {
|
||||
builder.move_to(start_pos);
|
||||
return false;
|
||||
}
|
||||
} else return false;
|
||||
}
|
||||
while (auto tc = builder.try_find_literal(form.tool_start)) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
|
||||
gbnf_format_literal(form.tool_start).c_str(),
|
||||
gbnf_format_literal(tc->prelude).c_str()
|
||||
);
|
||||
builder.move_to(tc->groups[0].begin - tc->prelude.size());
|
||||
break;
|
||||
}
|
||||
|
||||
// Find tool name
|
||||
auto func_name = builder.try_find_literal(all_space(form.tool_sep) ? form.key_start : form.tool_sep);
|
||||
if (!func_name) {
|
||||
auto [sz, tc] = try_find_tool_end();
|
||||
func_name = tc;
|
||||
}
|
||||
if (!func_name) {
|
||||
// Partial tool name not supported
|
||||
throw common_chat_msg_partial_exception("incomplete tool_call");
|
||||
}
|
||||
// If the model generate multiple tool call and the first tool call has no argument
|
||||
if (func_name->prelude.find(form.tool_end) != std::string::npos || (form.last_tool_end ? func_name->prelude.find(*form.last_tool_end) != std::string::npos : false)) {
|
||||
builder.move_to(func_name->groups[0].begin - func_name->prelude.size());
|
||||
auto [sz, tc] = try_find_tool_end();
|
||||
func_name = tc;
|
||||
}
|
||||
|
||||
// Parse tool name
|
||||
builder.move_to(all_space(form.tool_sep) ? func_name->groups[0].begin : func_name->groups[0].end);
|
||||
std::string function_name = string_strip(func_name->prelude);
|
||||
// Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name
|
||||
if (builder.syntax().format == COMMON_CHAT_FORMAT_KIMI_K2) {
|
||||
if (string_starts_with(function_name, "functions.")) {
|
||||
static const std::regex re(":\\d+$");
|
||||
if (std::regex_search(function_name, re)) {
|
||||
function_name = function_name.substr(10, function_name.rfind(":") - 10);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Argument JSON
|
||||
json arguments = json::object();
|
||||
|
||||
// Helper to generate a partial argument JSON
|
||||
const auto gen_partial_args = [&](auto set_partial_arg) {
|
||||
gen_partial_json(set_partial_arg, arguments, builder, function_name);
|
||||
};
|
||||
|
||||
// Parse all arg_key/arg_value pairs
|
||||
while (auto tc = builder.try_find_literal(form.key_start)) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n",
|
||||
gbnf_format_literal(form.key_start).c_str(),
|
||||
gbnf_format_literal(tc->prelude).c_str()
|
||||
);
|
||||
builder.move_to(tc->groups[0].begin - tc->prelude.size());
|
||||
break;
|
||||
}
|
||||
if (tc->groups[0].end - tc->groups[0].begin != form.key_start.size()) {
|
||||
auto tool_call_arg = arguments.dump();
|
||||
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
|
||||
tool_call_arg.resize(tool_call_arg.size() - 1);
|
||||
}
|
||||
builder.add_tool_call(function_name, "", tool_call_arg);
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_start));
|
||||
}
|
||||
|
||||
// Parse arg_key
|
||||
auto key_res = builder.try_find_literal(form.key_val_sep);
|
||||
if (!key_res) {
|
||||
gen_partial_args([&](auto &rest, auto &needle) {arguments[rest + needle] = "";});
|
||||
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.key_val_sep) + " after " + gbnf_format_literal(form.key_start));
|
||||
}
|
||||
if (key_res->groups[0].end - key_res->groups[0].begin != form.key_val_sep.size()) {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key_res->prelude + needle] = "";});
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_val_sep));
|
||||
}
|
||||
auto &key = key_res->prelude;
|
||||
recovery = false;
|
||||
|
||||
// Parse arg_value
|
||||
if (form.key_val_sep2) {
|
||||
if (auto tc = builder.try_find_literal(*form.key_val_sep2)) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("Failed to parse XML-Style tool call: Unexcepted %s between %s and %s\n",
|
||||
gbnf_format_literal(tc->prelude).c_str(),
|
||||
gbnf_format_literal(form.key_val_sep).c_str(),
|
||||
gbnf_format_literal(*form.key_val_sep2).c_str()
|
||||
);
|
||||
return return_error(builder, start_pos, false);
|
||||
}
|
||||
if (tc->groups[0].end - tc->groups[0].begin != form.key_val_sep2->size()) {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(*form.key_val_sep2));
|
||||
}
|
||||
} else {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(*form.key_val_sep2) + " after " + gbnf_format_literal(form.key_val_sep));
|
||||
}
|
||||
}
|
||||
auto val_start = builder.pos();
|
||||
|
||||
// Test if arg_val is a partial JSON
|
||||
std::optional<common_json> value_json = std::nullopt;
|
||||
if (!form.raw_argval || !*form.raw_argval) {
|
||||
try { value_json = builder.try_consume_json(); }
|
||||
catch (const std::runtime_error&) { builder.move_to(val_start); }
|
||||
// TODO: Delete this when json_partial adds top-level support for null/true/false
|
||||
if (builder.pos() == val_start) {
|
||||
const static std::regex number_regex(R"([0-9-][0-9]*(\.\d*)?([eE][+-]?\d*)?)");
|
||||
builder.consume_spaces();
|
||||
std::string_view sv = utf8_truncate_safe_view(builder.input());
|
||||
sv.remove_prefix(builder.pos());
|
||||
std::string rest = "a";
|
||||
if (sv.size() < 6) rest = sv;
|
||||
if (string_starts_with("null", rest) || string_starts_with("true", rest) || string_starts_with("false", rest) || std::regex_match(sv.begin(), sv.end(), number_regex)) {
|
||||
value_json = {123, {"123", "123"}};
|
||||
builder.consume_rest();
|
||||
} else {
|
||||
builder.move_to(val_start);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If it is a JSON and followed by </arg_value>, parse as json
|
||||
// cannot support streaming because it may be a plain text starting with JSON
|
||||
if (value_json) {
|
||||
auto json_end = builder.pos();
|
||||
builder.consume_spaces();
|
||||
if (builder.pos() == builder.input().size()) {
|
||||
if (form.raw_argval && !*form.raw_argval && (value_json->json.is_string() || value_json->json.is_object() || value_json->json.is_array())) {
|
||||
arguments[key] = value_json->json;
|
||||
auto json_str = arguments.dump();
|
||||
if (!value_json->healing_marker.json_dump_marker.empty()) {
|
||||
GGML_ASSERT(std::string::npos != json_str.rfind(value_json->healing_marker.json_dump_marker));
|
||||
json_str.resize(json_str.rfind(value_json->healing_marker.json_dump_marker));
|
||||
} else {
|
||||
GGML_ASSERT(json_str.back() == '}');
|
||||
json_str.resize(json_str.size() - 1);
|
||||
}
|
||||
builder.add_tool_call(function_name, "", json_str);
|
||||
} else {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||
}
|
||||
LOG_DBG("Possible JSON arg_value: %s\n", value_json->json.dump().c_str());
|
||||
throw common_chat_msg_partial_exception("JSON arg_value detected. Waiting for more tokens for validations.");
|
||||
}
|
||||
builder.move_to(json_end);
|
||||
auto [val_end_size, tc] = try_find_val_end();
|
||||
if (tc && all_space(tc->prelude) && value_json->healing_marker.marker.empty()) {
|
||||
if (tc->groups[0].end - tc->groups[0].begin != val_end_size) {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;});
|
||||
LOG_DBG("Possible terminated JSON arg_value: %s\n", value_json->json.dump().c_str());
|
||||
throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.val_end) + (form.last_val_end ? gbnf_format_literal(*form.last_val_end) : ""));
|
||||
} else arguments[key] = value_json->json;
|
||||
} else builder.move_to(val_start);
|
||||
}
|
||||
|
||||
// If not, parse as plain text
|
||||
if (val_start == builder.pos()) {
|
||||
if (auto [val_end_size, value_plain] = try_find_val_end(); value_plain) {
|
||||
auto &value_str = value_plain->prelude;
|
||||
if (form.trim_raw_argval) value_str = string_strip(value_str);
|
||||
if (value_plain->groups[0].end - value_plain->groups[0].begin != val_end_size) {
|
||||
gen_partial_args([&](auto &, auto &needle) {arguments[key] = value_str + needle;});
|
||||
throw common_chat_msg_partial_exception(
|
||||
"Expected " + gbnf_format_literal(form.val_end) +
|
||||
" after " + gbnf_format_literal(form.key_val_sep) +
|
||||
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
|
||||
);
|
||||
}
|
||||
arguments[key] = value_str;
|
||||
} else {
|
||||
if (form.trim_raw_argval) {
|
||||
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = string_strip(rest) + needle;});
|
||||
} else {
|
||||
gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = rest + needle;});
|
||||
}
|
||||
throw common_chat_msg_partial_exception(
|
||||
"Expected " + gbnf_format_literal(form.val_end) +
|
||||
" after " + gbnf_format_literal(form.key_val_sep) +
|
||||
(form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "")
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Consume closing tag
|
||||
if (auto [tool_end_size, tc] = try_find_tool_end(); tc) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||
gbnf_format_literal(form.tool_end).c_str(),
|
||||
gbnf_format_literal(tc->prelude).c_str()
|
||||
);
|
||||
return return_error(builder, start_pos, recovery);
|
||||
}
|
||||
if (tc->groups[0].end - tc->groups[0].begin == tool_end_size) {
|
||||
// Add the parsed tool call
|
||||
if (!builder.add_tool_call(function_name, "", arguments.dump())) {
|
||||
throw common_chat_msg_partial_exception("Failed to add XML-Style tool call");
|
||||
}
|
||||
recovery = false;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
auto tool_call_arg = arguments.dump();
|
||||
if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') {
|
||||
tool_call_arg.resize(tool_call_arg.size() - 1);
|
||||
}
|
||||
builder.add_tool_call(function_name, "", tool_call_arg);
|
||||
throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.tool_end) + " after " + gbnf_format_literal(form.val_end));
|
||||
}
|
||||
if (auto tc = builder.try_find_literal(form.scope_end)) {
|
||||
if (!all_space(tc->prelude)) {
|
||||
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||
gbnf_format_literal(form.scope_end).c_str(),
|
||||
gbnf_format_literal(tc->prelude).c_str()
|
||||
);
|
||||
return return_error(builder, start_pos, recovery);
|
||||
}
|
||||
} else {
|
||||
if (all_space(form.scope_end)) return true;
|
||||
builder.consume_spaces();
|
||||
if (builder.pos() == builder.input().size())
|
||||
throw common_chat_msg_partial_exception("incomplete tool calls");
|
||||
LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n",
|
||||
gbnf_format_literal(form.scope_end).c_str(),
|
||||
gbnf_format_literal(builder.consume_rest()).c_str()
|
||||
);
|
||||
return return_error(builder, start_pos, recovery);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||
* May cause std::runtime_error if there is invalid syntax because partial valid tool call is already sent out to client.
|
||||
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||
*/
|
||||
bool common_chat_msg_parser::try_consume_xml_tool_calls(const struct xml_tool_call_format & form) {
|
||||
auto pos = pos_;
|
||||
auto tsize = result_.tool_calls.size();
|
||||
try { return parse_xml_tool_calls(*this, form); }
|
||||
catch (const xml_toolcall_syntax_exception&) {}
|
||||
move_to(pos);
|
||||
result_.tool_calls.resize(tsize);
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse content uses reasoning and XML-Style tool call
|
||||
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
|
||||
*/
|
||||
inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>") {
|
||||
constexpr auto rstrip = [](std::string &s) {
|
||||
s.resize(std::distance(s.begin(), std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base()));
|
||||
};
|
||||
// Erase substring from l to r, along with additional spaces nearby
|
||||
constexpr auto erase_spaces = [](auto &str, size_t l, size_t r) {
|
||||
while (/* l > -1 && */ --l < str.size() && std::isspace(static_cast<unsigned char>(str[l])));
|
||||
++l;
|
||||
while (++r < str.size() && std::isspace(static_cast<unsigned char>(str[r])));
|
||||
if (l < r) str[l] = '\n';
|
||||
if (l + 1 < r) str[l + 1] = '\n';
|
||||
if (l != 0) l += 2;
|
||||
str.erase(l, r - l);
|
||||
return l;
|
||||
};
|
||||
constexpr auto trim_suffix = [](std::string &content, std::initializer_list<std::string_view> list) {
|
||||
auto best_match = content.size();
|
||||
for (auto pattern: list) {
|
||||
if (pattern.size() == 0) continue;
|
||||
for (auto match_idx = content.size() - std::min(pattern.size(), content.size()); content.size() > match_idx; match_idx++) {
|
||||
auto match_len = content.size() - match_idx;
|
||||
if (content.compare(match_idx, match_len, pattern.data(), match_len) == 0 && best_match > match_idx) {
|
||||
best_match = match_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (content.size() > best_match) {
|
||||
content.erase(best_match);
|
||||
}
|
||||
};
|
||||
const auto trim_potential_partial_word = [&start_think, &end_think, &form, trim_suffix](std::string &content) {
|
||||
return trim_suffix(content, {
|
||||
start_think, end_think, form.scope_start, form.tool_start, form.tool_sep, form.key_start,
|
||||
form.key_val_sep, form.key_val_sep2 ? form.key_val_sep2->c_str() : "",
|
||||
form.val_end, form.last_val_end ? form.last_val_end->c_str() : "",
|
||||
form.tool_end, form.last_tool_end ? form.last_tool_end->c_str() : "",
|
||||
form.scope_end
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
// Trim leading spaces without affecting keyword matching
|
||||
static const common_regex spaces_regex("\\s*");
|
||||
{
|
||||
auto tc = builder.consume_regex(spaces_regex);
|
||||
auto spaces = builder.str(tc.groups[0]);
|
||||
auto s1 = spaces.size();
|
||||
trim_potential_partial_word(spaces);
|
||||
auto s2 = spaces.size();
|
||||
builder.move_to(builder.pos() - (s1 - s2));
|
||||
}
|
||||
|
||||
// Parse content
|
||||
bool reasoning_unclosed = builder.syntax().thinking_forced_open;
|
||||
std::string unclosed_reasoning_content("");
|
||||
for (;;) {
|
||||
auto tc = try_find_2_literal_splited_by_spaces(builder, form.scope_start, form.tool_start);
|
||||
std::string content;
|
||||
std::string tool_call_start;
|
||||
|
||||
if (tc) {
|
||||
content = std::move(tc->prelude);
|
||||
tool_call_start = builder.str(tc->groups[0]);
|
||||
LOG_DBG("Matched tool start: %s\n", gbnf_format_literal(tool_call_start).c_str());
|
||||
} else {
|
||||
content = builder.consume_rest();
|
||||
utf8_truncate_safe_resize(content);
|
||||
}
|
||||
|
||||
// Handle unclosed think block
|
||||
if (reasoning_unclosed) {
|
||||
if (auto pos = content.find(end_think); pos == std::string::npos && builder.pos() != builder.input().size()) {
|
||||
unclosed_reasoning_content += content;
|
||||
if (form.allow_toolcall_in_think) {
|
||||
builder.move_to(tc->groups[0].begin);
|
||||
if (!builder.try_consume_xml_tool_calls(form)) {
|
||||
unclosed_reasoning_content += tool_call_start;
|
||||
builder.move_to(tc->groups[0].end);
|
||||
}
|
||||
} else {
|
||||
unclosed_reasoning_content += tool_call_start;
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
reasoning_unclosed = false;
|
||||
std::string reasoning_content;
|
||||
if (pos == std::string::npos) {
|
||||
reasoning_content = std::move(content);
|
||||
} else {
|
||||
reasoning_content = content.substr(0, pos);
|
||||
content.erase(0, pos + end_think.size());
|
||||
}
|
||||
if (builder.pos() == builder.input().size() && all_space(content)) {
|
||||
rstrip(reasoning_content);
|
||||
trim_potential_partial_word(reasoning_content);
|
||||
rstrip(reasoning_content);
|
||||
if (reasoning_content.empty()) {
|
||||
rstrip(unclosed_reasoning_content);
|
||||
trim_potential_partial_word(unclosed_reasoning_content);
|
||||
rstrip(unclosed_reasoning_content);
|
||||
if (unclosed_reasoning_content.empty()) continue;
|
||||
}
|
||||
}
|
||||
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
|
||||
builder.add_content(start_think);
|
||||
builder.add_content(unclosed_reasoning_content);
|
||||
builder.add_content(reasoning_content);
|
||||
if (builder.pos() != builder.input().size() || !all_space(content))
|
||||
builder.add_content(end_think);
|
||||
} else {
|
||||
builder.add_reasoning_content(unclosed_reasoning_content);
|
||||
builder.add_reasoning_content(reasoning_content);
|
||||
}
|
||||
unclosed_reasoning_content.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Handle multiple think block
|
||||
bool toolcall_in_think = false;
|
||||
for (auto think_start = content.find(start_think); think_start != std::string::npos; think_start = content.find(start_think, think_start)) {
|
||||
if (auto think_end = content.find(end_think, think_start + start_think.size()); think_end != std::string::npos) {
|
||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
||||
auto reasoning_content = content.substr(think_start + start_think.size(), think_end - think_start - start_think.size());
|
||||
builder.add_reasoning_content(reasoning_content);
|
||||
think_start = erase_spaces(content, think_start, think_end + end_think.size() - 1);
|
||||
} else {
|
||||
think_start = think_end + end_think.size() - 1;
|
||||
}
|
||||
} else {
|
||||
// This <tool_call> start is in thinking block, skip this tool call
|
||||
auto pos = think_start + start_think.size();
|
||||
unclosed_reasoning_content = content.substr(pos) + tool_call_start;
|
||||
reasoning_unclosed = true;
|
||||
content.resize(think_start);
|
||||
toolcall_in_think = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) {
|
||||
rstrip(content);
|
||||
// Handle unclosed </think> token from content: delete all </think> token
|
||||
if (auto pos = content.rfind(end_think); pos != std::string::npos) {
|
||||
while (pos != std::string::npos) {
|
||||
pos = erase_spaces(content, pos, pos + end_think.size() - 1);
|
||||
pos = content.rfind(end_think, pos);
|
||||
}
|
||||
}
|
||||
// Strip if needed
|
||||
if (content.size() > 0 && std::isspace(static_cast<unsigned char>(content[0]))) {
|
||||
content = string_strip(content);
|
||||
}
|
||||
}
|
||||
|
||||
// remove potential partial suffix
|
||||
if (content.size() > 0 && builder.pos() == builder.input().size() && unclosed_reasoning_content.empty()) {
|
||||
rstrip(content);
|
||||
trim_potential_partial_word(content);
|
||||
rstrip(content);
|
||||
}
|
||||
|
||||
// Add content
|
||||
if (content.size() != 0) {
|
||||
// If there are multiple content blocks
|
||||
if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content && builder.result().content.size() != 0) {
|
||||
builder.add_content("\n\n");
|
||||
}
|
||||
builder.add_content(content);
|
||||
}
|
||||
|
||||
// This <tool_call> start is in thinking block, skip this tool call
|
||||
if (toolcall_in_think && !form.allow_toolcall_in_think) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// There is no tool call and all content is parsed
|
||||
if (!tc) {
|
||||
GGML_ASSERT(builder.pos() == builder.input().size());
|
||||
GGML_ASSERT(unclosed_reasoning_content.empty());
|
||||
GGML_ASSERT(!reasoning_unclosed);
|
||||
break;
|
||||
}
|
||||
|
||||
builder.move_to(tc->groups[0].begin);
|
||||
if (builder.try_consume_xml_tool_calls(form)) {
|
||||
auto end_of_tool = builder.pos();
|
||||
builder.consume_spaces();
|
||||
if (builder.pos() != builder.input().size()) {
|
||||
builder.move_to(end_of_tool);
|
||||
if (!builder.result().content.empty()) {
|
||||
builder.add_content("\n\n");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static const common_regex next_char_regex(".");
|
||||
auto c = builder.str(builder.consume_regex(next_char_regex).groups[0]);
|
||||
rstrip(c);
|
||||
builder.add_content(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse content uses reasoning and XML-Style tool call
|
||||
* TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed.
|
||||
*/
|
||||
void common_chat_msg_parser::consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think, const std::string & end_think) {
|
||||
parse_msg_with_xml_tool_calls(*this, form, start_think, end_think);
|
||||
}
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
#pragma once
|
||||
|
||||
#include "chat.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
||||
// Sample config:
|
||||
// MiniMax-M2 (left): <minimax:tool_call>\n<invoke name="tool-name">\n<parameter name="key">value</parameter>\n...</invoke>\n...</minimax:tool_call>
|
||||
// GLM 4.5 (right): <tool_call>function_name\n<arg_key>key</arg_key>\n<arg_value>value</arg_value>\n</tool_call>
|
||||
struct xml_tool_call_format {
|
||||
std::string scope_start; // <minimax:tool_call>\n // \n // can be empty
|
||||
std::string tool_start; // <invoke name=\" // <tool_call>
|
||||
std::string tool_sep; // \">\n // \n // can be empty only for parse_xml_tool_calls
|
||||
std::string key_start; // <parameter name=\" // <arg_key>
|
||||
std::string key_val_sep; // \"> // </arg_key>\n<arg_value>
|
||||
std::string val_end; // </parameter>\n // </arg_value>\n
|
||||
std::string tool_end; // </invoke>\n // </tool_call>\n
|
||||
std::string scope_end; // </minimax:tool_call> // // can be empty
|
||||
// Set this if there can be dynamic spaces inside key_val_sep.
|
||||
// e.g. key_val_sep=</arg_key> key_val_sep2=<arg_value> for GLM4.5
|
||||
std::optional<std::string> key_val_sep2 = std::nullopt;
|
||||
// Set true if argval should only be raw string. e.g. Hello "world" hi
|
||||
// Set false if argval should only be json string. e.g. "Hello \"world\" hi"
|
||||
// Defaults to std::nullopt, both will be allowed.
|
||||
std::optional<bool> raw_argval = std::nullopt;
|
||||
std::optional<std::string> last_val_end = std::nullopt;
|
||||
std::optional<std::string> last_tool_end = std::nullopt;
|
||||
bool trim_raw_argval = false;
|
||||
bool allow_toolcall_in_think = false; // TODO: UNTESTED!!!
|
||||
};
|
||||
|
||||
// make a GBNF that accept any strings except those containing any of the forbidden strings.
|
||||
std::string make_gbnf_excluding(std::vector<std::string> forbids);
|
||||
|
||||
/**
|
||||
* Build grammar for xml-style tool call
|
||||
* form.scope_start and form.scope_end can be empty.
|
||||
* Requires data.format for model-specific hacks.
|
||||
*/
|
||||
void build_grammar_xml_tool_call(common_chat_params & data, const nlohmann::ordered_json & tools, const struct xml_tool_call_format & form);
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "chat.h"
|
||||
#include "chat-parser-xml-toolcall.h"
|
||||
#include "json-partial.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
|
|
@ -119,5 +120,14 @@ class common_chat_msg_parser {
|
|||
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||
);
|
||||
|
||||
/**
|
||||
* Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched.
|
||||
* form.scope_start, form.tool_sep and form.scope_end can be empty.
|
||||
*/
|
||||
bool try_consume_xml_tool_calls(const struct xml_tool_call_format & form);
|
||||
|
||||
// Parse content uses reasoning and XML-Style tool call
|
||||
void consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think = "<think>", const std::string & end_think = "</think>");
|
||||
|
||||
void clear_tools();
|
||||
};
|
||||
|
|
|
|||
548
common/chat.cpp
548
common/chat.cpp
|
|
@ -643,6 +643,12 @@ const char * common_chat_format_name(common_chat_format format) {
|
|||
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
|
||||
case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
|
||||
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools";
|
||||
case COMMON_CHAT_FORMAT_MINIMAX_M2: return "MiniMax-M2";
|
||||
case COMMON_CHAT_FORMAT_GLM_4_5: return "GLM 4.5";
|
||||
case COMMON_CHAT_FORMAT_KIMI_K2: return "Kimi K2";
|
||||
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder";
|
||||
case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
|
||||
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
|
|
@ -1807,6 +1813,278 @@ static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
static common_chat_params common_chat_params_init_minimax_m2(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
data.prompt = apply(tmpl, params);
|
||||
data.format = COMMON_CHAT_FORMAT_MINIMAX_M2;
|
||||
|
||||
// Handle thinking tags based on prompt ending
|
||||
if (string_ends_with(data.prompt, "<think>\n")) {
|
||||
if (!params.enable_thinking) {
|
||||
// Close the thinking tag immediately if thinking is disabled
|
||||
data.prompt += "</think>\n\n";
|
||||
} else {
|
||||
// Mark thinking as forced open (template started with <think>)
|
||||
data.thinking_forced_open = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Preserve MiniMax-M2 special tokens
|
||||
data.preserved_tokens = {
|
||||
"<think>",
|
||||
"</think>",
|
||||
"<minimax:tool_call>",
|
||||
"</minimax:tool_call>",
|
||||
};
|
||||
|
||||
// build grammar for tool call
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "<minimax:tool_call>\n",
|
||||
/* form.tool_start = */ "<invoke name=\"",
|
||||
/* form.tool_sep = */ "\">\n",
|
||||
/* form.key_start = */ "<parameter name=\"",
|
||||
/* form.key_val_sep = */ "\">",
|
||||
/* form.val_end = */ "</parameter>\n",
|
||||
/* form.tool_end = */ "</invoke>\n",
|
||||
/* form.scope_end = */ "</minimax:tool_call>",
|
||||
};
|
||||
build_grammar_xml_tool_call(data, params.tools, form);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "<minimax:tool_call>",
|
||||
/* form.tool_start = */ "<invoke name=\"",
|
||||
/* form.tool_sep = */ "\">",
|
||||
/* form.key_start = */ "<parameter name=\"",
|
||||
/* form.key_val_sep = */ "\">",
|
||||
/* form.val_end = */ "</parameter>",
|
||||
/* form.tool_end = */ "</invoke>",
|
||||
/* form.scope_end = */ "</minimax:tool_call>",
|
||||
};
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
data.prompt = apply(tmpl, params);
|
||||
data.format = COMMON_CHAT_FORMAT_QWEN3_CODER_XML;
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
"<function=",
|
||||
"</function>",
|
||||
"<parameter=",
|
||||
"</parameter>",
|
||||
};
|
||||
|
||||
// build grammar for tool call
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "<tool_call>\n",
|
||||
/* form.tool_start = */ "<function=",
|
||||
/* form.tool_sep = */ ">\n",
|
||||
/* form.key_start = */ "<parameter=",
|
||||
/* form.key_val_sep = */ ">\n",
|
||||
/* form.val_end = */ "\n</parameter>\n",
|
||||
/* form.tool_end = */ "</function>\n",
|
||||
/* form.scope_end = */ "</tool_call>",
|
||||
};
|
||||
build_grammar_xml_tool_call(data, params.tools, form);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<tool_call>";
|
||||
form.tool_start = "<function=";
|
||||
form.tool_sep = ">";
|
||||
form.key_start = "<parameter=";
|
||||
form.key_val_sep = ">";
|
||||
form.val_end = "</parameter>";
|
||||
form.tool_end = "</function>";
|
||||
form.scope_end = "</tool_call>";
|
||||
form.trim_raw_argval = true;
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
data.prompt = apply(tmpl, params);
|
||||
data.format = COMMON_CHAT_FORMAT_KIMI_K2;
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<think>",
|
||||
"</think>",
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_call_begin|>",
|
||||
"<|tool_call_argument_begin|>",
|
||||
"<|tool_call_end|>",
|
||||
"<|tool_calls_section_end|>",
|
||||
"<|im_end|>",
|
||||
"<|im_system|>",
|
||||
"<|im_middle|>",
|
||||
};
|
||||
|
||||
data.additional_stops.insert(data.additional_stops.end(), {
|
||||
"<|im_end|>",
|
||||
"<|im_middle|>"
|
||||
});
|
||||
// build grammar for tool call
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<|tool_calls_section_begin|>";
|
||||
form.tool_start = "<|tool_call_begin|>";
|
||||
form.tool_sep = "<|tool_call_argument_begin|>{";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}<|tool_call_end|>";
|
||||
form.scope_end = "<|tool_calls_section_end|>";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
return form;
|
||||
})();
|
||||
build_grammar_xml_tool_call(data, params.tools, form);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<|tool_calls_section_begin|>";
|
||||
form.tool_start = "<|tool_call_begin|>";
|
||||
form.tool_sep = "<|tool_call_argument_begin|>{";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}<|tool_call_end|>";
|
||||
form.scope_end = "<|tool_calls_section_end|>";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
data.prompt = apply(tmpl, params);
|
||||
data.format = COMMON_CHAT_FORMAT_APRIEL_1_5;
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<thinking>",
|
||||
"</thinking>",
|
||||
"<tool_calls>",
|
||||
"</tool_calls>",
|
||||
};
|
||||
|
||||
// build grammar for tool call
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<tool_calls>[";
|
||||
form.tool_start = "{\"name\": \"";
|
||||
form.tool_sep = "\", \"arguments\": {";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}, ";
|
||||
form.scope_end = "]</tool_calls>";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
form.last_tool_end = "}";
|
||||
return form;
|
||||
})();
|
||||
build_grammar_xml_tool_call(data, params.tools, form);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<tool_calls>[";
|
||||
form.tool_start = "{\"name\": \"";
|
||||
form.tool_sep = "\", \"arguments\": {";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}, ";
|
||||
form.scope_end = "]</tool_calls>";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
form.last_tool_end = "}";
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<thinking>", "</thinking>");
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
data.prompt = apply(tmpl, params);
|
||||
data.format = COMMON_CHAT_FORMAT_XIAOMI_MIMO;
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
};
|
||||
|
||||
// build grammar for tool call
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "\n";
|
||||
form.tool_start = "<tool_call>\n{\"name\": \"";
|
||||
form.tool_sep = "\", \"arguments\": {";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}\n</tool_call>";
|
||||
form.scope_end = "";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
return form;
|
||||
})();
|
||||
build_grammar_xml_tool_call(data, params.tools, form);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "";
|
||||
form.tool_start = "<tool_call>\n{\"name\": \"";
|
||||
form.tool_sep = "\", \"arguments\": {";
|
||||
form.key_start = "\"";
|
||||
form.key_val_sep = "\": ";
|
||||
form.val_end = ", ";
|
||||
form.tool_end = "}\n</tool_call>";
|
||||
form.scope_end = "";
|
||||
form.raw_argval = false;
|
||||
form.last_val_end = "";
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
|
|
@ -2041,6 +2319,100 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
|
|||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_glm_4_5(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = inputs.tools.is_array() && !inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
std::string prompt = apply(tmpl, inputs);
|
||||
|
||||
// match the existing trimming behavior
|
||||
if (inputs.add_bos && string_starts_with(prompt, tmpl.bos_token())) {
|
||||
prompt.erase(0, tmpl.bos_token().size());
|
||||
}
|
||||
if (inputs.add_eos && string_ends_with(prompt, tmpl.eos_token())) {
|
||||
prompt.erase(prompt.size() - tmpl.eos_token().size());
|
||||
}
|
||||
if (string_ends_with(prompt, "<think>")) {
|
||||
if (!inputs.enable_thinking) {
|
||||
prompt += "</think>";
|
||||
} else {
|
||||
data.thinking_forced_open = true;
|
||||
}
|
||||
}
|
||||
|
||||
// add GLM preserved tokens
|
||||
data.preserved_tokens = {
|
||||
"<|endoftext|>",
|
||||
"[MASK]",
|
||||
"[gMASK]",
|
||||
"[sMASK]",
|
||||
"<sop>",
|
||||
"<eop>",
|
||||
"<|system|>",
|
||||
"<|user|>",
|
||||
"<|assistant|>",
|
||||
"<|observation|>",
|
||||
"<|begin_of_image|>",
|
||||
"<|end_of_image|>",
|
||||
"<|begin_of_video|>",
|
||||
"<|end_of_video|>",
|
||||
"<|begin_of_audio|>",
|
||||
"<|end_of_audio|>",
|
||||
"<|begin_of_transcription|>",
|
||||
"<|end_of_transcription|>",
|
||||
"<|code_prefix|>",
|
||||
"<|code_middle|>",
|
||||
"<|code_suffix|>",
|
||||
"/nothink",
|
||||
"<think>",
|
||||
"</think>",
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
"<arg_key>",
|
||||
"</arg_key>",
|
||||
"<arg_value>",
|
||||
"</arg_value>"
|
||||
};
|
||||
|
||||
// extra GLM 4.5 stop word
|
||||
data.additional_stops.insert(data.additional_stops.end(), {
|
||||
"<|user|>",
|
||||
"<|observation|>"
|
||||
});
|
||||
|
||||
// build grammar for tool call
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "",
|
||||
/* form.tool_start = */ "\n<tool_call>",
|
||||
/* form.tool_sep = */ "\n",
|
||||
/* form.key_start = */ "<arg_key>",
|
||||
/* form.key_val_sep = */ "</arg_key>\n<arg_value>",
|
||||
/* form.val_end = */ "</arg_value>\n",
|
||||
/* form.tool_end = */ "</tool_call>\n",
|
||||
/* form.scope_end = */ "",
|
||||
};
|
||||
build_grammar_xml_tool_call(data, inputs.tools, form);
|
||||
|
||||
data.prompt = prompt;
|
||||
data.format = COMMON_CHAT_FORMAT_GLM_4_5;
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "",
|
||||
/* form.tool_start = */ "<tool_call>",
|
||||
/* form.tool_sep = */ "",
|
||||
/* form.key_start = */ "<arg_key>",
|
||||
/* form.key_val_sep = */ "</arg_key>",
|
||||
/* form.val_end = */ "</arg_value>",
|
||||
/* form.tool_end = */ "</tool_call>",
|
||||
/* form.scope_end = */ "",
|
||||
/* form.key_val_sep2 = */ "<arg_value>",
|
||||
};
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
common_chat_params data;
|
||||
|
|
@ -2704,91 +3076,17 @@ static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
|
|||
}
|
||||
|
||||
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags first - this handles the main reasoning content
|
||||
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse tool calls - Seed-OSS uses <seed:tool_call> format
|
||||
static const common_regex tool_call_begin_regex("<seed:tool_call>");
|
||||
static const common_regex tool_call_end_regex("</seed:tool_call>");
|
||||
static const common_regex function_regex("<function=([^>]+)>");
|
||||
static const common_regex param_regex("<parameter=([^>]+)>");
|
||||
|
||||
while (auto tool_res = builder.try_find_regex(tool_call_begin_regex)) {
|
||||
builder.consume_spaces(); // Consume whitespace after <seed:tool_call>
|
||||
|
||||
// Look for function call inside tool call, ignore any content before it
|
||||
if (auto func_res = builder.try_find_regex(function_regex, std::string::npos, false)) {
|
||||
auto function_name = builder.str(func_res->groups[1]);
|
||||
|
||||
// Parse Seed-OSS parameters <parameter=name>value</parameter>
|
||||
json args = json::object();
|
||||
// Parse all parameters
|
||||
while (auto param_res = builder.try_find_regex(param_regex, std::string::npos, false)) {
|
||||
// again, ignore noise around parameters
|
||||
auto param_name = builder.str(param_res->groups[1]);
|
||||
builder.move_to(param_res->groups[0].end);
|
||||
builder.consume_spaces(); // Consume whitespace after parameter
|
||||
auto savedPos = builder.pos();
|
||||
if (auto param_parse = builder.try_find_literal("</parameter>")) {
|
||||
auto param = param_parse->prelude;
|
||||
builder.move_to(savedPos);
|
||||
try {
|
||||
if (auto param_res = builder.try_consume_json()) {
|
||||
args[param_name] = param_res->json;
|
||||
} else {
|
||||
args[param_name] = param;
|
||||
}
|
||||
} catch (json::exception &) {
|
||||
args[param_name] = param;
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool parameter");
|
||||
}
|
||||
}
|
||||
// Look for closing function tag
|
||||
auto end_func = builder.try_find_literal("</function>");
|
||||
if (end_func) {
|
||||
builder.move_to(end_func->groups[0].end);
|
||||
builder.consume_spaces(); // Consume whitespace after </function>
|
||||
|
||||
// Add the tool call with parsed arguments, but only if we REALLY got the literal
|
||||
auto eaten_fragment = builder.input().substr(end_func->groups[0].begin, end_func->groups[0].end);
|
||||
auto funlen = std::string("</function>").length();
|
||||
if (eaten_fragment.length() >= funlen && eaten_fragment.substr(0, funlen) == std::string("</function>")) {
|
||||
if (!builder.add_tool_call(function_name, "", args.dump())) {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
// Look for closing tool call tag
|
||||
if (auto end_tool = builder.try_find_regex(tool_call_end_regex, std::string::npos, false)) {
|
||||
builder.move_to(end_tool->groups[0].end);
|
||||
builder.consume_spaces(); // Consume trailing whitespace after tool call
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
} else {
|
||||
// No function found - don't consume content here, let it be handled at the end
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Consume any remaining whitespace after all tool call processing
|
||||
builder.consume_spaces();
|
||||
auto remaining = builder.consume_rest();
|
||||
// If there's any non-whitespace content remaining, add it as content
|
||||
if (!string_strip(remaining).empty()) {
|
||||
builder.add_content(remaining);
|
||||
}
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "<seed:tool_call>",
|
||||
/* form.tool_start = */ "<function=",
|
||||
/* form.tool_sep = */ ">",
|
||||
/* form.key_start = */ "<parameter=",
|
||||
/* form.key_val_sep = */ ">",
|
||||
/* form.val_end = */ "</parameter>",
|
||||
/* form.tool_end = */ "</function>",
|
||||
/* form.scope_end = */ "</seed:tool_call>",
|
||||
};
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<seed:think>", "</seed:think>");
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
|
|
@ -2927,6 +3225,35 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
return common_chat_params_init_granite(tmpl, params);
|
||||
}
|
||||
|
||||
// GLM 4.5: detect by <arg_key> and <arg_value> tags (check before Hermes since both use <tool_call>)
|
||||
if (src.find("[gMASK]<sop>") != std::string::npos &&
|
||||
src.find("<arg_key>") != std::string::npos &&
|
||||
src.find("<arg_value>") != std::string::npos &&
|
||||
params.json_schema.is_null()) {
|
||||
return common_chat_params_init_glm_4_5(tmpl, params);
|
||||
}
|
||||
|
||||
// Qwen3-Coder XML format detection (must come before Hermes 2 Pro)
|
||||
// Detect via explicit XML markers unique to Qwen3-Coder to avoid false positives in other templates.
|
||||
// Require presence of <tool_call>, <function=...>, and <parameter=...> blocks.
|
||||
if (src.find("<tool_call>") != std::string::npos &&
|
||||
src.find("<function>") != std::string::npos &&
|
||||
src.find("<function=") != std::string::npos &&
|
||||
src.find("<parameters>") != std::string::npos &&
|
||||
src.find("<parameter=") != std::string::npos) {
|
||||
return common_chat_params_init_qwen3_coder_xml(tmpl, params);
|
||||
}
|
||||
|
||||
// Xiaomi MiMo format detection (must come before Hermes 2 Pro)
|
||||
if (src.find("<tools>") != std::string::npos &&
|
||||
src.find("# Tools") != std::string::npos &&
|
||||
src.find("</tools>") != std::string::npos &&
|
||||
src.find("<tool_calls>") != std::string::npos &&
|
||||
src.find("</tool_calls>") != std::string::npos &&
|
||||
src.find("<tool_response>") != std::string::npos) {
|
||||
return common_chat_params_init_xiaomi_mimo(tmpl, params);
|
||||
}
|
||||
|
||||
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
||||
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
|
||||
return common_chat_params_init_hermes_2_pro(tmpl, params);
|
||||
|
|
@ -2958,6 +3285,29 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
return common_chat_params_init_lfm2(tmpl, params);
|
||||
}
|
||||
|
||||
// MiniMax-M2 format detection
|
||||
if (src.find("]~!b[") != std::string::npos && src.find("]~b]") != std::string::npos) {
|
||||
return common_chat_params_init_minimax_m2(tmpl, params);
|
||||
}
|
||||
|
||||
// Kimi K2 format detection
|
||||
if (src.find("<|im_system|>tool_declare<|im_middle|>") != std::string::npos &&
|
||||
src.find("<|tool_calls_section_begin|>") != std::string::npos &&
|
||||
src.find("## Return of") != std::string::npos) {
|
||||
return common_chat_params_init_kimi_k2(tmpl, params);
|
||||
}
|
||||
|
||||
// Apriel 1.5 format detection
|
||||
if (src.find("<thinking>") != std::string::npos &&
|
||||
src.find("</thinking>") != std::string::npos &&
|
||||
src.find("<available_tools>") != std::string::npos &&
|
||||
src.find("<|assistant|>") != std::string::npos &&
|
||||
src.find("<|tool_result|>") != std::string::npos &&
|
||||
src.find("<tool_calls>[") != std::string::npos &&
|
||||
src.find("]</tool_calls>") != std::string::npos) {
|
||||
return common_chat_params_init_apriel_1_5(tmpl, params);
|
||||
}
|
||||
|
||||
// Use generic handler when mixing tools + JSON schema.
|
||||
// TODO: support that mix in handlers below.
|
||||
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
||||
|
|
@ -3009,7 +3359,7 @@ static common_chat_params common_chat_templates_apply_legacy(
|
|||
const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs)
|
||||
{
|
||||
int alloc_size = 0;
|
||||
size_t alloc_size = 0;
|
||||
std::vector<llama_chat_message> chat;
|
||||
std::vector<std::string> contents;
|
||||
|
||||
|
|
@ -3031,7 +3381,8 @@ static common_chat_params common_chat_templates_apply_legacy(
|
|||
const auto & msg = inputs.messages[i];
|
||||
const auto & content = contents[i];
|
||||
chat.push_back({msg.role.c_str(), content.c_str()});
|
||||
alloc_size += (msg.role.size() + content.size()) * 1.25;
|
||||
size_t msg_size = msg.role.size() + content.size();
|
||||
alloc_size += msg_size + (msg_size / 4); // == msg_size * 1.25 but avoiding float ops
|
||||
}
|
||||
|
||||
std::vector<char> buf(alloc_size);
|
||||
|
|
@ -3053,6 +3404,11 @@ static common_chat_params common_chat_templates_apply_legacy(
|
|||
res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
||||
}
|
||||
|
||||
// for safety, we check the result again
|
||||
if (res < 0 || (size_t) res > buf.size()) {
|
||||
throw std::runtime_error("failed to apply chat template, try using --jinja");
|
||||
}
|
||||
|
||||
common_chat_params params;
|
||||
params.prompt = std::string(buf.data(), res);
|
||||
if (!inputs.json_schema.empty()) {
|
||||
|
|
@ -3139,6 +3495,24 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
|||
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
|
||||
common_chat_parse_lfm2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_MINIMAX_M2:
|
||||
common_chat_parse_minimax_m2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_GLM_4_5:
|
||||
common_chat_parse_glm_4_5(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_KIMI_K2:
|
||||
common_chat_parse_kimi_k2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
|
||||
common_chat_parse_qwen3_coder_xml(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_APRIEL_1_5:
|
||||
common_chat_parse_apriel_1_5(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_XIAOMI_MIMO:
|
||||
common_chat_parse_xiaomi_mimo(builder);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -117,6 +117,12 @@ enum common_chat_format {
|
|||
COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
||||
COMMON_CHAT_FORMAT_APERTUS,
|
||||
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
|
||||
COMMON_CHAT_FORMAT_GLM_4_5,
|
||||
COMMON_CHAT_FORMAT_MINIMAX_M2,
|
||||
COMMON_CHAT_FORMAT_KIMI_K2,
|
||||
COMMON_CHAT_FORMAT_QWEN3_CODER_XML,
|
||||
COMMON_CHAT_FORMAT_APRIEL_1_5,
|
||||
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
|
|
|||
|
|
@ -297,8 +297,25 @@ bool common_json_parse(
|
|||
it = temptative_end;
|
||||
return true;
|
||||
}
|
||||
// TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
||||
// fprintf(stderr, "Closing: TODO\n");
|
||||
// handle unclosed top-level primitive
|
||||
if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) {
|
||||
std::string str(it, temptative_end);
|
||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;
|
||||
if (can_parse(str + "\"")) {
|
||||
// Was inside an string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"";
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) {
|
||||
// Was inside an string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"";
|
||||
} else {
|
||||
// TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
||||
// fprintf(stderr, "Closing: TODO\n");
|
||||
return false;
|
||||
}
|
||||
out.json = json::parse(str);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
out.json = json::parse(it, end);
|
||||
|
|
|
|||
|
|
@ -303,6 +303,8 @@ static std::string format_literal(const std::string & literal) {
|
|||
return "\"" + escaped + "\"";
|
||||
}
|
||||
|
||||
std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); }
|
||||
|
||||
class SchemaConverter {
|
||||
private:
|
||||
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
||||
|
|
|
|||
|
|
@ -18,4 +18,6 @@ struct common_grammar_options {
|
|||
bool dotall = false;
|
||||
};
|
||||
|
||||
std::string gbnf_format_literal(const std::string & literal);
|
||||
|
||||
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
||||
|
|
|
|||
|
|
@ -1673,11 +1673,9 @@ class GPTNeoXModel(TextModel):
|
|||
model_arch = gguf.MODEL_ARCH.GPTNEOX
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
|
||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_dimension_count(
|
||||
int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
|
||||
|
|
@ -1735,7 +1733,7 @@ class BloomModel(TextModel):
|
|||
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
|
||||
self.gguf_writer.add_embedding_length(n_embed)
|
||||
self.gguf_writer.add_feed_forward_length(4 * n_embed)
|
||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head)
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||
|
|
@ -1798,10 +1796,9 @@ class MPTModel(TextModel):
|
|||
self.gguf_writer.add_unk_token_id(0)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["n_layers"]
|
||||
self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"])
|
||||
self.gguf_writer.add_head_count(self.hparams["n_heads"])
|
||||
if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"):
|
||||
|
|
@ -1834,7 +1831,6 @@ class OrionModel(TextModel):
|
|||
self._set_vocab_sentencepiece()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
head_count = self.hparams["num_attention_heads"]
|
||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||
|
||||
|
|
@ -1852,7 +1848,7 @@ class OrionModel(TextModel):
|
|||
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
||||
self.gguf_writer.add_context_length(ctx_length)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(head_count)
|
||||
self.gguf_writer.add_head_count_kv(head_count_kv)
|
||||
|
|
@ -1869,7 +1865,6 @@ class BaichuanModel(TextModel):
|
|||
self._set_vocab_sentencepiece()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
head_count = self.hparams["num_attention_heads"]
|
||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||
|
||||
|
|
@ -1886,7 +1881,7 @@ class BaichuanModel(TextModel):
|
|||
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
||||
self.gguf_writer.add_context_length(ctx_length)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count(head_count)
|
||||
|
|
@ -1993,7 +1988,6 @@ class XverseModel(TextModel):
|
|||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
head_count = self.hparams["num_attention_heads"]
|
||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||
|
||||
|
|
@ -2010,7 +2004,7 @@ class XverseModel(TextModel):
|
|||
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
||||
self.gguf_writer.add_context_length(ctx_length)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count(head_count)
|
||||
|
|
@ -2053,10 +2047,6 @@ class FalconModel(TextModel):
|
|||
model_arch = gguf.MODEL_ARCH.FALCON
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams.get("num_hidden_layers")
|
||||
if block_count is None:
|
||||
block_count = self.hparams["n_layer"] # old name
|
||||
|
||||
n_head = self.hparams.get("num_attention_heads")
|
||||
if n_head is None:
|
||||
n_head = self.hparams["n_head"] # old name
|
||||
|
|
@ -2069,7 +2059,7 @@ class FalconModel(TextModel):
|
|||
self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||
|
|
@ -2107,12 +2097,10 @@ class StarCoderModel(TextModel):
|
|||
model_arch = gguf.MODEL_ARCH.STARCODER
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["n_layer"]
|
||||
|
||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
||||
self.gguf_writer.add_head_count_kv(1)
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||
|
|
@ -2142,14 +2130,12 @@ class RefactModel(TextModel):
|
|||
multiple_of = 256
|
||||
ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
block_count = self.hparams["n_layer"]
|
||||
|
||||
# refact uses Alibi. So this is from config.json which might be used by training.
|
||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||
|
||||
self.gguf_writer.add_feed_forward_length(ff_dim)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
||||
self.gguf_writer.add_head_count_kv(1)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
|
||||
|
|
@ -2196,11 +2182,10 @@ class StableLMModel(TextModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
|
||||
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
|
||||
|
|
@ -3151,7 +3136,7 @@ class DbrxModel(TextModel):
|
|||
def set_gguf_parameters(self):
|
||||
ffn_config = self.hparams["ffn_config"]
|
||||
attn_config = self.hparams["attn_config"]
|
||||
self.gguf_writer.add_block_count(self.hparams["n_layers"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
|
||||
self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
|
|
@ -3353,7 +3338,7 @@ class QwenModel(TextModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
|
||||
|
|
@ -4384,7 +4369,7 @@ class GPT2Model(TextModel):
|
|||
model_arch = gguf.MODEL_ARCH.GPT2
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_context_length(self.hparams["n_ctx"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
||||
|
|
@ -4416,8 +4401,6 @@ class Phi2Model(TextModel):
|
|||
model_arch = gguf.MODEL_ARCH.PHI2
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
|
||||
|
||||
rot_pct = self.find_hparam(["partial_rotary_factor"])
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
|
|
@ -4426,7 +4409,7 @@ class Phi2Model(TextModel):
|
|||
|
||||
self.gguf_writer.add_embedding_length(n_embd)
|
||||
self.gguf_writer.add_feed_forward_length(4 * n_embd)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head)
|
||||
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"]))
|
||||
|
|
@ -4544,8 +4527,6 @@ class Phi3MiniModel(TextModel):
|
|||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
|
||||
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
|
||||
|
|
@ -4559,7 +4540,7 @@ class Phi3MiniModel(TextModel):
|
|||
self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds)
|
||||
self.gguf_writer.add_embedding_length(n_embd)
|
||||
self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_eps)
|
||||
|
|
@ -4679,12 +4660,11 @@ class PlamoModel(TextModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
self.gguf_writer.add_context_length(4096) # not in config.json
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
|
||||
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
|
||||
|
|
@ -4807,7 +4787,6 @@ class Plamo2Model(TextModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||
|
||||
# Which layers are Mamba layers
|
||||
|
|
@ -4819,10 +4798,10 @@ class Plamo2Model(TextModel):
|
|||
num_attention_heads = []
|
||||
|
||||
if mamba_enabled:
|
||||
for i in range(block_count):
|
||||
if block_count <= (mamba_step // 2):
|
||||
for i in range(self.block_count):
|
||||
if self.block_count <= (mamba_step // 2):
|
||||
# use attention in last layer
|
||||
is_mamba = (i != block_count - 1)
|
||||
is_mamba = (i != self.block_count - 1)
|
||||
else:
|
||||
is_mamba = (i % mamba_step) != (mamba_step // 2)
|
||||
if is_mamba:
|
||||
|
|
@ -4840,7 +4819,7 @@ class Plamo2Model(TextModel):
|
|||
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
|
||||
self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128))
|
||||
self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128))
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
|
||||
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000))
|
||||
|
||||
|
|
@ -4897,12 +4876,10 @@ class CodeShellModel(TextModel):
|
|||
model_arch = gguf.MODEL_ARCH.CODESHELL
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["n_layer"]
|
||||
|
||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
||||
self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"])
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||
|
|
@ -5044,7 +5021,7 @@ class InternLM2Model(TextModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
|
||||
|
|
@ -5665,11 +5642,10 @@ class GemmaModel(TextModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
|
||||
|
|
@ -5705,11 +5681,10 @@ class Gemma2Model(TextModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
|
||||
|
|
@ -5753,12 +5728,11 @@ class Gemma3Model(TextModel):
|
|||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
# some default values are not specified in the hparams
|
||||
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072))
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8))
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
|
||||
|
|
@ -6034,7 +6008,6 @@ class Rwkv6Model(TextModel):
|
|||
self._set_vocab_rwkv_world()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
head_size = self.hparams["head_size"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
layer_norm_eps = self.hparams["layer_norm_epsilon"]
|
||||
|
|
@ -6046,7 +6019,7 @@ class Rwkv6Model(TextModel):
|
|||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
||||
self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
|
|
@ -6110,7 +6083,6 @@ class RWKV6Qwen2Model(Rwkv6Model):
|
|||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
num_attention_heads = self.hparams["num_attention_heads"]
|
||||
num_key_value_heads = self.hparams["num_key_value_heads"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
|
|
@ -6123,7 +6095,7 @@ class RWKV6Qwen2Model(Rwkv6Model):
|
|||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
|
||||
self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
|
||||
|
|
@ -6164,7 +6136,6 @@ class Rwkv7Model(TextModel):
|
|||
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
try:
|
||||
head_size = self.hparams["head_size"]
|
||||
layer_norm_eps = self.hparams["layer_norm_epsilon"]
|
||||
|
|
@ -6189,7 +6160,7 @@ class Rwkv7Model(TextModel):
|
|||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||
|
|
@ -6283,7 +6254,6 @@ class ARwkv7Model(Rwkv7Model):
|
|||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
head_size = self.hparams["head_size"]
|
||||
rms_norm_eps = self.hparams["rms_norm_eps"]
|
||||
|
|
@ -6300,7 +6270,7 @@ class ARwkv7Model(Rwkv7Model):
|
|||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
|
||||
|
|
@ -7524,7 +7494,7 @@ class T5Model(TextModel):
|
|||
self.gguf_writer.add_context_length(n_ctx)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_layers"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None:
|
||||
self.gguf_writer.add_decoder_block_count(dec_n_layer)
|
||||
self.gguf_writer.add_head_count(self.hparams["num_heads"])
|
||||
|
|
@ -7663,7 +7633,7 @@ class T5EncoderModel(TextModel):
|
|||
self.gguf_writer.add_context_length(n_ctx)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_layers"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(self.hparams["num_heads"])
|
||||
self.gguf_writer.add_key_length(self.hparams["d_kv"])
|
||||
self.gguf_writer.add_value_length(self.hparams["d_kv"])
|
||||
|
|
@ -7726,7 +7696,7 @@ class JaisModel(TextModel):
|
|||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"])
|
||||
|
|
@ -8068,7 +8038,7 @@ class ChatGLMModel(TextModel):
|
|||
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
|
||||
self.gguf_writer.add_embedding_length(n_embed)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed)))
|
||||
self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams["num_hidden_layers"]))
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5))
|
||||
|
|
@ -8150,7 +8120,6 @@ class ExaoneModel(TextModel):
|
|||
num_kv_heads = hparams.get("num_key_value_heads", num_heads)
|
||||
layer_norm_eps = hparams["layer_norm_epsilon"]
|
||||
intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim
|
||||
num_layers = hparams["num_layers"]
|
||||
# ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0
|
||||
# attention_dropout_rate = hparams["attention_dropout"]
|
||||
# ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0
|
||||
|
|
@ -8161,7 +8130,7 @@ class ExaoneModel(TextModel):
|
|||
self.gguf_writer.add_context_length(max_position_embeddings)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps)
|
||||
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
||||
self.gguf_writer.add_block_count(num_layers)
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
if (rope_theta := self.hparams.get("rope_theta")) is not None:
|
||||
|
|
|
|||
|
|
@ -392,9 +392,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|||
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
|
||||
|
||||
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power10)
|
||||
elseif (EXTRACTED_NUMBER EQUAL 9)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power9)
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
||||
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
|
||||
else()
|
||||
|
|
|
|||
|
|
@ -384,7 +384,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|||
char * src1_ddc = (char *) src1->data;
|
||||
|
||||
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
|
||||
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1;
|
||||
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
|
||||
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
|
||||
|
||||
if (src0->type == src1->type && contiguous_srcs) {
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||
|
|
|
|||
|
|
@ -3001,6 +3001,10 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
|||
static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
|
||||
const ggml_tensor * view,
|
||||
const ggml_tensor * set_rows) {
|
||||
|
||||
if (rope->op != GGML_OP_ROPE || view->op != GGML_OP_VIEW || set_rows->op != GGML_OP_SET_ROWS) {
|
||||
return false;
|
||||
}
|
||||
// ne3 not tested
|
||||
if (rope->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -2901,15 +2901,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
if (path == FAPATH) { \
|
||||
if (aligned) { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
} \
|
||||
} else { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
|
|
|||
|
|
@ -0,0 +1,106 @@
|
|||
[gMASK]<sop>
|
||||
{%- if tools -%}
|
||||
<|system|>
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{% for tool in tools %}
|
||||
{{ tool | tojson(ensure_ascii=False) }}
|
||||
{% endfor %}
|
||||
</tools>
|
||||
|
||||
For each function call, output the function name and arguments within the following XML format:
|
||||
<tool_call>{function-name}
|
||||
<arg_key>{arg-key-1}</arg_key>
|
||||
<arg_value>{arg-value-1}</arg_value>
|
||||
<arg_key>{arg-key-2}</arg_key>
|
||||
<arg_value>{arg-value-2}</arg_value>
|
||||
...
|
||||
</tool_call>{%- endif -%}
|
||||
{%- macro visible_text(content) -%}
|
||||
{%- if content is string -%}
|
||||
{{- content }}
|
||||
{%- elif content is iterable and content is not mapping -%}
|
||||
{%- for item in content -%}
|
||||
{%- if item is mapping and item.type == 'text' -%}
|
||||
{{- item.text }}
|
||||
{%- elif item is string -%}
|
||||
{{- item }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{- content }}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{%- set ns = namespace(last_user_index=-1) %}
|
||||
{%- for m in messages %}
|
||||
{%- if m.role == 'user' %}
|
||||
{% set ns.last_user_index = loop.index0 -%}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{% for m in messages %}
|
||||
{%- if m.role == 'user' -%}<|user|>
|
||||
{{ visible_text(m.content) }}
|
||||
{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}
|
||||
{%- elif m.role == 'assistant' -%}
|
||||
<|assistant|>
|
||||
{%- set reasoning_content = '' %}
|
||||
{%- set content = visible_text(m.content) %}
|
||||
{%- if m.reasoning_content is string %}
|
||||
{%- set reasoning_content = m.reasoning_content %}
|
||||
{%- else %}
|
||||
{%- if '</think>' in content %}
|
||||
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
||||
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if loop.index0 > ns.last_user_index and reasoning_content -%}
|
||||
{{ '\n<think>' + reasoning_content.strip() + '</think>'}}
|
||||
{%- else -%}
|
||||
{{ '\n<think></think>' }}
|
||||
{%- endif -%}
|
||||
{%- if content.strip() -%}
|
||||
{{ '\n' + content.strip() }}
|
||||
{%- endif -%}
|
||||
{% if m.tool_calls %}
|
||||
{% for tc in m.tool_calls %}
|
||||
{%- if tc.function %}
|
||||
{%- set tc = tc.function %}
|
||||
{%- endif %}
|
||||
{{ '\n<tool_call>' + tc.name }}
|
||||
{% set _args = tc.arguments or {} %}
|
||||
{% if _args is not mapping %}
|
||||
{{ raise_exception("Invalid tool call arguments passed: " + _args | string) }}
|
||||
{% endif %}
|
||||
{% for k, v in _args.items() %}
|
||||
<arg_key>{{ k }}</arg_key>
|
||||
<arg_value>{{ v | tojson(ensure_ascii=False) if v is not string else v }}</arg_value>
|
||||
{% endfor %}
|
||||
</tool_call>{% endfor %}
|
||||
{% endif %}
|
||||
{%- elif m.role == 'tool' -%}
|
||||
{%- if m.content is string -%}
|
||||
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
||||
{{- '<|observation|>' }}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_response>\n' }}
|
||||
{{- m.content }}
|
||||
{{- '\n</tool_response>' }}
|
||||
{%- else -%}
|
||||
<|observation|>{% for tr in m.content %}
|
||||
|
||||
<tool_response>
|
||||
{{ tr.output if tr.output is defined else tr }}
|
||||
</tool_response>{% endfor -%}
|
||||
{% endif -%}
|
||||
{%- elif m.role == 'system' -%}
|
||||
<|system|>
|
||||
{{ visible_text(m.content) }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
<|assistant|>{{- '\n<think></think>' if (enable_thinking is defined and not enable_thinking) else '' -}}
|
||||
{%- endif -%}
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
{% macro render_content(msg) -%}
|
||||
{%- set c = msg.get('content') -%}
|
||||
{%- if c is string -%}
|
||||
{{ c }}
|
||||
{%- elif c is not none -%}
|
||||
{% for content in c -%}
|
||||
{% if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
|
||||
<|media_start|>image<|media_content|><|media_pad|><|media_end|>
|
||||
{% else -%}
|
||||
{{ content['text'] }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- endmacro %}
|
||||
|
||||
{%- set tool_response_queue = namespace(ids=[]) -%}
|
||||
{%- set tool_call_counter = namespace(value=1) -%}
|
||||
|
||||
{%- if tools -%}
|
||||
<|im_system|>tool_declare<|im_middle|>{{ tools | tojson }}<|im_end|>
|
||||
{%- endif -%}
|
||||
{% for message in messages %}
|
||||
{%- if loop.first and messages[0]['role'] != 'system' -%}
|
||||
<|im_system|>system<|im_middle|>You are Kimi, an AI assistant created by Moonshot AI.<|im_end|>
|
||||
{% endif %}
|
||||
|
||||
{%- set role_name = message.get('name') or message['role'] -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
<|im_user|>{{role_name}}<|im_middle|>
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
<|im_assistant|>{{role_name}}<|im_middle|>
|
||||
{%- else -%}
|
||||
<|im_system|>{{role_name}}<|im_middle|>
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['role'] == 'assistant' and message.get('tool_calls') -%}
|
||||
{{render_content(message)}}<|tool_calls_section_begin|>
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
{%- if tool_call['id'] is defined -%}
|
||||
{%- set formatted_id = tool_call['id'] -%}
|
||||
{%- else -%}
|
||||
{%- set formatted_id = 'functions.' + tool_call['function']['name'] + ':' + (tool_call_counter.value | string) -%}
|
||||
{%- set tool_call_counter.value = tool_call_counter.value + 1 -%}
|
||||
{%- endif -%}
|
||||
{%- set _ = tool_response_queue.ids.append(formatted_id) -%}
|
||||
<|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|>
|
||||
{%- endfor -%}
|
||||
<|tool_calls_section_end|>
|
||||
{%- elif message['role'] == 'tool' -%}
|
||||
{%- if tool_response_queue.ids -%}
|
||||
{%- set tool_call_id = tool_response_queue.ids.pop(0) -%}
|
||||
{%- else -%}
|
||||
{%- set tool_call_id = 'functions.' + message.get('name', 'unknown') + ':' + (tool_call_counter.value | string) -%}
|
||||
{%- endif -%}
|
||||
## Return of {{ tool_call_id }}
|
||||
{{render_content(message)}}
|
||||
{%- elif message['content'] is not none -%}
|
||||
{{render_content(message)}}
|
||||
{%- endif -%}
|
||||
<|im_end|>
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
<|im_assistant|>assistant<|im_middle|>
|
||||
{%- endif -%}
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
{%- macro render_content(msg) -%}
|
||||
{%- set c = msg.get('content') -%}
|
||||
{%- if c is string -%}
|
||||
{{ c }}
|
||||
{%- elif c is not none -%}
|
||||
{% for content in c -%}
|
||||
{% if content['type'] == 'image' or 'image' in content or 'image_url' in content -%}
|
||||
<|media_start|>image<|media_content|><|media_pad|><|media_end|>
|
||||
{% else -%}
|
||||
{{ content['text'] }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
|
||||
{% macro set_roles(message) -%}
|
||||
{%- set role_name = message.get('name') or message['role'] -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
<|im_user|>{{role_name}}<|im_middle|>
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
<|im_assistant|>{{role_name}}<|im_middle|>
|
||||
{%- else -%}
|
||||
<|im_system|>{{role_name}}<|im_middle|>
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- set tool_response_queue = namespace(ids=[]) -%}
|
||||
{%- set tool_call_counter = namespace(value=1) -%}
|
||||
|
||||
{%- macro render_toolcalls(message) -%}
|
||||
<|tool_calls_section_begin|>
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
{%- if tool_call['id'] is defined -%}
|
||||
{%- set formatted_id = tool_call['id'] -%}
|
||||
{%- else -%}
|
||||
{%- set formatted_id = 'functions.' + tool_call['function']['name'] + ':' + (tool_call_counter.value | string) -%}
|
||||
{%- set tool_call_counter.value = tool_call_counter.value + 1 -%}
|
||||
{%- endif -%}
|
||||
{%- set _ = tool_response_queue.ids.append(formatted_id) -%}
|
||||
<|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|>
|
||||
{%- endfor -%}
|
||||
<|tool_calls_section_end|>
|
||||
{%- endmacro -%}
|
||||
|
||||
|
||||
{# Find last non-tool-call assisitant message #}
|
||||
{%- set ns = namespace(last_non_tool_call_assistant_msg=-1) -%}
|
||||
{%- for idx in range(messages|length-1, -1, -1) -%}
|
||||
{%- if messages[idx]['role'] == 'assistant' and not messages[idx].get('tool_calls') -%}
|
||||
{%- set ns.last_non_tool_call_assistant_msg = idx -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{# split all messages into history & suffix, reasoning_content in suffix should be reserved.#}
|
||||
{%- set hist_msgs = messages[:ns.last_non_tool_call_assistant_msg+1] -%}
|
||||
{%- set suffix_msgs = messages[ns.last_non_tool_call_assistant_msg+1:] -%}
|
||||
|
||||
{%- if tools -%}
|
||||
<|im_system|>tool_declare<|im_middle|>{{ tools | tojson }}<|im_end|>
|
||||
{%- endif -%}
|
||||
|
||||
{%- if messages|length == 0 or messages[0]['role'] != 'system' -%}
|
||||
<|im_system|>system<|im_middle|>You are Kimi, an AI assistant created by Moonshot AI.<|im_end|>
|
||||
{%- endif -%}
|
||||
|
||||
{%- for message in hist_msgs -%}
|
||||
{{set_roles(message)}}
|
||||
{%- if message['role'] == 'assistant' -%}
|
||||
<think></think>{{render_content(message)}}
|
||||
{%- if message.get('tool_calls') -%}
|
||||
{{render_toolcalls(message)}}
|
||||
{%- endif -%}
|
||||
{%- elif message['role'] == 'tool' -%}
|
||||
{%- if tool_response_queue.ids -%}
|
||||
{%- set tool_call_id = tool_response_queue.ids.pop(0) -%}
|
||||
{%- else -%}
|
||||
{%- set tool_call_id = 'functions.' + message.get('name', 'unknown') + ':' + (tool_call_counter.value | string) -%}
|
||||
{%- endif -%}
|
||||
## Return of {{ tool_call_id }}
|
||||
{{render_content(message)}}
|
||||
{%- elif message['content'] is not none -%}
|
||||
{{render_content(message)}}
|
||||
{%- endif -%}
|
||||
<|im_end|>
|
||||
{%- endfor -%}
|
||||
|
||||
{%- for message in suffix_msgs -%}
|
||||
{{set_roles(message)}}
|
||||
{%- if message['role'] == 'assistant' -%}
|
||||
{%- set rc = message.get('reasoning_content', '') -%}
|
||||
<think>{{rc}}</think>{{render_content(message)}}
|
||||
{%- if message.get('tool_calls') -%}
|
||||
{{render_toolcalls(message)}}
|
||||
{%- endif -%}
|
||||
{%- elif message['role'] == 'tool' -%}
|
||||
{%- if tool_response_queue.ids -%}
|
||||
{%- set tool_call_id = tool_response_queue.ids.pop(0) -%}
|
||||
{%- else -%}
|
||||
{%- set tool_call_id = 'functions.' + message.get('name', 'unknown') + ':' + (tool_call_counter.value | string) -%}
|
||||
{%- endif -%}
|
||||
## Return of {{ tool_call_id }}
|
||||
{{render_content(message)}}
|
||||
{%- elif message['content'] is not none -%}
|
||||
{{render_content(message)}}
|
||||
{%- endif -%}
|
||||
<|im_end|>
|
||||
{%- endfor -%}
|
||||
|
||||
|
||||
{%- if add_generation_prompt -%}
|
||||
<|im_assistant|>assistant<|im_middle|>
|
||||
{%- endif -%}
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
{%- if tools %}
|
||||
{{- '<|im_start|>system\n' }}
|
||||
{%- if messages[0]['role'] == 'system' %}
|
||||
{{- messages[0]['content'] }}
|
||||
{%- else %}
|
||||
{{- 'You are MiMo, an AI assistant developed by Xiaomi.' }}
|
||||
{%- endif %}
|
||||
{{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
||||
{%- for tool in tools %}
|
||||
{{- "\n" }}
|
||||
{{- tool | tojson }}
|
||||
{%- endfor %}
|
||||
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
||||
{%- else %}
|
||||
{%- if messages[0]['role'] == 'system' %}
|
||||
{{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>system\nYou are MiMo, an AI assistant developed by Xiaomi.<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- for message in messages %}
|
||||
{%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" %}
|
||||
{{- '<|im_start|>' + message.role }}
|
||||
{%- if message.content %}
|
||||
{{- '\n' + message.content }}
|
||||
{%- endif %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if tool_call.function is defined %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_call>\n{"name": "' }}
|
||||
{{- tool_call.name }}
|
||||
{{- '", "arguments": ' }}
|
||||
{{- tool_call.arguments | tojson }}
|
||||
{{- '}\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
|
||||
{{- '<|im_start|>user' }}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_response>\n' }}
|
||||
{{- message.content }}
|
||||
{{- '\n</tool_response>' }}
|
||||
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- endif %}
|
||||
|
|
@ -0,0 +1,159 @@
|
|||
{# ----------‑‑‑ special token variables ‑‑‑---------- #}
|
||||
{%- set toolcall_begin_token = '<minimax:tool_call>' -%}
|
||||
{%- set toolcall_end_token = '</minimax:tool_call>' -%}
|
||||
{#- Tool Rendering Functions ============================================== -#}
|
||||
{%- macro render_tool_namespace(namespace_name, tool_list) -%}
|
||||
{%- for tool in tool_list -%}
|
||||
<tool>{{ tool.function | tojson(ensure_ascii=False) }}</tool>
|
||||
{% endfor -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro visible_text(content) -%}
|
||||
{%- if content is string -%}
|
||||
{{ content }}
|
||||
{%- elif content is iterable and content is not mapping -%}
|
||||
{%- for item in content -%}
|
||||
{%- if item is mapping and item.type == 'text' -%}
|
||||
{{- item.text }}
|
||||
{%- elif item is string -%}
|
||||
{{- item }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{- content }}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{#- System Message Construction ============================================ -#}
|
||||
{%- macro build_system_message(system_message) -%}
|
||||
{%- if system_message and system_message.content -%}
|
||||
{{- visible_text(system_message.content) }}
|
||||
{%- else -%}
|
||||
{%- if model_identity is not defined -%}
|
||||
{%- set model_identity = "You are a helpful assistant." -%}
|
||||
{%- endif -%}
|
||||
{{- model_identity }}
|
||||
{%- endif -%}
|
||||
|
||||
{#- Handle current_date -#}
|
||||
{%- if system_message and system_message.current_date -%}
|
||||
{{- '\n' ~ 'Current date: ' + system_message.current_date }}
|
||||
{%- endif -%}
|
||||
{#- Handle current_location -#}
|
||||
{%- if system_message and system_message.current_location -%}
|
||||
{{- '\n' ~ 'Current location: ' + system_message.current_location }}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{#- Main Template Logic ================================================= -#}
|
||||
{#- Extract system message (only first message if it's system) -#}
|
||||
{%- set system_message = none -%}
|
||||
{%- set conversation_messages = messages -%}
|
||||
{%- if messages and messages[0].role == "system" -%}
|
||||
{%- set system_message = messages[0] -%}
|
||||
{%- set conversation_messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
{#- Get the last user message turn, for interleved thinking -#}
|
||||
{%- set ns = namespace(last_user_index=-1) %}
|
||||
{% for m in conversation_messages %}
|
||||
{%- if m.role == 'user' %}
|
||||
{% set ns.last_user_index = loop.index0 -%}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{#- Render system message -#}
|
||||
{{- ']~!b[' ~ ']~b]system' ~ '\n' }}
|
||||
{{- build_system_message(system_message) }}
|
||||
{#- Render tools if available -#}
|
||||
{%- if tools -%}
|
||||
{{- '\n\n' ~ '# Tools' ~ '\n' ~ 'You may call one or more tools to assist with the user query.\nHere are the tools available in JSONSchema format:' ~ '\n' }}
|
||||
{{- '\n' ~ '<tools>' ~ '\n' }}
|
||||
{{- render_tool_namespace("functions", tools) }}
|
||||
{{- '</tools>' ~ '\n\n' }}
|
||||
{{- 'When making tool calls, use XML format to invoke tools and pass parameters:' ~ '\n' }}
|
||||
{{- '\n' ~ toolcall_begin_token }}
|
||||
<invoke name="tool-name-1">
|
||||
<parameter name="param-key-1">param-value-1</parameter>
|
||||
<parameter name="param-key-2">param-value-2</parameter>
|
||||
...
|
||||
</invoke>
|
||||
{{- '\n' ~ toolcall_end_token }}
|
||||
{%- endif -%}
|
||||
{{- '[e~[\n' }}
|
||||
|
||||
{#- Render messages -#}
|
||||
{%- set last_tool_call = namespace(name=none) -%}
|
||||
{%- for message in conversation_messages -%}
|
||||
{%- if message.role == 'assistant' -%}
|
||||
{#- Only render reasoning_content if no user message follows -#}
|
||||
{{- ']~b]ai' ~ '\n' }}
|
||||
|
||||
{%- set reasoning_content = '' %}
|
||||
{%- set content = visible_text(message.content) %}
|
||||
{%- if message.reasoning_content is string %}
|
||||
{%- set reasoning_content = message.reasoning_content %}
|
||||
{%- else %}
|
||||
{%- if '</think>' in content %}
|
||||
{%- set reasoning_content = content.split('</think>')[0].strip('\n').split('<think>')[-1].strip('\n') %}
|
||||
{%- set content = content.split('</think>')[-1].strip('\n') %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if reasoning_content and loop.index0 > ns.last_user_index -%}
|
||||
{{- '<think>' ~ '\n' ~ reasoning_content ~ '\n' ~ '</think>' ~ '\n\n' }}
|
||||
{%- endif -%}
|
||||
{%- if content -%}
|
||||
{{- content }}
|
||||
{%- endif -%}
|
||||
{%- if message.tool_calls -%}
|
||||
{{- '\n' ~ toolcall_begin_token ~ '\n' }}
|
||||
|
||||
{%- for tool_call in message.tool_calls -%}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '<invoke name="' + tool_call.name + '">' }}
|
||||
{% set _args = tool_call.arguments %}
|
||||
{%- for k, v in _args.items() %}
|
||||
{{- '<parameter name="' + k + '">' }}
|
||||
{{- v | tojson(ensure_ascii=False) if v is not string else v }}
|
||||
{{- '</parameter>' }}
|
||||
{% endfor %}
|
||||
{{- '</invoke>' ~ '\n' }}
|
||||
{%- endfor -%}
|
||||
|
||||
{{- toolcall_end_token}}
|
||||
{%- set last_tool_call.name = message.tool_calls[-1].function.name -%}
|
||||
{%- else -%}
|
||||
{%- set last_tool_call.name = none -%}
|
||||
{%- endif -%}
|
||||
{{- '[e~[' ~ '\n' }}
|
||||
|
||||
{%- elif message.role == 'tool' -%}
|
||||
{%- if last_tool_call.name is none -%}
|
||||
{{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }}
|
||||
{%- endif -%}
|
||||
{%- if loop.first or (conversation_messages[loop.index0 - 1].role != 'tool') -%}
|
||||
{{- ']~b]tool' }}
|
||||
{%- endif -%}
|
||||
{%- if message.content is string -%}
|
||||
{{- '\n<response>' }}
|
||||
{{- message.content }}
|
||||
{{- '</response>' }}
|
||||
{%- else -%}
|
||||
{%- for tr in message.content -%}
|
||||
{{- '\n<response>' }}
|
||||
{{- tr.output if tr.output is defined else (tr.text if tr.type == 'text' and tr.text is defined else tr) }}
|
||||
{{- '\n</response>' }}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- if loop.last or (conversation_messages[loop.index0 + 1].role != 'tool') -%}
|
||||
{{- '[e~[\n' -}}
|
||||
{%- endif -%}
|
||||
|
||||
{%- elif message.role == 'user' -%}
|
||||
{{- ']~b]user' ~ '\n' }}
|
||||
{{- visible_text(message.content) }}
|
||||
{{- '[e~[' ~ '\n' }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{#- Generation prompt -#}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- ']~b]ai' ~ '\n' ~ '<think>' ~ '\n' }}
|
||||
{%- endif -%}
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
{% macro render_extra_keys(json_dict, handled_keys) %}
|
||||
{%- if json_dict is mapping %}
|
||||
{%- for json_key in json_dict if json_key not in handled_keys %}
|
||||
{%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %}
|
||||
{{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }}
|
||||
{%- else %}
|
||||
{{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{% endmacro %}
|
||||
|
||||
{%- if messages[0]["role"] == "system" %}
|
||||
{%- set system_message = messages[0]["content"] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
|
||||
{%- if not tools is defined %}
|
||||
{%- set tools = [] %}
|
||||
{%- endif %}
|
||||
|
||||
{%- if system_message is defined %}
|
||||
{{- "<|im_start|>system\n" + system_message }}
|
||||
{%- else %}
|
||||
{%- if tools is iterable and tools | length > 0 %}
|
||||
{{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if tools is iterable and tools | length > 0 %}
|
||||
{{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }}
|
||||
{{- "<tools>" }}
|
||||
{%- for tool in tools %}
|
||||
{%- if tool.function is defined %}
|
||||
{%- set tool = tool.function %}
|
||||
{%- endif %}
|
||||
{{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }}
|
||||
{%- if tool.description is defined %}
|
||||
{{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }}
|
||||
{%- endif %}
|
||||
{{- '\n<parameters>' }}
|
||||
{%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}
|
||||
{%- for param_name, param_fields in tool.parameters.properties|items %}
|
||||
{{- '\n<parameter>' }}
|
||||
{{- '\n<name>' ~ param_name ~ '</name>' }}
|
||||
{%- if param_fields.type is defined %}
|
||||
{{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }}
|
||||
{%- endif %}
|
||||
{%- if param_fields.description is defined %}
|
||||
{{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}
|
||||
{%- endif %}
|
||||
{%- set handled_keys = ['name', 'type', 'description'] %}
|
||||
{{- render_extra_keys(param_fields, handled_keys) }}
|
||||
{{- '\n</parameter>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{% set handled_keys = ['type', 'properties'] %}
|
||||
{{- render_extra_keys(tool.parameters, handled_keys) }}
|
||||
{{- '\n</parameters>' }}
|
||||
{%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}
|
||||
{{- render_extra_keys(tool, handled_keys) }}
|
||||
{{- '\n</function>' }}
|
||||
{%- endfor %}
|
||||
{{- "\n</tools>" }}
|
||||
{{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
|
||||
{%- endif %}
|
||||
{%- if system_message is defined %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- else %}
|
||||
{%- if tools is iterable and tools | length > 0 %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- for message in loop_messages %}
|
||||
{%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %}
|
||||
{{- '<|im_start|>' + message.role }}
|
||||
{%- if message.content is defined and message.content is string and message.content | trim | length > 0 %}
|
||||
{{- '\n' + message.content | trim + '\n' }}
|
||||
{%- endif %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if tool_call.function is defined %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
|
||||
{%- if tool_call.arguments is defined %}
|
||||
{%- for args_name, args_value in tool_call.arguments|items %}
|
||||
{{- '<parameter=' + args_name + '>\n' }}
|
||||
{%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
|
||||
{{- args_value }}
|
||||
{{- '\n</parameter>\n' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '</function>\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if loop.previtem and loop.previtem.role != "tool" %}
|
||||
{{- '<|im_start|>user\n' }}
|
||||
{%- endif %}
|
||||
{{- '<tool_response>\n' }}
|
||||
{{- message.content }}
|
||||
{{- '\n</tool_response>\n' }}
|
||||
{%- if not loop.last and loop.nextitem.role != "tool" %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif loop.last %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- endif %}
|
||||
|
|
@ -0,0 +1,126 @@
|
|||
{# Unsloth template fixes #}
|
||||
{%- set available_tools_string = '' -%}
|
||||
{%- set thought_instructions = '' -%}
|
||||
{%- set add_tool_id = true -%}
|
||||
{%- set tool_output_format = "default" -%}
|
||||
{%- if tools is not none and tools|length > 0 -%}
|
||||
{%- set available_tools_string -%}
|
||||
You are provided with function signatures within <available_tools></available_tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about the arguments. You should infer the argument values from previous user responses and the system message. Here are the available tools:
|
||||
<available_tools>
|
||||
{% for tool in tools %}
|
||||
{{ tool|string }}
|
||||
{% endfor %}
|
||||
</available_tools>
|
||||
{%- endset -%}
|
||||
{%- endif -%}
|
||||
{%- if tool_output_format is none or tool_output_format == "default" -%}
|
||||
{%- set tool_output_instructions -%}
|
||||
Return all function calls as a list of json objects within <tool_call></tool_call> XML tags. Each json object should contain a function name and arguments as follows:
|
||||
<tool_calls>[{"name": <function-name-1>, "arguments": <args-dict-1>}, {"name": <function-name-2>, "arguments": <args-dict-2>},...]</tool_calls>
|
||||
{%- endset -%}
|
||||
{%- elif tool_output_format == "yaml" -%}
|
||||
{%- set tool_output_instructions -%}
|
||||
Return all function calls as a list of yaml objects within <tool_call></tool_call> XML tags. Each yaml object should contain a function name and arguments as follows:
|
||||
<tool_calls>
|
||||
- name: <function-name-1>
|
||||
arguments: <args-dict-1>
|
||||
- name: <function-name-2>
|
||||
arguments: <args-dict-2>
|
||||
...
|
||||
</tool_calls>
|
||||
{%- endset -%}
|
||||
{%- endif -%}
|
||||
{%- if add_thoughts -%}
|
||||
{%- set thought_instructions -%}
|
||||
Prior to generating the function calls, you should generate the reasoning for why you're calling the function. Please generate these reasoning thoughts between <thinking> and </thinking> XML tags.
|
||||
{%- endset -%}
|
||||
{%- endif -%}
|
||||
{{- bos_token -}}
|
||||
{%- set reasoning_prompt='You are a thoughtful and systematic AI assistant built by ServiceNow Language Models (SLAM) lab. Before providing an answer, analyze the problem carefully and present your reasoning step by step. After explaining your thought process, provide the final solution in the following format: [BEGIN FINAL RESPONSE] ... [END FINAL RESPONSE].' -%}
|
||||
{%- if messages[0]['role'] != 'system' and tools is not none and tools|length > 0 -%}
|
||||
{{- '<|system|>\n' + reasoning_prompt + available_tools_string + "\n" + tool_output_instructions + '\n<|end|>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- if messages|selectattr('role', 'equalto', 'system')|list|length == 0 -%}
|
||||
{{- '<|system|>\n' + reasoning_prompt + '\n<|end|>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{- '<|user|>\n' }}
|
||||
{%- if message['content'] is not string %}
|
||||
{%- for chunk in message['content'] %}
|
||||
{%- if chunk['type'] == 'text' %}
|
||||
{{- chunk['text'] }}
|
||||
{%- elif chunk['type'] == 'image' or chunk['type'] == 'image_url'%}
|
||||
{{- '[IMG]' }}
|
||||
{%- else %}
|
||||
{{- raise_exception('Unrecognized content type!') }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- else %}
|
||||
{{- message['content'] }}
|
||||
{%- endif %}
|
||||
{{- '\n<|end|>\n' }}
|
||||
{%- elif message['role'] == 'content' -%}
|
||||
{%- if message['content'] is not string %}
|
||||
{{- '<|content|>\n' + message['content'][0]['text'] + '\n<|end|>\n' -}}
|
||||
{%- else %}
|
||||
{{- '<|content|>\n' + message['content'] + '\n<|end|>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- elif message['role'] == 'system' -%}
|
||||
{%- if message['content'] is not none and message['content']|length > 0 %}
|
||||
{%- if message['content'] is string %}
|
||||
{%- set system_message = message['content'] %}
|
||||
{%- else %}
|
||||
{%- set system_message = message['content'][0]['text'] %}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{%- set system_message = '' %}
|
||||
{%- endif %}
|
||||
{%- if tools is not none and tools|length > 0 -%}
|
||||
{{- '<|system|>\n' + reasoning_prompt + system_message + '\n' + available_tools_string + '\n<|end|>\n' -}}
|
||||
{%- else -%}
|
||||
{{- '<|system|>\n' + reasoning_prompt + system_message + '\n<|end|>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
{%- if loop.last -%}
|
||||
{%- set add_tool_id = false -%}
|
||||
{%- endif -%}
|
||||
{{- '<|assistant|>\n' -}}
|
||||
{%- if message['content'] is not none and message['content']|length > 0 -%}
|
||||
{%- if message['content'] is not string and message['content'][0]['text'] is not none %}
|
||||
{{- message['content'][0]['text'] }}
|
||||
{%- else %}
|
||||
{{- message['content'] -}}
|
||||
{%- endif -%}
|
||||
{%- elif message['chosen'] is not none and message['chosen']|length > 0 -%}
|
||||
{{- message['chosen'][0] -}}
|
||||
{%- endif -%}
|
||||
{%- if add_thoughts and 'thought' in message and message['thought'] is not none -%}
|
||||
{{- '<thinking>' + message['thought'] + '</thinking>' -}}
|
||||
{%- endif -%}
|
||||
{%- if message['tool_calls'] is not none and message['tool_calls']|length > 0 -%}
|
||||
{{- '\n<tool_calls>[' -}}
|
||||
{%- for tool_call in message["tool_calls"] -%}
|
||||
{{- '{"name": "' + tool_call['function']['name'] + '", "arguments": ' + tool_call['function']['arguments']|string -}}
|
||||
{%- if add_tool_id == true -%}
|
||||
{{- ', "id": "' + tool_call['id'] + '"' -}}
|
||||
{%- endif -%}
|
||||
{{- '}' -}}
|
||||
{%- if not loop.last -%}{{- ', ' -}}{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ']</tool_calls>' -}}
|
||||
{%- endif -%}
|
||||
{{- '\n<|end|>\n' + eos_token -}}
|
||||
{%- elif message['role'] == 'tool' -%}
|
||||
{%- if message['content'] is string %}
|
||||
{%- set tool_message = message['content'] %}
|
||||
{%- else %}
|
||||
{%- set tool_message = message['content'][0]['text'] %}
|
||||
{%- endif -%}
|
||||
{{- '<|tool_result|>\n' + tool_message|string + '\n<|end|>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- if loop.last and add_generation_prompt and message['role'] != 'assistant' -%}
|
||||
{{- '<|assistant|>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{# Copyright 2025-present Unsloth. Apache 2.0 License. #}
|
||||
|
|
@ -1281,6 +1281,7 @@ struct llm_tokenizer_plamo2 : llm_tokenizer {
|
|||
|
||||
// Build suffix list in lexicographical order of reversed strings
|
||||
std::vector<std::string> suffixes;
|
||||
suffixes.reserve(suffix_to_score.size() + 1);
|
||||
for (const auto & pair : suffix_to_score) {
|
||||
suffixes.push_back(pair.first);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2776,24 +2776,34 @@ struct test_cpy : public test_case {
|
|||
struct test_cont : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
bool use_view_slice;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
return VARS_TO_STR3(type, ne, use_view_slice);
|
||||
}
|
||||
|
||||
test_cont(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 10, 10, 1})
|
||||
: type(type), ne(ne) {}
|
||||
std::array<int64_t, 4> ne = {10, 10, 10, 1},
|
||||
bool use_view_slice = false)
|
||||
: type(type), ne(ne), use_view_slice(use_view_slice) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_param(src);
|
||||
ggml_set_name(src, "src");
|
||||
|
||||
src = ggml_transpose(ctx, src);
|
||||
ggml_set_name(src, "src_transposed");
|
||||
|
||||
ggml_tensor * out = ggml_cont(ctx, src);
|
||||
ggml_tensor * dst;
|
||||
if (use_view_slice) {
|
||||
dst = ggml_view_4d(ctx, src, src->ne[0], 1, src->ne[2], src->ne[3],
|
||||
src->nb[1], src->nb[2], src->nb[3], src->nb[0] * (src->ne[1] - 1));
|
||||
ggml_set_name(dst, "src_view_slice");
|
||||
} else {
|
||||
dst = ggml_transpose(ctx, src);
|
||||
ggml_set_name(dst, "src_transposed");
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_cont(ctx, dst);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
|
|
@ -6945,16 +6955,17 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
|
||||
|
||||
test_cases.emplace_back(new test_cont());
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1}));
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 3 ,5}));
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 3, 5 ,7}));
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 1 ,1}));
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 3 ,5}));
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 3, 5 ,7}));
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 1 ,1}));
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 3 ,5}));
|
||||
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7}));
|
||||
for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
|
||||
for (bool use_view_slice : { true, false }) {
|
||||
for (std::array<int64_t, 4> ne : std::initializer_list<std::array<int64_t, 4>>{ {2, 1, 1, 1}, {2, 1, 3, 5},
|
||||
{2, 3, 5, 7}, {1, 4, 4, 1}, {1, 8, 17, 1}, {10, 10, 10, 1} }) {
|
||||
if (use_view_slice && (type_dst == GGML_TYPE_F16 || type_dst == GGML_TYPE_BF16)) {
|
||||
continue; // TODO: add after WebGPU is fixed
|
||||
}
|
||||
test_cases.emplace_back(new test_cont(type_dst, ne, use_view_slice));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
|
||||
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
|
||||
|
|
|
|||
1052
tests/test-chat.cpp
1052
tests/test-chat.cpp
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
|
@ -25,3 +25,4 @@ vite.config.ts.timestamp-*
|
|||
|
||||
*storybook.log
|
||||
storybook-static
|
||||
*.code-workspace
|
||||
|
|
@ -2109,9 +2109,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@sveltejs/kit": {
|
||||
"version": "2.48.4",
|
||||
"resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.48.4.tgz",
|
||||
"integrity": "sha512-TGFX1pZUt9qqY20Cv5NyYvy0iLWHf2jXi8s+eCGsig7jQMdwZWKUFMR6TbvFNhfDSUpc1sH/Y5EHv20g3HHA3g==",
|
||||
"version": "2.48.5",
|
||||
"resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.48.5.tgz",
|
||||
"integrity": "sha512-/rnwfSWS3qwUSzvHynUTORF9xSJi7PCR9yXkxUOnRrNqyKmCmh3FPHH+E9BbgqxXfTevGXBqgnlh9kMb+9T5XA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
|
@ -5087,9 +5087,9 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/js-yaml": {
|
||||
"version": "4.1.0",
|
||||
"resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz",
|
||||
"integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==",
|
||||
"version": "4.1.1",
|
||||
"resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz",
|
||||
"integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
class?: string;
|
||||
message: DatabaseMessage;
|
||||
onCopy?: (message: DatabaseMessage) => void;
|
||||
onContinueAssistantMessage?: (message: DatabaseMessage) => void;
|
||||
onDelete?: (message: DatabaseMessage) => void;
|
||||
onEditWithBranching?: (message: DatabaseMessage, newContent: string) => void;
|
||||
onEditWithReplacement?: (
|
||||
|
|
@ -17,6 +18,7 @@
|
|||
newContent: string,
|
||||
shouldBranch: boolean
|
||||
) => void;
|
||||
onEditUserMessagePreserveResponses?: (message: DatabaseMessage, newContent: string) => void;
|
||||
onNavigateToSibling?: (siblingId: string) => void;
|
||||
onRegenerateWithBranching?: (message: DatabaseMessage) => void;
|
||||
siblingInfo?: ChatMessageSiblingInfo | null;
|
||||
|
|
@ -26,9 +28,11 @@
|
|||
class: className = '',
|
||||
message,
|
||||
onCopy,
|
||||
onContinueAssistantMessage,
|
||||
onDelete,
|
||||
onEditWithBranching,
|
||||
onEditWithReplacement,
|
||||
onEditUserMessagePreserveResponses,
|
||||
onNavigateToSibling,
|
||||
onRegenerateWithBranching,
|
||||
siblingInfo = null
|
||||
|
|
@ -133,17 +137,33 @@
|
|||
onRegenerateWithBranching?.(message);
|
||||
}
|
||||
|
||||
function handleContinue() {
|
||||
onContinueAssistantMessage?.(message);
|
||||
}
|
||||
|
||||
function handleSaveEdit() {
|
||||
if (message.role === 'user') {
|
||||
// For user messages, trim to avoid accidental whitespace
|
||||
onEditWithBranching?.(message, editedContent.trim());
|
||||
} else {
|
||||
onEditWithReplacement?.(message, editedContent.trim(), shouldBranchAfterEdit);
|
||||
// For assistant messages, preserve exact content including trailing whitespace
|
||||
// This is important for the Continue feature to work properly
|
||||
onEditWithReplacement?.(message, editedContent, shouldBranchAfterEdit);
|
||||
}
|
||||
|
||||
isEditing = false;
|
||||
shouldBranchAfterEdit = false;
|
||||
}
|
||||
|
||||
function handleSaveEditOnly() {
|
||||
if (message.role === 'user') {
|
||||
// For user messages, trim to avoid accidental whitespace
|
||||
onEditUserMessagePreserveResponses?.(message, editedContent.trim());
|
||||
}
|
||||
|
||||
isEditing = false;
|
||||
}
|
||||
|
||||
function handleShowDeleteDialogChange(show: boolean) {
|
||||
showDeleteDialog = show;
|
||||
}
|
||||
|
|
@ -166,6 +186,7 @@
|
|||
onEditedContentChange={handleEditedContentChange}
|
||||
{onNavigateToSibling}
|
||||
onSaveEdit={handleSaveEdit}
|
||||
onSaveEditOnly={handleSaveEditOnly}
|
||||
onShowDeleteDialogChange={handleShowDeleteDialogChange}
|
||||
{showDeleteDialog}
|
||||
{siblingInfo}
|
||||
|
|
@ -181,6 +202,7 @@
|
|||
messageContent={message.content}
|
||||
onCancelEdit={handleCancelEdit}
|
||||
onConfirmDelete={handleConfirmDelete}
|
||||
onContinue={handleContinue}
|
||||
onCopy={handleCopy}
|
||||
onDelete={handleDelete}
|
||||
onEdit={handleEdit}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
<script lang="ts">
|
||||
import { Edit, Copy, RefreshCw, Trash2 } from '@lucide/svelte';
|
||||
import { Edit, Copy, RefreshCw, Trash2, ArrowRight } from '@lucide/svelte';
|
||||
import { ActionButton, ConfirmationDialog } from '$lib/components/app';
|
||||
import ChatMessageBranchingControls from './ChatMessageBranchingControls.svelte';
|
||||
|
||||
|
|
@ -18,6 +18,7 @@
|
|||
onCopy: () => void;
|
||||
onEdit?: () => void;
|
||||
onRegenerate?: () => void;
|
||||
onContinue?: () => void;
|
||||
onDelete: () => void;
|
||||
onConfirmDelete: () => void;
|
||||
onNavigateToSibling?: (siblingId: string) => void;
|
||||
|
|
@ -31,6 +32,7 @@
|
|||
onCopy,
|
||||
onEdit,
|
||||
onConfirmDelete,
|
||||
onContinue,
|
||||
onDelete,
|
||||
onNavigateToSibling,
|
||||
onShowDeleteDialogChange,
|
||||
|
|
@ -69,6 +71,10 @@
|
|||
<ActionButton icon={RefreshCw} tooltip="Regenerate" onclick={onRegenerate} />
|
||||
{/if}
|
||||
|
||||
{#if role === 'assistant' && onContinue}
|
||||
<ActionButton icon={ArrowRight} tooltip="Continue" onclick={onContinue} />
|
||||
{/if}
|
||||
|
||||
<ActionButton icon={Trash2} tooltip="Delete" onclick={onDelete} />
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
import { ChatMessageThinkingBlock, MarkdownContent } from '$lib/components/app';
|
||||
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
|
||||
import { isLoading } from '$lib/stores/chat.svelte';
|
||||
import autoResizeTextarea from '$lib/utils/autoresize-textarea';
|
||||
import { fade } from 'svelte/transition';
|
||||
import {
|
||||
Check,
|
||||
|
|
@ -39,6 +40,7 @@
|
|||
onCancelEdit?: () => void;
|
||||
onCopy: () => void;
|
||||
onConfirmDelete: () => void;
|
||||
onContinue?: () => void;
|
||||
onDelete: () => void;
|
||||
onEdit?: () => void;
|
||||
onEditKeydown?: (event: KeyboardEvent) => void;
|
||||
|
|
@ -65,6 +67,7 @@
|
|||
messageContent,
|
||||
onCancelEdit,
|
||||
onConfirmDelete,
|
||||
onContinue,
|
||||
onCopy,
|
||||
onDelete,
|
||||
onEdit,
|
||||
|
|
@ -107,6 +110,12 @@
|
|||
void copyToClipboard(model ?? '');
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (isEditing && textareaElement) {
|
||||
autoResizeTextarea(textareaElement);
|
||||
}
|
||||
});
|
||||
|
||||
function formatToolCallBadge(toolCall: ApiChatCompletionToolCall, index: number) {
|
||||
const callNumber = index + 1;
|
||||
const functionName = toolCall.function?.name?.trim();
|
||||
|
|
@ -190,7 +199,10 @@
|
|||
bind:value={editedContent}
|
||||
class="min-h-[50vh] w-full resize-y rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
|
||||
onkeydown={onEditKeydown}
|
||||
oninput={(e) => onEditedContentChange?.(e.currentTarget.value)}
|
||||
oninput={(e) => {
|
||||
autoResizeTextarea(e.currentTarget);
|
||||
onEditedContentChange?.(e.currentTarget.value);
|
||||
}}
|
||||
placeholder="Edit assistant message..."
|
||||
></textarea>
|
||||
|
||||
|
|
@ -335,6 +347,9 @@
|
|||
{onCopy}
|
||||
{onEdit}
|
||||
{onRegenerate}
|
||||
onContinue={currentConfig.enableContinueGeneration && !thinkingContent
|
||||
? onContinue
|
||||
: undefined}
|
||||
{onDelete}
|
||||
{onConfirmDelete}
|
||||
{onNavigateToSibling}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
<script lang="ts">
|
||||
import { Check, X } from '@lucide/svelte';
|
||||
import { Check, X, Send } from '@lucide/svelte';
|
||||
import { Card } from '$lib/components/ui/card';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { ChatAttachmentsList, MarkdownContent } from '$lib/components/app';
|
||||
import { INPUT_CLASSES } from '$lib/constants/input-classes';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import autoResizeTextarea from '$lib/utils/autoresize-textarea';
|
||||
import ChatMessageActions from './ChatMessageActions.svelte';
|
||||
|
||||
interface Props {
|
||||
|
|
@ -22,6 +23,7 @@
|
|||
} | null;
|
||||
onCancelEdit: () => void;
|
||||
onSaveEdit: () => void;
|
||||
onSaveEditOnly?: () => void;
|
||||
onEditKeydown: (event: KeyboardEvent) => void;
|
||||
onEditedContentChange: (content: string) => void;
|
||||
onCopy: () => void;
|
||||
|
|
@ -43,6 +45,7 @@
|
|||
deletionInfo,
|
||||
onCancelEdit,
|
||||
onSaveEdit,
|
||||
onSaveEditOnly,
|
||||
onEditKeydown,
|
||||
onEditedContentChange,
|
||||
onCopy,
|
||||
|
|
@ -58,6 +61,12 @@
|
|||
let messageElement: HTMLElement | undefined = $state();
|
||||
const currentConfig = config();
|
||||
|
||||
$effect(() => {
|
||||
if (isEditing && textareaElement) {
|
||||
autoResizeTextarea(textareaElement);
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (!messageElement || !message.content.trim()) return;
|
||||
|
||||
|
|
@ -95,20 +104,34 @@
|
|||
bind:value={editedContent}
|
||||
class="min-h-[60px] w-full resize-none rounded-2xl px-3 py-2 text-sm {INPUT_CLASSES}"
|
||||
onkeydown={onEditKeydown}
|
||||
oninput={(e) => onEditedContentChange(e.currentTarget.value)}
|
||||
oninput={(e) => {
|
||||
autoResizeTextarea(e.currentTarget);
|
||||
onEditedContentChange(e.currentTarget.value);
|
||||
}}
|
||||
placeholder="Edit your message..."
|
||||
></textarea>
|
||||
|
||||
<div class="mt-2 flex justify-end gap-2">
|
||||
<Button class="h-8 px-3" onclick={onCancelEdit} size="sm" variant="outline">
|
||||
<Button class="h-8 px-3" onclick={onCancelEdit} size="sm" variant="ghost">
|
||||
<X class="mr-1 h-3 w-3" />
|
||||
|
||||
Cancel
|
||||
</Button>
|
||||
|
||||
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent.trim()} size="sm">
|
||||
<Check class="mr-1 h-3 w-3" />
|
||||
{#if onSaveEditOnly}
|
||||
<Button
|
||||
class="h-8 px-3"
|
||||
onclick={onSaveEditOnly}
|
||||
disabled={!editedContent.trim()}
|
||||
size="sm"
|
||||
variant="outline"
|
||||
>
|
||||
<Check class="mr-1 h-3 w-3" />
|
||||
Save
|
||||
</Button>
|
||||
{/if}
|
||||
|
||||
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent.trim()} size="sm">
|
||||
<Send class="mr-1 h-3 w-3" />
|
||||
Send
|
||||
</Button>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@
|
|||
import { DatabaseStore } from '$lib/stores/database';
|
||||
import {
|
||||
activeConversation,
|
||||
continueAssistantMessage,
|
||||
deleteMessage,
|
||||
navigateToSibling,
|
||||
editMessageWithBranching,
|
||||
editAssistantMessage,
|
||||
editMessageWithBranching,
|
||||
editUserMessagePreserveResponses,
|
||||
navigateToSibling,
|
||||
regenerateMessageWithBranching
|
||||
} from '$lib/stores/chat.svelte';
|
||||
import { getMessageSiblings } from '$lib/utils/branching';
|
||||
|
|
@ -93,6 +95,26 @@
|
|||
|
||||
refreshAllMessages();
|
||||
}
|
||||
|
||||
async function handleContinueAssistantMessage(message: DatabaseMessage) {
|
||||
onUserAction?.();
|
||||
|
||||
await continueAssistantMessage(message.id);
|
||||
|
||||
refreshAllMessages();
|
||||
}
|
||||
|
||||
async function handleEditUserMessagePreserveResponses(
|
||||
message: DatabaseMessage,
|
||||
newContent: string
|
||||
) {
|
||||
onUserAction?.();
|
||||
|
||||
await editUserMessagePreserveResponses(message.id, newContent);
|
||||
|
||||
refreshAllMessages();
|
||||
}
|
||||
|
||||
async function handleDeleteMessage(message: DatabaseMessage) {
|
||||
await deleteMessage(message.id);
|
||||
|
||||
|
|
@ -110,7 +132,9 @@
|
|||
onNavigateToSibling={handleNavigateToSibling}
|
||||
onEditWithBranching={handleEditWithBranching}
|
||||
onEditWithReplacement={handleEditWithReplacement}
|
||||
onEditUserMessagePreserveResponses={handleEditUserMessagePreserveResponses}
|
||||
onRegenerateWithBranching={handleRegenerateWithBranching}
|
||||
onContinueAssistantMessage={handleContinueAssistantMessage}
|
||||
/>
|
||||
{/each}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -52,6 +52,11 @@
|
|||
{ value: 'dark', label: 'Dark', icon: Moon }
|
||||
]
|
||||
},
|
||||
{
|
||||
key: 'pasteLongTextToFileLen',
|
||||
label: 'Paste long text to file length',
|
||||
type: 'input'
|
||||
},
|
||||
{
|
||||
key: 'showMessageStats',
|
||||
label: 'Show message generation statistics',
|
||||
|
|
@ -68,14 +73,15 @@
|
|||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'askForTitleConfirmation',
|
||||
label: 'Ask for confirmation before changing conversation title',
|
||||
key: 'showModelInfo',
|
||||
label: 'Show model information',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'pasteLongTextToFileLen',
|
||||
label: 'Paste long text to file length',
|
||||
type: 'input'
|
||||
key: 'enableContinueGeneration',
|
||||
label: 'Enable "Continue" button',
|
||||
type: 'checkbox',
|
||||
isExperimental: true
|
||||
},
|
||||
{
|
||||
key: 'pdfAsImage',
|
||||
|
|
@ -83,13 +89,13 @@
|
|||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'showModelInfo',
|
||||
label: 'Show model information',
|
||||
key: 'renderUserContentAsMarkdown',
|
||||
label: 'Render user content as Markdown',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'renderUserContentAsMarkdown',
|
||||
label: 'Render user content as Markdown',
|
||||
key: 'askForTitleConfirmation',
|
||||
label: 'Ask for confirmation before changing conversation title',
|
||||
type: 'checkbox'
|
||||
}
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
<script lang="ts">
|
||||
import { RotateCcw } from '@lucide/svelte';
|
||||
import { RotateCcw, FlaskConical } from '@lucide/svelte';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import Label from '$lib/components/ui/label/label.svelte';
|
||||
|
|
@ -55,8 +55,12 @@
|
|||
})()}
|
||||
|
||||
<div class="flex items-center gap-2">
|
||||
<Label for={field.key} class="text-sm font-medium">
|
||||
<Label for={field.key} class="flex items-center gap-1.5 text-sm font-medium">
|
||||
{field.label}
|
||||
|
||||
{#if field.isExperimental}
|
||||
<FlaskConical class="h-3.5 w-3.5 text-muted-foreground" />
|
||||
{/if}
|
||||
</Label>
|
||||
{#if isCustomRealTime}
|
||||
<ParameterSourceIndicator />
|
||||
|
|
@ -97,8 +101,12 @@
|
|||
</p>
|
||||
{/if}
|
||||
{:else if field.type === 'textarea'}
|
||||
<Label for={field.key} class="block text-sm font-medium">
|
||||
<Label for={field.key} class="block flex items-center gap-1.5 text-sm font-medium">
|
||||
{field.label}
|
||||
|
||||
{#if field.isExperimental}
|
||||
<FlaskConical class="h-3.5 w-3.5 text-muted-foreground" />
|
||||
{/if}
|
||||
</Label>
|
||||
|
||||
<Textarea
|
||||
|
|
@ -129,8 +137,12 @@
|
|||
})()}
|
||||
|
||||
<div class="flex items-center gap-2">
|
||||
<Label for={field.key} class="text-sm font-medium">
|
||||
<Label for={field.key} class="flex items-center gap-1.5 text-sm font-medium">
|
||||
{field.label}
|
||||
|
||||
{#if field.isExperimental}
|
||||
<FlaskConical class="h-3.5 w-3.5 text-muted-foreground" />
|
||||
{/if}
|
||||
</Label>
|
||||
{#if isCustomRealTime}
|
||||
<ParameterSourceIndicator />
|
||||
|
|
@ -214,9 +226,13 @@
|
|||
for={field.key}
|
||||
class="cursor-pointer text-sm leading-none font-medium {isDisabled
|
||||
? 'text-muted-foreground'
|
||||
: ''}"
|
||||
: ''} flex items-center gap-1.5"
|
||||
>
|
||||
{field.label}
|
||||
|
||||
{#if field.isExperimental}
|
||||
<FlaskConical class="h-3.5 w-3.5 text-muted-foreground" />
|
||||
{/if}
|
||||
</label>
|
||||
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
|
|
|
|||
|
|
@ -38,7 +38,8 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
|||
max_tokens: -1,
|
||||
custom: '', // custom json-stringified object
|
||||
// experimental features
|
||||
pyInterpreterEnabled: false
|
||||
pyInterpreterEnabled: false,
|
||||
enableContinueGeneration: false
|
||||
};
|
||||
|
||||
export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||
|
|
@ -96,5 +97,7 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
|||
modelSelectorEnabled:
|
||||
'Enable the model selector in the chat input to choose the inference model. Sends the associated model field in API requests.',
|
||||
pyInterpreterEnabled:
|
||||
'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.'
|
||||
'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.',
|
||||
enableContinueGeneration:
|
||||
'Enable "Continue" button for assistant messages. Currently works only with non-reasoning models.'
|
||||
};
|
||||
|
|
|
|||
|
|
@ -312,7 +312,6 @@ export class ChatService {
|
|||
let aggregatedContent = '';
|
||||
let fullReasoningContent = '';
|
||||
let aggregatedToolCalls: ApiChatCompletionToolCall[] = [];
|
||||
let hasReceivedData = false;
|
||||
let lastTimings: ChatMessageTimings | undefined;
|
||||
let streamFinished = false;
|
||||
let modelEmitted = false;
|
||||
|
|
@ -352,8 +351,6 @@ export class ChatService {
|
|||
return;
|
||||
}
|
||||
|
||||
hasReceivedData = true;
|
||||
|
||||
if (!abortSignal?.aborted) {
|
||||
onToolCallChunk?.(serializedToolCalls);
|
||||
}
|
||||
|
|
@ -415,7 +412,6 @@ export class ChatService {
|
|||
|
||||
if (content) {
|
||||
finalizeOpenToolCallBatch();
|
||||
hasReceivedData = true;
|
||||
aggregatedContent += content;
|
||||
if (!abortSignal?.aborted) {
|
||||
onChunk?.(content);
|
||||
|
|
@ -424,7 +420,6 @@ export class ChatService {
|
|||
|
||||
if (reasoningContent) {
|
||||
finalizeOpenToolCallBatch();
|
||||
hasReceivedData = true;
|
||||
fullReasoningContent += reasoningContent;
|
||||
if (!abortSignal?.aborted) {
|
||||
onReasoningChunk?.(reasoningContent);
|
||||
|
|
@ -446,15 +441,6 @@ export class ChatService {
|
|||
if (streamFinished) {
|
||||
finalizeOpenToolCallBatch();
|
||||
|
||||
if (
|
||||
!hasReceivedData &&
|
||||
aggregatedContent.length === 0 &&
|
||||
aggregatedToolCalls.length === 0
|
||||
) {
|
||||
const noResponseError = new Error('No response received from server. Please try again.');
|
||||
throw noResponseError;
|
||||
}
|
||||
|
||||
const finalToolCalls =
|
||||
aggregatedToolCalls.length > 0 ? JSON.stringify(aggregatedToolCalls) : undefined;
|
||||
|
||||
|
|
|
|||
|
|
@ -1486,6 +1486,10 @@ class ChatStore {
|
|||
timestamp: Date.now()
|
||||
});
|
||||
|
||||
// Ensure currNode points to the edited message to maintain correct path
|
||||
await DatabaseStore.updateCurrentNode(this.activeConversation.id, messageToEdit.id);
|
||||
this.activeConversation.currNode = messageToEdit.id;
|
||||
|
||||
this.updateMessageAtIndex(messageIndex, {
|
||||
content: newContent,
|
||||
timestamp: Date.now()
|
||||
|
|
@ -1499,6 +1503,69 @@ class ChatStore {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Edits a user message and preserves all responses below
|
||||
* Updates the message content in-place without deleting or regenerating responses
|
||||
*
|
||||
* **Use Case**: When you want to fix a typo or rephrase a question without losing the assistant's response
|
||||
*
|
||||
* **Important Behavior:**
|
||||
* - Does NOT create a branch (unlike editMessageWithBranching)
|
||||
* - Does NOT regenerate assistant responses
|
||||
* - Only updates the user message content in the database
|
||||
* - Preserves the entire conversation tree below the edited message
|
||||
* - Updates conversation title if this is the first user message
|
||||
*
|
||||
* @param messageId - The ID of the user message to edit
|
||||
* @param newContent - The new content for the message
|
||||
*/
|
||||
async editUserMessagePreserveResponses(messageId: string, newContent: string): Promise<void> {
|
||||
if (!this.activeConversation) return;
|
||||
|
||||
try {
|
||||
const messageIndex = this.findMessageIndex(messageId);
|
||||
if (messageIndex === -1) {
|
||||
console.error('Message not found for editing');
|
||||
return;
|
||||
}
|
||||
|
||||
const messageToEdit = this.activeMessages[messageIndex];
|
||||
if (messageToEdit.role !== 'user') {
|
||||
console.error('Only user messages can be edited with this method');
|
||||
return;
|
||||
}
|
||||
|
||||
// Simply update the message content in-place
|
||||
await DatabaseStore.updateMessage(messageId, {
|
||||
content: newContent,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
|
||||
this.updateMessageAtIndex(messageIndex, {
|
||||
content: newContent,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
|
||||
// Check if first user message for title update
|
||||
const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id);
|
||||
const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null);
|
||||
const isFirstUserMessage =
|
||||
rootMessage && messageToEdit.parent === rootMessage.id && messageToEdit.role === 'user';
|
||||
|
||||
if (isFirstUserMessage && newContent.trim()) {
|
||||
await this.updateConversationTitleWithConfirmation(
|
||||
this.activeConversation.id,
|
||||
newContent.trim(),
|
||||
this.titleUpdateConfirmationCallback
|
||||
);
|
||||
}
|
||||
|
||||
this.updateConversationTimestamp();
|
||||
} catch (error) {
|
||||
console.error('Failed to edit user message:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Edits a message by creating a new branch with the edited content
|
||||
* @param messageId - The ID of the message to edit
|
||||
|
|
@ -1696,6 +1763,200 @@ class ChatStore {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Continues generation for an existing assistant message
|
||||
* @param messageId - The ID of the assistant message to continue
|
||||
*/
|
||||
async continueAssistantMessage(messageId: string): Promise<void> {
|
||||
if (!this.activeConversation || this.isLoading) return;
|
||||
|
||||
try {
|
||||
const messageIndex = this.findMessageIndex(messageId);
|
||||
if (messageIndex === -1) {
|
||||
console.error('Message not found for continuation');
|
||||
return;
|
||||
}
|
||||
|
||||
const messageToContinue = this.activeMessages[messageIndex];
|
||||
if (messageToContinue.role !== 'assistant') {
|
||||
console.error('Only assistant messages can be continued');
|
||||
return;
|
||||
}
|
||||
|
||||
// Race condition protection: Check if this specific conversation is already loading
|
||||
// This prevents multiple rapid clicks on "Continue" from creating concurrent operations
|
||||
if (this.isConversationLoading(this.activeConversation.id)) {
|
||||
console.warn('Continuation already in progress for this conversation');
|
||||
return;
|
||||
}
|
||||
|
||||
this.errorDialogState = null;
|
||||
this.setConversationLoading(this.activeConversation.id, true);
|
||||
this.clearConversationStreaming(this.activeConversation.id);
|
||||
|
||||
// IMPORTANT: Fetch the latest content from the database to ensure we have
|
||||
// the most up-to-date content, especially after a stopped generation
|
||||
// This prevents issues where the in-memory state might be stale
|
||||
const allMessages = await DatabaseStore.getConversationMessages(this.activeConversation.id);
|
||||
const dbMessage = allMessages.find((m) => m.id === messageId);
|
||||
|
||||
if (!dbMessage) {
|
||||
console.error('Message not found in database for continuation');
|
||||
this.setConversationLoading(this.activeConversation.id, false);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Use content from database as the source of truth
|
||||
const originalContent = dbMessage.content;
|
||||
const originalThinking = dbMessage.thinking || '';
|
||||
|
||||
// Get conversation context up to (but not including) the message to continue
|
||||
const conversationContext = this.activeMessages.slice(0, messageIndex);
|
||||
|
||||
const contextWithContinue = [
|
||||
...conversationContext.map((msg) => {
|
||||
if ('id' in msg && 'convId' in msg && 'timestamp' in msg) {
|
||||
return msg as DatabaseMessage & { extra?: DatabaseMessageExtra[] };
|
||||
}
|
||||
return msg as ApiChatMessageData;
|
||||
}),
|
||||
{
|
||||
role: 'assistant' as const,
|
||||
content: originalContent
|
||||
}
|
||||
];
|
||||
|
||||
let appendedContent = '';
|
||||
let appendedThinking = '';
|
||||
let hasReceivedContent = false;
|
||||
|
||||
await chatService.sendMessage(
|
||||
contextWithContinue,
|
||||
{
|
||||
...this.getApiOptions(),
|
||||
|
||||
onChunk: (chunk: string) => {
|
||||
hasReceivedContent = true;
|
||||
appendedContent += chunk;
|
||||
// Preserve originalContent exactly as-is, including any trailing whitespace
|
||||
// The concatenation naturally preserves any whitespace at the end of originalContent
|
||||
const fullContent = originalContent + appendedContent;
|
||||
|
||||
this.setConversationStreaming(
|
||||
messageToContinue.convId,
|
||||
fullContent,
|
||||
messageToContinue.id
|
||||
);
|
||||
|
||||
this.updateMessageAtIndex(messageIndex, {
|
||||
content: fullContent
|
||||
});
|
||||
},
|
||||
|
||||
onReasoningChunk: (reasoningChunk: string) => {
|
||||
hasReceivedContent = true;
|
||||
appendedThinking += reasoningChunk;
|
||||
|
||||
const fullThinking = originalThinking + appendedThinking;
|
||||
|
||||
this.updateMessageAtIndex(messageIndex, {
|
||||
thinking: fullThinking
|
||||
});
|
||||
},
|
||||
|
||||
onComplete: async (
|
||||
finalContent?: string,
|
||||
reasoningContent?: string,
|
||||
timings?: ChatMessageTimings
|
||||
) => {
|
||||
const fullContent = originalContent + (finalContent || appendedContent);
|
||||
const fullThinking = originalThinking + (reasoningContent || appendedThinking);
|
||||
|
||||
const updateData: {
|
||||
content: string;
|
||||
thinking: string;
|
||||
timestamp: number;
|
||||
timings?: ChatMessageTimings;
|
||||
} = {
|
||||
content: fullContent,
|
||||
thinking: fullThinking,
|
||||
timestamp: Date.now(),
|
||||
timings: timings
|
||||
};
|
||||
|
||||
await DatabaseStore.updateMessage(messageToContinue.id, updateData);
|
||||
|
||||
this.updateMessageAtIndex(messageIndex, updateData);
|
||||
|
||||
this.updateConversationTimestamp();
|
||||
|
||||
this.setConversationLoading(messageToContinue.convId, false);
|
||||
this.clearConversationStreaming(messageToContinue.convId);
|
||||
slotsService.clearConversationState(messageToContinue.convId);
|
||||
},
|
||||
|
||||
onError: async (error: Error) => {
|
||||
if (this.isAbortError(error)) {
|
||||
// User cancelled - save partial continuation if any content was received
|
||||
if (hasReceivedContent && appendedContent) {
|
||||
const partialContent = originalContent + appendedContent;
|
||||
const partialThinking = originalThinking + appendedThinking;
|
||||
|
||||
await DatabaseStore.updateMessage(messageToContinue.id, {
|
||||
content: partialContent,
|
||||
thinking: partialThinking,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
|
||||
this.updateMessageAtIndex(messageIndex, {
|
||||
content: partialContent,
|
||||
thinking: partialThinking,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
}
|
||||
|
||||
this.setConversationLoading(messageToContinue.convId, false);
|
||||
this.clearConversationStreaming(messageToContinue.convId);
|
||||
slotsService.clearConversationState(messageToContinue.convId);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Non-abort error - rollback to original content
|
||||
console.error('Continue generation error:', error);
|
||||
|
||||
// Rollback: Restore original content in UI
|
||||
this.updateMessageAtIndex(messageIndex, {
|
||||
content: originalContent,
|
||||
thinking: originalThinking
|
||||
});
|
||||
|
||||
// Ensure database has original content (in case of partial writes)
|
||||
await DatabaseStore.updateMessage(messageToContinue.id, {
|
||||
content: originalContent,
|
||||
thinking: originalThinking
|
||||
});
|
||||
|
||||
this.setConversationLoading(messageToContinue.convId, false);
|
||||
this.clearConversationStreaming(messageToContinue.convId);
|
||||
slotsService.clearConversationState(messageToContinue.convId);
|
||||
|
||||
const dialogType = error.name === 'TimeoutError' ? 'timeout' : 'server';
|
||||
this.showErrorDialog(dialogType, error.message);
|
||||
}
|
||||
},
|
||||
messageToContinue.convId
|
||||
);
|
||||
} catch (error) {
|
||||
if (this.isAbortError(error)) return;
|
||||
console.error('Failed to continue message:', error);
|
||||
if (this.activeConversation) {
|
||||
this.setConversationLoading(this.activeConversation.id, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Public methods for accessing per-conversation states
|
||||
*/
|
||||
|
|
@ -1743,8 +2004,11 @@ export const refreshActiveMessages = chatStore.refreshActiveMessages.bind(chatSt
|
|||
export const navigateToSibling = chatStore.navigateToSibling.bind(chatStore);
|
||||
export const editAssistantMessage = chatStore.editAssistantMessage.bind(chatStore);
|
||||
export const editMessageWithBranching = chatStore.editMessageWithBranching.bind(chatStore);
|
||||
export const editUserMessagePreserveResponses =
|
||||
chatStore.editUserMessagePreserveResponses.bind(chatStore);
|
||||
export const regenerateMessageWithBranching =
|
||||
chatStore.regenerateMessageWithBranching.bind(chatStore);
|
||||
export const continueAssistantMessage = chatStore.continueAssistantMessage.bind(chatStore);
|
||||
export const deleteMessage = chatStore.deleteMessage.bind(chatStore);
|
||||
export const getDeletionInfo = chatStore.getDeletionInfo.bind(chatStore);
|
||||
export const updateConversationName = chatStore.updateConversationName.bind(chatStore);
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ export interface SettingsFieldConfig {
|
|||
key: string;
|
||||
label: string;
|
||||
type: 'input' | 'textarea' | 'checkbox' | 'select';
|
||||
isExperimental?: boolean;
|
||||
help?: string;
|
||||
options?: Array<{ value: string; label: string; icon?: typeof import('@lucide/svelte').Icon }>;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue