diff --git a/.devops/cann.Dockerfile b/.devops/cann.Dockerfile index 9d9fabf887..cd8f87b2ea 100644 --- a/.devops/cann.Dockerfile +++ b/.devops/cann.Dockerfile @@ -3,7 +3,8 @@ # ============================================================================== # Define the CANN base image for easier version updates later -ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.1.rc1-910b-openeuler22.03-py3.10 +ARG CHIP_TYPE=910b +ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.3.rc1.alpha001-${CHIP_TYPE}-openeuler22.03-py3.11 # ============================================================================== # BUILD STAGE @@ -11,9 +12,6 @@ ARG CANN_BASE_IMAGE=quay.io/ascend/cann:8.1.rc1-910b-openeuler22.03-py3.10 # ============================================================================== FROM ${CANN_BASE_IMAGE} AS build -# Define the Ascend chip model for compilation. Default is Ascend910B3 -ARG ASCEND_SOC_TYPE=Ascend910B3 - # -- Install build dependencies -- RUN yum install -y gcc g++ cmake make git libcurl-devel python3 python3-pip && \ yum clean all && \ @@ -36,13 +34,14 @@ ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH # For brevity, only core variables are listed here. You can paste the original ENV list here. # -- Build llama.cpp -- -# Use the passed ASCEND_SOC_TYPE argument and add general build options +# Use the passed CHIP_TYPE argument and add general build options +ARG CHIP_TYPE RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh --force \ && \ cmake -B build \ -DGGML_CANN=ON \ -DCMAKE_BUILD_TYPE=Release \ - -DSOC_TYPE=${ASCEND_SOC_TYPE} \ + -DSOC_TYPE=ascend${CHIP_TYPE} \ . && \ cmake --build build --config Release -j$(nproc) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e3697ffaaa..5215cc3572 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1390,14 +1390,10 @@ jobs: strategy: matrix: arch: [x86, aarch64] - cann: - - '8.1.RC1.alpha001-910b-openeuler22.03-py3.10' - device: - - 'ascend910b3' - build: - - 'Release' + chip_type: ['910b', '310p'] + build: ['Release'] runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} - container: ascendai/cann:${{ matrix.cann }} + container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }} steps: - name: Checkout uses: actions/checkout@v4 @@ -1414,7 +1410,7 @@ jobs: cmake -S . -B build \ -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ -DGGML_CANN=on \ - -DSOC_TYPE=${{ matrix.device }} + -DSOC_TYPE=ascend${{ matrix.chip_type }} cmake --build build -j $(nproc) # TODO: simplify the following workflows using a matrix diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e72caa423b..0d5739c24b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -693,6 +693,51 @@ jobs: path: llama-${{ steps.tag.outputs.name }}-xcframework.zip name: llama-${{ steps.tag.outputs.name }}-xcframework + openEuler-cann: + strategy: + matrix: + arch: [x86, aarch64] + chip_type: ['910b', '310p'] + build: ['Release'] + runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} + container: ascendai/cann:${{ matrix.chip_type == '910b' && '8.3.rc1.alpha001-910b-openeuler22.03-py3.11' || '8.2.rc1-310p-openeuler22.03-py3.11' }} + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Dependencies + run: | + yum update -y + yum install -y git gcc gcc-c++ make cmake libcurl-devel + git config --global --add safe.directory "$GITHUB_WORKSPACE" + + - name: Build + run: | + export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH} + + cmake -S . -B build \ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DGGML_CANN=on \ + -DSOC_TYPE=ascend${{ matrix.chip_type }} + cmake --build build -j $(nproc) + + - name: Determine tag name + id: tag + uses: ./.github/actions/get-tag-name + + - name: Pack artifacts + run: | + cp LICENSE ./build/bin/ + zip -r llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip ./build/bin/* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip + name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.zip + release: if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} @@ -714,6 +759,7 @@ jobs: - macOS-arm64 - macOS-x64 - ios-xcode-build + - openEuler-cann steps: - name: Clone diff --git a/.gitignore b/.gitignore index c7d0009785..8575a141c4 100644 --- a/.gitignore +++ b/.gitignore @@ -20,52 +20,40 @@ *.so *.swp *.tmp +*.DS_Store # IDE / OS -.cache/ -.ccls-cache/ -.direnv/ -.DS_Store -.envrc -.idea/ -.swiftpm -.vs/ -.vscode/ -nppBackup +/.cache/ +/.ccls-cache/ +/.direnv/ +/.envrc +/.idea/ +/.swiftpm +/.vs/ +/.vscode/ +/nppBackup # Coverage -gcovr-report/ -lcov-report/ +/gcovr-report/ +/lcov-report/ # Build Artifacts -tags -.build/ -build* -release -debug -!build-info.cmake -!build-info.cpp.in -!build-info.sh -!build.zig -!docs/build.md +/tags +/.build/ +/build* +/release +/debug /libllama.so /llama-* /vulkan-shaders-gen -android-ndk-* -arm_neon.h -cmake-build-* -CMakeSettings.json -compile_commands.json -ggml-metal-embed.metal -llama-batched-swift /rpc-server -out/ -tmp/ -autogen-*.md +/out/ +/tmp/ +/autogen-*.md # Deprecated @@ -74,44 +62,38 @@ autogen-*.md # CI -!.github/workflows/*.yml +!/.github/workflows/*.yml # Models -models/* -models-mnt -!models/.editorconfig -!models/ggml-vocab-*.gguf* -!models/templates +/models/* +/models-mnt +!/models/.editorconfig +!/models/ggml-vocab-*.gguf* +!/models/templates # Zig -zig-out/ -zig-cache/ - -# Logs - -ppl-*.txt -qnt-*.txt -perf-*.txt +/zig-out/ +/zig-cache/ # Examples -examples/jeopardy/results.txt -tools/server/*.css.hpp -tools/server/*.html.hpp -tools/server/*.js.hpp -tools/server/*.mjs.hpp -tools/server/*.gz.hpp -!build_64.sh -!examples/*.bat -!examples/*/*.kts -!examples/*/*/*.kts -!examples/sycl/*.bat -!examples/sycl/*.sh +/examples/jeopardy/results.txt +/tools/server/*.css.hpp +/tools/server/*.html.hpp +/tools/server/*.js.hpp +/tools/server/*.mjs.hpp +/tools/server/*.gz.hpp +!/build_64.sh +!/examples/*.bat +!/examples/*/*.kts +!/examples/*/*/*.kts +!/examples/sycl/*.bat +!/examples/sycl/*.sh # Server Web UI temporary files -node_modules -tools/server/webui/dist +/tools/server/webui/node_modules +/tools/server/webui/dist # Python @@ -147,8 +129,8 @@ poetry.toml # Local scripts /run-vim.sh /run-chat.sh -.ccache/ +/.ccache/ # IDE -*.code-workspace -.windsurf/ +/*.code-workspace +/.windsurf/ diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 706fa32eed..bb168e8358 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -50,6 +50,8 @@ add_library(${TARGET} STATIC base64.hpp chat-parser.cpp chat-parser.h + chat-parser-xml-toolcall.h + chat-parser-xml-toolcall.cpp chat.cpp chat.h common.cpp diff --git a/common/chat-parser-xml-toolcall.cpp b/common/chat-parser-xml-toolcall.cpp new file mode 100644 index 0000000000..7349895550 --- /dev/null +++ b/common/chat-parser-xml-toolcall.cpp @@ -0,0 +1,861 @@ +#include "chat.h" +#include "chat-parser.h" +#include "common.h" +#include "json-partial.h" +#include "json-schema-to-grammar.h" +#include "log.h" +#include "regex-partial.h" + +using json = nlohmann::ordered_json; + +class xml_toolcall_syntax_exception : public std::runtime_error { + public: + xml_toolcall_syntax_exception(const std::string & message) : std::runtime_error(message) {} +}; + +template +inline void sort_uniq(std::vector &vec) { + std::sort(vec.begin(), vec.end()); + vec.erase(std::unique(vec.begin(), vec.end()), vec.end()); +} + +template +inline bool all_space(const T &str) { + return std::all_of(str.begin(), str.end(), [](unsigned char ch) { return std::isspace(ch); }); +} + +static size_t utf8_truncate_safe(const std::string_view s) { + size_t len = s.size(); + if (len == 0) return 0; + size_t i = len; + for (size_t back = 0; back < 4 && i > 0; ++back) { + --i; + unsigned char c = s[i]; + if ((c & 0x80) == 0) { + return len; + } else if ((c & 0xC0) == 0xC0) { + size_t expected_len = 0; + if ((c & 0xE0) == 0xC0) expected_len = 2; + else if ((c & 0xF0) == 0xE0) expected_len = 3; + else if ((c & 0xF8) == 0xF0) expected_len = 4; + else return i; + if (len - i >= expected_len) { + return len; + } else { + return i; + } + } + } + return len - std::min(len, size_t(3)); +} + +inline void utf8_truncate_safe_resize(std::string &s) { + s.resize(utf8_truncate_safe(s)); +} + +inline std::string_view utf8_truncate_safe_view(const std::string_view s) { + return s.substr(0, utf8_truncate_safe(s)); +} + +static std::optional try_find_2_literal_splited_by_spaces(common_chat_msg_parser & builder, const std::string & literal1, const std::string & literal2) { + if (literal1.size() == 0) return builder.try_find_literal(literal2); + const auto saved_pos = builder.pos(); + while (auto res = builder.try_find_literal(literal1)) { + builder.consume_spaces(); + const auto match_len = std::min(literal2.size(), builder.input().size() - builder.pos()); + if (builder.input().compare(builder.pos(), match_len, literal2, 0, match_len) == 0) { + if (res->prelude.size() != res->groups[0].begin - saved_pos) { + res->prelude = builder.str({saved_pos, res->groups[0].begin}); + } + builder.move_to(builder.pos() + match_len); + res->groups[0].end = builder.pos(); + GGML_ASSERT(res->groups[0].begin != res->groups[0].end); + return res; + } + builder.move_to(res->groups[0].begin + 1); + } + builder.move_to(saved_pos); + return std::nullopt; +} + +/** + * make a GBNF that accept any strings except those containing any of the forbidden strings. + */ +std::string make_gbnf_excluding(std::vector forbids) { + constexpr auto charclass_escape = [](unsigned char c) -> std::string { + if (c == '\\' || c == ']' || c == '^' || c == '-') { + std::string s = "\\"; + s.push_back((char)c); + return s; + } + if (isprint(c)) { + return std::string(1, (char)c); + } + char buf[16]; + snprintf(buf, 15, "\\x%02X", c); + return std::string(buf); + }; + constexpr auto build_expr = [charclass_escape](auto self, const std::vector& forbids, int l, int r, int depth) -> std::string { + std::vector>> children; + int i = l; + while (i < r) { + const std::string &s = forbids[i]; + if ((int)s.size() == depth) { + ++i; + continue; + } + unsigned char c = (unsigned char)s[depth]; + int j = i; + while (j < r && (int)forbids[j].size() > depth && + (unsigned char)forbids[j][depth] == c) { + ++j; + } + children.push_back({c, {i, j}}); + i = j; + } + std::vector alts; + if (!children.empty()) { + std::string cls; + for (auto &ch : children) cls += charclass_escape(ch.first); + alts.push_back(std::string("[^") + cls + "]"); + } + for (auto &ch : children) { + std::string childExpr = self(self, forbids, ch.second.first, ch.second.second, depth+1); + if (!childExpr.empty()) { + std::string quoted_ch = "\""; + if (ch.first == '\\') quoted_ch += "\\\\"; + else if (ch.first == '"') quoted_ch += "\\\""; + else if (isprint(ch.first)) quoted_ch.push_back(ch.first); + else { + char buf[16]; + snprintf(buf, 15, "\\x%02X", ch.first); + quoted_ch += buf; + } + quoted_ch += "\""; + std::string branch = quoted_ch + std::string(" ") + childExpr; + alts.push_back(branch); + } + } + if (alts.empty()) return ""; + std::ostringstream oss; + oss << "( "; + for (size_t k = 0; k < alts.size(); ++k) { + if (k) oss << " | "; + oss << alts[k]; + } + oss << " )"; + return oss.str(); + }; + if (forbids.empty()) return "( . )*"; + sort(forbids.begin(), forbids.end()); + std::string expr = build_expr(build_expr, forbids, 0, forbids.size(), 0); + if (expr.empty()) { + std::string cls; + for (auto &s : forbids) if (!s.empty()) cls += charclass_escape((unsigned char)s[0]); + expr = std::string("( [^") + cls + "] )"; + } + if (forbids.size() == 1) + return expr + "*"; + else + return std::string("( ") + expr + " )*"; +} + +/** + * Build grammar for xml-style tool call + * form.scope_start and form.scope_end can be empty. + * Requires data.format for model-specific hacks. + */ +void build_grammar_xml_tool_call(common_chat_params & data, const json & tools, const struct xml_tool_call_format & form) { + GGML_ASSERT(!form.tool_start.empty()); + GGML_ASSERT(!form.tool_sep.empty()); + GGML_ASSERT(!form.key_start.empty()); + GGML_ASSERT(!form.val_end.empty()); + GGML_ASSERT(!form.tool_end.empty()); + + std::string key_val_sep = form.key_val_sep; + if (form.key_val_sep2) { + key_val_sep += "\n"; + key_val_sep += *form.key_val_sep2; + } + GGML_ASSERT(!key_val_sep.empty()); + + if (tools.is_array() && !tools.empty()) { + data.grammar = build_grammar([&](const common_grammar_builder &builder) { + auto string_arg_val = form.last_val_end ? + builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end, *form.last_val_end})) : + builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end})); + + std::vector tool_rules; + for (const auto & tool : tools) { + if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) { + LOG_WRN("Skipping tool without function: %s", tool.dump(2).c_str()); + continue; + } + const auto & function = tool.at("function"); + if (!function.contains("name") || !function.at("name").is_string()) { + LOG_WRN("Skipping invalid function (invalid name): %s", function.dump(2).c_str()); + continue; + } + if (!function.contains("parameters") || !function.at("parameters").is_object()) { + LOG_WRN("Skipping invalid function (invalid parameters): %s", function.dump(2).c_str()); + continue; + } + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + + struct parameter_rule { + std::string symbol_name; + bool is_required; + }; + std::vector arg_rules; + if (!parameters.contains("properties") || !parameters.at("properties").is_object()) { + LOG_WRN("Skipping invalid function (invalid properties): %s", function.dump(2).c_str()); + continue; + } else { + std::vector requiredParameters; + if (parameters.contains("required")) { + try { parameters.at("required").get_to(requiredParameters); } + catch (const std::runtime_error&) { + LOG_WRN("Invalid function required parameters, ignoring: %s", function.at("required").dump(2).c_str()); + } + } + sort_uniq(requiredParameters); + for (const auto & [key, value] : parameters.at("properties").items()) { + std::string quoted_key = key; + bool required = std::binary_search(requiredParameters.begin(), requiredParameters.end(), key); + if (form.key_start.back() == '"' && key_val_sep[0] == '"') { + quoted_key = gbnf_format_literal(key); + quoted_key = quoted_key.substr(1, quoted_key.size() - 2); + } + arg_rules.push_back(parameter_rule {builder.add_rule("func-" + name + "-kv-" + key, + gbnf_format_literal(form.key_start) + " " + + gbnf_format_literal(quoted_key) + " " + + gbnf_format_literal(key_val_sep) + " " + + ((value.contains("type") && value["type"].is_string() && value["type"] == "string" && (!form.raw_argval || *form.raw_argval)) ? + (form.raw_argval ? + string_arg_val : + "( " + string_arg_val + " | " + builder.add_schema(name + "-arg-" + key, value) + " )" + ) : + builder.add_schema(name + "-arg-" + key, value) + ) + ), required}); + } + } + + auto next_arg_with_sep = builder.add_rule(name + "-last-arg-end", form.last_val_end ? gbnf_format_literal(*form.last_val_end) : gbnf_format_literal(form.val_end)); + decltype(next_arg_with_sep) next_arg = "\"\""; + for (auto i = arg_rules.size() - 1; /* i >= 0 && */ i < arg_rules.size(); --i) { + std::string include_this_arg = arg_rules[i].symbol_name + " " + next_arg_with_sep; + next_arg = builder.add_rule(name + "-arg-after-" + std::to_string(i), arg_rules[i].is_required ? + include_this_arg : "( " + include_this_arg + " ) | " + next_arg + ); + include_this_arg = gbnf_format_literal(form.val_end) + " " + include_this_arg; + next_arg_with_sep = builder.add_rule(name + "-arg-after-" + std::to_string(i) + "-with-sep", arg_rules[i].is_required ? + include_this_arg : "( " + include_this_arg + " ) | " + next_arg_with_sep + ); + } + + std::string quoted_name = name; + if (form.tool_start.back() == '"' && form.tool_sep[0] == '"') { + quoted_name = gbnf_format_literal(name); + quoted_name = quoted_name.substr(1, quoted_name.size() - 2); + } + quoted_name = gbnf_format_literal(quoted_name); + // Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name + if (data.format == COMMON_CHAT_FORMAT_KIMI_K2) { + quoted_name = "\"functions.\" " + quoted_name + " \":\" [0-9]+"; + } + tool_rules.push_back(builder.add_rule(name + "-call", + gbnf_format_literal(form.tool_start) + " " + + quoted_name + " " + + gbnf_format_literal(form.tool_sep) + " " + + next_arg + )); + } + + auto tool_call_once = builder.add_rule("root-tool-call-once", string_join(tool_rules, " | ")); + auto tool_call_more = builder.add_rule("root-tool-call-more", gbnf_format_literal(form.tool_end) + " " + tool_call_once); + auto call_end = builder.add_rule("root-call-end", form.last_tool_end ? gbnf_format_literal(*form.last_tool_end) : gbnf_format_literal(form.tool_end)); + auto tool_call_multiple_with_end = builder.add_rule("root-tool-call-multiple-with-end", tool_call_once + " " + tool_call_more + "* " + call_end); + builder.add_rule("root", + (form.scope_start.empty() ? "" : gbnf_format_literal(form.scope_start) + " ") + + tool_call_multiple_with_end + "?" + + (form.scope_end.empty() ? "" : " " + gbnf_format_literal(form.scope_end)) + ); + }); + + // grammar trigger for tool call + data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, form.scope_start + form.tool_start }); + } +} + +/** + * Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched. + * Throws xml_toolcall_syntax_exception if there is invalid syntax and cannot recover the original status for common_chat_msg_parser. + * form.scope_start, form.tool_sep and form.scope_end can be empty. + */ +inline bool parse_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form) { + GGML_ASSERT(!form.tool_start.empty()); + GGML_ASSERT(!form.key_start.empty()); + GGML_ASSERT(!form.key_val_sep.empty()); + GGML_ASSERT(!form.val_end.empty()); + GGML_ASSERT(!form.tool_end.empty()); + + // Helper to choose return false or throw error + constexpr auto return_error = [](common_chat_msg_parser & builder, auto &start_pos, const bool &recovery) { + LOG_DBG("Failed to parse XML-Style tool call at position: %s\n", gbnf_format_literal(builder.consume_rest().substr(0, 20)).c_str()); + if (recovery) { + builder.move_to(start_pos); + return false; + } else throw xml_toolcall_syntax_exception("Tool call parsing failed with unrecoverable errors. Try using a grammar to constrain the model’s output."); + }; + // Drop substring from needle to end from a JSON + constexpr auto partial_json = [](std::string &json_str, std::string_view needle = "XML_TOOL_CALL_PARTIAL_FLAG") { + auto pos = json_str.rfind(needle); + if (pos == std::string::npos) { + return false; + } + for (auto i = pos + needle.size(); i < json_str.size(); ++i) { + unsigned char ch = static_cast(json_str[i]); + if (ch != '\'' && ch != '"' && ch != '}' && ch != ':' && !std::isspace(ch)) { + return false; + } + } + if (pos != 0 && json_str[pos - 1] == '"') { + --pos; + } + json_str.resize(pos); + return true; + }; + // Helper to generate a partial argument JSON + constexpr auto gen_partial_json = [partial_json](auto set_partial_arg, auto &arguments, auto &builder, auto &function_name) { + auto rest = builder.consume_rest(); + utf8_truncate_safe_resize(rest); + set_partial_arg(rest, "XML_TOOL_CALL_PARTIAL_FLAG"); + auto tool_str = arguments.dump(); + if (partial_json(tool_str)) { + if (builder.add_tool_call(function_name, "", tool_str)) { + return; + } + } + LOG_DBG("Failed to parse partial XML-Style tool call, fallback to non-partial: %s\n", tool_str.c_str()); + }; + // Helper to find a close (because there may be form.last_val_end or form.last_tool_end) + constexpr auto try_find_close = []( + common_chat_msg_parser & builder, + const std::string & end, + const std::optional & alt_end, + const std::string & end_next, + const std::optional & alt_end_next + ) { + auto saved_pos = builder.pos(); + auto tc = builder.try_find_literal(end); + auto val_end_size = end.size(); + if (alt_end) { + auto pos_1 = builder.pos(); + builder.move_to(saved_pos); + auto tc2 = try_find_2_literal_splited_by_spaces(builder, *alt_end, end_next); + if (alt_end_next) { + builder.move_to(saved_pos); + auto tc3 = try_find_2_literal_splited_by_spaces(builder, *alt_end, *alt_end_next); + if (tc3 && (!tc2 || tc2->prelude.size() > tc3->prelude.size())) { + tc2 = tc3; + } + } + if (tc2 && (!tc || tc->prelude.size() > tc2->prelude.size())) { + tc = tc2; + tc->groups[0].end = std::min(builder.input().size(), tc->groups[0].begin + alt_end->size()); + builder.move_to(tc->groups[0].end); + val_end_size = alt_end->size(); + } else { + builder.move_to(pos_1); + } + } + return std::make_pair(val_end_size, tc); + }; + // Helper to find a val_end or last_val_end, returns matched pattern size + const auto try_find_val_end = [try_find_close, &builder, &form]() { + return try_find_close(builder, form.val_end, form.last_val_end, form.tool_end, form.last_tool_end); + }; + // Helper to find a tool_end or last_tool_end, returns matched pattern size + const auto try_find_tool_end = [try_find_close, &builder, &form]() { + return try_find_close(builder, form.tool_end, form.last_tool_end, form.scope_end, std::nullopt); + }; + + bool recovery = true; + const auto start_pos = builder.pos(); + if (!all_space(form.scope_start)) { + if (auto tc = builder.try_find_literal(form.scope_start)) { + if (all_space(tc->prelude)) { + if (form.scope_start.size() != tc->groups[0].end - tc->groups[0].begin) + throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.scope_start)); + } else { + builder.move_to(start_pos); + return false; + } + } else return false; + } + while (auto tc = builder.try_find_literal(form.tool_start)) { + if (!all_space(tc->prelude)) { + LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n", + gbnf_format_literal(form.tool_start).c_str(), + gbnf_format_literal(tc->prelude).c_str() + ); + builder.move_to(tc->groups[0].begin - tc->prelude.size()); + break; + } + + // Find tool name + auto func_name = builder.try_find_literal(all_space(form.tool_sep) ? form.key_start : form.tool_sep); + if (!func_name) { + auto [sz, tc] = try_find_tool_end(); + func_name = tc; + } + if (!func_name) { + // Partial tool name not supported + throw common_chat_msg_partial_exception("incomplete tool_call"); + } + // If the model generate multiple tool call and the first tool call has no argument + if (func_name->prelude.find(form.tool_end) != std::string::npos || (form.last_tool_end ? func_name->prelude.find(*form.last_tool_end) != std::string::npos : false)) { + builder.move_to(func_name->groups[0].begin - func_name->prelude.size()); + auto [sz, tc] = try_find_tool_end(); + func_name = tc; + } + + // Parse tool name + builder.move_to(all_space(form.tool_sep) ? func_name->groups[0].begin : func_name->groups[0].end); + std::string function_name = string_strip(func_name->prelude); + // Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name + if (builder.syntax().format == COMMON_CHAT_FORMAT_KIMI_K2) { + if (string_starts_with(function_name, "functions.")) { + static const std::regex re(":\\d+$"); + if (std::regex_search(function_name, re)) { + function_name = function_name.substr(10, function_name.rfind(":") - 10); + } + } + } + + // Argument JSON + json arguments = json::object(); + + // Helper to generate a partial argument JSON + const auto gen_partial_args = [&](auto set_partial_arg) { + gen_partial_json(set_partial_arg, arguments, builder, function_name); + }; + + // Parse all arg_key/arg_value pairs + while (auto tc = builder.try_find_literal(form.key_start)) { + if (!all_space(tc->prelude)) { + LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n", + gbnf_format_literal(form.key_start).c_str(), + gbnf_format_literal(tc->prelude).c_str() + ); + builder.move_to(tc->groups[0].begin - tc->prelude.size()); + break; + } + if (tc->groups[0].end - tc->groups[0].begin != form.key_start.size()) { + auto tool_call_arg = arguments.dump(); + if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') { + tool_call_arg.resize(tool_call_arg.size() - 1); + } + builder.add_tool_call(function_name, "", tool_call_arg); + throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_start)); + } + + // Parse arg_key + auto key_res = builder.try_find_literal(form.key_val_sep); + if (!key_res) { + gen_partial_args([&](auto &rest, auto &needle) {arguments[rest + needle] = "";}); + throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.key_val_sep) + " after " + gbnf_format_literal(form.key_start)); + } + if (key_res->groups[0].end - key_res->groups[0].begin != form.key_val_sep.size()) { + gen_partial_args([&](auto &, auto &needle) {arguments[key_res->prelude + needle] = "";}); + throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_val_sep)); + } + auto &key = key_res->prelude; + recovery = false; + + // Parse arg_value + if (form.key_val_sep2) { + if (auto tc = builder.try_find_literal(*form.key_val_sep2)) { + if (!all_space(tc->prelude)) { + LOG_DBG("Failed to parse XML-Style tool call: Unexcepted %s between %s and %s\n", + gbnf_format_literal(tc->prelude).c_str(), + gbnf_format_literal(form.key_val_sep).c_str(), + gbnf_format_literal(*form.key_val_sep2).c_str() + ); + return return_error(builder, start_pos, false); + } + if (tc->groups[0].end - tc->groups[0].begin != form.key_val_sep2->size()) { + gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;}); + throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(*form.key_val_sep2)); + } + } else { + gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;}); + throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(*form.key_val_sep2) + " after " + gbnf_format_literal(form.key_val_sep)); + } + } + auto val_start = builder.pos(); + + // Test if arg_val is a partial JSON + std::optional value_json = std::nullopt; + if (!form.raw_argval || !*form.raw_argval) { + try { value_json = builder.try_consume_json(); } + catch (const std::runtime_error&) { builder.move_to(val_start); } + // TODO: Delete this when json_partial adds top-level support for null/true/false + if (builder.pos() == val_start) { + const static std::regex number_regex(R"([0-9-][0-9]*(\.\d*)?([eE][+-]?\d*)?)"); + builder.consume_spaces(); + std::string_view sv = utf8_truncate_safe_view(builder.input()); + sv.remove_prefix(builder.pos()); + std::string rest = "a"; + if (sv.size() < 6) rest = sv; + if (string_starts_with("null", rest) || string_starts_with("true", rest) || string_starts_with("false", rest) || std::regex_match(sv.begin(), sv.end(), number_regex)) { + value_json = {123, {"123", "123"}}; + builder.consume_rest(); + } else { + builder.move_to(val_start); + } + } + } + + // If it is a JSON and followed by , parse as json + // cannot support streaming because it may be a plain text starting with JSON + if (value_json) { + auto json_end = builder.pos(); + builder.consume_spaces(); + if (builder.pos() == builder.input().size()) { + if (form.raw_argval && !*form.raw_argval && (value_json->json.is_string() || value_json->json.is_object() || value_json->json.is_array())) { + arguments[key] = value_json->json; + auto json_str = arguments.dump(); + if (!value_json->healing_marker.json_dump_marker.empty()) { + GGML_ASSERT(std::string::npos != json_str.rfind(value_json->healing_marker.json_dump_marker)); + json_str.resize(json_str.rfind(value_json->healing_marker.json_dump_marker)); + } else { + GGML_ASSERT(json_str.back() == '}'); + json_str.resize(json_str.size() - 1); + } + builder.add_tool_call(function_name, "", json_str); + } else { + gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;}); + } + LOG_DBG("Possible JSON arg_value: %s\n", value_json->json.dump().c_str()); + throw common_chat_msg_partial_exception("JSON arg_value detected. Waiting for more tokens for validations."); + } + builder.move_to(json_end); + auto [val_end_size, tc] = try_find_val_end(); + if (tc && all_space(tc->prelude) && value_json->healing_marker.marker.empty()) { + if (tc->groups[0].end - tc->groups[0].begin != val_end_size) { + gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;}); + LOG_DBG("Possible terminated JSON arg_value: %s\n", value_json->json.dump().c_str()); + throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.val_end) + (form.last_val_end ? gbnf_format_literal(*form.last_val_end) : "")); + } else arguments[key] = value_json->json; + } else builder.move_to(val_start); + } + + // If not, parse as plain text + if (val_start == builder.pos()) { + if (auto [val_end_size, value_plain] = try_find_val_end(); value_plain) { + auto &value_str = value_plain->prelude; + if (form.trim_raw_argval) value_str = string_strip(value_str); + if (value_plain->groups[0].end - value_plain->groups[0].begin != val_end_size) { + gen_partial_args([&](auto &, auto &needle) {arguments[key] = value_str + needle;}); + throw common_chat_msg_partial_exception( + "Expected " + gbnf_format_literal(form.val_end) + + " after " + gbnf_format_literal(form.key_val_sep) + + (form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "") + ); + } + arguments[key] = value_str; + } else { + if (form.trim_raw_argval) { + gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = string_strip(rest) + needle;}); + } else { + gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = rest + needle;}); + } + throw common_chat_msg_partial_exception( + "Expected " + gbnf_format_literal(form.val_end) + + " after " + gbnf_format_literal(form.key_val_sep) + + (form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "") + ); + } + } + } + + // Consume closing tag + if (auto [tool_end_size, tc] = try_find_tool_end(); tc) { + if (!all_space(tc->prelude)) { + LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n", + gbnf_format_literal(form.tool_end).c_str(), + gbnf_format_literal(tc->prelude).c_str() + ); + return return_error(builder, start_pos, recovery); + } + if (tc->groups[0].end - tc->groups[0].begin == tool_end_size) { + // Add the parsed tool call + if (!builder.add_tool_call(function_name, "", arguments.dump())) { + throw common_chat_msg_partial_exception("Failed to add XML-Style tool call"); + } + recovery = false; + continue; + } + } + + auto tool_call_arg = arguments.dump(); + if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') { + tool_call_arg.resize(tool_call_arg.size() - 1); + } + builder.add_tool_call(function_name, "", tool_call_arg); + throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.tool_end) + " after " + gbnf_format_literal(form.val_end)); + } + if (auto tc = builder.try_find_literal(form.scope_end)) { + if (!all_space(tc->prelude)) { + LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n", + gbnf_format_literal(form.scope_end).c_str(), + gbnf_format_literal(tc->prelude).c_str() + ); + return return_error(builder, start_pos, recovery); + } + } else { + if (all_space(form.scope_end)) return true; + builder.consume_spaces(); + if (builder.pos() == builder.input().size()) + throw common_chat_msg_partial_exception("incomplete tool calls"); + LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n", + gbnf_format_literal(form.scope_end).c_str(), + gbnf_format_literal(builder.consume_rest()).c_str() + ); + return return_error(builder, start_pos, recovery); + } + + return true; +} + +/** + * Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched. + * May cause std::runtime_error if there is invalid syntax because partial valid tool call is already sent out to client. + * form.scope_start, form.tool_sep and form.scope_end can be empty. + */ +bool common_chat_msg_parser::try_consume_xml_tool_calls(const struct xml_tool_call_format & form) { + auto pos = pos_; + auto tsize = result_.tool_calls.size(); + try { return parse_xml_tool_calls(*this, form); } + catch (const xml_toolcall_syntax_exception&) {} + move_to(pos); + result_.tool_calls.resize(tsize); + return false; +} + +/** + * Parse content uses reasoning and XML-Style tool call + * TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed. + */ +inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form, const std::string & start_think = "", const std::string & end_think = "") { + constexpr auto rstrip = [](std::string &s) { + s.resize(std::distance(s.begin(), std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base())); + }; + // Erase substring from l to r, along with additional spaces nearby + constexpr auto erase_spaces = [](auto &str, size_t l, size_t r) { + while (/* l > -1 && */ --l < str.size() && std::isspace(static_cast(str[l]))); + ++l; + while (++r < str.size() && std::isspace(static_cast(str[r]))); + if (l < r) str[l] = '\n'; + if (l + 1 < r) str[l + 1] = '\n'; + if (l != 0) l += 2; + str.erase(l, r - l); + return l; + }; + constexpr auto trim_suffix = [](std::string &content, std::initializer_list list) { + auto best_match = content.size(); + for (auto pattern: list) { + if (pattern.size() == 0) continue; + for (auto match_idx = content.size() - std::min(pattern.size(), content.size()); content.size() > match_idx; match_idx++) { + auto match_len = content.size() - match_idx; + if (content.compare(match_idx, match_len, pattern.data(), match_len) == 0 && best_match > match_idx) { + best_match = match_idx; + } + } + } + if (content.size() > best_match) { + content.erase(best_match); + } + }; + const auto trim_potential_partial_word = [&start_think, &end_think, &form, trim_suffix](std::string &content) { + return trim_suffix(content, { + start_think, end_think, form.scope_start, form.tool_start, form.tool_sep, form.key_start, + form.key_val_sep, form.key_val_sep2 ? form.key_val_sep2->c_str() : "", + form.val_end, form.last_val_end ? form.last_val_end->c_str() : "", + form.tool_end, form.last_tool_end ? form.last_tool_end->c_str() : "", + form.scope_end + }); + }; + + + // Trim leading spaces without affecting keyword matching + static const common_regex spaces_regex("\\s*"); + { + auto tc = builder.consume_regex(spaces_regex); + auto spaces = builder.str(tc.groups[0]); + auto s1 = spaces.size(); + trim_potential_partial_word(spaces); + auto s2 = spaces.size(); + builder.move_to(builder.pos() - (s1 - s2)); + } + + // Parse content + bool reasoning_unclosed = builder.syntax().thinking_forced_open; + std::string unclosed_reasoning_content(""); + for (;;) { + auto tc = try_find_2_literal_splited_by_spaces(builder, form.scope_start, form.tool_start); + std::string content; + std::string tool_call_start; + + if (tc) { + content = std::move(tc->prelude); + tool_call_start = builder.str(tc->groups[0]); + LOG_DBG("Matched tool start: %s\n", gbnf_format_literal(tool_call_start).c_str()); + } else { + content = builder.consume_rest(); + utf8_truncate_safe_resize(content); + } + + // Handle unclosed think block + if (reasoning_unclosed) { + if (auto pos = content.find(end_think); pos == std::string::npos && builder.pos() != builder.input().size()) { + unclosed_reasoning_content += content; + if (form.allow_toolcall_in_think) { + builder.move_to(tc->groups[0].begin); + if (!builder.try_consume_xml_tool_calls(form)) { + unclosed_reasoning_content += tool_call_start; + builder.move_to(tc->groups[0].end); + } + } else { + unclosed_reasoning_content += tool_call_start; + } + continue; + } else { + reasoning_unclosed = false; + std::string reasoning_content; + if (pos == std::string::npos) { + reasoning_content = std::move(content); + } else { + reasoning_content = content.substr(0, pos); + content.erase(0, pos + end_think.size()); + } + if (builder.pos() == builder.input().size() && all_space(content)) { + rstrip(reasoning_content); + trim_potential_partial_word(reasoning_content); + rstrip(reasoning_content); + if (reasoning_content.empty()) { + rstrip(unclosed_reasoning_content); + trim_potential_partial_word(unclosed_reasoning_content); + rstrip(unclosed_reasoning_content); + if (unclosed_reasoning_content.empty()) continue; + } + } + if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) { + builder.add_content(start_think); + builder.add_content(unclosed_reasoning_content); + builder.add_content(reasoning_content); + if (builder.pos() != builder.input().size() || !all_space(content)) + builder.add_content(end_think); + } else { + builder.add_reasoning_content(unclosed_reasoning_content); + builder.add_reasoning_content(reasoning_content); + } + unclosed_reasoning_content.clear(); + } + } + + // Handle multiple think block + bool toolcall_in_think = false; + for (auto think_start = content.find(start_think); think_start != std::string::npos; think_start = content.find(start_think, think_start)) { + if (auto think_end = content.find(end_think, think_start + start_think.size()); think_end != std::string::npos) { + if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) { + auto reasoning_content = content.substr(think_start + start_think.size(), think_end - think_start - start_think.size()); + builder.add_reasoning_content(reasoning_content); + think_start = erase_spaces(content, think_start, think_end + end_think.size() - 1); + } else { + think_start = think_end + end_think.size() - 1; + } + } else { + // This start is in thinking block, skip this tool call + auto pos = think_start + start_think.size(); + unclosed_reasoning_content = content.substr(pos) + tool_call_start; + reasoning_unclosed = true; + content.resize(think_start); + toolcall_in_think = true; + } + } + + if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) { + rstrip(content); + // Handle unclosed token from content: delete all token + if (auto pos = content.rfind(end_think); pos != std::string::npos) { + while (pos != std::string::npos) { + pos = erase_spaces(content, pos, pos + end_think.size() - 1); + pos = content.rfind(end_think, pos); + } + } + // Strip if needed + if (content.size() > 0 && std::isspace(static_cast(content[0]))) { + content = string_strip(content); + } + } + + // remove potential partial suffix + if (content.size() > 0 && builder.pos() == builder.input().size() && unclosed_reasoning_content.empty()) { + rstrip(content); + trim_potential_partial_word(content); + rstrip(content); + } + + // Add content + if (content.size() != 0) { + // If there are multiple content blocks + if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content && builder.result().content.size() != 0) { + builder.add_content("\n\n"); + } + builder.add_content(content); + } + + // This start is in thinking block, skip this tool call + if (toolcall_in_think && !form.allow_toolcall_in_think) { + continue; + } + + // There is no tool call and all content is parsed + if (!tc) { + GGML_ASSERT(builder.pos() == builder.input().size()); + GGML_ASSERT(unclosed_reasoning_content.empty()); + GGML_ASSERT(!reasoning_unclosed); + break; + } + + builder.move_to(tc->groups[0].begin); + if (builder.try_consume_xml_tool_calls(form)) { + auto end_of_tool = builder.pos(); + builder.consume_spaces(); + if (builder.pos() != builder.input().size()) { + builder.move_to(end_of_tool); + if (!builder.result().content.empty()) { + builder.add_content("\n\n"); + } + } + } else { + static const common_regex next_char_regex("."); + auto c = builder.str(builder.consume_regex(next_char_regex).groups[0]); + rstrip(c); + builder.add_content(c); + } + } +} + +/** + * Parse content uses reasoning and XML-Style tool call + * TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed. + */ +void common_chat_msg_parser::consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think, const std::string & end_think) { + parse_msg_with_xml_tool_calls(*this, form, start_think, end_think); +} diff --git a/common/chat-parser-xml-toolcall.h b/common/chat-parser-xml-toolcall.h new file mode 100644 index 0000000000..67face2b94 --- /dev/null +++ b/common/chat-parser-xml-toolcall.h @@ -0,0 +1,45 @@ +#pragma once + +#include "chat.h" + +#include + +#include +#include +#include + + +// Sample config: +// MiniMax-M2 (left): \n\nvalue\n...\n... +// GLM 4.5 (right): function_name\nkey\nvalue\n +struct xml_tool_call_format { + std::string scope_start; // \n // \n // can be empty + std::string tool_start; // + std::string tool_sep; // \">\n // \n // can be empty only for parse_xml_tool_calls + std::string key_start; // + std::string key_val_sep; // \"> // \n + std::string val_end; // \n // \n + std::string tool_end; // \n // \n + std::string scope_end; // // // can be empty + // Set this if there can be dynamic spaces inside key_val_sep. + // e.g. key_val_sep= key_val_sep2= for GLM4.5 + std::optional key_val_sep2 = std::nullopt; + // Set true if argval should only be raw string. e.g. Hello "world" hi + // Set false if argval should only be json string. e.g. "Hello \"world\" hi" + // Defaults to std::nullopt, both will be allowed. + std::optional raw_argval = std::nullopt; + std::optional last_val_end = std::nullopt; + std::optional last_tool_end = std::nullopt; + bool trim_raw_argval = false; + bool allow_toolcall_in_think = false; // TODO: UNTESTED!!! +}; + +// make a GBNF that accept any strings except those containing any of the forbidden strings. +std::string make_gbnf_excluding(std::vector forbids); + +/** + * Build grammar for xml-style tool call + * form.scope_start and form.scope_end can be empty. + * Requires data.format for model-specific hacks. + */ +void build_grammar_xml_tool_call(common_chat_params & data, const nlohmann::ordered_json & tools, const struct xml_tool_call_format & form); diff --git a/common/chat-parser.h b/common/chat-parser.h index c8cdc63fb5..78c4b74c2d 100644 --- a/common/chat-parser.h +++ b/common/chat-parser.h @@ -1,6 +1,7 @@ #pragma once #include "chat.h" +#include "chat-parser-xml-toolcall.h" #include "json-partial.h" #include "regex-partial.h" @@ -119,5 +120,14 @@ class common_chat_msg_parser { const std::vector> & content_paths = {} ); + /** + * Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched. + * form.scope_start, form.tool_sep and form.scope_end can be empty. + */ + bool try_consume_xml_tool_calls(const struct xml_tool_call_format & form); + + // Parse content uses reasoning and XML-Style tool call + void consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think = "", const std::string & end_think = ""); + void clear_tools(); }; diff --git a/common/chat.cpp b/common/chat.cpp index 938872e82e..6fa05a6041 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -643,6 +643,12 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2"; case COMMON_CHAT_FORMAT_APERTUS: return "Apertus"; case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools"; + case COMMON_CHAT_FORMAT_MINIMAX_M2: return "MiniMax-M2"; + case COMMON_CHAT_FORMAT_GLM_4_5: return "GLM 4.5"; + case COMMON_CHAT_FORMAT_KIMI_K2: return "Kimi K2"; + case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder"; + case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5"; + case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo"; default: throw std::runtime_error("Unknown chat format"); } @@ -1807,6 +1813,278 @@ static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) { } } + +static common_chat_params common_chat_params_init_minimax_m2(const common_chat_template & tmpl, const struct templates_params & params) { + common_chat_params data; + data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + data.prompt = apply(tmpl, params); + data.format = COMMON_CHAT_FORMAT_MINIMAX_M2; + + // Handle thinking tags based on prompt ending + if (string_ends_with(data.prompt, "\n")) { + if (!params.enable_thinking) { + // Close the thinking tag immediately if thinking is disabled + data.prompt += "\n\n"; + } else { + // Mark thinking as forced open (template started with ) + data.thinking_forced_open = true; + } + } + + // Preserve MiniMax-M2 special tokens + data.preserved_tokens = { + "", + "", + "", + "", + }; + + // build grammar for tool call + static const xml_tool_call_format form { + /* form.scope_start = */ "\n", + /* form.tool_start = */ "\n", + /* form.key_start = */ "", + /* form.val_end = */ "\n", + /* form.tool_end = */ "\n", + /* form.scope_end = */ "", + }; + build_grammar_xml_tool_call(data, params.tools, form); + + return data; +} + +static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) { + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.key_start = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) { + common_chat_params data; + data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + data.prompt = apply(tmpl, params); + data.format = COMMON_CHAT_FORMAT_QWEN3_CODER_XML; + + data.preserved_tokens = { + "", + "", + "", + "", + }; + + // build grammar for tool call + static const xml_tool_call_format form { + /* form.scope_start = */ "\n", + /* form.tool_start = */ "\n", + /* form.key_start = */ "\n", + /* form.val_end = */ "\n\n", + /* form.tool_end = */ "\n", + /* form.scope_end = */ "", + }; + build_grammar_xml_tool_call(data, params.tools, form); + + return data; +} + +static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = ""; + form.tool_start = "", + "", + "<|tool_calls_section_begin|>", + "<|tool_call_begin|>", + "<|tool_call_argument_begin|>", + "<|tool_call_end|>", + "<|tool_calls_section_end|>", + "<|im_end|>", + "<|im_system|>", + "<|im_middle|>", + }; + + data.additional_stops.insert(data.additional_stops.end(), { + "<|im_end|>", + "<|im_middle|>" + }); + // build grammar for tool call + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = "<|tool_calls_section_begin|>"; + form.tool_start = "<|tool_call_begin|>"; + form.tool_sep = "<|tool_call_argument_begin|>{"; + form.key_start = "\""; + form.key_val_sep = "\": "; + form.val_end = ", "; + form.tool_end = "}<|tool_call_end|>"; + form.scope_end = "<|tool_calls_section_end|>"; + form.raw_argval = false; + form.last_val_end = ""; + return form; + })(); + build_grammar_xml_tool_call(data, params.tools, form); + + return data; +} + +static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = "<|tool_calls_section_begin|>"; + form.tool_start = "<|tool_call_begin|>"; + form.tool_sep = "<|tool_call_argument_begin|>{"; + form.key_start = "\""; + form.key_val_sep = "\": "; + form.val_end = ", "; + form.tool_end = "}<|tool_call_end|>"; + form.scope_end = "<|tool_calls_section_end|>"; + form.raw_argval = false; + form.last_val_end = ""; + return form; + })(); + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_template & tmpl, const struct templates_params & params) { + common_chat_params data; + data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + data.prompt = apply(tmpl, params); + data.format = COMMON_CHAT_FORMAT_APRIEL_1_5; + + data.preserved_tokens = { + "", + "", + "", + "", + }; + + // build grammar for tool call + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = "["; + form.tool_start = "{\"name\": \""; + form.tool_sep = "\", \"arguments\": {"; + form.key_start = "\""; + form.key_val_sep = "\": "; + form.val_end = ", "; + form.tool_end = "}, "; + form.scope_end = "]"; + form.raw_argval = false; + form.last_val_end = ""; + form.last_tool_end = "}"; + return form; + })(); + build_grammar_xml_tool_call(data, params.tools, form); + + return data; +} + +static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = "["; + form.tool_start = "{\"name\": \""; + form.tool_sep = "\", \"arguments\": {"; + form.key_start = "\""; + form.key_val_sep = "\": "; + form.val_end = ", "; + form.tool_end = "}, "; + form.scope_end = "]"; + form.raw_argval = false; + form.last_val_end = ""; + form.last_tool_end = "}"; + return form; + })(); + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_template & tmpl, const struct templates_params & params) { + common_chat_params data; + data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + data.prompt = apply(tmpl, params); + data.format = COMMON_CHAT_FORMAT_XIAOMI_MIMO; + + data.preserved_tokens = { + "", + "", + }; + + // build grammar for tool call + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = "\n"; + form.tool_start = "\n{\"name\": \""; + form.tool_sep = "\", \"arguments\": {"; + form.key_start = "\""; + form.key_val_sep = "\": "; + form.val_end = ", "; + form.tool_end = "}\n"; + form.scope_end = ""; + form.raw_argval = false; + form.last_val_end = ""; + return form; + })(); + build_grammar_xml_tool_call(data, params.tools, form); + + return data; +} + +static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = ""; + form.tool_start = "\n{\"name\": \""; + form.tool_sep = "\", \"arguments\": {"; + form.key_start = "\""; + form.key_val_sep = "\": "; + form.val_end = ", "; + form.tool_end = "}\n"; + form.scope_end = ""; + form.raw_argval = false; + form.last_val_end = ""; + return form; + })(); + builder.consume_reasoning_with_xml_tool_calls(form); +} + static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2041,6 +2319,100 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { } } +static common_chat_params common_chat_params_init_glm_4_5(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + data.grammar_lazy = inputs.tools.is_array() && !inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + std::string prompt = apply(tmpl, inputs); + + // match the existing trimming behavior + if (inputs.add_bos && string_starts_with(prompt, tmpl.bos_token())) { + prompt.erase(0, tmpl.bos_token().size()); + } + if (inputs.add_eos && string_ends_with(prompt, tmpl.eos_token())) { + prompt.erase(prompt.size() - tmpl.eos_token().size()); + } + if (string_ends_with(prompt, "")) { + if (!inputs.enable_thinking) { + prompt += ""; + } else { + data.thinking_forced_open = true; + } + } + + // add GLM preserved tokens + data.preserved_tokens = { + "<|endoftext|>", + "[MASK]", + "[gMASK]", + "[sMASK]", + "", + "", + "<|system|>", + "<|user|>", + "<|assistant|>", + "<|observation|>", + "<|begin_of_image|>", + "<|end_of_image|>", + "<|begin_of_video|>", + "<|end_of_video|>", + "<|begin_of_audio|>", + "<|end_of_audio|>", + "<|begin_of_transcription|>", + "<|end_of_transcription|>", + "<|code_prefix|>", + "<|code_middle|>", + "<|code_suffix|>", + "/nothink", + "", + "", + "", + "", + "", + "", + "", + "" + }; + + // extra GLM 4.5 stop word + data.additional_stops.insert(data.additional_stops.end(), { + "<|user|>", + "<|observation|>" + }); + + // build grammar for tool call + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "\n", + /* form.tool_sep = */ "\n", + /* form.key_start = */ "", + /* form.key_val_sep = */ "\n", + /* form.val_end = */ "\n", + /* form.tool_end = */ "\n", + /* form.scope_end = */ "", + }; + build_grammar_xml_tool_call(data, inputs.tools, form); + + data.prompt = prompt; + data.format = COMMON_CHAT_FORMAT_GLM_4_5; + return data; +} + +static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) { + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.tool_sep = */ "", + /* form.key_start = */ "", + /* form.key_val_sep = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + /* form.key_val_sep2 = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { LOG_DBG("%s\n", __func__); common_chat_params data; @@ -2704,91 +3076,17 @@ static void common_chat_parse_lfm2(common_chat_msg_parser & builder) { } static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { - // Parse thinking tags first - this handles the main reasoning content - builder.try_parse_reasoning("", ""); - - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // Parse tool calls - Seed-OSS uses format - static const common_regex tool_call_begin_regex(""); - static const common_regex tool_call_end_regex(""); - static const common_regex function_regex("]+)>"); - static const common_regex param_regex("]+)>"); - - while (auto tool_res = builder.try_find_regex(tool_call_begin_regex)) { - builder.consume_spaces(); // Consume whitespace after - - // Look for function call inside tool call, ignore any content before it - if (auto func_res = builder.try_find_regex(function_regex, std::string::npos, false)) { - auto function_name = builder.str(func_res->groups[1]); - - // Parse Seed-OSS parameters value - json args = json::object(); - // Parse all parameters - while (auto param_res = builder.try_find_regex(param_regex, std::string::npos, false)) { - // again, ignore noise around parameters - auto param_name = builder.str(param_res->groups[1]); - builder.move_to(param_res->groups[0].end); - builder.consume_spaces(); // Consume whitespace after parameter - auto savedPos = builder.pos(); - if (auto param_parse = builder.try_find_literal("")) { - auto param = param_parse->prelude; - builder.move_to(savedPos); - try { - if (auto param_res = builder.try_consume_json()) { - args[param_name] = param_res->json; - } else { - args[param_name] = param; - } - } catch (json::exception &) { - args[param_name] = param; - } - } else { - throw common_chat_msg_partial_exception("Incomplete tool parameter"); - } - } - // Look for closing function tag - auto end_func = builder.try_find_literal(""); - if (end_func) { - builder.move_to(end_func->groups[0].end); - builder.consume_spaces(); // Consume whitespace after - - // Add the tool call with parsed arguments, but only if we REALLY got the literal - auto eaten_fragment = builder.input().substr(end_func->groups[0].begin, end_func->groups[0].end); - auto funlen = std::string("").length(); - if (eaten_fragment.length() >= funlen && eaten_fragment.substr(0, funlen) == std::string("")) { - if (!builder.add_tool_call(function_name, "", args.dump())) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } else { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } else { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - // Look for closing tool call tag - if (auto end_tool = builder.try_find_regex(tool_call_end_regex, std::string::npos, false)) { - builder.move_to(end_tool->groups[0].end); - builder.consume_spaces(); // Consume trailing whitespace after tool call - } else { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } else { - // No function found - don't consume content here, let it be handled at the end - break; - } - } - - // Consume any remaining whitespace after all tool call processing - builder.consume_spaces(); - auto remaining = builder.consume_rest(); - // If there's any non-whitespace content remaining, add it as content - if (!string_strip(remaining).empty()) { - builder.add_content(remaining); - } + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.key_start = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); } static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { @@ -2927,6 +3225,35 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_granite(tmpl, params); } + // GLM 4.5: detect by and tags (check before Hermes since both use ) + if (src.find("[gMASK]") != std::string::npos && + src.find("") != std::string::npos && + src.find("") != std::string::npos && + params.json_schema.is_null()) { + return common_chat_params_init_glm_4_5(tmpl, params); + } + + // Qwen3-Coder XML format detection (must come before Hermes 2 Pro) + // Detect via explicit XML markers unique to Qwen3-Coder to avoid false positives in other templates. + // Require presence of , , and blocks. + if (src.find("") != std::string::npos && + src.find("") != std::string::npos && + src.find("") != std::string::npos && + src.find("") != std::string::npos && + src.find("# Tools") != std::string::npos && + src.find("") != std::string::npos && + src.find("") != std::string::npos && + src.find("") != std::string::npos && + src.find("") != std::string::npos) { + return common_chat_params_init_xiaomi_mimo(tmpl, params); + } + // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) if (src.find("") != std::string::npos && params.json_schema.is_null()) { return common_chat_params_init_hermes_2_pro(tmpl, params); @@ -2958,6 +3285,29 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_lfm2(tmpl, params); } + // MiniMax-M2 format detection + if (src.find("]~!b[") != std::string::npos && src.find("]~b]") != std::string::npos) { + return common_chat_params_init_minimax_m2(tmpl, params); + } + + // Kimi K2 format detection + if (src.find("<|im_system|>tool_declare<|im_middle|>") != std::string::npos && + src.find("<|tool_calls_section_begin|>") != std::string::npos && + src.find("## Return of") != std::string::npos) { + return common_chat_params_init_kimi_k2(tmpl, params); + } + + // Apriel 1.5 format detection + if (src.find("") != std::string::npos && + src.find("") != std::string::npos && + src.find("") != std::string::npos && + src.find("<|assistant|>") != std::string::npos && + src.find("<|tool_result|>") != std::string::npos && + src.find("[") != std::string::npos && + src.find("]") != std::string::npos) { + return common_chat_params_init_apriel_1_5(tmpl, params); + } + // Use generic handler when mixing tools + JSON schema. // TODO: support that mix in handlers below. if ((params.tools.is_array() && params.json_schema.is_object())) { @@ -3009,7 +3359,7 @@ static common_chat_params common_chat_templates_apply_legacy( const struct common_chat_templates * tmpls, const struct common_chat_templates_inputs & inputs) { - int alloc_size = 0; + size_t alloc_size = 0; std::vector chat; std::vector contents; @@ -3031,7 +3381,8 @@ static common_chat_params common_chat_templates_apply_legacy( const auto & msg = inputs.messages[i]; const auto & content = contents[i]; chat.push_back({msg.role.c_str(), content.c_str()}); - alloc_size += (msg.role.size() + content.size()) * 1.25; + size_t msg_size = msg.role.size() + content.size(); + alloc_size += msg_size + (msg_size / 4); // == msg_size * 1.25 but avoiding float ops } std::vector buf(alloc_size); @@ -3053,6 +3404,11 @@ static common_chat_params common_chat_templates_apply_legacy( res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); } + // for safety, we check the result again + if (res < 0 || (size_t) res > buf.size()) { + throw std::runtime_error("failed to apply chat template, try using --jinja"); + } + common_chat_params params; params.prompt = std::string(buf.data(), res); if (!inputs.json_schema.empty()) { @@ -3139,6 +3495,24 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: common_chat_parse_lfm2(builder); break; + case COMMON_CHAT_FORMAT_MINIMAX_M2: + common_chat_parse_minimax_m2(builder); + break; + case COMMON_CHAT_FORMAT_GLM_4_5: + common_chat_parse_glm_4_5(builder); + break; + case COMMON_CHAT_FORMAT_KIMI_K2: + common_chat_parse_kimi_k2(builder); + break; + case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: + common_chat_parse_qwen3_coder_xml(builder); + break; + case COMMON_CHAT_FORMAT_APRIEL_1_5: + common_chat_parse_apriel_1_5(builder); + break; + case COMMON_CHAT_FORMAT_XIAOMI_MIMO: + common_chat_parse_xiaomi_mimo(builder); + break; default: throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } diff --git a/common/chat.h b/common/chat.h index 50efb0d4e5..754c411e23 100644 --- a/common/chat.h +++ b/common/chat.h @@ -117,6 +117,12 @@ enum common_chat_format { COMMON_CHAT_FORMAT_NEMOTRON_V2, COMMON_CHAT_FORMAT_APERTUS, COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS, + COMMON_CHAT_FORMAT_GLM_4_5, + COMMON_CHAT_FORMAT_MINIMAX_M2, + COMMON_CHAT_FORMAT_KIMI_K2, + COMMON_CHAT_FORMAT_QWEN3_CODER_XML, + COMMON_CHAT_FORMAT_APRIEL_1_5, + COMMON_CHAT_FORMAT_XIAOMI_MIMO, COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; diff --git a/common/json-partial.cpp b/common/json-partial.cpp index 919927dc32..aaf11310ab 100644 --- a/common/json-partial.cpp +++ b/common/json-partial.cpp @@ -297,8 +297,25 @@ bool common_json_parse( it = temptative_end; return true; } - // TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...) - // fprintf(stderr, "Closing: TODO\n"); + // handle unclosed top-level primitive + if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) { + std::string str(it, temptative_end); + const auto & magic_seed = out.healing_marker.marker = healing_marker; + if (can_parse(str + "\"")) { + // Was inside an string + str += (out.healing_marker.json_dump_marker = magic_seed) + "\""; + } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) { + // Was inside an string after an escape + str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\""; + } else { + // TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...) + // fprintf(stderr, "Closing: TODO\n"); + return false; + } + out.json = json::parse(str); + it = temptative_end; + return true; + } return false; } out.json = json::parse(it, end); diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 478aa1be7b..e64dc059f3 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -303,6 +303,8 @@ static std::string format_literal(const std::string & literal) { return "\"" + escaped + "\""; } +std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); } + class SchemaConverter { private: friend std::string build_grammar(const std::function & cb, const common_grammar_options & options); diff --git a/common/json-schema-to-grammar.h b/common/json-schema-to-grammar.h index 362991b542..c89ab7f997 100644 --- a/common/json-schema-to-grammar.h +++ b/common/json-schema-to-grammar.h @@ -18,4 +18,6 @@ struct common_grammar_options { bool dotall = false; }; +std::string gbnf_format_literal(const std::string & literal); + std::string build_grammar(const std::function & cb, const common_grammar_options & options = {}); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0cc3df0975..8743202ad6 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1673,11 +1673,9 @@ class GPTNeoXModel(TextModel): model_arch = gguf.MODEL_ARCH.GPTNEOX def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] - self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_dimension_count( int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])), @@ -1735,7 +1733,7 @@ class BloomModel(TextModel): self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed)) self.gguf_writer.add_embedding_length(n_embed) self.gguf_writer.add_feed_forward_length(4 * n_embed) - self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) @@ -1798,10 +1796,9 @@ class MPTModel(TextModel): self.gguf_writer.add_unk_token_id(0) def set_gguf_parameters(self): - block_count = self.hparams["n_layers"] self.gguf_writer.add_context_length(self.hparams["max_seq_len"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(4 * self.hparams["d_model"]) self.gguf_writer.add_head_count(self.hparams["n_heads"]) if kv_n_heads := self.hparams["attn_config"].get("kv_n_heads"): @@ -1834,7 +1831,6 @@ class OrionModel(TextModel): self._set_vocab_sentencepiece() def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] head_count = self.hparams["num_attention_heads"] head_count_kv = self.hparams.get("num_key_value_heads", head_count) @@ -1852,7 +1848,7 @@ class OrionModel(TextModel): self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count_kv(head_count_kv) @@ -1869,7 +1865,6 @@ class BaichuanModel(TextModel): self._set_vocab_sentencepiece() def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] head_count = self.hparams["num_attention_heads"] head_count_kv = self.hparams.get("num_key_value_heads", head_count) @@ -1886,7 +1881,7 @@ class BaichuanModel(TextModel): self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) self.gguf_writer.add_head_count(head_count) @@ -1993,7 +1988,6 @@ class XverseModel(TextModel): special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] head_count = self.hparams["num_attention_heads"] head_count_kv = self.hparams.get("num_key_value_heads", head_count) @@ -2010,7 +2004,7 @@ class XverseModel(TextModel): self.gguf_writer.add_tensor_data_layout("Meta AI original pth") self.gguf_writer.add_context_length(ctx_length) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) self.gguf_writer.add_head_count(head_count) @@ -2053,10 +2047,6 @@ class FalconModel(TextModel): model_arch = gguf.MODEL_ARCH.FALCON def set_gguf_parameters(self): - block_count = self.hparams.get("num_hidden_layers") - if block_count is None: - block_count = self.hparams["n_layer"] # old name - n_head = self.hparams.get("num_attention_heads") if n_head is None: n_head = self.hparams["n_head"] # old name @@ -2069,7 +2059,7 @@ class FalconModel(TextModel): self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) @@ -2107,12 +2097,10 @@ class StarCoderModel(TextModel): model_arch = gguf.MODEL_ARCH.STARCODER def set_gguf_parameters(self): - block_count = self.hparams["n_layer"] - self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(self.hparams["n_head"]) self.gguf_writer.add_head_count_kv(1) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) @@ -2142,14 +2130,12 @@ class RefactModel(TextModel): multiple_of = 256 ff_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - block_count = self.hparams["n_layer"] - # refact uses Alibi. So this is from config.json which might be used by training. self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(ff_dim) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(self.hparams["n_head"]) self.gguf_writer.add_head_count_kv(1) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) @@ -2196,11 +2182,10 @@ class StableLMModel(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"]) self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) @@ -3151,7 +3136,7 @@ class DbrxModel(TextModel): def set_gguf_parameters(self): ffn_config = self.hparams["ffn_config"] attn_config = self.hparams["attn_config"] - self.gguf_writer.add_block_count(self.hparams["n_layers"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_context_length(self.hparams["max_seq_len"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) @@ -3353,7 +3338,7 @@ class QwenModel(TextModel): def set_gguf_parameters(self): self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) - self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) @@ -4384,7 +4369,7 @@ class GPT2Model(TextModel): model_arch = gguf.MODEL_ARCH.GPT2 def set_gguf_parameters(self): - self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_context_length(self.hparams["n_ctx"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) @@ -4416,8 +4401,6 @@ class Phi2Model(TextModel): model_arch = gguf.MODEL_ARCH.PHI2 def set_gguf_parameters(self): - block_count = self.find_hparam(["num_hidden_layers", "n_layer"]) - rot_pct = self.find_hparam(["partial_rotary_factor"]) n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) @@ -4426,7 +4409,7 @@ class Phi2Model(TextModel): self.gguf_writer.add_embedding_length(n_embd) self.gguf_writer.add_feed_forward_length(4 * n_embd) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head) self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"])) @@ -4544,8 +4527,6 @@ class Phi3MiniModel(TextModel): special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): - block_count = self.find_hparam(["num_hidden_layers", "n_layer"]) - n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"]) @@ -4559,7 +4540,7 @@ class Phi3MiniModel(TextModel): self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds) self.gguf_writer.add_embedding_length(n_embd) self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"])) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_layer_norm_rms_eps(rms_eps) @@ -4679,12 +4660,11 @@ class PlamoModel(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] self.gguf_writer.add_context_length(4096) # not in config.json self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) @@ -4807,7 +4787,6 @@ class Plamo2Model(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) # Which layers are Mamba layers @@ -4819,10 +4798,10 @@ class Plamo2Model(TextModel): num_attention_heads = [] if mamba_enabled: - for i in range(block_count): - if block_count <= (mamba_step // 2): + for i in range(self.block_count): + if self.block_count <= (mamba_step // 2): # use attention in last layer - is_mamba = (i != block_count - 1) + is_mamba = (i != self.block_count - 1) else: is_mamba = (i % mamba_step) != (mamba_step // 2) if is_mamba: @@ -4840,7 +4819,7 @@ class Plamo2Model(TextModel): self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096)) self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128)) self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128)) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06)) self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000)) @@ -4897,12 +4876,10 @@ class CodeShellModel(TextModel): model_arch = gguf.MODEL_ARCH.CODESHELL def set_gguf_parameters(self): - block_count = self.hparams["n_layer"] - self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(self.hparams["n_head"]) self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"]) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) @@ -5044,7 +5021,7 @@ class InternLM2Model(TextModel): def set_gguf_parameters(self): self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) - self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) @@ -5665,11 +5642,10 @@ class GemmaModel(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"]) @@ -5705,11 +5681,10 @@ class Gemma2Model(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"]) @@ -5753,12 +5728,11 @@ class Gemma3Model(TextModel): def set_gguf_parameters(self): hparams = self.hparams - block_count = hparams["num_hidden_layers"] # some default values are not specified in the hparams self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072)) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8)) self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6)) @@ -6034,7 +6008,6 @@ class Rwkv6Model(TextModel): self._set_vocab_rwkv_world() def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] head_size = self.hparams["head_size"] hidden_size = self.hparams["hidden_size"] layer_norm_eps = self.hparams["layer_norm_epsilon"] @@ -6046,7 +6019,7 @@ class Rwkv6Model(TextModel): # RWKV isn't context limited self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_embedding_length(hidden_size) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_layer_norm_eps(layer_norm_eps) self.gguf_writer.add_rescale_every_n_layers(rescale_every_n_layers) self.gguf_writer.add_wkv_head_size(head_size) @@ -6110,7 +6083,6 @@ class RWKV6Qwen2Model(Rwkv6Model): self._set_vocab_gpt2() def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] num_attention_heads = self.hparams["num_attention_heads"] num_key_value_heads = self.hparams["num_key_value_heads"] hidden_size = self.hparams["hidden_size"] @@ -6123,7 +6095,7 @@ class RWKV6Qwen2Model(Rwkv6Model): # RWKV isn't context limited self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_embedding_length(hidden_size) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_wkv_head_size(head_size) self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim) self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim) @@ -6164,7 +6136,6 @@ class Rwkv7Model(TextModel): return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32 def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] try: head_size = self.hparams["head_size"] layer_norm_eps = self.hparams["layer_norm_epsilon"] @@ -6189,7 +6160,7 @@ class Rwkv7Model(TextModel): # RWKV isn't context limited self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_embedding_length(hidden_size) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_layer_norm_eps(layer_norm_eps) self.gguf_writer.add_wkv_head_size(head_size) self.gguf_writer.add_decay_lora_rank(lora_rank_decay) @@ -6283,7 +6254,6 @@ class ARwkv7Model(Rwkv7Model): self._set_vocab_gpt2() def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] hidden_size = self.hparams["hidden_size"] head_size = self.hparams["head_size"] rms_norm_eps = self.hparams["rms_norm_eps"] @@ -6300,7 +6270,7 @@ class ARwkv7Model(Rwkv7Model): # RWKV isn't context limited self.gguf_writer.add_context_length(1048576) self.gguf_writer.add_embedding_length(hidden_size) - self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) self.gguf_writer.add_wkv_head_size(head_size) self.gguf_writer.add_decay_lora_rank(lora_rank_decay) @@ -7524,7 +7494,7 @@ class T5Model(TextModel): self.gguf_writer.add_context_length(n_ctx) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) - self.gguf_writer.add_block_count(self.hparams["num_layers"]) + self.gguf_writer.add_block_count(self.block_count) if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None: self.gguf_writer.add_decoder_block_count(dec_n_layer) self.gguf_writer.add_head_count(self.hparams["num_heads"]) @@ -7663,7 +7633,7 @@ class T5EncoderModel(TextModel): self.gguf_writer.add_context_length(n_ctx) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) - self.gguf_writer.add_block_count(self.hparams["num_layers"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(self.hparams["num_heads"]) self.gguf_writer.add_key_length(self.hparams["d_kv"]) self.gguf_writer.add_value_length(self.hparams["d_kv"]) @@ -7726,7 +7696,7 @@ class JaisModel(TextModel): self._set_vocab_gpt2() def set_gguf_parameters(self): - self.gguf_writer.add_block_count(self.hparams["n_layer"]) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_context_length(self.hparams["n_positions"]) self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"]) @@ -8068,7 +8038,7 @@ class ChatGLMModel(TextModel): self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed)) self.gguf_writer.add_embedding_length(n_embed) self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed))) - self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams["num_hidden_layers"])) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5)) @@ -8150,7 +8120,6 @@ class ExaoneModel(TextModel): num_kv_heads = hparams.get("num_key_value_heads", num_heads) layer_norm_eps = hparams["layer_norm_epsilon"] intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim - num_layers = hparams["num_layers"] # ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0 # attention_dropout_rate = hparams["attention_dropout"] # ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0 @@ -8161,7 +8130,7 @@ class ExaoneModel(TextModel): self.gguf_writer.add_context_length(max_position_embeddings) self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps) self.gguf_writer.add_feed_forward_length(intermediate_size) - self.gguf_writer.add_block_count(num_layers) + self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_file_type(self.ftype) if (rope_theta := self.hparams.get("rope_theta")) is not None: diff --git a/docs/ops.md b/docs/ops.md index 4ada4384fc..62a921e8f7 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -17,12 +17,12 @@ Legend: | ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ | | ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | -| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | +| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | | ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | -| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | +| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | -| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | +| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | | CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | | CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ | @@ -43,9 +43,9 @@ Legend: | ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ | | EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ | | EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | -| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | | FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | -| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | +| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | | GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | | GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ | @@ -87,7 +87,7 @@ Legend: | ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | | ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | -| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | +| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | @@ -99,7 +99,7 @@ Legend: | SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | | SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | +| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | | SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | | SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | @@ -107,7 +107,7 @@ Legend: | SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | | SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | | SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | -| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ | +| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ | | SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | | SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | @@ -116,6 +116,6 @@ Legend: | TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ | | TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | +| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | | UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ | | XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | diff --git a/docs/ops/Vulkan.csv b/docs/ops/Vulkan.csv index 290bdd1215..8073930e94 100644 --- a/docs/ops/Vulkan.csv +++ b/docs/ops/Vulkan.csv @@ -5,8 +5,8 @@ "Vulkan0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","NEG","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","NEG","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" -"Vulkan0","STEP","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","STEP","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","STEP","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","STEP","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","TANH","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","TANH","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","ELU","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" @@ -29,18 +29,18 @@ "Vulkan0","EXP","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","EXPM1","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","EXPM1","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","SOFTPLUS","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","SOFTPLUS","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","SOFTPLUS","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","SOFTPLUS","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","GELU_ERF","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" -"Vulkan0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","ABS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan" "Vulkan0","ABS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan" "Vulkan0","SGN","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan" @@ -89,8 +89,8 @@ "Vulkan0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" "Vulkan0","NEG","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","NEG","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" -"Vulkan0","STEP","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","STEP","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","STEP","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","STEP","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","TANH","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","TANH","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","ELU","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" @@ -113,18 +113,18 @@ "Vulkan0","EXP","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","EXPM1","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" "Vulkan0","EXPM1","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","SOFTPLUS","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","SOFTPLUS","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","SOFTPLUS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","SOFTPLUS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","GELU_ERF","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" -"Vulkan0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","ABS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan" "Vulkan0","ABS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan" "Vulkan0","SGN","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","Vulkan" @@ -5654,7 +5654,7 @@ "Vulkan0","SUB","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" "Vulkan0","MUL","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" "Vulkan0","DIV","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","Vulkan" -"Vulkan0","ADD1","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan" +"Vulkan0","ADD1","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=0.000000,inplace=0","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=0","support","1","yes","Vulkan" "Vulkan0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=1","support","1","yes","Vulkan" @@ -8632,10 +8632,10 @@ "Vulkan0","COS","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" "Vulkan0","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","0","no","Vulkan" "Vulkan0","LEAKY_RELU","type=f16,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","Vulkan" -"Vulkan0","FLOOR","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f16,ne=[10,2,2,2]","support","1","yes","Vulkan" "Vulkan0","SQR","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","SQRT","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","LOG","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" @@ -8643,10 +8643,10 @@ "Vulkan0","COS","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" "Vulkan0","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","0","no","Vulkan" "Vulkan0","LEAKY_RELU","type=f16,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","Vulkan" -"Vulkan0","FLOOR","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" "Vulkan0","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","Vulkan" "Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" @@ -8654,10 +8654,10 @@ "Vulkan0","COS","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" "Vulkan0","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan" "Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan" -"Vulkan0","FLOOR","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan" "Vulkan0","SQR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","SQRT","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","LOG","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" @@ -8665,10 +8665,10 @@ "Vulkan0","COS","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan" "Vulkan0","LEAKY_RELU","type=f32,ne_a=[7,1,5,3],negative_slope=0.100000","support","1","yes","Vulkan" -"Vulkan0","FLOOR","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","CEIL","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","ROUND","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan" -"Vulkan0","TRUNC","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan" +"Vulkan0","FLOOR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","CEIL","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","ROUND","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" +"Vulkan0","TRUNC","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,1,1],n_past=5","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,1],n_past=5","support","1","yes","Vulkan" "Vulkan0","DIAG_MASK_INF","type=f32,ne=[10,10,3,2],n_past=5","support","1","yes","Vulkan" @@ -9478,7 +9478,7 @@ "Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","no","Vulkan" "Vulkan0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","0","no","Vulkan" "Vulkan0","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","1","yes","Vulkan" -"Vulkan0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","0","no","Vulkan" +"Vulkan0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","1","yes","Vulkan" "Vulkan0","TIMESTEP_EMBEDDING","type=f32,ne_a=[2,1,1,1],dim=320,max_period=10000","support","1","yes","Vulkan" "Vulkan0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","Vulkan" "Vulkan0","CUMSUM","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan" @@ -9487,9 +9487,9 @@ "Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","Vulkan" "Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","Vulkan" "Vulkan0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","Vulkan" -"Vulkan0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","0","no","Vulkan" -"Vulkan0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","0","no","Vulkan" -"Vulkan0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","0","no","Vulkan" +"Vulkan0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","1","yes","Vulkan" +"Vulkan0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","Vulkan" +"Vulkan0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","Vulkan" "Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","Vulkan" "Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","Vulkan" "Vulkan0","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","Vulkan" diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index c8d9854635..606c6d1783 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2544,7 +2544,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { int64_t shifts[] = { 1 }; int64_t dims[] = { 3 }; - aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims); + aclnn_roll(ctx, acl_input_tensor.get(), acl_input_roll_tensor.get(), shifts, dims); // init [-1, 1, -1, 1, ...] minus_one_scale_buffer = minus_one_scale_allocator.get(); @@ -2564,7 +2564,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { } int64_t index_num = src0->ne[0]; float value = -1; - aclnn_index_fill_tensor(ctx, acl_minus_one_tensor, dim, index, index_num, value); + aclnn_index_fill_tensor(ctx, acl_minus_one_tensor.get(), dim, index, index_num, value); } else { // roll input: [q0,q1,q2,...] -> // [q_half,q_half+1,...,q_end,q0,q1,...q_half-1] @@ -2576,7 +2576,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { int64_t shifts[] = { src0->ne[0] / 2 }; int64_t dims[] = { 3 }; - aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims); + aclnn_roll(ctx, acl_input_tensor.get(), acl_input_roll_tensor.get(), shifts, dims); // init [-1, -1, -1, 1, 1,1,...] minus_one_scale_buffer = minus_one_scale_allocator.get(); @@ -2599,7 +2599,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { first_half_ne, first_half_nb, GGML_MAX_DIMS); bool inplace = true; float scale = -1; - aclnn_muls(ctx, acl_first_half_tensor, scale, nullptr, inplace); + aclnn_muls(ctx, acl_first_half_tensor.get(), scale, nullptr, inplace); } // TODO: n_dims < ne0 @@ -2620,14 +2620,15 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_cann_create_tensor(input_roll_buffer, ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS); - aclnn_mul(ctx, acl_input_roll_reshape_tensor, acl_minus_one_tensor, acl_input_roll_mul_scale_tensor); + aclnn_mul(ctx, acl_input_roll_reshape_tensor.get(), acl_minus_one_tensor.get(), + acl_input_roll_mul_scale_tensor.get()); // output void * output_fp32_buffer; if (src0->type == GGML_TYPE_F32) { - aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor); - aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor); - aclnn_add(ctx, acl_src, acl_input_roll_mul_scale_tensor, acl_dst); + aclnn_mul(ctx, acl_src.get(), acl_cos_reshape_tensor.get()); + aclnn_mul(ctx, acl_input_roll_mul_scale_tensor.get(), acl_sin_reshape_tensor.get()); + aclnn_add(ctx, acl_src.get(), acl_input_roll_mul_scale_tensor.get(), acl_dst.get()); // TODO: ne0 != n_dims in mode2 } else if (src0->type == GGML_TYPE_F16) { size_t input_fp32_nb[GGML_MAX_DIMS]; @@ -2648,10 +2649,10 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { output_fp32_buffer = fp32_allocator.get(); acl_tensor_ptr output_fp32_tensor = ggml_cann_create_tensor(output_fp32_buffer, ACL_FLOAT, sizeof(float), dst->ne, input_fp32_nb, GGML_MAX_DIMS); - aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor, input_fp32_tensor1); - aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor, input_fp32_tensor2); - aclnn_add(ctx, input_fp32_tensor1, input_fp32_tensor2, output_fp32_tensor); - aclnn_cast(ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16); + aclnn_mul(ctx, acl_src.get(), acl_cos_reshape_tensor.get(), input_fp32_tensor1.get()); + aclnn_mul(ctx, acl_input_roll_mul_scale_tensor.get(), acl_sin_reshape_tensor.get(), input_fp32_tensor2.get()); + aclnn_add(ctx, input_fp32_tensor1.get(), input_fp32_tensor2.get(), output_fp32_tensor.get()); + aclnn_cast(ctx, output_fp32_tensor.get(), acl_dst.get(), ACL_FLOAT16); } return; #endif diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index e52e050a81..d0cab0bcb9 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -145,26 +145,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name) include(CheckCXXSourceRuns) - function(check_arm_feature tag code) + macro(check_arm_feature tag feature code) set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+${tag}") check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag}) if (GGML_MACHINE_SUPPORTS_${tag}) - set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+${tag}" PARENT_SCOPE) + set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+${tag}") else() set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+no${tag}") check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag}) if (GGML_MACHINE_SUPPORTS_no${tag}) - set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+no${tag}" PARENT_SCOPE) + set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+no${tag}") + list(APPEND ARCH_FLAGS -U__ARM_FEATURE_${feature}) endif() endif() set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) - endfunction() + endmacro() - check_arm_feature(dotprod "#include \nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }") - check_arm_feature(i8mm "#include \nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }") - check_arm_feature(sve "#include \nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }") - check_arm_feature(sme "#include \n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }") + check_arm_feature(dotprod DOTPROD "#include \nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }") + check_arm_feature(i8mm MATMUL_INT8 "#include \nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }") + check_arm_feature(sve SVE "#include \nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }") + check_arm_feature(sme SME "#include \n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }") list(APPEND ARCH_FLAGS "${ARM_NATIVE_FLAG}${ARM_NATIVE_FLAG_FIX}") else() @@ -216,35 +217,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() endif() - # show enabled features - if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") - set(FEAT_INPUT_FILE "NUL") - else() - set(FEAT_INPUT_FILE "/dev/null") - endif() + message(STATUS "Checking for ARM features using flags:") + foreach(flag IN LISTS ARCH_FLAGS) + message(STATUS " ${flag}") + endforeach() - execute_process( - COMMAND ${CMAKE_C_COMPILER} ${ARCH_FLAGS} -dM -E - - INPUT_FILE ${FEAT_INPUT_FILE} - OUTPUT_VARIABLE ARM_FEATURE - RESULT_VARIABLE ARM_FEATURE_RESULT - ) - if (ARM_FEATURE_RESULT) - message(WARNING "Failed to get ARM features") - else() - foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME) - string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos) - if (NOT ${feature_pos} EQUAL -1) - # Special handling for MATMUL_INT8 when machine doesn't support i8mm - if ("${feature}" STREQUAL "MATMUL_INT8" AND GGML_MACHINE_SUPPORTS_noi8mm) - message(STATUS "ARM feature ${feature} detected but unsetting due to machine not supporting i8mm") - list(APPEND ARCH_FLAGS -U__ARM_FEATURE_MATMUL_INT8) - else() - message(STATUS "ARM feature ${feature} enabled") - endif() - endif() - endforeach() - endif() + include(CheckCXXSourceCompiles) + set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) + set(CMAKE_REQUIRED_FLAGS "${ARCH_FLAGS}") + foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME) + set(ARM_FEATURE "HAVE_${feature}") + check_cxx_source_compiles( + " + #if !defined(__ARM_FEATURE_${feature}) + # error \"Feature ${feature} is not defined\" + #endif + int main() { return 0; } + " + ${ARM_FEATURE} + ) + endforeach() + set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) endif() elseif (GGML_SYSTEM_ARCH STREQUAL "x86") message(STATUS "x86 detected") @@ -399,9 +392,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}") if (EXTRACTED_NUMBER GREATER_EQUAL 10) - list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64) + list(APPEND ARCH_FLAGS -mcpu=power10) elseif (EXTRACTED_NUMBER EQUAL 9) - list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64) + list(APPEND ARCH_FLAGS -mcpu=power9) elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native) else() diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 50612237c8..c1afde9627 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -384,7 +384,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char * src1_ddc = (char *) src1->data; const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1); - const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1; + const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && + src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0); if (src0->type == src1->type && contiguous_srcs) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8648e31413..685dd37d13 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3005,6 +3005,10 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope, const ggml_tensor * view, const ggml_tensor * set_rows) { + + if (rope->op != GGML_OP_ROPE || view->op != GGML_OP_VIEW || set_rows->op != GGML_OP_SET_ROWS) { + return false; + } // ne3 not tested if (rope->src[0]->ne[3] != 1) { return false; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index bb3eb977ce..f83dfdaef6 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -406,8 +406,8 @@ enum shader_reduction_mode { SHADER_REDUCTION_MODE_COUNT, }; +// argsort pipelines for up to 1<<10 invocations per workgroup static constexpr uint32_t num_argsort_pipelines = 11; -static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); static constexpr uint32_t num_topk_moe_pipelines = 10; static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, @@ -526,6 +526,7 @@ struct vk_device_struct { bool multi_add; bool shader_int64; bool buffer_device_address; + bool vulkan_memory_model; bool add_rms_fusion; uint32_t partials_binding_alignment; @@ -539,6 +540,9 @@ struct vk_device_struct { uint32_t subgroup_max_size; bool subgroup_require_full_support; + // floor(log2(maxComputeWorkGroupInvocations)) + uint32_t max_workgroup_size_log2 {}; + bool coopmat_support; bool coopmat_acc_f32_support {}; bool coopmat_acc_f16_support {}; @@ -638,6 +642,7 @@ struct vk_device_struct { vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32; vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT]; vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT]; vk_pipeline pipeline_norm_f32; @@ -664,6 +669,20 @@ struct vk_device_struct { vk_pipeline pipeline_hardsigmoid[2]; vk_pipeline pipeline_hardswish[2]; vk_pipeline pipeline_abs[2]; + vk_pipeline pipeline_softplus[2]; + vk_pipeline pipeline_step[2]; + vk_pipeline pipeline_round[2]; + vk_pipeline pipeline_ceil[2]; + vk_pipeline pipeline_floor[2]; + vk_pipeline pipeline_trunc[2]; + + vk_pipeline pipeline_add1_f16_f16; + vk_pipeline pipeline_add1_f16_f32; + vk_pipeline pipeline_add1_f32_f32; + + vk_pipeline pipeline_arange_f32; + + vk_pipeline pipeline_fill_f32; vk_pipeline pipeline_geglu[2]; vk_pipeline pipeline_reglu[2]; @@ -683,6 +702,7 @@ struct vk_device_struct { vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; + vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; @@ -1173,8 +1193,14 @@ struct vk_op_soft_max_push_constants { struct vk_op_argsort_push_constants { uint32_t ncols; + uint32_t ncols_padded; + uint32_t ncols_padded_log2; uint32_t nrows; - int32_t order; + uint32_t order; + uint32_t outer_start; + uint32_t outer_end; + uint32_t inner_start; + uint32_t inner_end; }; struct vk_op_im2col_push_constants { @@ -2901,15 +2927,15 @@ static void ggml_vk_load_shaders(vk_device& device) { if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } \ } \ @@ -3697,6 +3723,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); + if (device->float_controls_rte_fp16) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); @@ -3793,8 +3822,14 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32_rte", log_f32_rte_len, log_f32_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16_rte", log_f16_rte_len, log_f16_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + } else { + ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + } ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -3820,6 +3855,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY(hardsigmoid) CREATE_UNARY(hardswish) CREATE_UNARY(abs) + CREATE_UNARY(softplus) + CREATE_UNARY(step) + CREATE_UNARY(round) + CREATE_UNARY(ceil) + CREATE_UNARY(floor) + CREATE_UNARY(trunc) #undef CREATE_UNARY #define CREATE_UNARY_RTE(name) \ @@ -3833,6 +3874,14 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY_RTE(exp) #undef CREATE_UNARY_RTE + ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + #define CREATE_GLU(name) \ if (device->float_controls_rte_fp16) { \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ @@ -3885,7 +3934,15 @@ static void ggml_vk_load_shaders(vk_device& device) { } for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<max_workgroup_size_log2); + if (i <= device->max_workgroup_size_log2 && + 2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) { + const uint32_t NCOLS_PADDED_LOG2 = i; + ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true); + } + const uint32_t WG_UNROLL_FACTOR = BLOCK_SIZE > 1 ? 2 : 1; + BLOCK_SIZE /= WG_UNROLL_FACTOR; + ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); @@ -4286,6 +4343,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; + device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations))); + std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); // Try to find a non-graphics compute queue and transfer-focused queues @@ -4425,6 +4484,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->shader_int64 = device_features2.features.shaderInt64; device->buffer_device_address = vk12_features.bufferDeviceAddress; + device->vulkan_memory_model = vk12_features.vulkanMemoryModel; if (device->subgroup_size_control) { device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; @@ -6241,6 +6301,17 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const // Choose "contiguous copy" shader if src/dst are contiguous bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst)); + // Use optimized "transpose" shader if src dim1 is the innermost dimension. + bool transpose = dst && src->nb[1] == ggml_type_size(to) && ggml_are_same_shape(dst, src); + + if (transpose && src->type == to) { + if (ggml_type_size(to) == 4) { + return ctx->device->pipeline_cpy_transpose_32; + } else if (ggml_type_size(to) == 2) { + return ctx->device->pipeline_cpy_transpose_16; + } + } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) { if (contig) { return ctx->device->pipeline_contig_cpy_f32_f32; @@ -8236,6 +8307,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_ABS: return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_SOFTPLUS: + return ctx->device->pipeline_softplus[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_STEP: + return ctx->device->pipeline_step[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_ROUND: + return ctx->device->pipeline_round[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_CEIL: + return ctx->device->pipeline_ceil[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_FLOOR: + return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_TRUNC: + return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16]; default: break; } @@ -8338,19 +8421,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; } - case GGML_OP_ARGSORT: - if (ctx->num_additional_fused_ops) { - uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); - GGML_ASSERT(idx < num_topk_moe_pipelines); - topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); - return ctx->device->pipeline_topk_moe[idx][mode]; - } - - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { - uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); - return ctx->device->pipeline_argsort_f32[idx]; - } - return nullptr; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: @@ -8443,7 +8513,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const case GGML_OP_CONV_TRANSPOSE_2D: if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { - std::array elements; + std::array elements{}; if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst); else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst); vk_conv_shapes shape; @@ -8521,6 +8591,27 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } } return nullptr; + case GGML_OP_ADD1: + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_add1_f16_f16; + } + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_add1_f16_f32; + } + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_add1_f32_f32; + } + return nullptr; + case GGML_OP_ARANGE: + if (dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_arange_f32; + } + return nullptr; + case GGML_OP_FILL: + if (dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_fill_f32; + } + return nullptr; default: return nullptr; } @@ -8742,8 +8833,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); break; case GGML_OP_ARGSORT: - elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; - elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + GGML_ASSERT(0); break; case GGML_OP_IM2COL: { @@ -8811,6 +8901,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_SUB: case GGML_OP_DIV: case GGML_OP_MUL: + case GGML_OP_ADD1: + case GGML_OP_ARANGE: + case GGML_OP_FILL: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: @@ -8852,6 +8945,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } else { elements = { ne, 1, 1 }; } + + if (pipeline == ctx->device->pipeline_cpy_transpose_32 || + pipeline == ctx->device->pipeline_cpy_transpose_16) { + // 32x32 tiles + elements[0] = (uint32_t)CEIL_DIV(dst->ne[0], 32); + elements[1] = (uint32_t)CEIL_DIV(dst->ne[1], 32); + elements[2] = (uint32_t)(dst->ne[2]*dst->ne[3]); + elements[0] = std::min(elements[0], ctx->device->properties.limits.maxComputeWorkGroupCount[0]); + elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + } } break; case GGML_OP_ADD_ID: { @@ -9417,6 +9521,63 @@ static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, cons ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst)); } +static void ggml_vk_add1(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD1, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }); +} + +static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { + VK_LOG_DEBUG("ggml_vk_arange(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")"); + + vk_op_push_constants pc = { + (uint32_t)ggml_nelements(dst), + 1, + ggml_get_op_params_f32(dst, 0), + ggml_get_op_params_f32(dst, 2), + }; + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE); + GGML_ASSERT(pipeline != nullptr); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false); + + std::array elements = { (uint32_t)ggml_nelements(dst), 1, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements); +} + +static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { + VK_LOG_DEBUG("ggml_vk_fill(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")"); + + vk_op_push_constants pc = { + (uint32_t)ggml_nelements(dst), + 1, + ggml_get_op_params_f32(dst, 0), + 0.0f, + }; + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL); + GGML_ASSERT(pipeline != nullptr); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false); + + std::array elements = { (uint32_t)ggml_nelements(dst), 1, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements); +} + static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst)); } @@ -9859,16 +10020,89 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons } static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - int32_t * op_params = (int32_t *)dst->op_params; + const uint32_t * op_params = (const uint32_t *)dst->op_params; uint32_t ncols = src0->ne[0]; uint32_t nrows = ggml_nrows(src0); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, { - ncols, - nrows, - op_params[0], - }); + uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols))); + uint32_t ncolsp2 = 1 << ncols_pad_log2; + + vk_op_argsort_push_constants pc { ncols, ncolsp2, ncols_pad_log2, nrows, op_params[0], 0, 0, 0, 0, }; + + // Pick the largest workgroup size <= ncolsp2 + uint32_t pipeline_idx = std::min(ncols_pad_log2, num_argsort_pipelines - 1); + + // Use the "small" argsort shader if the whole sort can be done by a single workgroup. + bool use_small = ncols_pad_log2 <= ctx->device->max_workgroup_size_log2 && + ctx->device->pipeline_argsort_f32[pipeline_idx] != nullptr; + + vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[pipeline_idx] + : ctx->device->pipeline_argsort_large_f32[pipeline_idx]; + + vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer subbuf1 = dst_buf; + + // Reserve space for ivec2 per element, with rows padded to a power of two + if (!use_small) { + const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int); + + if (ctx->prealloc_size_x < x_sz) { + ctx->prealloc_size_x = x_sz; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size }; + } + + std::array elements; + + elements[0] = ncolsp2; + elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = 1; + + // First dispatch initializes tmp_idx and does the first N passes where + // there is only communication between threads in the same workgroup. + { + vk_op_argsort_push_constants pc2 = pc; + pc2.outer_start = 0; + pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2); + pc2.inner_start = 0; + pc2.inner_end = 100; + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements); + } + if (!use_small) { + ggml_vk_sync_buffers(ctx, subctx); + // Loop over outer/inner passes, synchronizing between each pass. + for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) { + for (uint32_t inner = 0; inner < outer + 1; ++inner) { + vk_op_argsort_push_constants pc2 = pc; + pc2.outer_start = outer; + pc2.outer_end = outer + 1; + pc2.inner_start = inner; + pc2.inner_end = inner + 1; + // When the inner idx is large enough, there's only communication + // within a workgroup. So the remaining inner iterations can all + // run in the same dispatch. + if (outer - inner < pipeline_idx) { + pc2.inner_end = 100; + inner = outer; + pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx]; + } else { + // Smaller workgroup empirically seems to perform better + pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx - 2]; + } + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements); + ggml_vk_sync_buffers(ctx, subctx); + } + } + ctx->prealloc_x_need_sync = true; + } } static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -11176,6 +11410,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_TRUNC: break; default: return false; @@ -11217,6 +11457,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_ADD1: + case GGML_OP_ARANGE: + case GGML_OP_FILL: case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_SCALE: @@ -11429,6 +11672,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_UPSCALE: ggml_vk_upscale(ctx, compute_ctx, src0, node); + break; + case GGML_OP_ADD1: + ggml_vk_add1(ctx, compute_ctx, src0, src1, node); + + break; + case GGML_OP_ARANGE: + ggml_vk_arange(ctx, compute_ctx, node); + + break; + case GGML_OP_FILL: + ggml_vk_fill(ctx, compute_ctx, node); + break; case GGML_OP_SCALE: ggml_vk_scale(ctx, compute_ctx, src0, node); @@ -11513,6 +11768,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_TRUNC: ggml_vk_unary(ctx, compute_ctx, src0, node); break; default: @@ -11715,6 +11976,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_ADD1: + case GGML_OP_ARANGE: + case GGML_OP_FILL: case GGML_OP_ADD_ID: case GGML_OP_CONCAT: case GGML_OP_UPSCALE: @@ -11786,6 +12050,12 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_TRUNC: buf = tensor->buffer; break; default: @@ -13388,6 +13658,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_TRUNC: return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && @@ -13638,10 +13914,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } // We can handle copying from a type to the same type if it's - // contiguous (memcpy). We use f16 or f32 shaders to do the copy, + // either not quantized or is quantized and contiguous. + // We use f16 or f32 shaders to do the copy, // so the type/block size must be a multiple of 4. if (src0_type == src1_type && - ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op) && + (!ggml_is_quantized(src0_type) || (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op))) && (ggml_type_size(src0_type) % 2) == 0) { return true; } @@ -13688,10 +13965,25 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_LOG: return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_OP_ARGSORT: - return op->ne[0] <= max_argsort_cols; + { + if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { + return false; + } + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + // pipeline_argsort_large_f32 requires vulkan memory model. + if (device->vulkan_memory_model) { + return true; + } else { + return op->ne[0] <= (1 << device->max_workgroup_size_log2); + } + } case GGML_OP_UPSCALE: case GGML_OP_ACC: case GGML_OP_CONCAT: + case GGML_OP_ADD1: + case GGML_OP_ARANGE: + case GGML_OP_FILL: case GGML_OP_SCALE: case GGML_OP_PAD: case GGML_OP_ROLL: @@ -14174,6 +14466,16 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } else if (tensor->op == GGML_OP_SCALE) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]); + } else if (tensor->op == GGML_OP_ADD1) { + tensor_clone = ggml_add1(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_ARANGE) { + const float start = ggml_get_op_params_f32(tensor, 0); + const float stop = ggml_get_op_params_f32(tensor, 1); + const float step = ggml_get_op_params_f32(tensor, 2); + tensor_clone = ggml_arange(ggml_ctx, start, stop, step); + } else if (tensor->op == GGML_OP_FILL) { + const float value = ggml_get_op_params_f32(tensor, 0); + tensor_clone = ggml_fill(ggml_ctx, tensor_clone, value); } else if (tensor->op == GGML_OP_SQR) { tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SQRT) { @@ -14287,6 +14589,24 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_UNARY_OP_ABS: tensor_clone = ggml_abs(ggml_ctx, src_clone[0]); break; + case GGML_UNARY_OP_SOFTPLUS: + tensor_clone = ggml_softplus(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_STEP: + tensor_clone = ggml_step(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_ROUND: + tensor_clone = ggml_round(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_CEIL: + tensor_clone = ggml_ceil(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_FLOOR: + tensor_clone = ggml_floor(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_TRUNC: + tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]); + break; default: std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp new file mode 100644 index 0000000000..db60725d4c --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp @@ -0,0 +1,28 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +#include "types.glsl" +#include "generic_binary_head.glsl" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset()])); + + idx += num_threads; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp b/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp new file mode 100644 index 0000000000..f4936eeada --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + // p.param1 = start, p.param2 = step + float value = p.param1 + p.param2 * float(i); + data_d[i] = D_TYPE(value); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp index c4e68bc023..0fc2b9b725 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -4,28 +4,27 @@ #include "types.glsl" layout(constant_id = 0) const int BLOCK_SIZE = 1024; -layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10; +layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10; #define ASC 0 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) buffer D {int data_d[];}; +layout (binding = 2) writeonly buffer D {int data_d[];}; layout (push_constant) uniform parameter { uint ncols; + uint ncols_padded; + uint ncols_padded_log2; uint nrows; uint order; + uint outer_start; + uint outer_end; + uint inner_start; + uint inner_end; } p; -shared int dst_row[BLOCK_SIZE]; -shared A_TYPE a_sh[BLOCK_SIZE]; - -void swap(uint idx0, uint idx1) { - int tmp = dst_row[idx0]; - dst_row[idx0] = dst_row[idx1]; - dst_row[idx1] = tmp; -} +shared ivec2 dst_row[BLOCK_SIZE]; void argsort(bool needs_bounds_check, const uint row) { // bitonic sort @@ -34,11 +33,10 @@ void argsort(bool needs_bounds_check, const uint row) { const uint row_offset = row * p.ncols; // initialize indices - dst_row[col] = col; - a_sh[col] = data_a[row_offset + col]; + dst_row[col] = ivec2(col, floatBitsToInt(data_a[row_offset + col])); barrier(); - uint num_outer_loop_iters = BLOCK_SIZE_LOG2; + uint num_outer_loop_iters = NCOLS_PADDED_LOG2; [[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) { uint num_inner_loop_iters = outer_idx + 1; [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) { @@ -47,14 +45,15 @@ void argsort(bool needs_bounds_check, const uint row) { int idx_0 = (col & k) == 0 ? col : ixj; int idx_1 = (col & k) == 0 ? ixj : col; - int sh_idx_0 = dst_row[idx_0]; - int sh_idx_1 = dst_row[idx_1]; - bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false; - bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false; + ivec2 sh_idx_0 = dst_row[idx_0]; + ivec2 sh_idx_1 = dst_row[idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false; if ((idx_0_oob || - (!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) { - swap(idx_0, idx_1); + (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y))) && (ixj > col)) { + dst_row[idx_0] = sh_idx_1; + dst_row[idx_1] = sh_idx_0; } barrier(); @@ -63,9 +62,9 @@ void argsort(bool needs_bounds_check, const uint row) { if (col < p.ncols) { if (p.order == ASC) { - data_d[row_offset + col] = dst_row[col]; + data_d[row_offset + col] = dst_row[col].x; } else { - data_d[row_offset + p.ncols - col - 1] = dst_row[col]; + data_d[row_offset + p.ncols - col - 1] = dst_row[col].x; } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp new file mode 100644 index 0000000000..920bac6bb8 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp @@ -0,0 +1,114 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_memory_scope_semantics : enable +#pragma use_vulkan_memory_model + +#include "types.glsl" + +layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int WG_UNROLL_FACTOR = 2; +#define ASC 0 + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) workgroupcoherent buffer B {ivec2 tmp_idx[];}; +layout (binding = 2) workgroupcoherent buffer D {int data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + uint ncols_padded; + uint ncols_padded_log2; + uint nrows; + uint order; + uint outer_start; + uint outer_end; + uint inner_start; + uint inner_end; +} p; + +void argsort(bool needs_bounds_check, const uint row) { + // bitonic sort + int col = int(gl_GlobalInvocationID.x); + col = (col % BLOCK_SIZE) + (col / BLOCK_SIZE) * BLOCK_SIZE * WG_UNROLL_FACTOR; + + const uint row_offset = row * p.ncols; + uint idx_offset = row * p.ncols_padded; + + bool need_barrier = false; + + // initialize indices + if (p.outer_start == 0 && p.inner_start == 0) { + [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) { + uint c = u*BLOCK_SIZE + col; + if (c < p.ncols_padded) { + ivec2 v = ivec2(c, floatBitsToInt(data_a[row_offset + c])); + tmp_idx[idx_offset + c] = v; + } + } + need_barrier = true; + } + + [[unroll]] for (uint outer_idx = p.outer_start, k = (2 << outer_idx); outer_idx < p.outer_end; k *= 2, outer_idx++) { + uint inner_end = min(p.inner_end, outer_idx + 1); + for (uint j = k >> (p.inner_start + 1), inner_idx = p.inner_start; inner_idx < inner_end; j /= 2, inner_idx++) { + if (need_barrier) { + controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease); + } + need_barrier = true; + [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) { + int c = u*BLOCK_SIZE + col; + const int ixj = int(c ^ j); + + if (ixj < c) { + continue; + } + + int idx_0 = (c & k) == 0 ? c : ixj; + int idx_1 = (c & k) == 0 ? ixj : c; + + ivec2 sh_idx_0 = tmp_idx[idx_offset + idx_0]; + ivec2 sh_idx_1 = tmp_idx[idx_offset + idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false; + + if ((idx_0_oob || + (!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y)))) { + tmp_idx[idx_offset + idx_0] = sh_idx_1; + tmp_idx[idx_offset + idx_1] = sh_idx_0; + } + } + } + } + + if (p.outer_end == p.ncols_padded_log2 && + p.inner_end >= p.ncols_padded_log2 + 1) { + controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease); + [[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) { + uint c = u*BLOCK_SIZE + col; + if (c < p.ncols) { + if (p.order == ASC) { + data_d[row_offset + c] = tmp_idx[idx_offset + c].x; + } else { + data_d[row_offset + p.ncols - c - 1] = tmp_idx[idx_offset + c].x; + } + } + } + } +} + +void main() { + if (p.ncols == p.ncols_padded) { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + argsort(false, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } else { + uint row = gl_WorkGroupID.y; + while (row < p.nrows) { + argsort(true, row); + row += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp new file mode 100644 index 0000000000..0028d3721d --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(ceil(x)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp new file mode 100644 index 0000000000..220ccc9111 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp @@ -0,0 +1,67 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +// workgroup does 32x32 tile, but uses 32x8 threads +#define TILE_DIM 32 +layout(local_size_x = 32, local_size_y = 8, local_size_z = 1) in; + +shared uint sh[TILE_DIM][TILE_DIM + 1]; + +void iter(uvec3 wg_id) { + const uint tile_col = wg_id.x; + const uint tile_row = wg_id.y; + + const uint tid_col = gl_LocalInvocationID.x; + const uint tid_row = gl_LocalInvocationID.y; + + const uint i2 = wg_id.z % p.ne12; + const uint i3 = wg_id.z / p.ne12; + const uint i02 = i2; + const uint i03 = i3; + + // The workgroup does TILE_DIM x TILE_DIM, but swaps the LSBs of the + // src coords to make memory accesses contiguous, dst has tid.x in i0, + // src has tid.x in i01 + + [[unroll]] for (uint y = 0; y < 4; ++y) { + const uint i00 = tile_col * TILE_DIM + tid_row + 8 * y; + const uint i01 = tile_row * TILE_DIM + tid_col; + if (i00 < p.ne00 && i01 < p.ne01 && i02 < p.ne02 && i03 < p.ne03) { + const uint src_idx = i00 * p.nb00 + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + sh[tid_row + 8 * y][tid_col] = uint(data_a[get_aoffset() + src_idx]); + } + } + + barrier(); + + [[unroll]] for (uint y = 0; y < 4; ++y) { + const uint i0 = tile_col * TILE_DIM + tid_col; + const uint i1 = tile_row * TILE_DIM + tid_row + 8 * y; + if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) { + const uint dst_idx = i0 * p.nb10 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + // load transposed + data_d[get_doffset() + dst_idx] = D_TYPE(sh[tid_col][tid_row + 8 * y]); + } + } +} + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +void main() { + uint z = gl_WorkGroupID.z; + uint y = gl_WorkGroupID.y; + bool need_barrier = false; + for (uint z = gl_WorkGroupID.z; z < p.ne12 * p.ne13; z += gl_NumWorkGroups.z) { + for (uint y = gl_WorkGroupID.y; y < CEIL_DIV(p.ne11, TILE_DIM); y += gl_NumWorkGroups.y) { + for (uint x = gl_WorkGroupID.x; x < CEIL_DIV(p.ne10, TILE_DIM); x += gl_NumWorkGroups.x) { + if (need_barrier) { + barrier(); + } + need_barrier = true; + iter(uvec3(x, y, z)); + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp b/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp new file mode 100644 index 0000000000..a56be76c61 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp @@ -0,0 +1,19 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + // p.param1 = fill value + data_d[i] = D_TYPE(p.param1); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp b/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp new file mode 100644 index 0000000000..20017eb184 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(floor(x)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/log.comp b/ggml/src/ggml-vulkan/vulkan-shaders/log.comp index 4aef724e90..ff2812d3d7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/log.comp @@ -1,5 +1,6 @@ #version 450 +#include "rte.glsl" #include "types.glsl" #include "generic_unary_head.glsl" @@ -12,6 +13,6 @@ void main() { return; } - const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + const float val = float(data_a[get_aoffset() + src0_idx(idx)]); data_d[get_doffset() + dst_idx(idx)] = D_TYPE(log(val)); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/round.comp b/ggml/src/ggml-vulkan/vulkan-shaders/round.comp new file mode 100644 index 0000000000..e6155dcbf3 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/round.comp @@ -0,0 +1,29 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + float result; + // Round halfway cases away from zero as roundf does. + if (x >= 0.0) { + result = floor(x + 0.5); + } else { + result = ceil(x - 0.5); + } + data_d[i] = D_TYPE(result); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp b/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp new file mode 100644 index 0000000000..323e3cdea4 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp @@ -0,0 +1,23 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + const float result = (x > 20.0f) ? x : log(1.0f + exp(x)); + data_d[i] = D_TYPE(result); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/step.comp b/ggml/src/ggml-vulkan/vulkan-shaders/step.comp new file mode 100644 index 0000000000..654a2124e0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/step.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(x >= 0.0f ? 1.0f : 0.0f); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp new file mode 100644 index 0000000000..cf1b76d3bb --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(trunc(x)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 8a2ce321dc..bc992068f8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -734,6 +734,9 @@ void process_shaders() { string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); + string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}}); + string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}}); + for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); @@ -802,9 +805,6 @@ void process_shaders() { string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("log_f32", "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("log_f16", "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -819,6 +819,9 @@ void process_shaders() { std::string suffix = rte ? "_rte" : ""; string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}}); + + string_to_spv("log_f16" + suffix, "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("log_f32" + suffix, "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); } string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -843,6 +846,25 @@ void process_shaders() { string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("add1_f16_f16", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("add1_f16_f32", "add1.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("arange_f32", "arange.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("fill_f32", "fill.comp", {{"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("step_f16", "step.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("step_f32", "step.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("round_f16", "round.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("round_f32", "round.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("ceil_f16", "ceil.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("ceil_f32", "ceil.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("floor_f16", "floor.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("floor_f32", "floor.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("trunc_f16", "trunc.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + for (auto rte : {false, true}) { std::string suffix = rte ? "_rte" : ""; string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); @@ -889,6 +911,7 @@ void process_shaders() { string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); + string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}}); string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); diff --git a/models/templates/GLM-4.6.jinja b/models/templates/GLM-4.6.jinja new file mode 100644 index 0000000000..4a09db53c8 --- /dev/null +++ b/models/templates/GLM-4.6.jinja @@ -0,0 +1,106 @@ +[gMASK] +{%- if tools -%} +<|system|> +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{% for tool in tools %} +{{ tool | tojson(ensure_ascii=False) }} +{% endfor %} + + +For each function call, output the function name and arguments within the following XML format: +{function-name} +{arg-key-1} +{arg-value-1} +{arg-key-2} +{arg-value-2} +... +{%- endif -%} +{%- macro visible_text(content) -%} + {%- if content is string -%} + {{- content }} + {%- elif content is iterable and content is not mapping -%} + {%- for item in content -%} + {%- if item is mapping and item.type == 'text' -%} + {{- item.text }} + {%- elif item is string -%} + {{- item }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{- content }} + {%- endif -%} +{%- endmacro -%} +{%- set ns = namespace(last_user_index=-1) %} +{%- for m in messages %} + {%- if m.role == 'user' %} + {% set ns.last_user_index = loop.index0 -%} + {%- endif %} +{%- endfor %} +{% for m in messages %} +{%- if m.role == 'user' -%}<|user|> +{{ visible_text(m.content) }} +{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}} +{%- elif m.role == 'assistant' -%} +<|assistant|> +{%- set reasoning_content = '' %} +{%- set content = visible_text(m.content) %} +{%- if m.reasoning_content is string %} + {%- set reasoning_content = m.reasoning_content %} +{%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} +{%- endif %} +{%- if loop.index0 > ns.last_user_index and reasoning_content -%} +{{ '\n' + reasoning_content.strip() + ''}} +{%- else -%} +{{ '\n' }} +{%- endif -%} +{%- if content.strip() -%} +{{ '\n' + content.strip() }} +{%- endif -%} +{% if m.tool_calls %} +{% for tc in m.tool_calls %} +{%- if tc.function %} + {%- set tc = tc.function %} +{%- endif %} +{{ '\n' + tc.name }} +{% set _args = tc.arguments or {} %} +{% if _args is not mapping %} + {{ raise_exception("Invalid tool call arguments passed: " + _args | string) }} +{% endif %} +{% for k, v in _args.items() %} +{{ k }} +{{ v | tojson(ensure_ascii=False) if v is not string else v }} +{% endfor %} +{% endfor %} +{% endif %} +{%- elif m.role == 'tool' -%} +{%- if m.content is string -%} +{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|observation|>' }} +{%- endif %} +{{- '\n\n' }} +{{- m.content }} +{{- '\n' }} +{%- else -%} +<|observation|>{% for tr in m.content %} + + +{{ tr.output if tr.output is defined else tr }} +{% endfor -%} +{% endif -%} +{%- elif m.role == 'system' -%} +<|system|> +{{ visible_text(m.content) }} +{%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt -%} + <|assistant|>{{- '\n' if (enable_thinking is defined and not enable_thinking) else '' -}} +{%- endif -%} diff --git a/models/templates/Kimi-K2-Instruct.jinja b/models/templates/Kimi-K2-Instruct.jinja new file mode 100644 index 0000000000..a9439135ba --- /dev/null +++ b/models/templates/Kimi-K2-Instruct.jinja @@ -0,0 +1,64 @@ +{% macro render_content(msg) -%} + {%- set c = msg.get('content') -%} + {%- if c is string -%} + {{ c }} + {%- elif c is not none -%} + {% for content in c -%} + {% if content['type'] == 'image' or 'image' in content or 'image_url' in content -%} + <|media_start|>image<|media_content|><|media_pad|><|media_end|> + {% else -%} + {{ content['text'] }} + {%- endif -%} + {%- endfor -%} + {%- endif -%} +{%- endmacro %} + +{%- set tool_response_queue = namespace(ids=[]) -%} +{%- set tool_call_counter = namespace(value=1) -%} + +{%- if tools -%} + <|im_system|>tool_declare<|im_middle|>{{ tools | tojson }}<|im_end|> +{%- endif -%} +{% for message in messages %} + {%- if loop.first and messages[0]['role'] != 'system' -%} + <|im_system|>system<|im_middle|>You are Kimi, an AI assistant created by Moonshot AI.<|im_end|> + {% endif %} + + {%- set role_name = message.get('name') or message['role'] -%} + {%- if message['role'] == 'user' -%} + <|im_user|>{{role_name}}<|im_middle|> + {%- elif message['role'] == 'assistant' -%} + <|im_assistant|>{{role_name}}<|im_middle|> + {%- else -%} + <|im_system|>{{role_name}}<|im_middle|> + {%- endif -%} + + {%- if message['role'] == 'assistant' and message.get('tool_calls') -%} + {{render_content(message)}}<|tool_calls_section_begin|> + {%- for tool_call in message['tool_calls'] -%} + {%- if tool_call['id'] is defined -%} + {%- set formatted_id = tool_call['id'] -%} + {%- else -%} + {%- set formatted_id = 'functions.' + tool_call['function']['name'] + ':' + (tool_call_counter.value | string) -%} + {%- set tool_call_counter.value = tool_call_counter.value + 1 -%} + {%- endif -%} + {%- set _ = tool_response_queue.ids.append(formatted_id) -%} + <|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|> + {%- endfor -%} + <|tool_calls_section_end|> + {%- elif message['role'] == 'tool' -%} + {%- if tool_response_queue.ids -%} + {%- set tool_call_id = tool_response_queue.ids.pop(0) -%} + {%- else -%} + {%- set tool_call_id = 'functions.' + message.get('name', 'unknown') + ':' + (tool_call_counter.value | string) -%} + {%- endif -%} + ## Return of {{ tool_call_id }} +{{render_content(message)}} + {%- elif message['content'] is not none -%} + {{render_content(message)}} + {%- endif -%} + <|im_end|> +{%- endfor -%} +{%- if add_generation_prompt -%} + <|im_assistant|>assistant<|im_middle|> +{%- endif -%} diff --git a/models/templates/Kimi-K2-Thinking.jinja b/models/templates/Kimi-K2-Thinking.jinja new file mode 100644 index 0000000000..4c2af6a783 --- /dev/null +++ b/models/templates/Kimi-K2-Thinking.jinja @@ -0,0 +1,112 @@ +{%- macro render_content(msg) -%} + {%- set c = msg.get('content') -%} + {%- if c is string -%} + {{ c }} + {%- elif c is not none -%} + {% for content in c -%} + {% if content['type'] == 'image' or 'image' in content or 'image_url' in content -%} + <|media_start|>image<|media_content|><|media_pad|><|media_end|> + {% else -%} + {{ content['text'] }} + {%- endif -%} + {%- endfor -%} + {%- endif -%} +{%- endmacro -%} + +{% macro set_roles(message) -%} + {%- set role_name = message.get('name') or message['role'] -%} + {%- if message['role'] == 'user' -%} + <|im_user|>{{role_name}}<|im_middle|> + {%- elif message['role'] == 'assistant' -%} + <|im_assistant|>{{role_name}}<|im_middle|> + {%- else -%} + <|im_system|>{{role_name}}<|im_middle|> + {%- endif -%} +{%- endmacro -%} + +{%- set tool_response_queue = namespace(ids=[]) -%} +{%- set tool_call_counter = namespace(value=1) -%} + +{%- macro render_toolcalls(message) -%} + <|tool_calls_section_begin|> + {%- for tool_call in message['tool_calls'] -%} + {%- if tool_call['id'] is defined -%} + {%- set formatted_id = tool_call['id'] -%} + {%- else -%} + {%- set formatted_id = 'functions.' + tool_call['function']['name'] + ':' + (tool_call_counter.value | string) -%} + {%- set tool_call_counter.value = tool_call_counter.value + 1 -%} + {%- endif -%} + {%- set _ = tool_response_queue.ids.append(formatted_id) -%} + <|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|> + {%- endfor -%} + <|tool_calls_section_end|> +{%- endmacro -%} + + +{# Find last non-tool-call assisitant message #} +{%- set ns = namespace(last_non_tool_call_assistant_msg=-1) -%} +{%- for idx in range(messages|length-1, -1, -1) -%} + {%- if messages[idx]['role'] == 'assistant' and not messages[idx].get('tool_calls') -%} + {%- set ns.last_non_tool_call_assistant_msg = idx -%} + {%- endif -%} +{%- endfor -%} + +{# split all messages into history & suffix, reasoning_content in suffix should be reserved.#} +{%- set hist_msgs = messages[:ns.last_non_tool_call_assistant_msg+1] -%} +{%- set suffix_msgs = messages[ns.last_non_tool_call_assistant_msg+1:] -%} + +{%- if tools -%} + <|im_system|>tool_declare<|im_middle|>{{ tools | tojson }}<|im_end|> +{%- endif -%} + +{%- if messages|length == 0 or messages[0]['role'] != 'system' -%} + <|im_system|>system<|im_middle|>You are Kimi, an AI assistant created by Moonshot AI.<|im_end|> +{%- endif -%} + +{%- for message in hist_msgs -%} + {{set_roles(message)}} + {%- if message['role'] == 'assistant' -%} + {{render_content(message)}} + {%- if message.get('tool_calls') -%} + {{render_toolcalls(message)}} + {%- endif -%} + {%- elif message['role'] == 'tool' -%} + {%- if tool_response_queue.ids -%} + {%- set tool_call_id = tool_response_queue.ids.pop(0) -%} + {%- else -%} + {%- set tool_call_id = 'functions.' + message.get('name', 'unknown') + ':' + (tool_call_counter.value | string) -%} + {%- endif -%} + ## Return of {{ tool_call_id }} +{{render_content(message)}} + {%- elif message['content'] is not none -%} + {{render_content(message)}} + {%- endif -%} + <|im_end|> +{%- endfor -%} + +{%- for message in suffix_msgs -%} + {{set_roles(message)}} + {%- if message['role'] == 'assistant' -%} + {%- set rc = message.get('reasoning_content', '') -%} + {{rc}}{{render_content(message)}} + {%- if message.get('tool_calls') -%} + {{render_toolcalls(message)}} + {%- endif -%} + {%- elif message['role'] == 'tool' -%} + {%- if tool_response_queue.ids -%} + {%- set tool_call_id = tool_response_queue.ids.pop(0) -%} + {%- else -%} + {%- set tool_call_id = 'functions.' + message.get('name', 'unknown') + ':' + (tool_call_counter.value | string) -%} + {%- endif -%} + ## Return of {{ tool_call_id }} +{{render_content(message)}} + {%- elif message['content'] is not none -%} + {{render_content(message)}} + {%- endif -%} + <|im_end|> +{%- endfor -%} + + +{%- if add_generation_prompt -%} + <|im_assistant|>assistant<|im_middle|> +{%- endif -%} diff --git a/models/templates/MiMo-VL.jinja b/models/templates/MiMo-VL.jinja new file mode 100644 index 0000000000..9c1b1696a4 --- /dev/null +++ b/models/templates/MiMo-VL.jinja @@ -0,0 +1,54 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0]['role'] == 'system' %} + {{- messages[0]['content'] }} + {%- else %} + {{- 'You are MiMo, an AI assistant developed by Xiaomi.' }} + {%- endif %} + {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0]['role'] == 'system' %} + {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }} + {%- else %} + {{- '<|im_start|>system\nYou are MiMo, an AI assistant developed by Xiaomi.<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- if message.content %} + {{- '\n' + message.content }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments | tojson }} + {{- '}\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/models/templates/MiniMax-M2.jinja b/models/templates/MiniMax-M2.jinja new file mode 100644 index 0000000000..9302ccedb2 --- /dev/null +++ b/models/templates/MiniMax-M2.jinja @@ -0,0 +1,159 @@ +{# ----------‑‑‑ special token variables ‑‑‑---------- #} +{%- set toolcall_begin_token = '' -%} +{%- set toolcall_end_token = '' -%} +{#- Tool Rendering Functions ============================================== -#} +{%- macro render_tool_namespace(namespace_name, tool_list) -%} +{%- for tool in tool_list -%} +{{ tool.function | tojson(ensure_ascii=False) }} +{% endfor -%} +{%- endmacro -%} +{%- macro visible_text(content) -%} + {%- if content is string -%} + {{ content }} + {%- elif content is iterable and content is not mapping -%} + {%- for item in content -%} + {%- if item is mapping and item.type == 'text' -%} + {{- item.text }} + {%- elif item is string -%} + {{- item }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{- content }} + {%- endif -%} +{%- endmacro -%} +{#- System Message Construction ============================================ -#} +{%- macro build_system_message(system_message) -%} + {%- if system_message and system_message.content -%} + {{- visible_text(system_message.content) }} + {%- else -%} + {%- if model_identity is not defined -%} + {%- set model_identity = "You are a helpful assistant." -%} + {%- endif -%} + {{- model_identity }} + {%- endif -%} + + {#- Handle current_date -#} + {%- if system_message and system_message.current_date -%} + {{- '\n' ~ 'Current date: ' + system_message.current_date }} + {%- endif -%} + {#- Handle current_location -#} + {%- if system_message and system_message.current_location -%} + {{- '\n' ~ 'Current location: ' + system_message.current_location }} + {%- endif -%} +{%- endmacro -%} +{#- Main Template Logic ================================================= -#} +{#- Extract system message (only first message if it's system) -#} +{%- set system_message = none -%} +{%- set conversation_messages = messages -%} +{%- if messages and messages[0].role == "system" -%} + {%- set system_message = messages[0] -%} + {%- set conversation_messages = messages[1:] -%} +{%- endif -%} +{#- Get the last user message turn, for interleved thinking -#} +{%- set ns = namespace(last_user_index=-1) %} +{% for m in conversation_messages %} + {%- if m.role == 'user' %} + {% set ns.last_user_index = loop.index0 -%} + {%- endif %} +{%- endfor %} +{#- Render system message -#} +{{- ']~!b[' ~ ']~b]system' ~ '\n' }} +{{- build_system_message(system_message) }} +{#- Render tools if available -#} +{%- if tools -%} + {{- '\n\n' ~ '# Tools' ~ '\n' ~ 'You may call one or more tools to assist with the user query.\nHere are the tools available in JSONSchema format:' ~ '\n' }} + {{- '\n' ~ '' ~ '\n' }} + {{- render_tool_namespace("functions", tools) }} + {{- '' ~ '\n\n' }} +{{- 'When making tool calls, use XML format to invoke tools and pass parameters:' ~ '\n' }} +{{- '\n' ~ toolcall_begin_token }} + +param-value-1 +param-value-2 +... + +{{- '\n' ~ toolcall_end_token }} +{%- endif -%} +{{- '[e~[\n' }} + +{#- Render messages -#} +{%- set last_tool_call = namespace(name=none) -%} +{%- for message in conversation_messages -%} + {%- if message.role == 'assistant' -%} + {#- Only render reasoning_content if no user message follows -#} + {{- ']~b]ai' ~ '\n' }} + + {%- set reasoning_content = '' %} + {%- set content = visible_text(message.content) %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].strip('\n').split('')[-1].strip('\n') %} + {%- set content = content.split('')[-1].strip('\n') %} + {%- endif %} + {%- endif %} + {%- if reasoning_content and loop.index0 > ns.last_user_index -%} + {{- '' ~ '\n' ~ reasoning_content ~ '\n' ~ '' ~ '\n\n' }} + {%- endif -%} + {%- if content -%} + {{- content }} + {%- endif -%} + {%- if message.tool_calls -%} + {{- '\n' ~ toolcall_begin_token ~ '\n' }} + + {%- for tool_call in message.tool_calls -%} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '' }} + {% set _args = tool_call.arguments %} + {%- for k, v in _args.items() %} + {{- '' }} + {{- v | tojson(ensure_ascii=False) if v is not string else v }} + {{- '' }} + {% endfor %} + {{- '' ~ '\n' }} + {%- endfor -%} + + {{- toolcall_end_token}} + {%- set last_tool_call.name = message.tool_calls[-1].function.name -%} + {%- else -%} + {%- set last_tool_call.name = none -%} + {%- endif -%} + {{- '[e~[' ~ '\n' }} + + {%- elif message.role == 'tool' -%} + {%- if last_tool_call.name is none -%} + {{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }} + {%- endif -%} + {%- if loop.first or (conversation_messages[loop.index0 - 1].role != 'tool') -%} + {{- ']~b]tool' }} + {%- endif -%} + {%- if message.content is string -%} + {{- '\n' }} + {{- message.content }} + {{- '' }} + {%- else -%} + {%- for tr in message.content -%} + {{- '\n' }} + {{- tr.output if tr.output is defined else (tr.text if tr.type == 'text' and tr.text is defined else tr) }} + {{- '\n' }} + {%- endfor -%} + {%- endif -%} + {%- if loop.last or (conversation_messages[loop.index0 + 1].role != 'tool') -%} + {{- '[e~[\n' -}} + {%- endif -%} + + {%- elif message.role == 'user' -%} + {{- ']~b]user' ~ '\n' }} + {{- visible_text(message.content) }} + {{- '[e~[' ~ '\n' }} + {%- endif -%} +{%- endfor -%} + +{#- Generation prompt -#} +{%- if add_generation_prompt -%} +{{- ']~b]ai' ~ '\n' ~ '' ~ '\n' }} +{%- endif -%} diff --git a/models/templates/Qwen3-Coder.jinja b/models/templates/Qwen3-Coder.jinja new file mode 100644 index 0000000000..49b0e8d0ee --- /dev/null +++ b/models/templates/Qwen3-Coder.jinja @@ -0,0 +1,117 @@ +{% macro render_extra_keys(json_dict, handled_keys) %} + {%- if json_dict is mapping %} + {%- for json_key in json_dict if json_key not in handled_keys %} + {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %} + {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '' }} + {%- else %} + {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '' }} + {%- endif %} + {%- endfor %} + {%- endif %} +{% endmacro %} + +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = [] %} +{%- endif %} + +{%- if system_message is defined %} + {{- "<|im_start|>system\n" + system_message }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }} + {%- endif %} +{%- endif %} +{%- if tools is iterable and tools | length > 0 %} + {{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }} + {{- "" }} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- "\n\n" ~ tool.name ~ "" }} + {%- if tool.description is defined %} + {{- '\n' ~ (tool.description | trim) ~ '' }} + {%- endif %} + {{- '\n' }} + {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- '\n' }} + {{- '\n' ~ param_name ~ '' }} + {%- if param_fields.type is defined %} + {{- '\n' ~ (param_fields.type | string) ~ '' }} + {%- endif %} + {%- if param_fields.description is defined %} + {{- '\n' ~ (param_fields.description | trim) ~ '' }} + {%- endif %} + {%- set handled_keys = ['name', 'type', 'description'] %} + {{- render_extra_keys(param_fields, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {% set handled_keys = ['type', 'properties'] %} + {{- render_extra_keys(tool.parameters, handled_keys) }} + {{- '\n' }} + {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %} + {{- render_extra_keys(tool, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} +{%- endif %} +{%- if system_message is defined %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in loop_messages %} + {%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %} + {{- '<|im_start|>' + message.role }} + {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %} + {{- '\n' + message.content | trim + '\n' }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n\n' }} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user\n' }} + {%- endif %} + {{- '\n' }} + {{- message.content }} + {{- '\n\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/models/templates/unsloth-Apriel-1.5.jinja b/models/templates/unsloth-Apriel-1.5.jinja new file mode 100644 index 0000000000..29e582fbf6 --- /dev/null +++ b/models/templates/unsloth-Apriel-1.5.jinja @@ -0,0 +1,126 @@ +{# Unsloth template fixes #} +{%- set available_tools_string = '' -%} +{%- set thought_instructions = '' -%} +{%- set add_tool_id = true -%} +{%- set tool_output_format = "default" -%} +{%- if tools is not none and tools|length > 0 -%} + {%- set available_tools_string -%} +You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about the arguments. You should infer the argument values from previous user responses and the system message. Here are the available tools: + +{% for tool in tools %} +{{ tool|string }} +{% endfor %} + +{%- endset -%} +{%- endif -%} +{%- if tool_output_format is none or tool_output_format == "default" -%} +{%- set tool_output_instructions -%} +Return all function calls as a list of json objects within XML tags. Each json object should contain a function name and arguments as follows: +[{"name": , "arguments": }, {"name": , "arguments": },...] +{%- endset -%} +{%- elif tool_output_format == "yaml" -%} +{%- set tool_output_instructions -%} +Return all function calls as a list of yaml objects within XML tags. Each yaml object should contain a function name and arguments as follows: + +- name: + arguments: +- name: + arguments: +... + +{%- endset -%} +{%- endif -%} +{%- if add_thoughts -%} +{%- set thought_instructions -%} +Prior to generating the function calls, you should generate the reasoning for why you're calling the function. Please generate these reasoning thoughts between and XML tags. +{%- endset -%} +{%- endif -%} +{{- bos_token -}} +{%- set reasoning_prompt='You are a thoughtful and systematic AI assistant built by ServiceNow Language Models (SLAM) lab. Before providing an answer, analyze the problem carefully and present your reasoning step by step. After explaining your thought process, provide the final solution in the following format: [BEGIN FINAL RESPONSE] ... [END FINAL RESPONSE].' -%} +{%- if messages[0]['role'] != 'system' and tools is not none and tools|length > 0 -%} + {{- '<|system|>\n' + reasoning_prompt + available_tools_string + "\n" + tool_output_instructions + '\n<|end|>\n' -}} +{%- endif -%} +{%- if messages|selectattr('role', 'equalto', 'system')|list|length == 0 -%} +{{- '<|system|>\n' + reasoning_prompt + '\n<|end|>\n' -}} +{%- endif -%} +{%- for message in messages -%} + {%- if message['role'] == 'user' -%} + {{- '<|user|>\n' }} + {%- if message['content'] is not string %} + {%- for chunk in message['content'] %} + {%- if chunk['type'] == 'text' %} + {{- chunk['text'] }} + {%- elif chunk['type'] == 'image' or chunk['type'] == 'image_url'%} + {{- '[IMG]' }} + {%- else %} + {{- raise_exception('Unrecognized content type!') }} + {%- endif %} + {%- endfor %} + {%- else %} + {{- message['content'] }} + {%- endif %} + {{- '\n<|end|>\n' }} + {%- elif message['role'] == 'content' -%} + {%- if message['content'] is not string %} + {{- '<|content|>\n' + message['content'][0]['text'] + '\n<|end|>\n' -}} + {%- else %} + {{- '<|content|>\n' + message['content'] + '\n<|end|>\n' -}} + {%- endif -%} + {%- elif message['role'] == 'system' -%} + {%- if message['content'] is not none and message['content']|length > 0 %} + {%- if message['content'] is string %} + {%- set system_message = message['content'] %} + {%- else %} + {%- set system_message = message['content'][0]['text'] %} + {%- endif %} + {%- else %} + {%- set system_message = '' %} + {%- endif %} + {%- if tools is not none and tools|length > 0 -%} + {{- '<|system|>\n' + reasoning_prompt + system_message + '\n' + available_tools_string + '\n<|end|>\n' -}} + {%- else -%} + {{- '<|system|>\n' + reasoning_prompt + system_message + '\n<|end|>\n' -}} + {%- endif -%} + {%- elif message['role'] == 'assistant' -%} + {%- if loop.last -%} + {%- set add_tool_id = false -%} + {%- endif -%} + {{- '<|assistant|>\n' -}} + {%- if message['content'] is not none and message['content']|length > 0 -%} + {%- if message['content'] is not string and message['content'][0]['text'] is not none %} + {{- message['content'][0]['text'] }} + {%- else %} + {{- message['content'] -}} + {%- endif -%} + {%- elif message['chosen'] is not none and message['chosen']|length > 0 -%} + {{- message['chosen'][0] -}} + {%- endif -%} + {%- if add_thoughts and 'thought' in message and message['thought'] is not none -%} + {{- '' + message['thought'] + '' -}} + {%- endif -%} + {%- if message['tool_calls'] is not none and message['tool_calls']|length > 0 -%} + {{- '\n[' -}} + {%- for tool_call in message["tool_calls"] -%} + {{- '{"name": "' + tool_call['function']['name'] + '", "arguments": ' + tool_call['function']['arguments']|string -}} + {%- if add_tool_id == true -%} + {{- ', "id": "' + tool_call['id'] + '"' -}} + {%- endif -%} + {{- '}' -}} + {%- if not loop.last -%}{{- ', ' -}}{%- endif -%} + {%- endfor -%} + {{- ']' -}} + {%- endif -%} + {{- '\n<|end|>\n' + eos_token -}} + {%- elif message['role'] == 'tool' -%} + {%- if message['content'] is string %} + {%- set tool_message = message['content'] %} + {%- else %} + {%- set tool_message = message['content'][0]['text'] %} + {%- endif -%} + {{- '<|tool_result|>\n' + tool_message|string + '\n<|end|>\n' -}} + {%- endif -%} + {%- if loop.last and add_generation_prompt and message['role'] != 'assistant' -%} + {{- '<|assistant|>\n' -}} + {%- endif -%} +{%- endfor -%} +{# Copyright 2025-present Unsloth. Apache 2.0 License. #} diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 29e31cecd1..a73c4c448b 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1281,6 +1281,7 @@ struct llm_tokenizer_plamo2 : llm_tokenizer { // Build suffix list in lexicographical order of reversed strings std::vector suffixes; + suffixes.reserve(suffix_to_score.size() + 1); for (const auto & pair : suffix_to_score) { suffixes.push_back(pair.first); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index fab8465934..8f96250ddc 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2776,24 +2776,34 @@ struct test_cpy : public test_case { struct test_cont : public test_case { const ggml_type type; const std::array ne; + bool use_view_slice; std::string vars() override { - return VARS_TO_STR2(type, ne); + return VARS_TO_STR3(type, ne, use_view_slice); } test_cont(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 10, 1}) - : type(type), ne(ne) {} + std::array ne = {10, 10, 10, 1}, + bool use_view_slice = false) + : type(type), ne(ne), use_view_slice(use_view_slice) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_param(src); ggml_set_name(src, "src"); - src = ggml_transpose(ctx, src); - ggml_set_name(src, "src_transposed"); - ggml_tensor * out = ggml_cont(ctx, src); + ggml_tensor * dst; + if (use_view_slice) { + dst = ggml_view_4d(ctx, src, src->ne[0], 1, src->ne[2], src->ne[3], + src->nb[1], src->nb[2], src->nb[3], src->nb[0] * (src->ne[1] - 1)); + ggml_set_name(dst, "src_view_slice"); + } else { + dst = ggml_transpose(ctx, src); + ggml_set_name(dst, "src_transposed"); + } + + ggml_tensor * out = ggml_cont(ctx, dst); ggml_set_name(out, "out"); return out; @@ -6945,16 +6955,17 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true)); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0})); - test_cases.emplace_back(new test_cont()); - test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1})); - test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 3 ,5})); - test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 3, 5 ,7})); - test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 1 ,1})); - test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 3 ,5})); - test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 3, 5 ,7})); - test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 1 ,1})); - test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 3 ,5})); - test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7})); + for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) { + for (bool use_view_slice : { true, false }) { + for (std::array ne : std::initializer_list>{ {2, 1, 1, 1}, {2, 1, 3, 5}, + {2, 3, 5, 7}, {1, 4, 4, 1}, {1, 8, 17, 1}, {10, 10, 10, 1} }) { + if (use_view_slice && (type_dst == GGML_TYPE_F16 || type_dst == GGML_TYPE_BF16)) { + continue; // TODO: add after WebGPU is fixed + } + test_cases.emplace_back(new test_cont(type_dst, ne, use_view_slice)); + } + } + } auto add_test_bin_bcast = [&](ggml_type type, std::array ne, std::array nr) { for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) { @@ -7015,6 +7026,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); test_cases.emplace_back(new test_add1()); + test_cases.emplace_back(new test_add1(GGML_TYPE_F32, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_scale()); test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f)); test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test @@ -7354,9 +7366,13 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3})); test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3})); test_cases.emplace_back(new test_floor (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_floor (type, { 1024, 1024, 1, 1 })); test_cases.emplace_back(new test_ceil (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_ceil (type, { 1024, 1024, 1, 1 })); test_cases.emplace_back(new test_round (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_round (type, { 1024, 1024, 1, 1 })); test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_trunc (type, { 1024, 1024, 1, 1 })); } test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5)); @@ -7501,13 +7517,15 @@ static std::vector> make_test_cases_eval() { } for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) { - test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order)); + for (uint32_t i = 4; i <= 1024*1024; i *= 2) { + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i-1, 1, 1, 1})); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {i, 1, 1, 1})); + } test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1023, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1025, 2, 1, 3}, order)); - test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // many backends only handle up to 1024 test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2047, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2048, 2, 1, 3}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2049, 2, 1, 3}, order)); @@ -7556,6 +7574,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1})); test_cases.emplace_back(new test_roll()); test_cases.emplace_back(new test_arange()); + test_cases.emplace_back(new test_arange(GGML_TYPE_F32, 0.0f, 1048576.0f, 1.0f)); test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); @@ -7583,6 +7602,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_fill(0.0f)); test_cases.emplace_back(new test_fill(2.0f, GGML_TYPE_F32, { 303, 207, 11, 3 })); test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 })); + test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F32, { 2048, 512, 2, 2 })); test_cases.emplace_back(new test_solve_tri()); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 11, 11, 1, 1 }, { 5, 11, 1, 1 })); diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 4a8ba849b3..62dd1583fa 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -75,6 +75,8 @@ static common_chat_msg normalize(const common_chat_msg & msg) { } return normalized; } + + template <> bool equals(const common_chat_msg & expected, const common_chat_msg & actual) { return normalize(expected) == normalize(actual); @@ -148,17 +150,29 @@ static std::string renormalize_json(const std::string & json_str) { return json_str; } } -static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) { +static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual, bool ignore_whitespace_differences = false) { assert_equals(expected.role, actual.role); - assert_equals(expected.content, actual.content); + if (ignore_whitespace_differences) { + assert_equals(string_strip(expected.content), string_strip(actual.content)); + } else { + assert_equals(expected.content, actual.content); + } assert_equals(expected.content_parts.size(), actual.content_parts.size()); for (size_t i = 0; i < expected.content_parts.size(); i++) { const auto & expected_part = expected.content_parts[i]; const auto & actual_part = actual.content_parts[i]; assert_equals(expected_part.type, actual_part.type); - assert_equals(expected_part.text, actual_part.text); + if (ignore_whitespace_differences) { + assert_equals(string_strip(expected_part.text), string_strip(actual_part.text)); + } else { + assert_equals(expected_part.text, actual_part.text); + } + } + if (ignore_whitespace_differences) { + assert_equals(string_strip(expected.reasoning_content), string_strip(actual.reasoning_content)); + } else { + assert_equals(expected.reasoning_content, actual.reasoning_content); } - assert_equals(expected.reasoning_content, actual.reasoning_content); assert_equals(expected.tool_calls.size(), actual.tool_calls.size()); for (size_t i = 0; i < expected.tool_calls.size(); i++) { const auto & expected_tool_call = expected.tool_calls[i]; @@ -183,6 +197,24 @@ common_chat_tool special_function_tool { "required": ["arg1"] })", }; +common_chat_tool special_function_tool_with_optional_param { + /* .name = */ "special_function_with_opt", + /* .description = */ "I'm special but have optional stuff", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "arg1": { + "type": "integer", + "description": "The arg." + }, + "arg2": { + "type": "integer", + "description": "The optional arg." + } + }, + "required": ["arg1"] + })", +}; common_chat_tool python_tool { /* .name = */ "python", /* .description = */ "an ipython interpreter", @@ -211,7 +243,7 @@ common_chat_tool code_interpreter_tool { "required": ["code"] })", }; -std::vector tools { special_function_tool, python_tool }; +std::vector tools { special_function_tool, special_function_tool_with_optional_param, python_tool }; std::vector llama_3_1_tools { special_function_tool, code_interpreter_tool }; struct delta_data { @@ -219,6 +251,17 @@ struct delta_data { common_chat_params params; }; +static common_chat_msg simple_assist_msg(const std::string & content, const std::string & reasoning_content = "", const std::string & tool_name = "", const std::string & arguments = "", const std::string & id = "") { + common_chat_msg msg; + msg.role = "assistant"; + msg.content = content; + msg.reasoning_content = reasoning_content; + if (!tool_name.empty()) { + msg.tool_calls.push_back({ tool_name, arguments, id }); + } + return msg; +} + static delta_data init_delta(const struct common_chat_templates * tmpls, const std::vector & end_tokens, const common_chat_msg & user_message, const common_chat_msg & delta_message, @@ -280,7 +323,9 @@ static void test_templates(const struct common_chat_templates * tmpls, const std const std::string & expected_delta = "", bool expect_grammar_triggered = true, bool test_grammar_if_triggered = true, - common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE) { + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE, + bool ignore_whitespace_differences = false + ) { common_chat_msg user_message; user_message.role = "user"; user_message.content = "Hello, world!"; @@ -288,7 +333,11 @@ static void test_templates(const struct common_chat_templates * tmpls, const std for (const auto & tool_choice : std::vector {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) { auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice); if (!expected_delta.empty()) { - assert_equals(expected_delta, data.delta); + if (ignore_whitespace_differences) { + assert_equals(string_strip(expected_delta), string_strip(data.delta)); + } else { + assert_equals(expected_delta, data.delta); + } } if (expect_grammar_triggered) { @@ -296,7 +345,7 @@ static void test_templates(const struct common_chat_templates * tmpls, const std syntax.format = data.params.format; syntax.reasoning_format = reasoning_format; const auto msg = common_chat_parse(data.delta, /* is_partial= */ false, syntax); - assert_msg_equals(test_message, msg); + assert_msg_equals(test_message, msg, ignore_whitespace_differences); } if (!test_message.tool_calls.empty()) { @@ -373,6 +422,44 @@ static void test_templates(const struct common_chat_templates * tmpls, const std } } +/** + * Test if streaming=true is consistant with streaming=false for given partial parser + * Also test if there is any problem with partial message + */ +template +static void test_parser_with_streaming(const common_chat_msg & expected, const std::string & raw_message, T parse_msg) { + auto merged = simple_assist_msg(""); + auto last_msg = parse_msg(""); + for (size_t i = 1; i <= raw_message.size(); ++i) { + auto curr_msg = parse_msg(raw_message.substr(0, i)); + if (curr_msg == simple_assist_msg("")) continue; + LOG_INF("Streaming msg: %s\n", common_chat_msgs_to_json_oaicompat({curr_msg}).dump().c_str()); + for (auto diff: common_chat_msg_diff::compute_diffs(last_msg, curr_msg)) { + LOG_INF("Streaming diff: %s\n", common_chat_msg_diff_to_json_oaicompat(diff).dump().c_str()); + if (!diff.reasoning_content_delta.empty()) { + merged.reasoning_content += diff.reasoning_content_delta; + } + if (!diff.content_delta.empty()) { + merged.content += diff.content_delta; + } + if (diff.tool_call_index != std::string::npos) { + if (!diff.tool_call_delta.name.empty()) { + merged.tool_calls.push_back({diff.tool_call_delta.name, "", ""}); + } + if (!diff.tool_call_delta.arguments.empty()) { + GGML_ASSERT(!merged.tool_calls.empty()); + merged.tool_calls.back().arguments += diff.tool_call_delta.arguments; + } + } + LOG_INF("Streaming merged: %s\n", common_chat_msgs_to_json_oaicompat({merged}).dump().c_str()); + } + assert_msg_equals(curr_msg, merged, true); + last_msg = curr_msg; + } + assert_msg_equals(expected, parse_msg(raw_message), true); + assert_msg_equals(expected, merged, true); +} + const common_chat_msg message_user { "user", "Hey there!", @@ -395,16 +482,7 @@ const common_chat_msg message_user_parts { /* .tool_name = */ "", /* .tool_call_id = */ "", }; -static common_chat_msg simple_assist_msg(const std::string & content, const std::string & reasoning_content = "", const std::string & tool_name = "", const std::string & arguments = "", const std::string & id = "") { - common_chat_msg msg; - msg.role = "assistant"; - msg.content = content; - msg.reasoning_content = reasoning_content; - if (!tool_name.empty()) { - msg.tool_calls.push_back({ tool_name, arguments, id }); - } - return msg; -} + const common_chat_msg message_assist = simple_assist_msg("Hello, world!\nWhat's up?"); const common_chat_msg message_assist_empty = simple_assist_msg(""); const common_chat_msg message_assist_thoughts_unparsed_deepseek = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?"); @@ -417,6 +495,8 @@ const common_chat_msg message_assist_thoughts = simple_assist const common_chat_msg message_assist_thoughts_unopened_unparsed = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?"); const common_chat_msg message_assist_thoughts_no_content = simple_assist_msg("", "I'm\nthinking"); const common_chat_msg message_assist_call = simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"); +const common_chat_msg message_assist_call_noopt = simple_assist_msg("", "", "special_function_with_opt", "{\"arg1\": 1}"); +const common_chat_msg message_assist_call_withopt = simple_assist_msg("", "", "special_function_with_opt", "{\"arg1\": 1, \"arg2\": 2}"); const common_chat_msg message_assist_call_content = simple_assist_msg("Hello, world!\nWhat's up?", "", "special_function", "{\"arg1\":1}"); const common_chat_msg message_assist_call_empty_args = simple_assist_msg("", "", "special_function"); const common_chat_msg message_assist_call_cutoff_args = simple_assist_msg("", "", "special_function", "{\"arg"); @@ -1834,7 +1914,7 @@ static void test_template_output_parsers() { // Test partial parsing for incomplete tool call - don't actually add the call until parsing parameters is done assert_msg_equals( - simple_assist_msg("", ""), + simple_assist_msg("", "", "calculate_sum", "{\"numbers\":"), common_chat_parse( "\n" "\n" @@ -2288,6 +2368,940 @@ Hey there!<|im_end|> // above verify edge cases and format variations for the tool call output format. } + { + auto tmpls = read_templates("models/templates/MiniMax-M2.jinja"); + std::vector end_tokens{ "[e~[" }; + + assert_equals(COMMON_CHAT_FORMAT_MINIMAX_M2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_MINIMAX_M2, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + + // Test parsing regular content + assert_msg_equals(message_assist, + common_chat_parse( + "Hello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_MINIMAX_M2})); + + // Test parsing content with thinking + assert_msg_equals(message_assist_thoughts, + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + + // Test parsing tool calls + assert_msg_equals(message_assist_call, + common_chat_parse( + "1", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_MINIMAX_M2})); + + // Test parsing tool calls with thinking + assert_msg_equals(message_assist_call_thoughts, + common_chat_parse( + "I'm\nthinking1", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + })); + + // Test tool calls with extra content + assert_msg_equals(message_assist_call_content, + common_chat_parse( + "1Hello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_MINIMAX_M2} + )); + + // Test tool calls with extra content AND thinking + assert_msg_equals(message_assist_call_thoughts_content, + common_chat_parse( + "I'm\nthinking1Hello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + })); + + // Test streaming + test_parser_with_streaming(message_assist_call_thoughts_content, + "I'm\nthinking\nHello, world!\nWhat's up?\n1", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + }); }); + test_parser_with_streaming(message_assist_call_thoughts_unparsed, + "I'm\nthinking\n\n1", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE + }); }); + test_parser_with_streaming(message_assist_call_thoughts_content, + "I'm\nthinking\n\n\nHello, world!\nWhat's up?\n\n\n\n1\n\n\n", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + }); }); + test_parser_with_streaming(message_assist_call_withopt, + "\n\n1\n2\n\n", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE + }); }); + + // Test template generation for regular content + test_templates(tmpls.get(), end_tokens, message_assist, tools, + "Hello, world!\nWhat's up?", + /* expect_grammar_triggered= */ false); + + // Test template generation for tool calls + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, + "\n\n1\n\n", + /* expect_grammar_triggered= */ true, + /* test_grammar_if_triggered= */ true, + /* common_reasoning_format= */ COMMON_REASONING_FORMAT_NONE, + /* ignore_whitespace_differences= */ true + ); + + // Test template generation for tools with optional parameters + test_templates(tmpls.get(), end_tokens, message_assist_call_noopt, tools, + "\n\n1\n\n", + /* expect_grammar_triggered= */ true, + /* test_grammar_if_triggered= */ true, + /* common_reasoning_format= */ COMMON_REASONING_FORMAT_NONE, + /* ignore_whitespace_differences= */ true + ); + test_templates(tmpls.get(), end_tokens, message_assist_call_withopt, tools, + "\n\n1\n2\n\n", + /* expect_grammar_triggered= */ true, + /* test_grammar_if_triggered= */ true, + /* common_reasoning_format= */ COMMON_REASONING_FORMAT_NONE, + /* ignore_whitespace_differences= */ true + ); + } + + { + auto tmpls = read_templates("models/templates/GLM-4.6.jinja"); + std::vector end_tokens{ "<|assistant|>", "<|observation|>" }; + + assert_equals(COMMON_CHAT_FORMAT_GLM_4_5, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_GLM_4_5, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + + // Test parsing regular content + assert_msg_equals(message_assist, + common_chat_parse( + "Hello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_GLM_4_5})); + + // Test parsing content with thinking + assert_msg_equals(message_assist_thoughts, + common_chat_parse( + "\nI'm\nthinking\nHello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + }), true); + + // Test parsing tool calls + assert_msg_equals(message_assist_call, + common_chat_parse( + "\nspecial_function\narg1\n1\n", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_GLM_4_5}), true); + + // Test parsing tool calls with thinking + assert_msg_equals(message_assist_call_thoughts, + common_chat_parse( + "\nI'm\nthinking\nspecial_function\narg1\n1\n", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + }), true); + + // Test tool calls with extra content + assert_msg_equals(message_assist_call_content, + common_chat_parse( + "\nspecial_function\narg1\n1\nHello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_GLM_4_5} + ), true); + + // Test tool calls with extra content AND thinking + assert_msg_equals(message_assist_call_thoughts_content, + common_chat_parse( + "\nI'm\nthinkingHello, world!\nWhat's up?\nspecial_function\narg1\n1\n", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + }), true); + + // Test streaming + test_parser_with_streaming(message_assist_call_thoughts_content, + "\nI'm\nthinkingHello, world!\nWhat's up?\nspecial_function\narg1\n1\n", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + }); }); + test_parser_with_streaming(message_assist_call_thoughts_unparsed, + "\nI'm\nthinking\n\nspecial_function\narg1\n1\n", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE + }); }); + test_parser_with_streaming(message_assist_call_withopt, + "\n\nspecial_function_with_opt\narg1\n1\narg2\n2\n\n", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + }); }); + test_parser_with_streaming( + simple_assist_msg("", "", "complex_function", "{\"name\":\"John Doe\",\"age\":30,\"active\":true,\"score\":95.5}"), + "complex_function\n" + "name\n" + "John Doe\n" + "age\n" + "30\n" + "active\n" + "true\n" + "score\n" + "95.5\n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_GLM_4_5}); }); + test_parser_with_streaming( + simple_assist_msg("", "", "web_search", "{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}"), + "web_search\n" + "query\n" + "\"From Zero\" Linkin Park album tracklist complete songs\n" + "limit\n" + "3\n" + "type\n" + "text\n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_GLM_4_5}); }); + + // Test interleaved thinking + test_parser_with_streaming(simple_assist_msg("Hello, world!\n\nWhat's up?", "I'm\nthinkingThinking2", "special_function", "{\"arg1\": 1}"), + "\nI'm\nthinkingHello, world!\nThinking2What's up?\nspecial_function\narg1\n1\n", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + }); }); + test_parser_with_streaming(simple_assist_msg("\nI'm\nthinkingHello, world!\nThinking2What's up?", "", "special_function", "{\"arg1\": 1}"), + "\nI'm\nthinkingHello, world!\nThinking2What's up?\nspecial_function\narg1\n1\n", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE + }); }); + + // Test template generation for regular content + test_templates(tmpls.get(), end_tokens, message_assist, tools, + "\n\nHello, world!\nWhat's up?", + /* expect_grammar_triggered= */ false); + + // Test template generation for tool calls + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, + "\n\nspecial_function\narg1\n1\n\n", + /* expect_grammar_triggered= */ true, + /* test_grammar_if_triggered= */ false, + /* common_reasoning_format= */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* ignore_whitespace_differences= */ true + ); + + // Test template generation for tools with optional parameters + test_templates(tmpls.get(), end_tokens, message_assist_call_noopt, tools, + "\n\nspecial_function_with_opt\narg1\n1\n\n", + /* expect_grammar_triggered= */ true, + /* test_grammar_if_triggered= */ false, + /* common_reasoning_format= */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* ignore_whitespace_differences= */ true + ); + test_templates(tmpls.get(), end_tokens, message_assist_call_withopt, tools, + "\n\nspecial_function_with_opt\narg1\n1\narg2\n2\n\n", + /* expect_grammar_triggered= */ true, + /* test_grammar_if_triggered= */ false, + /* common_reasoning_format= */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* ignore_whitespace_differences= */ true + ); + } + + { + auto tmpls = read_templates("models/templates/Kimi-K2-Thinking.jinja"); + std::vector end_tokens{ "<|im_end|>" }; + + assert_equals(COMMON_CHAT_FORMAT_KIMI_K2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_KIMI_K2, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + + // Test parsing regular content + assert_msg_equals(message_assist, + common_chat_parse( + "Hello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_KIMI_K2})); + + // Test parsing content with thinking + assert_msg_equals(message_assist_thoughts, + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + + // Test parsing tool calls + assert_msg_equals(message_assist_call, + common_chat_parse( + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_KIMI_K2})); + + // Test parsing tool calls with thinking + assert_msg_equals(message_assist_call_thoughts, + common_chat_parse( + "I'm\nthinking<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + })); + + // Test tool calls with extra content + assert_msg_equals(message_assist_call_content, + common_chat_parse( + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>Hello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_KIMI_K2} + )); + + // Test tool calls with extra content AND thinking + assert_msg_equals(message_assist_call_thoughts_content, + common_chat_parse( + "I'm\nthinking<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>Hello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + })); + + // Test streaming + test_parser_with_streaming(message_assist_call_thoughts_content, + "I'm\nthinking\nHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + }); }); + test_parser_with_streaming(message_assist_call_thoughts_unparsed, + "I'm\nthinking\n\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE + }); }); + test_parser_with_streaming(message_assist_call_thoughts_content, + "I'm\nthinking\n\n\nHello, world!\nWhat's up?\n\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>\n", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + }); }); + test_parser_with_streaming(message_assist_call_withopt, + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function_with_opt:1<|tool_call_argument_begin|>{\"arg1\": 1, \"arg2\": 2}<|tool_call_end|><|tool_calls_section_end|>", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE + }); }); + test_parser_with_streaming(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": \"123456\"}"), + "I'm\nthinkingHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": \"123456\"}<|tool_call_end|><|tool_calls_section_end|>", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + }); }); + test_parser_with_streaming(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": [1, 2, \"345\", 6]}"), + "I'm\nthinkingHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": [1, 2, \"345\", 6]}<|tool_call_end|><|tool_calls_section_end|>", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + }); }); + test_parser_with_streaming(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": {\"12\": 34, \"5\": [67, 8], \"9\": \"10\"}}"), + "I'm\nthinkingHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": {\"12\": 34, \"5\": [67, 8], \"9\": \"10\"}}<|tool_call_end|><|tool_calls_section_end|>", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + }); }); + + // Test template generation for regular content + test_templates(tmpls.get(), end_tokens, message_assist, tools, + "Hello, world!\nWhat's up?", + /* expect_grammar_triggered= */ false); + + // Test template generation for tool calls + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", + /* expect_grammar_triggered= */ true, + /* test_grammar_if_triggered= */ true, + /* common_reasoning_format= */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* ignore_whitespace_differences= */ true + ); + + // Test template generation for tools with optional parameters + test_templates(tmpls.get(), end_tokens, message_assist_call_noopt, tools, + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function_with_opt:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", + /* expect_grammar_triggered= */ true, + /* test_grammar_if_triggered= */ true, + /* common_reasoning_format= */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* ignore_whitespace_differences= */ true + ); + test_templates(tmpls.get(), end_tokens, message_assist_call_withopt, tools, + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function_with_opt:1<|tool_call_argument_begin|>{\"arg1\": 1, \"arg2\": 2}<|tool_call_end|><|tool_calls_section_end|>", + /* expect_grammar_triggered= */ true, + /* test_grammar_if_triggered= */ true, + /* common_reasoning_format= */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* ignore_whitespace_differences= */ true + ); + } + + // Test Qwen3-Coder XML format + { + // Basic XML tool call parsing + assert_msg_equals( + message_assist_call, + common_chat_parse( + "\n" + " \n" + " \n" + " 1\n" + " \n" + " \n" + "", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_QWEN3_CODER_XML})); + + // Multiple parameters with different types + common_chat_msg expected_multi_param; + expected_multi_param.role = "assistant"; + expected_multi_param.tool_calls = { + { "complex_function", "{\"name\":\"John Doe\",\"age\":30,\"active\":true,\"score\":95.5}", "" } + }; + + test_parser_with_streaming(expected_multi_param, + "\n" + " \n" + " \n" + " John Doe\n" + " \n" + " \n" + " 30\n" + " \n" + " \n" + " true\n" + " \n" + " \n" + " 95.5\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Special characters and Unicode + common_chat_msg expected_special_chars; + expected_special_chars.role = "assistant"; + expected_special_chars.tool_calls = { + { "unicode_function", "{\"message\":\"Hello 世界! 🌍 Special chars: @#$%^&*()\"}", "" } + }; + + test_parser_with_streaming(expected_special_chars, + "\n" + " \n" + " \n" + " Hello 世界! 🌍 Special chars: @#$%^&*()\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Multiline content with newlines and indentation + common_chat_msg expected_multiline; + expected_multiline.role = "assistant"; + expected_multiline.tool_calls = { + { "code_function", "{\"code\":\"def hello():\\n print(\\\"Hello, World!\\\")\\n return True\"}", "" } + }; + + test_parser_with_streaming(expected_multiline, + "\n" + " \n" + " \n" + "def hello():\n" + " print(\"Hello, World!\")\n" + " return True\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // JSON object as parameter value + common_chat_msg expected_json_param; + expected_json_param.role = "assistant"; + expected_json_param.tool_calls = { + { "json_function", "{\"config\":{\"host\":\"localhost\",\"port\":8080,\"ssl\":false}}", "" } + }; + + test_parser_with_streaming( + expected_json_param, + "\n" + " \n" + " \n" + " {\"host\": \"localhost\", \"port\": 8080, \"ssl\": false}\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Array as parameter value + common_chat_msg expected_array_param; + expected_array_param.role = "assistant"; + expected_array_param.tool_calls = { + { "array_function", "{\"items\":[\"apple\",\"banana\",\"cherry\"]}", "" } + }; + + test_parser_with_streaming( + expected_array_param, + "\n" + " \n" + " \n" + " [\"apple\", \"banana\", \"cherry\"]\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Empty parameter + common_chat_msg expected_empty_param; + expected_empty_param.role = "assistant"; + expected_empty_param.tool_calls = { + { "empty_function", "{\"empty_param\":\"\"}", "" } + }; + + test_parser_with_streaming( + expected_empty_param, + "\n" + " \n" + " \n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Boolean values (true/false) + common_chat_msg expected_boolean; + expected_boolean.role = "assistant"; + expected_boolean.tool_calls = { + { "boolean_function", "{\"enabled\":true,\"debug\":false}", "" } + }; + + test_parser_with_streaming( + expected_boolean, + "\n" + " \n" + " \n" + " true\n" + " \n" + " \n" + " false\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Null value + common_chat_msg expected_null; + expected_null.role = "assistant"; + expected_null.tool_calls = { + { "null_function", "{\"optional_param\":null}", "" } + }; + + test_parser_with_streaming( + expected_null, + "\n" + " \n" + " \n" + " null\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Negative numbers and scientific notation + common_chat_msg expected_numbers; + expected_numbers.role = "assistant"; + expected_numbers.tool_calls = { + { "math_function", "{\"negative\":-42,\"decimal\":-3.14,\"scientific\":1.23e-4}", "" } + }; + + test_parser_with_streaming( + expected_numbers, + "\n" + " \n" + " \n" + " -42\n" + " \n" + " \n" + " -3.14\n" + " \n" + " \n" + " 1.23e-4\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // XML-like content in parameters (should be escaped) + common_chat_msg expected_xml_content; + expected_xml_content.role = "assistant"; + expected_xml_content.tool_calls = { + { "xml_function", "{\"xml_content\":\"value\"}", "" } + }; + + test_parser_with_streaming( + expected_xml_content, + "\n" + " \n" + " \n" + " value\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Quotes and escape characters + common_chat_msg expected_quotes; + expected_quotes.role = "assistant"; + expected_quotes.tool_calls = { + { "quote_function", "{\"message\":\"She said \\\"Hello!\\\" and left.\"}", "" } + }; + + test_parser_with_streaming( + expected_quotes, + "\n" + " \n" + " \n" + " She said \"Hello!\" and left.\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Long parameter value (simplified) + std::string long_text = "This is a long text parameter that should test the parser's ability to handle larger amounts of text data."; + + common_chat_msg expected_long_text; + expected_long_text.role = "assistant"; + expected_long_text.tool_calls = { + { "long_function", "{\"long_text\":\"" + long_text + "\"}", "" } + }; + + test_parser_with_streaming( + expected_long_text, + "\n" + " \n" + " \n" + " " + long_text + "\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Mixed content with text before and after tool call + common_chat_msg expected_mixed_content; + expected_mixed_content.role = "assistant"; + expected_mixed_content.content = "I'll help you search for products. "; + expected_mixed_content.tool_calls = { + { "search_function", "{\"query\":\"laptops\"}", "" } + }; + + test_parser_with_streaming( + expected_mixed_content, + "I'll help you search for products. \n" + " \n" + " \n" + " laptops\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Compact format (no extra whitespace) + common_chat_msg expected_compact; + expected_compact.role = "assistant"; + expected_compact.tool_calls = { + { "compact_function", "{\"param\":\"value\"}", "" } + }; + + test_parser_with_streaming( + expected_compact, + "value", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Function name with underscores and numbers + common_chat_msg expected_complex_name; + expected_complex_name.role = "assistant"; + expected_complex_name.tool_calls = { + { "get_user_data_v2", "{\"user_id\":12345}", "" } + }; + + test_parser_with_streaming( + expected_complex_name, + "\n" + " \n" + " \n" + " 12345\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Parameter names with underscores and numbers + common_chat_msg expected_complex_params; + expected_complex_params.role = "assistant"; + expected_complex_params.tool_calls = { + { "test_function", "{\"param_1\":\"value1\",\"param_2_name\":\"value2\",\"param3\":123}", "" } + }; + + test_parser_with_streaming( + expected_complex_params, + "\n" + " \n" + " \n" + " value1\n" + " \n" + " \n" + " value2\n" + " \n" + " \n" + " 123\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Very deeply nested XML content in parameter + common_chat_msg expected_deep_xml; + expected_deep_xml.role = "assistant"; + expected_deep_xml.tool_calls = { + { "xml_parser", "{\"xml\":\"deep content\"}", "" } + }; + + test_parser_with_streaming( + expected_deep_xml, + "\n" + " \n" + " \n" + " deep content\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Parameter with only whitespace + common_chat_msg expected_whitespace_param; + expected_whitespace_param.role = "assistant"; + expected_whitespace_param.tool_calls = { + { "whitespace_function", "{\"spaces\":\"\"}", "" } + }; + + test_parser_with_streaming( + expected_whitespace_param, + "\n" + " \n" + " \n" + " \n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Parameter with tabs and mixed whitespace + common_chat_msg expected_mixed_whitespace; + expected_mixed_whitespace.role = "assistant"; + expected_mixed_whitespace.tool_calls = { + { "tab_function", "{\"content\":\"line1\\n\\tindented line\\n spaces\"}", "" } + }; + + test_parser_with_streaming( + expected_mixed_whitespace, + "\n" + " \n" + " \n" + "line1\n" + "\tindented line\n" + " spaces\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Control characters and special Unicode + common_chat_msg expected_control_chars; + expected_control_chars.role = "assistant"; + expected_control_chars.tool_calls = { + { "control_function", "{\"text\":\"Line1\\nLine2\\tTabbed\\rCarriage return\"}", "" } + }; + + test_parser_with_streaming( + expected_control_chars, + "\n" + " \n" + " \n" + "Line1\nLine2\tTabbed\rCarriage return\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Emoji and extended Unicode characters + common_chat_msg expected_emoji; + expected_emoji.role = "assistant"; + expected_emoji.tool_calls = { + { "emoji_function", "{\"message\":\"Hello! 👋 🌟 🚀 Testing emojis: 😀😃😄😁 and symbols: ∑∏∆∇\"}", "" } + }; + + test_parser_with_streaming( + expected_emoji, + "\n" + " \n" + " \n" + " Hello! 👋 🌟 🚀 Testing emojis: 😀😃😄😁 and symbols: ∑∏∆∇\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Mathematical expressions and formulas + common_chat_msg expected_math; + expected_math.role = "assistant"; + expected_math.tool_calls = { + { "math_function", "{\"formula\":\"E = mc² and ∫f(x)dx = F(x) + C\"}", "" } + }; + + test_parser_with_streaming( + expected_math, + "\n" + " \n" + " \n" + " E = mc² and ∫f(x)dx = F(x) + C\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // SQL injection-like content (should be safely escaped) + common_chat_msg expected_sql; + expected_sql.role = "assistant"; + expected_sql.tool_calls = { + { "sql_function", "{\"query\":\"SELECT * FROM users WHERE id = 1; DROP TABLE users; --\"}", "" } + }; + + test_parser_with_streaming( + expected_sql, + "\n" + " \n" + " \n" + " SELECT * FROM users WHERE id = 1; DROP TABLE users; --\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // HTML/XML injection content + common_chat_msg expected_html; + expected_html.role = "assistant"; + expected_html.tool_calls = { + { "html_function", "{\"content\":\"\"}", "" } + }; + + test_parser_with_streaming( + expected_html, + "\n" + " \n" + " \n" + " \n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Binary-like content (base64) + common_chat_msg expected_binary; + expected_binary.role = "assistant"; + expected_binary.tool_calls = { + { "binary_function", "{\"data\":\"SGVsbG8gV29ybGQhIFRoaXMgaXMgYmFzZTY0IGVuY29kZWQgdGV4dC4=\"}", "" } + }; + + test_parser_with_streaming( + expected_binary, + "\n" + " \n" + " \n" + " SGVsbG8gV29ybGQhIFRoaXMgaXMgYmFzZTY0IGVuY29kZWQgdGV4dC4=\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + + // Very large numbers (should be parsed as scientific notation) + common_chat_msg expected_large_numbers; + expected_large_numbers.role = "assistant"; + expected_large_numbers.tool_calls = { + { "number_function", "{\"big_int\":1e+60}", "" } // Large number becomes scientific notation + }; + + test_parser_with_streaming( + expected_large_numbers, + "\n" + " \n" + " \n" + " 999999999999999999999999999999999999999999999999999999999999\n" + " \n" + " \n" + "", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); + } + + { + // Qwen3-Coder template + auto tmpls = read_templates("models/templates/Qwen3-Coder.jinja"); + common_chat_templates_inputs inputs; + inputs.messages = { message_user }; + + common_chat_tool qwen_union_tool { + /* .name = */ "qwen_union", + /* .description = */ "Test tool for union/anyOf handling", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "priority": { "type": ["number", "null"] }, + "maybe_text": { "anyOf": [ { "type": "string" } ] }, + "config": { "anyOf": [ { "type": "object" }, { "type": "null" } ] } + }, + "required": [] + })", + }; + inputs.tools = { qwen_union_tool }; + + auto params = common_chat_templates_apply(tmpls.get(), inputs); + assert_equals(COMMON_CHAT_FORMAT_QWEN3_CODER_XML, params.format); + assert_equals(false, params.grammar.empty()); + + // Grammar should compile successfully + auto grammar = build_grammar(params.grammar); + GGML_ASSERT(grammar && "Failed to build Qwen3-Coder grammar with union types"); + } + } static void test_msg_diffs_compute() { diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index c801e84c3d..1fccfdd17f 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -14,6 +14,8 @@ endif() set(TARGET_SRCS server.cpp utils.hpp + server-http.cpp + server-http.h ) set(PUBLIC_ASSETS index.html.gz diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index f3f73e9069..8842620a7a 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp new file mode 100644 index 0000000000..bebe0b49c4 --- /dev/null +++ b/tools/server/server-http.cpp @@ -0,0 +1,387 @@ +#include "utils.hpp" +#include "common.h" +#include "server-http.h" + +#include + +#include +#include +#include + +// auto generated files (see README.md for details) +#include "index.html.gz.hpp" +#include "loading.html.hpp" + +// +// HTTP implementation using cpp-httplib +// + +class server_http_context::Impl { +public: + std::unique_ptr srv; +}; + +server_http_context::server_http_context() + : pimpl(std::make_unique()) +{} + +server_http_context::~server_http_context() = default; + +static void log_server_request(const httplib::Request & req, const httplib::Response & res) { + // skip GH copilot requests when using default port + if (req.path == "/v1/health") { + return; + } + + // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch + + SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); + + SRV_DBG("request: %s\n", req.body.c_str()); + SRV_DBG("response: %s\n", res.body.c_str()); +} + +bool server_http_context::init(const common_params & params) { + path_prefix = params.api_prefix; + port = params.port; + hostname = params.hostname; + + auto & srv = pimpl->srv; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (params.ssl_file_key != "" && params.ssl_file_cert != "") { + LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str()); + srv.reset( + new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) + ); + } else { + LOG_INF("Running without SSL\n"); + srv.reset(new httplib::Server()); + } +#else + if (params.ssl_file_key != "" && params.ssl_file_cert != "") { + LOG_ERR("Server is built without SSL support\n"); + return false; + } + srv.reset(new httplib::Server()); +#endif + + srv->set_default_headers({{"Server", "llama.cpp"}}); + srv->set_logger(log_server_request); + srv->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { + // this is fail-safe; exceptions should already handled by `ex_wrapper` + + std::string message; + try { + std::rethrow_exception(ep); + } catch (const std::exception & e) { + message = e.what(); + } catch (...) { + message = "Unknown Exception"; + } + + res.status = 500; + res.set_content(message, "text/plain"); + LOG_ERR("got exception: %s\n", message.c_str()); + }); + + srv->set_error_handler([](const httplib::Request &, httplib::Response & res) { + if (res.status == 404) { + res.set_content( + safe_json_to_str(json { + {"error", { + {"message", "File Not Found"}, + {"type", "not_found_error"}, + {"code", 404} + }} + }), + "application/json; charset=utf-8" + ); + } + // for other error codes, we skip processing here because it's already done by res->error() + }); + + // set timeouts and change hostname and port + srv->set_read_timeout (params.timeout_read); + srv->set_write_timeout(params.timeout_write); + + if (params.api_keys.size() == 1) { + auto key = params.api_keys[0]; + std::string substr = key.substr(std::max((int)(key.length() - 4), 0)); + LOG_INF("%s: api_keys: ****%s\n", __func__, substr.c_str()); + } else if (params.api_keys.size() > 1) { + LOG_INF("%s: api_keys: %zu keys loaded\n", __func__, params.api_keys.size()); + } + + // + // Middlewares + // + + auto middleware_validate_api_key = [api_keys = params.api_keys](const httplib::Request & req, httplib::Response & res) { + static const std::unordered_set public_endpoints = { + "/health", + "/v1/health", + "/models", + "/v1/models", + "/api/tags" + }; + + // If API key is not set, skip validation + if (api_keys.empty()) { + return true; + } + + // If path is public or is static file, skip validation + if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") { + return true; + } + + // Check for API key in the header + auto auth_header = req.get_header_value("Authorization"); + + std::string prefix = "Bearer "; + if (auth_header.substr(0, prefix.size()) == prefix) { + std::string received_api_key = auth_header.substr(prefix.size()); + if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) { + return true; // API key is valid + } + } + + // API key is invalid or not provided + res.status = 401; + res.set_content( + safe_json_to_str(json { + {"error", { + {"message", "Invalid API Key"}, + {"type", "authentication_error"}, + {"code", 401} + }} + }), + "application/json; charset=utf-8" + ); + + LOG_WRN("Unauthorized: Invalid API Key\n"); + + return false; + }; + + auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) { + bool ready = is_ready.load(); + if (!ready) { + auto tmp = string_split(req.path, '.'); + if (req.path == "/" || tmp.back() == "html") { + res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); + res.status = 503; + } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { + // allow the models endpoint to be accessed during loading + return true; + } else { + res.status = 503; + res.set_content( + safe_json_to_str(json { + {"error", { + {"message", "Loading model"}, + {"type", "unavailable_error"}, + {"code", 503} + }} + }), + "application/json; charset=utf-8" + ); + } + return false; + } + return true; + }; + + // register server middlewares + srv->set_pre_routing_handler([middleware_validate_api_key, middleware_server_state](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + // If this is OPTIONS request, skip validation because browsers don't include Authorization header + if (req.method == "OPTIONS") { + res.set_header("Access-Control-Allow-Credentials", "true"); + res.set_header("Access-Control-Allow-Methods", "GET, POST"); + res.set_header("Access-Control-Allow-Headers", "*"); + res.set_content("", "text/html"); // blank response, no data + return httplib::Server::HandlerResponse::Handled; // skip further processing + } + if (!middleware_server_state(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } + if (!middleware_validate_api_key(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } + return httplib::Server::HandlerResponse::Unhandled; + }); + + int n_threads_http = params.n_threads_http; + if (n_threads_http < 1) { + // +2 threads for monitoring endpoints + n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); + } + LOG_INF("%s: using %d threads for HTTP server\n", __func__, n_threads_http); + srv->new_task_queue = [n_threads_http] { return new httplib::ThreadPool(n_threads_http); }; + + // + // Web UI setup + // + + if (!params.webui) { + LOG_INF("Web UI is disabled\n"); + } else { + // register static assets routes + if (!params.public_path.empty()) { + // Set the base directory for serving static files + bool is_found = srv->set_mount_point(params.api_prefix + "/", params.public_path); + if (!is_found) { + LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); + return 1; + } + } else { + // using embedded static index.html + srv->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) { + if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { + res.set_content("Error: gzip is not supported by this browser", "text/plain"); + } else { + res.set_header("Content-Encoding", "gzip"); + // COEP and COOP headers, required by pyodide (python interpreter) + res.set_header("Cross-Origin-Embedder-Policy", "require-corp"); + res.set_header("Cross-Origin-Opener-Policy", "same-origin"); + res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); + } + return false; + }); + } + } + return true; +} + +bool server_http_context::start() { + // Bind and listen + + auto & srv = pimpl->srv; + bool was_bound = false; + bool is_sock = false; + if (string_ends_with(std::string(hostname), ".sock")) { + is_sock = true; + LOG_INF("%s: setting address family to AF_UNIX\n", __func__); + srv->set_address_family(AF_UNIX); + // bind_to_port requires a second arg, any value other than 0 should + // simply get ignored + was_bound = srv->bind_to_port(hostname, 8080); + } else { + LOG_INF("%s: binding port with default address family\n", __func__); + // bind HTTP listen port + if (port == 0) { + int bound_port = srv->bind_to_any_port(hostname); + was_bound = (bound_port >= 0); + if (was_bound) { + port = bound_port; + } + } else { + was_bound = srv->bind_to_port(hostname, port); + } + } + + if (!was_bound) { + LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, hostname.c_str(), port); + return false; + } + + // run the HTTP server in a thread + thread = std::thread([this]() { pimpl->srv->listen_after_bind(); }); + srv->wait_until_ready(); + + listening_address = is_sock ? string_format("unix://%s", hostname.c_str()) + : string_format("http://%s:%d", hostname.c_str(), port); + return true; +} + +void server_http_context::stop() const { + if (pimpl->srv) { + pimpl->srv->stop(); + } +} + +static void set_headers(httplib::Response & res, const std::map & headers) { + for (const auto & [key, value] : headers) { + res.set_header(key, value); + } +} + +static std::map get_params(const httplib::Request & req) { + std::map params; + for (const auto & [key, value] : req.params) { + params[key] = value; + } + for (const auto & [key, value] : req.path_params) { + params[key] = value; + } + return params; +} + +static std::map get_headers(const httplib::Request & req) { + std::map headers; + for (const auto & [key, value] : req.headers) { + headers[key] = value; + } + return headers; +} + +static void process_handler_response(server_http_res_ptr & response, httplib::Response & res) { + if (response->is_stream()) { + res.status = response->status; + set_headers(res, response->headers); + std::string content_type = response->content_type; + // convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it + std::shared_ptr r_ptr = std::move(response); + const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool { + std::string chunk; + bool has_next = response->next(chunk); + if (!chunk.empty()) { + // TODO: maybe handle sink.write unsuccessful? for now, we rely on is_connection_closed() + sink.write(chunk.data(), chunk.size()); + SRV_DBG("http: streamed chunk: %s\n", chunk.c_str()); + } + if (!has_next) { + sink.done(); + SRV_DBG("%s", "http: stream ended\n"); + } + return has_next; + }; + const auto on_complete = [response = r_ptr](bool) mutable { + response.reset(); // trigger the destruction of the response object + }; + res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete); + } else { + res.status = response->status; + set_headers(res, response->headers); + res.set_content(response->data, response->content_type); + } +} + +void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const { + pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { + server_http_res_ptr response = handler(server_http_req{ + get_params(req), + get_headers(req), + req.path, + req.body, + req.is_connection_closed + }); + process_handler_response(response, res); + }); +} + +void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const { + pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { + server_http_res_ptr response = handler(server_http_req{ + get_params(req), + get_headers(req), + req.path, + req.body, + req.is_connection_closed + }); + process_handler_response(response, res); + }); +} + diff --git a/tools/server/server-http.h b/tools/server/server-http.h new file mode 100644 index 0000000000..24c0b40117 --- /dev/null +++ b/tools/server/server-http.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include +#include +#include +#include + +struct common_params; + +// generator-like API for HTTP response generation +// this object response with one of the 2 modes: +// 1) normal response: `data` contains the full response body +// 2) streaming response: each call to next(output) generates the next chunk +// when next(output) returns false, no more data after the current chunk +// note: some chunks can be empty, in which case no data is sent for that chunk +struct server_http_res { + std::string content_type = "application/json; charset=utf-8"; + int status = 200; + std::string data; + std::map headers; + + // TODO: move this to a virtual function once we have proper polymorphism support + std::function next = nullptr; + bool is_stream() const { + return next != nullptr; + } + + virtual ~server_http_res() = default; +}; + +// unique pointer, used by set_chunked_content_provider +// httplib requires the stream provider to be stored in heap +using server_http_res_ptr = std::unique_ptr; + +struct server_http_req { + std::map params; // path_params + query_params + std::map headers; // reserved for future use + std::string path; // reserved for future use + std::string body; + const std::function & should_stop; + + std::string get_param(const std::string & key, const std::string & def = "") const { + auto it = params.find(key); + if (it != params.end()) { + return it->second; + } + return def; + } +}; + +struct server_http_context { + class Impl; + std::unique_ptr pimpl; + + std::thread thread; // server thread + std::atomic is_ready = false; + + std::string path_prefix; + std::string hostname; + int port; + + server_http_context(); + ~server_http_context(); + + bool init(const common_params & params); + bool start(); + void stop() const; + + // note: the handler should never throw exceptions + using handler_t = std::function; + + void get(const std::string & path, const handler_t & handler) const; + void post(const std::string & path, const handler_t & handler) const; + + // for debugging + std::string listening_address; +}; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 21d2f0f417..18e30a9589 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1,5 +1,6 @@ #include "chat.h" #include "utils.hpp" +#include "server-http.h" #include "arg.h" #include "common.h" @@ -10,13 +11,6 @@ #include "speculative.h" #include "mtmd.h" -// mime type for sending response -#define MIMETYPE_JSON "application/json; charset=utf-8" - -// auto generated files (see README.md for details) -#include "index.html.gz.hpp" -#include "loading.html.hpp" - #include #include #include @@ -25,11 +19,20 @@ #include #include #include +#include #include #include -#include #include +// fix problem with std::min and std::max +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#endif + using json = nlohmann::ordered_json; constexpr int HTTP_POLLING_SECONDS = 1; @@ -1680,7 +1683,7 @@ struct server_slot { server_prompt prompt; void prompt_save(server_prompt_cache & prompt_cache) const { - assert(prompt.data.size() == 0); + GGML_ASSERT(prompt.data.size() == 0); const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); @@ -2404,6 +2407,7 @@ struct server_context { llama_batch_free(batch); } + // load the model and initialize llama_context bool load_model(const common_params & params) { SRV_INF("loading model '%s'\n", params.model.path.c_str()); @@ -2523,6 +2527,7 @@ struct server_context { return true; } + // initialize slots and server-related data void init() { SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); @@ -2622,6 +2627,11 @@ struct server_context { /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, /* enable_thinking */ enable_thinking, }; + + // print sample chat example to make it clear which template is used + LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, + common_chat_templates_source(chat_templates.get()), + common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); } server_slot * get_slot_by_id(int id) { @@ -4379,6 +4389,7 @@ struct server_context { } }; + // generator-like API for server responses, support pooling connection state and aggregating results struct server_response_reader { std::unordered_set id_tasks; @@ -4397,7 +4408,7 @@ struct server_response_reader { ctx_server.queue_tasks.post(std::move(tasks)); } - bool has_next() { + bool has_next() const { return !cancelled && received_count < id_tasks.size(); } @@ -4477,281 +4488,46 @@ struct server_response_reader { } }; -static void log_server_request(const httplib::Request & req, const httplib::Response & res) { - // skip GH copilot requests when using default port - if (req.path == "/v1/health") { - return; +// generator-like API for HTTP response generation +struct server_res_generator : server_http_res { + server_response_reader rd; + server_res_generator(server_context & ctx_server_) : rd(ctx_server_) {} + void ok(const json & response_data) { + status = 200; + data = safe_json_to_str(response_data); } - - // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch - - SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); - - SRV_DBG("request: %s\n", req.body.c_str()); - SRV_DBG("response: %s\n", res.body.c_str()); -} - -static void res_err(httplib::Response & res, const json & error_data) { - json final_response {{"error", error_data}}; - res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON); - res.status = json_value(error_data, "code", 500); -} - -static void res_ok(httplib::Response & res, const json & data) { - res.set_content(safe_json_to_str(data), MIMETYPE_JSON); - res.status = 200; -} - -std::function shutdown_handler; -std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; - -inline void signal_handler(int signal) { - if (is_terminating.test_and_set()) { - // in case it hangs, we can force terminate the server by hitting Ctrl+C twice - // this is for better developer experience, we can remove when the server is stable enough - fprintf(stderr, "Received second interrupt, terminating immediately.\n"); - exit(1); + void error(const json & error_data) { + status = json_value(error_data, "code", 500); + data = safe_json_to_str({{ "error", error_data }}); } +}; - shutdown_handler(signal); -} +struct server_routes { + const common_params & params; + server_context & ctx_server; + server_http_context & ctx_http; // for reading is_ready + server_routes(const common_params & params, server_context & ctx_server, server_http_context & ctx_http) + : params(params), ctx_server(ctx_server), ctx_http(ctx_http) {} -int main(int argc, char ** argv) { - // own arguments required by this example - common_params params; +public: + // handlers using lambda function, so that they can capture `this` without `std::bind` - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) { - return 1; - } - - // TODO: should we have a separate n_parallel parameter for the server? - // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 - // TODO: this is a common configuration that is suitable for most local use cases - // however, overriding the parameters is a bit confusing - figure out something more intuitive - if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) { - LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__); - - params.n_parallel = 4; - params.kv_unified = true; - } - - common_init(); - - // struct that contains llama context and inference - server_context ctx_server; - - llama_backend_init(); - llama_numa_init(params.numa); - - LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); - LOG_INF("\n"); - LOG_INF("%s\n", common_params_get_system_info(params).c_str()); - LOG_INF("\n"); - - std::unique_ptr svr; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (params.ssl_file_key != "" && params.ssl_file_cert != "") { - LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str()); - svr.reset( - new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) - ); - } else { - LOG_INF("Running without SSL\n"); - svr.reset(new httplib::Server()); - } -#else - if (params.ssl_file_key != "" && params.ssl_file_cert != "") { - LOG_ERR("Server is built without SSL support\n"); - return 1; - } - svr.reset(new httplib::Server()); -#endif - - std::atomic state{SERVER_STATE_LOADING_MODEL}; - - svr->set_default_headers({{"Server", "llama.cpp"}}); - svr->set_logger(log_server_request); - svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { - std::string message; - try { - std::rethrow_exception(ep); - } catch (const std::exception & e) { - message = e.what(); - } catch (...) { - message = "Unknown Exception"; - } - - try { - json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); - LOG_WRN("got exception: %s\n", formatted_error.dump().c_str()); - res_err(res, formatted_error); - } catch (const std::exception & e) { - LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str()); - } - }); - - svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { - if (res.status == 404) { - res_err(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); - } - // for other error codes, we skip processing here because it's already done by res_err() - }); - - // set timeouts and change hostname and port - svr->set_read_timeout (params.timeout_read); - svr->set_write_timeout(params.timeout_write); - - std::unordered_map log_data; - - log_data["hostname"] = params.hostname; - log_data["port"] = std::to_string(params.port); - - if (params.api_keys.size() == 1) { - auto key = params.api_keys[0]; - log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0)); - } else if (params.api_keys.size() > 1) { - log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded"; - } - - // Necessary similarity of prompt for slot selection - ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; - - // - // Middlewares - // - - auto middleware_validate_api_key = [¶ms](const httplib::Request & req, httplib::Response & res) { - static const std::unordered_set public_endpoints = { - "/health", - "/v1/health", - "/models", - "/v1/models", - "/api/tags" - }; - - // If API key is not set, skip validation - if (params.api_keys.empty()) { - return true; - } - - // If path is public or is static file, skip validation - if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") { - return true; - } - - // Check for API key in the header - auto auth_header = req.get_header_value("Authorization"); - - std::string prefix = "Bearer "; - if (auth_header.substr(0, prefix.size()) == prefix) { - std::string received_api_key = auth_header.substr(prefix.size()); - if (std::find(params.api_keys.begin(), params.api_keys.end(), received_api_key) != params.api_keys.end()) { - return true; // API key is valid - } - } - - // API key is invalid or not provided - res_err(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); - - LOG_WRN("Unauthorized: Invalid API Key\n"); - - return false; - }; - - auto middleware_server_state = [&state](const httplib::Request & req, httplib::Response & res) { - server_state current_state = state.load(); - if (current_state == SERVER_STATE_LOADING_MODEL) { - auto tmp = string_split(req.path, '.'); - if (req.path == "/" || tmp.back() == "html") { - res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); - res.status = 503; - } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { - // allow the models endpoint to be accessed during loading - return true; - } else { - res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); - } - return false; - } - return true; - }; - - // register server middlewares - svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - // If this is OPTIONS request, skip validation because browsers don't include Authorization header - if (req.method == "OPTIONS") { - res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Allow-Methods", "GET, POST"); - res.set_header("Access-Control-Allow-Headers", "*"); - res.set_content("", "text/html"); // blank response, no data - return httplib::Server::HandlerResponse::Handled; // skip further processing - } - if (!middleware_server_state(req, res)) { - return httplib::Server::HandlerResponse::Handled; - } - if (!middleware_validate_api_key(req, res)) { - return httplib::Server::HandlerResponse::Handled; - } - return httplib::Server::HandlerResponse::Unhandled; - }); - - // - // Route handlers (or controllers) - // - - const auto handle_health = [&](const httplib::Request &, httplib::Response & res) { + server_http_context::handler_t get_health = [this](const server_http_req &) { // error and loading states are handled by middleware - json health = {{"status", "ok"}}; - res_ok(res, health); + auto res = std::make_unique(ctx_server); + res->ok({{"status", "ok"}}); + return res; }; - const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) { - if (!params.endpoint_slots) { - res_err(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); - return; - } - - // request slots data using task queue - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_METRICS); - task.id = task_id; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task), true); // high-priority task - } - - // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res_err(res, result->to_json()); - return; - } - - // TODO: get rid of this dynamic_cast - auto res_task = dynamic_cast(result.get()); - GGML_ASSERT(res_task != nullptr); - - // optionally return "fail_on_no_slot" error - if (req.has_param("fail_on_no_slot")) { - if (res_task->n_idle_slots == 0) { - res_err(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); - return; - } - } - - res_ok(res, res_task->slots_data); - }; - - const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { + server_http_context::handler_t get_metrics = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); if (!params.endpoint_metrics) { - res_err(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); - return; + res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); + return res; } // request slots data using task queue + // TODO: use server_response_reader int task_id = ctx_server.queue_tasks.get_new_id(); { server_task task(SERVER_TASK_TYPE_METRICS); @@ -4765,8 +4541,8 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_id(task_id); if (result->is_error()) { - res_err(res, result->to_json()); - return; + res->error(result->to_json()); + return res; } // TODO: get rid of this dynamic_cast @@ -4840,130 +4616,86 @@ int main(int argc, char ** argv) { } } - res.set_header("Process-Start-Time-Unix", std::to_string(res_task->t_start)); - - res.set_content(prometheus.str(), "text/plain; version=0.0.4"); - res.status = 200; // HTTP OK + res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start); + res->content_type = "text/plain; version=0.0.4"; + res->ok(prometheus.str()); + return res; }; - const auto handle_slots_save = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { - json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return; + server_http_context::handler_t get_slots = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + if (!params.endpoint_slots) { + res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); + return res; } - std::string filepath = params.slot_save_path + filename; + // request slots data using task queue int task_id = ctx_server.queue_tasks.get_new_id(); { - server_task task(SERVER_TASK_TYPE_SLOT_SAVE); + server_task task(SERVER_TASK_TYPE_METRICS); task.id = task_id; - task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); + ctx_server.queue_tasks.post(std::move(task), true); // high-priority task } + // get the result server_task_result_ptr result = ctx_server.queue_results.recv(task_id); ctx_server.queue_results.remove_waiting_task_id(task_id); if (result->is_error()) { - res_err(res, result->to_json()); - return; + res->error(result->to_json()); + return res; } - res_ok(res, result->to_json()); + // TODO: get rid of this dynamic_cast + auto res_task = dynamic_cast(result.get()); + GGML_ASSERT(res_task != nullptr); + + // optionally return "fail_on_no_slot" error + if (!req.get_param("fail_on_no_slot").empty()) { + if (res_task->n_idle_slots == 0) { + res->error(format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); + return res; + } + } + + res->ok(res_task->slots_data); + return res; }; - const auto handle_slots_restore = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { - json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return; - } - std::string filepath = params.slot_save_path + filename; - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; - - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res_err(res, result->to_json()); - return; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res_ok(res, result->to_json()); - }; - - const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_ERASE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res_err(res, result->to_json()); - return; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res_ok(res, result->to_json()); - }; - - const auto handle_slots_action = [¶ms, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + server_http_context::handler_t post_slots = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); if (params.slot_save_path.empty()) { - res_err(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); - return; + res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); + return res; } - std::string id_slot_str = req.path_params.at("id_slot"); + std::string id_slot_str = req.get_param("id_slot"); int id_slot; try { id_slot = std::stoi(id_slot_str); } catch (const std::exception &) { - res_err(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); + return res; } - std::string action = req.get_param_value("action"); + std::string action = req.get_param("action"); if (action == "save") { - handle_slots_save(req, res, id_slot); + return handle_slots_save(req, id_slot); } else if (action == "restore") { - handle_slots_restore(req, res, id_slot); + return handle_slots_restore(req, id_slot); } else if (action == "erase") { - handle_slots_erase(req, res, id_slot); + return handle_slots_erase(req, id_slot); } else { - res_err(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + res->error(format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + return res; } }; - const auto handle_props = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { + server_http_context::handler_t get_props = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); json default_generation_settings_for_props; { @@ -5002,23 +4734,24 @@ int main(int argc, char ** argv) { } } - res_ok(res, data); + res->ok(data); + return res; }; - const auto handle_props_change = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - if (!ctx_server.params_base.endpoint_props) { - res_err(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); - return; + server_http_context::handler_t post_props = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + if (!params.endpoint_props) { + res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); + return res; } - - json data = json::parse(req.body); - // update any props here - res_ok(res, {{ "success", true }}); + res->ok({{ "success", true }}); + return res; }; - const auto handle_api_show = [&ctx_server](const httplib::Request &, httplib::Response & res) { + server_http_context::handler_t get_api_show = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); bool has_mtmd = ctx_server.mctx != nullptr; json data = { { @@ -5044,24 +4777,404 @@ int main(int argc, char ** argv) { {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})} }; - res_ok(res, data); + res->ok(data); + return res; }; - // handle completion-like requests (completion, chat, infill) - // we can optionally provide a custom format for partial results and final results - const auto handle_completions_impl = [&ctx_server]( - server_task_type type, - json & data, - const std::vector & files, - const std::function & is_connection_closed, - httplib::Response & res, - oaicompat_type oaicompat) -> void { + server_http_context::handler_t post_infill = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + // check model compatibility + std::string err; + if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "prefix token is missing. "; + } + if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "suffix token is missing. "; + } + if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "middle token is missing. "; + } + if (!err.empty()) { + res->error(format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + // validate input + json data = json::parse(req.body); + if (data.contains("prompt") && !data.at("prompt").is_string()) { + // prompt is optional + res->error(format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + } + + if (!data.contains("input_prefix")) { + res->error(format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); + } + + if (!data.contains("input_suffix")) { + res->error(format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST)); + } + + if (data.contains("input_extra") && !data.at("input_extra").is_array()) { + // input_extra is optional + res->error(format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + json input_extra = json_value(data, "input_extra", json::array()); + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + if (!chunk.contains("text") || !chunk.at("text").is_string()) { + res->error(format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + // filename is optional + if (chunk.contains("filename") && !chunk.at("filename").is_string()) { + res->error(format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } + data["input_extra"] = input_extra; // default to empty array if it's not exist + + std::string prompt = json_value(data, "prompt", std::string()); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true); + SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); + data["prompt"] = format_infill( + ctx_server.vocab, + data.at("input_prefix"), + data.at("input_suffix"), + data.at("input_extra"), + ctx_server.params_base.n_batch, + ctx_server.params_base.n_predict, + ctx_server.slots[0].n_ctx, // TODO: there should be a better way + ctx_server.params_base.spm_infill, + tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal. + ); + + std::vector files; // dummy + return handle_completions_impl( + SERVER_TASK_TYPE_INFILL, + data, + files, + req.should_stop, + OAICOMPAT_TYPE_NONE); // infill is not OAI compatible + }; + + server_http_context::handler_t post_completions = [this](const server_http_req & req) { + std::vector files; // dummy + const json body = json::parse(req.body); + return handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + body, + files, + req.should_stop, + OAICOMPAT_TYPE_NONE); + }; + + server_http_context::handler_t post_completions_oai = [this](const server_http_req & req) { + std::vector files; // dummy + const json body = json::parse(req.body); + return handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + body, + files, + req.should_stop, + OAICOMPAT_TYPE_COMPLETION); + }; + + server_http_context::handler_t post_chat_completions = [this](const server_http_req & req) { + std::vector files; + json body = json::parse(req.body); + json body_parsed = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + return handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + body_parsed, + files, + req.should_stop, + OAICOMPAT_TYPE_CHAT); + }; + + // same with handle_chat_completions, but without inference part + server_http_context::handler_t post_apply_template = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + std::vector files; // dummy, unused + json body = json::parse(req.body); + json data = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + res->ok({{ "prompt", std::move(data.at("prompt")) }}); + return res; + }; + + server_http_context::handler_t get_models = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + bool is_model_ready = ctx_http.is_ready.load(); + json model_meta = nullptr; + if (is_model_ready) { + model_meta = ctx_server.model_meta(); + } + bool has_mtmd = ctx_server.mctx != nullptr; + json models = { + {"models", { + { + {"name", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"model", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"modified_at", ""}, + {"size", ""}, + {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash + {"type", "model"}, + {"description", ""}, + {"tags", {""}}, + {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}, + {"parameters", ""}, + {"details", { + {"parent_model", ""}, + {"format", "gguf"}, + {"family", ""}, + {"families", {""}}, + {"parameter_size", ""}, + {"quantization_level", ""} + }} + } + }}, + {"object", "list"}, + {"data", { + { + {"id", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"object", "model"}, + {"created", std::time(0)}, + {"owned_by", "llamacpp"}, + {"meta", model_meta}, + }, + }} + }; + + res->ok(models); + return res; + }; + + server_http_context::handler_t post_tokenize = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + const json body = json::parse(req.body); + json tokens_response = json::array(); + if (body.count("content") != 0) { + const bool add_special = json_value(body, "add_special", false); + const bool parse_special = json_value(body, "parse_special", true); + const bool with_pieces = json_value(body, "with_pieces", false); + + llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special); + + if (with_pieces) { + for (const auto& token : tokens) { + std::string piece = common_token_to_piece(ctx_server.ctx, token); + json piece_json; + + // Check if the piece is valid UTF-8 + if (is_valid_utf8(piece)) { + piece_json = piece; + } else { + // If not valid UTF-8, store as array of byte values + piece_json = json::array(); + for (unsigned char c : piece) { + piece_json.push_back(static_cast(c)); + } + } + + tokens_response.push_back({ + {"id", token}, + {"piece", piece_json} + }); + } + } else { + tokens_response = tokens; + } + } + + const json data = format_tokenizer_response(tokens_response); + res->ok(data); + return res; + }; + + server_http_context::handler_t post_detokenize = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + const json body = json::parse(req.body); + + std::string content; + if (body.count("tokens") != 0) { + const llama_tokens tokens = body.at("tokens"); + content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend()); + } + + const json data = format_detokenized_response(content); + res->ok(data); + return res; + }; + + server_http_context::handler_t post_embeddings = [this](const server_http_req & req) { + return handle_embeddings_impl(req, OAICOMPAT_TYPE_NONE); + }; + + server_http_context::handler_t post_embeddings_oai = [this](const server_http_req & req) { + return handle_embeddings_impl(req, OAICOMPAT_TYPE_EMBEDDING); + }; + + server_http_context::handler_t post_rerank = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { + res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + const json body = json::parse(req.body); + + // if true, use TEI API format, otherwise use Jina API format + // Jina: https://jina.ai/reranker/ + // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank + bool is_tei_format = body.contains("texts"); + + json query; + if (body.count("query") == 1) { + query = body.at("query"); + if (!query.is_string()) { + res->error(format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } else { + res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + std::vector documents = json_value(body, "documents", + json_value(body, "texts", std::vector())); + if (documents.empty()) { + res->error(format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + int top_n = json_value(body, "top_n", (int)documents.size()); + + // create and queue the task + json responses = json::array(); + server_response_reader rd(ctx_server); + { + std::vector tasks; + tasks.reserve(documents.size()); + for (size_t i = 0; i < documents.size(); i++) { + auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); + server_task task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tmp); + tasks.push_back(std::move(task)); + } + rd.post_tasks(std::move(tasks)); + } + + // wait for the results + auto all_results = rd.wait_for_all(req.should_stop); + + // collect results + if (all_results.is_terminated) { + return res; // connection is closed + } else if (all_results.error) { + res->error(all_results.error->to_json()); + return res; + } else { + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + } + + // write JSON response + json root = format_response_rerank( + body, + responses, + is_tei_format, + documents, + top_n); + + res->ok(root); + return res; + }; + + server_http_context::handler_t get_lora_adapters = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + json result = json::array(); + const auto & loras = ctx_server.params_base.lora_adapters; + for (size_t i = 0; i < loras.size(); ++i) { + auto & lora = loras[i]; + json entry = { + {"id", i}, + {"path", lora.path}, + {"scale", lora.scale}, + {"task_name", lora.task_name}, + {"prompt_prefix", lora.prompt_prefix}, + }; + std::string alora_invocation_string = ""; + const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr); + std::vector alora_invocation_tokens; + if (n_alora_tokens) { + const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr); + for (uint64_t i = 0; i < n_alora_tokens; ++i) { + alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]); + alora_invocation_tokens.push_back(alora_tokens[i]); + } + entry["alora_invocation_string"] = alora_invocation_string; + entry["alora_invocation_tokens"] = alora_invocation_tokens; + } + result.push_back(std::move(entry)); + } + res->ok(result); + return res; + }; + + server_http_context::handler_t post_lora_adapters = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + const json body = json::parse(req.body); + if (!body.is_array()) { + res->error(format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SET_LORA); + task.id = task_id; + task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body); + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + // get the result + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; + }; + +private: + std::unique_ptr handle_completions_impl( + server_task_type type, + const json & data, + const std::vector & files, + const std::function & should_stop, + oaicompat_type oaicompat) { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); + auto res = std::make_unique(ctx_server); auto completion_id = gen_chatcmplid(); - // need to store the reader as a pointer, so that it won't be destroyed when the handle returns - // use shared_ptr as it's shared between the chunked_content_provider() and on_complete() - const auto rd = std::make_shared(ctx_server); + auto & rd = res->rd; try { std::vector tasks; @@ -5102,22 +5215,22 @@ int main(int argc, char ** argv) { tasks.push_back(std::move(task)); } - rd->post_tasks(std::move(tasks)); + rd.post_tasks(std::move(tasks)); } catch (const std::exception & e) { - res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return res; } bool stream = json_value(data, "stream", false); if (!stream) { // non-stream, wait for the results - auto all_results = rd->wait_for_all(is_connection_closed); + auto all_results = rd.wait_for_all(should_stop); if (all_results.is_terminated) { - return; // connection is closed + return res; // connection is closed } else if (all_results.error) { - res_err(res, all_results.error->to_json()); - return; + res->error(all_results.error->to_json()); + return res; } else { json arr = json::array(); for (auto & res : all_results.results) { @@ -5125,19 +5238,19 @@ int main(int argc, char ** argv) { arr.push_back(res->to_json()); } // if single request, return single object instead of array - res_ok(res, arr.size() == 1 ? arr[0] : arr); + res->ok(arr.size() == 1 ? arr[0] : arr); } } else { // in streaming mode, the first error must be treated as non-stream response // this is to match the OAI API behavior // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 - server_task_result_ptr first_result = rd->next(is_connection_closed); + server_task_result_ptr first_result = rd.next(should_stop); if (first_result == nullptr) { - return; // connection is closed + return res; // connection is closed } else if (first_result->is_error()) { - res_err(res, first_result->to_json()); - return; + res->error(first_result->to_json()); + return res; } else { GGML_ASSERT( dynamic_cast(first_result.get()) != nullptr @@ -5146,307 +5259,171 @@ int main(int argc, char ** argv) { } // next responses are streamed - json first_result_json = first_result->to_json(); - const auto chunked_content_provider = [first_result_json, rd, oaicompat](size_t, httplib::DataSink & sink) mutable -> bool { - // flush the first result as it's not an error - if (!first_result_json.empty()) { - if (!server_sent_event(sink, first_result_json)) { - sink.done(); - return false; // sending failed, go to on_complete() + res->data = format_sse(first_result->to_json()); // to be sent immediately + res->status = 200; + res->content_type = "text/event-stream"; + res->next = [res_this = res.get(), oaicompat, &should_stop](std::string & output) -> bool { + if (should_stop()) { + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met + } + + if (!res_this->data.empty()) { + // flush the first chunk + output = std::move(res_this->data); + res_this->data.clear(); + return true; + } + + server_response_reader & rd = res_this->rd; + + // check if there is more data + if (!rd.has_next()) { + if (oaicompat != OAICOMPAT_TYPE_NONE) { + output = "data: [DONE]\n\n"; + } else { + output = ""; } - first_result_json.clear(); // mark as sent + SRV_DBG("%s", "all results received, terminating stream\n"); + return false; // no more data, terminate } // receive subsequent results - auto result = rd->next([&sink]{ return !sink.is_writable(); }); + auto result = rd.next(should_stop); if (result == nullptr) { - sink.done(); - return false; // connection is closed, go to on_complete() + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met } // send the results json res_json = result->to_json(); - bool ok = false; if (result->is_error()) { - ok = server_sent_event(sink, json {{ "error", result->to_json() }}); - sink.done(); - return false; // go to on_complete() + output = format_sse(json {{ "error", res_json }}); + SRV_DBG("%s", "error received during streaming, terminating stream\n"); + return false; // terminate on error } else { GGML_ASSERT( dynamic_cast(result.get()) != nullptr || dynamic_cast(result.get()) != nullptr ); - ok = server_sent_event(sink, res_json); - } - - if (!ok) { - sink.done(); - return false; // sending failed, go to on_complete() - } - - // check if there is more data - if (!rd->has_next()) { - if (oaicompat != OAICOMPAT_TYPE_NONE) { - static const std::string ev_done = "data: [DONE]\n\n"; - sink.write(ev_done.data(), ev_done.size()); - } - sink.done(); - return false; // no more data, go to on_complete() + output = format_sse(res_json); } // has next data, continue return true; }; - - auto on_complete = [rd](bool) { - rd->stop(); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }; - - const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - json data = json::parse(req.body); - std::vector files; // dummy - handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_NONE); - }; - - const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - json data = oaicompat_completion_params_parse(json::parse(req.body)); - std::vector files; // dummy - handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_COMPLETION); - }; - - const auto handle_infill = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - // check model compatibility - std::string err; - if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "prefix token is missing. "; - } - if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "suffix token is missing. "; - } - if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "middle token is missing. "; - } - if (!err.empty()) { - res_err(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); - return; } - json data = json::parse(req.body); + return res; + } - // validate input - if (data.contains("prompt") && !data.at("prompt").is_string()) { - // prompt is optional - res_err(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + std::unique_ptr handle_slots_save(const server_http_req & req, int id_slot) { + auto res = std::make_unique(ctx_server); + const json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + std::string filepath = params.slot_save_path + filename; + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_SAVE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; + + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); } - if (!data.contains("input_prefix")) { - res_err(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; } - if (!data.contains("input_suffix")) { - res_err(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST)); + res->ok(result->to_json()); + return res; + } + + std::unique_ptr handle_slots_restore(const server_http_req & req, int id_slot) { + auto res = std::make_unique(ctx_server); + const json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + std::string filepath = params.slot_save_path + filename; + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; + + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); } - if (data.contains("input_extra") && !data.at("input_extra").is_array()) { - // input_extra is optional - res_err(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); - return; + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; } - json input_extra = json_value(data, "input_extra", json::array()); - for (const auto & chunk : input_extra) { - // { "text": string, "filename": string } - if (!chunk.contains("text") || !chunk.at("text").is_string()) { - res_err(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST)); - return; - } - // filename is optional - if (chunk.contains("filename") && !chunk.at("filename").is_string()) { - res_err(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST)); - return; - } - } - data["input_extra"] = input_extra; // default to empty array if it's not exist + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; + } - std::string prompt = json_value(data, "prompt", std::string()); - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true); - SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); - data["prompt"] = format_infill( - ctx_server.vocab, - data.at("input_prefix"), - data.at("input_suffix"), - data.at("input_extra"), - ctx_server.params_base.n_batch, - ctx_server.params_base.n_predict, - ctx_server.slots[0].n_ctx, // TODO: there should be a better way - ctx_server.params_base.spm_infill, - tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal. - ); + std::unique_ptr handle_slots_erase(const server_http_req &, int id_slot) { + auto res = std::make_unique(ctx_server); + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_ERASE); + task.id = task_id; + task.slot_action.slot_id = id_slot; - std::vector files; // dummy - handle_completions_impl( - SERVER_TASK_TYPE_INFILL, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_NONE); // infill is not OAI compatible - }; - - const auto handle_chat_completions = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - LOG_DBG("request: %s\n", req.body.c_str()); - - auto body = json::parse(req.body); - std::vector files; - json data = oaicompat_chat_params_parse( - body, - ctx_server.oai_parser_opt, - files); - - handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_CHAT); - }; - - // same with handle_chat_completions, but without inference part - const auto handle_apply_template = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - auto body = json::parse(req.body); - std::vector files; // dummy, unused - json data = oaicompat_chat_params_parse( - body, - ctx_server.oai_parser_opt, - files); - res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); - }; - - const auto handle_models = [¶ms, &ctx_server, &state](const httplib::Request &, httplib::Response & res) { - server_state current_state = state.load(); - json model_meta = nullptr; - if (current_state == SERVER_STATE_READY) { - model_meta = ctx_server.model_meta(); - } - bool has_mtmd = ctx_server.mctx != nullptr; - json models = { - {"models", { - { - {"name", params.model_alias.empty() ? params.model.path : params.model_alias}, - {"model", params.model_alias.empty() ? params.model.path : params.model_alias}, - {"modified_at", ""}, - {"size", ""}, - {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash - {"type", "model"}, - {"description", ""}, - {"tags", {""}}, - {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}, - {"parameters", ""}, - {"details", { - {"parent_model", ""}, - {"format", "gguf"}, - {"family", ""}, - {"families", {""}}, - {"parameter_size", ""}, - {"quantization_level", ""} - }} - } - }}, - {"object", "list"}, - {"data", { - { - {"id", params.model_alias.empty() ? params.model.path : params.model_alias}, - {"object", "model"}, - {"created", std::time(0)}, - {"owned_by", "llamacpp"}, - {"meta", model_meta}, - }, - }} - }; - - res_ok(res, models); - }; - - const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - const json body = json::parse(req.body); - - json tokens_response = json::array(); - if (body.count("content") != 0) { - const bool add_special = json_value(body, "add_special", false); - const bool parse_special = json_value(body, "parse_special", true); - const bool with_pieces = json_value(body, "with_pieces", false); - - llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special); - - if (with_pieces) { - for (const auto& token : tokens) { - std::string piece = common_token_to_piece(ctx_server.ctx, token); - json piece_json; - - // Check if the piece is valid UTF-8 - if (is_valid_utf8(piece)) { - piece_json = piece; - } else { - // If not valid UTF-8, store as array of byte values - piece_json = json::array(); - for (unsigned char c : piece) { - piece_json.push_back(static_cast(c)); - } - } - - tokens_response.push_back({ - {"id", token}, - {"piece", piece_json} - }); - } - } else { - tokens_response = tokens; - } + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); } - const json data = format_tokenizer_response(tokens_response); - res_ok(res, data); - }; + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); - const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - const json body = json::parse(req.body); - - std::string content; - if (body.count("tokens") != 0) { - const llama_tokens tokens = body.at("tokens"); - content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend()); + if (result->is_error()) { + res->error(result->to_json()); + return res; } - const json data = format_detokenized_response(content); - res_ok(res, data); - }; + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; + } - const auto handle_embeddings_impl = [&ctx_server](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) { + std::unique_ptr handle_embeddings_impl(const server_http_req & req, oaicompat_type oaicompat) { + auto res = std::make_unique(ctx_server); if (!ctx_server.params_base.embedding) { - res_err(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); - return; + res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return res; } if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - res_err(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return res; } const json body = json::parse(req.body); @@ -5459,8 +5436,8 @@ int main(int argc, char ** argv) { oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible prompt = body.at("content"); } else { - res_err(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return res; } bool use_base64 = false; @@ -5469,8 +5446,8 @@ int main(int argc, char ** argv) { if (format == "base64") { use_base64 = true; } else if (format != "float") { - res_err(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); + return res; } } @@ -5478,8 +5455,8 @@ int main(int argc, char ** argv) { for (const auto & tokens : tokenized_prompts) { // this check is necessary for models that do not add BOS token to the input if (tokens.empty()) { - res_err(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); + return res; } } @@ -5513,14 +5490,14 @@ int main(int argc, char ** argv) { } // wait for the results - auto all_results = rd.wait_for_all(req.is_connection_closed); + auto all_results = rd.wait_for_all(req.should_stop); // collect results if (all_results.is_terminated) { - return; // connection is closed + return res; // connection is closed } else if (all_results.error) { - res_err(res, all_results.error->to_json()); - return; + res->error(all_results.error->to_json()); + return res; } else { for (auto & res : all_results.results) { GGML_ASSERT(dynamic_cast(res.get()) != nullptr); @@ -5532,292 +5509,170 @@ int main(int argc, char ** argv) { json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses); - res_ok(res, root); + res->ok(root); + return res; + } +}; + +std::function shutdown_handler; +std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; + +inline void signal_handler(int signal) { + if (is_terminating.test_and_set()) { + // in case it hangs, we can force terminate the server by hitting Ctrl+C twice + // this is for better developer experience, we can remove when the server is stable enough + fprintf(stderr, "Received second interrupt, terminating immediately.\n"); + exit(1); + } + + shutdown_handler(signal); +} + +// wrapper function that handles exceptions and logs errors +// this is to make sure handler_t never throws exceptions; instead, it returns an error response +static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) { + return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr { + std::string message; + try { + return func(req); + } catch (const std::exception & e) { + message = e.what(); + } catch (...) { + message = "unknown error"; + } + + auto res = std::make_unique(); + res->status = 500; + try { + json error_data = format_error_response(message, ERROR_TYPE_SERVER); + res->status = json_value(error_data, "code", 500); + res->data = safe_json_to_str({{ "error", error_data }}); + LOG_WRN("got exception: %s\n", res->data.c_str()); + } catch (const std::exception & e) { + LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str()); + res->data = "Internal Server Error"; + } + return res; }; +} - const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { - handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE); - }; +int main(int argc, char ** argv) { + // own arguments required by this example + common_params params; - const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { - handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING); - }; + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) { + return 1; + } - const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { - res_err(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); - return; - } + // TODO: should we have a separate n_parallel parameter for the server? + // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 + // TODO: this is a common configuration that is suitable for most local use cases + // however, overriding the parameters is a bit confusing - figure out something more intuitive + if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) { + LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__); - const json body = json::parse(req.body); + params.n_parallel = 4; + params.kv_unified = true; + } - // if true, use TEI API format, otherwise use Jina API format - // Jina: https://jina.ai/reranker/ - // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank - bool is_tei_format = body.contains("texts"); + common_init(); - json query; - if (body.count("query") == 1) { - query = body.at("query"); - if (!query.is_string()) { - res_err(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); - return; - } - } else { - res_err(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return; - } + // struct that contains llama context and inference + server_context ctx_server; - std::vector documents = json_value(body, "documents", - json_value(body, "texts", std::vector())); - if (documents.empty()) { - res_err(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); - return; - } + // Necessary similarity of prompt for slot selection + ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; - int top_n = json_value(body, "top_n", (int)documents.size()); + llama_backend_init(); + llama_numa_init(params.numa); - // create and queue the task - json responses = json::array(); - server_response_reader rd(ctx_server); - { - std::vector tasks; - tasks.reserve(documents.size()); - for (size_t i = 0; i < documents.size(); i++) { - auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); - server_task task = server_task(SERVER_TASK_TYPE_RERANK); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.tokens = std::move(tmp); - tasks.push_back(std::move(task)); - } - rd.post_tasks(std::move(tasks)); - } + LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); - // wait for the results - auto all_results = rd.wait_for_all(req.is_connection_closed); - - // collect results - if (all_results.is_terminated) { - return; // connection is closed - } else if (all_results.error) { - res_err(res, all_results.error->to_json()); - return; - } else { - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - } - - // write JSON response - json root = format_response_rerank( - body, - responses, - is_tei_format, - documents, - top_n); - - res_ok(res, root); - }; - - const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { - json result = json::array(); - const auto & loras = ctx_server.params_base.lora_adapters; - for (size_t i = 0; i < loras.size(); ++i) { - auto & lora = loras[i]; - json entry = { - {"id", i}, - {"path", lora.path}, - {"scale", lora.scale}, - {"task_name", lora.task_name}, - {"prompt_prefix", lora.prompt_prefix}, - }; - std::string alora_invocation_string = ""; - const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr); - std::vector alora_invocation_tokens; - if (n_alora_tokens) { - const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr); - for (uint64_t i = 0; i < n_alora_tokens; ++i) { - alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]); - alora_invocation_tokens.push_back(alora_tokens[i]); - } - entry["alora_invocation_string"] = alora_invocation_string; - entry["alora_invocation_tokens"] = alora_invocation_tokens; - } - result.push_back(std::move(entry)); - } - res_ok(res, result); - res.status = 200; // HTTP OK - }; - - const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { - const json body = json::parse(req.body); - if (!body.is_array()) { - res_err(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); - return; - } - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SET_LORA); - task.id = task_id; - task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body); - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res_err(res, result->to_json()); - return; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res_ok(res, result->to_json()); - }; + server_http_context ctx_http; + if (!ctx_http.init(params)) { + LOG_ERR("%s: failed to initialize HTTP server\n", __func__); + return 1; + } // // Router // - if (!params.webui) { - LOG_INF("Web UI is disabled\n"); - } else { - // register static assets routes - if (!params.public_path.empty()) { - // Set the base directory for serving static files - bool is_found = svr->set_mount_point(params.api_prefix + "/", params.public_path); - if (!is_found) { - LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); - return 1; - } - } else { - // using embedded static index.html - svr->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) { - if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { - res.set_content("Error: gzip is not supported by this browser", "text/plain"); - } else { - res.set_header("Content-Encoding", "gzip"); - // COEP and COOP headers, required by pyodide (python interpreter) - res.set_header("Cross-Origin-Embedder-Policy", "require-corp"); - res.set_header("Cross-Origin-Opener-Policy", "same-origin"); - res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); - } - return false; - }); - } - } - // register API routes - svr->Get (params.api_prefix + "/health", handle_health); // public endpoint (no API key check) - svr->Get (params.api_prefix + "/v1/health", handle_health); // public endpoint (no API key check) - svr->Get (params.api_prefix + "/metrics", handle_metrics); - svr->Get (params.api_prefix + "/props", handle_props); - svr->Post(params.api_prefix + "/props", handle_props_change); - svr->Post(params.api_prefix + "/api/show", handle_api_show); - svr->Get (params.api_prefix + "/models", handle_models); // public endpoint (no API key check) - svr->Get (params.api_prefix + "/v1/models", handle_models); // public endpoint (no API key check) - svr->Get (params.api_prefix + "/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check) - svr->Post(params.api_prefix + "/completion", handle_completions); // legacy - svr->Post(params.api_prefix + "/completions", handle_completions); - svr->Post(params.api_prefix + "/v1/completions", handle_completions_oai); - svr->Post(params.api_prefix + "/chat/completions", handle_chat_completions); - svr->Post(params.api_prefix + "/v1/chat/completions", handle_chat_completions); - svr->Post(params.api_prefix + "/api/chat", handle_chat_completions); // ollama specific endpoint - svr->Post(params.api_prefix + "/infill", handle_infill); - svr->Post(params.api_prefix + "/embedding", handle_embeddings); // legacy - svr->Post(params.api_prefix + "/embeddings", handle_embeddings); - svr->Post(params.api_prefix + "/v1/embeddings", handle_embeddings_oai); - svr->Post(params.api_prefix + "/rerank", handle_rerank); - svr->Post(params.api_prefix + "/reranking", handle_rerank); - svr->Post(params.api_prefix + "/v1/rerank", handle_rerank); - svr->Post(params.api_prefix + "/v1/reranking", handle_rerank); - svr->Post(params.api_prefix + "/tokenize", handle_tokenize); - svr->Post(params.api_prefix + "/detokenize", handle_detokenize); - svr->Post(params.api_prefix + "/apply-template", handle_apply_template); + server_routes routes(params, ctx_server, ctx_http); + + ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) + ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) + ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics)); + ctx_http.get ("/props", ex_wrapper(routes.get_props)); + ctx_http.post("/props", ex_wrapper(routes.post_props)); + ctx_http.post("/api/show", ex_wrapper(routes.get_api_show)); + ctx_http.get ("/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) + ctx_http.get ("/v1/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) + ctx_http.get ("/api/tags", ex_wrapper(routes.get_models)); // ollama specific endpoint. public endpoint (no API key check) + ctx_http.post("/completion", ex_wrapper(routes.post_completions)); // legacy + ctx_http.post("/completions", ex_wrapper(routes.post_completions)); + ctx_http.post("/v1/completions", ex_wrapper(routes.post_completions_oai)); + ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions)); + ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions)); + ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint + ctx_http.post("/infill", ex_wrapper(routes.post_infill)); + ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy + ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings)); + ctx_http.post("/v1/embeddings", ex_wrapper(routes.post_embeddings_oai)); + ctx_http.post("/rerank", ex_wrapper(routes.post_rerank)); + ctx_http.post("/reranking", ex_wrapper(routes.post_rerank)); + ctx_http.post("/v1/rerank", ex_wrapper(routes.post_rerank)); + ctx_http.post("/v1/reranking", ex_wrapper(routes.post_rerank)); + ctx_http.post("/tokenize", ex_wrapper(routes.post_tokenize)); + ctx_http.post("/detokenize", ex_wrapper(routes.post_detokenize)); + ctx_http.post("/apply-template", ex_wrapper(routes.post_apply_template)); // LoRA adapters hotswap - svr->Get (params.api_prefix + "/lora-adapters", handle_lora_adapters_list); - svr->Post(params.api_prefix + "/lora-adapters", handle_lora_adapters_apply); + ctx_http.get ("/lora-adapters", ex_wrapper(routes.get_lora_adapters)); + ctx_http.post("/lora-adapters", ex_wrapper(routes.post_lora_adapters)); // Save & load slots - svr->Get (params.api_prefix + "/slots", handle_slots); - svr->Post(params.api_prefix + "/slots/:id_slot", handle_slots_action); + ctx_http.get ("/slots", ex_wrapper(routes.get_slots)); + ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots)); // // Start the server // - if (params.n_threads_http < 1) { - // +2 threads for monitoring endpoints - params.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); - } - log_data["n_threads_http"] = std::to_string(params.n_threads_http); - svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); }; - // clean up function, to be called before exit - auto clean_up = [&svr, &ctx_server]() { + // setup clean up function, to be called before exit + auto clean_up = [&ctx_http, &ctx_server]() { SRV_INF("%s: cleaning up before exit...\n", __func__); - svr->stop(); + ctx_http.stop(); ctx_server.queue_results.terminate(); llama_backend_free(); }; - bool was_bound = false; - bool is_sock = false; - if (string_ends_with(std::string(params.hostname), ".sock")) { - is_sock = true; - LOG_INF("%s: setting address family to AF_UNIX\n", __func__); - svr->set_address_family(AF_UNIX); - // bind_to_port requires a second arg, any value other than 0 should - // simply get ignored - was_bound = svr->bind_to_port(params.hostname, 8080); - } else { - LOG_INF("%s: binding port with default address family\n", __func__); - // bind HTTP listen port - if (params.port == 0) { - int bound_port = svr->bind_to_any_port(params.hostname); - if ((was_bound = (bound_port >= 0))) { - params.port = bound_port; - } - } else { - was_bound = svr->bind_to_port(params.hostname, params.port); - } - } - - if (!was_bound) { - LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port); + // start the HTTP server before loading the model to be able to serve /health requests + if (!ctx_http.start()) { clean_up(); + LOG_ERR("%s: exiting due to HTTP server error\n", __func__); return 1; } - // run the HTTP server in a thread - std::thread t([&]() { svr->listen_after_bind(); }); - svr->wait_until_ready(); - - LOG_INF("%s: HTTP server is listening, hostname: %s, port: %d, http threads: %d\n", __func__, params.hostname.c_str(), params.port, params.n_threads_http); - // load the model LOG_INF("%s: loading model\n", __func__); if (!ctx_server.load_model(params)) { clean_up(); - t.join(); + if (ctx_http.thread.joinable()) { + ctx_http.thread.join(); + } LOG_ERR("%s: exiting due to model loading error\n", __func__); return 1; } ctx_server.init(); - state.store(SERVER_STATE_READY); + ctx_http.is_ready.store(true); LOG_INF("%s: model loaded\n", __func__); - // print sample chat example to make it clear which template is used - LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - common_chat_templates_source(ctx_server.chat_templates.get()), - common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja, ctx_server.params_base.default_template_kwargs).c_str()); - ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { ctx_server.process_single_task(std::move(task)); }); @@ -5845,15 +5700,15 @@ int main(int argc, char ** argv) { SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif - LOG_INF("%s: server is listening on %s - starting the main loop\n", __func__, - is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() : - string_format("http://%s:%d", params.hostname.c_str(), params.port).c_str()); - + LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); + LOG_INF("%s: starting the main loop...\n", __func__); // this call blocks the main thread until queue_tasks.terminate() is called ctx_server.queue_tasks.start_loop(); clean_up(); - t.join(); + if (ctx_http.thread.joinable()) { + ctx_http.thread.join(); + } llama_memory_breakdown_print(ctx_server.ctx); return 0; diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index b1ecc5af5e..bf21726051 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -9,8 +9,6 @@ #include "mtmd-helper.h" #include "chat.h" -#include - #define JSON_ASSERT GGML_ASSERT #include @@ -426,6 +424,10 @@ static std::string gen_tool_call_id() { // other common utils // +static std::string safe_json_to_str(const json & data) { + return data.dump(-1, ' ', false, json::error_handler_t::replace); +} + // TODO: reuse llama_detokenize template static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { @@ -453,29 +455,25 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx, return out; } +// format server-sent event (SSE), return the formatted string to send // note: if data is a json array, it will be sent as multiple events, one per item -static bool server_sent_event(httplib::DataSink & sink, const json & data) { - static auto send_single = [](httplib::DataSink & sink, const json & data) -> bool { - const std::string str = - "data: " + - data.dump(-1, ' ', false, json::error_handler_t::replace) + +static std::string format_sse(const json & data) { + std::ostringstream ss; + auto send_single = [&ss](const json & data) { + ss << "data: " << + safe_json_to_str(data) << "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). - - LOG_DBG("data stream, to_send: %s", str.c_str()); - return sink.write(str.c_str(), str.size()); }; if (data.is_array()) { for (const auto & item : data) { - if (!send_single(sink, item)) { - return false; - } + send_single(item); } } else { - return send_single(sink, data); + send_single(data); } - return true; + return ss.str(); } // @@ -954,10 +952,6 @@ static json format_logit_bias(const std::vector & logit_bias) return data; } -static std::string safe_json_to_str(const json & data) { - return data.dump(-1, ' ', false, json::error_handler_t::replace); -} - static std::vector get_token_probabilities(llama_context * ctx, int idx) { std::vector cur; const auto * logits = llama_get_logits_ith(ctx, idx); diff --git a/tools/server/webui/.gitignore b/tools/server/webui/.gitignore index cc54bb717f..051d884b08 100644 --- a/tools/server/webui/.gitignore +++ b/tools/server/webui/.gitignore @@ -25,3 +25,4 @@ vite.config.ts.timestamp-* *storybook.log storybook-static +*.code-workspace \ No newline at end of file diff --git a/tools/server/webui/package-lock.json b/tools/server/webui/package-lock.json index a11b87ad50..4af5e86ab9 100644 --- a/tools/server/webui/package-lock.json +++ b/tools/server/webui/package-lock.json @@ -2109,9 +2109,9 @@ } }, "node_modules/@sveltejs/kit": { - "version": "2.48.4", - "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.48.4.tgz", - "integrity": "sha512-TGFX1pZUt9qqY20Cv5NyYvy0iLWHf2jXi8s+eCGsig7jQMdwZWKUFMR6TbvFNhfDSUpc1sH/Y5EHv20g3HHA3g==", + "version": "2.48.5", + "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.48.5.tgz", + "integrity": "sha512-/rnwfSWS3qwUSzvHynUTORF9xSJi7PCR9yXkxUOnRrNqyKmCmh3FPHH+E9BbgqxXfTevGXBqgnlh9kMb+9T5XA==", "dev": true, "license": "MIT", "dependencies": { @@ -5087,9 +5087,9 @@ "license": "MIT" }, "node_modules/js-yaml": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", - "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz", + "integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==", "dev": true, "license": "MIT", "dependencies": { diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte index e47a5a7dba..ae0dc2ed9f 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte @@ -10,6 +10,7 @@ class?: string; message: DatabaseMessage; onCopy?: (message: DatabaseMessage) => void; + onContinueAssistantMessage?: (message: DatabaseMessage) => void; onDelete?: (message: DatabaseMessage) => void; onEditWithBranching?: (message: DatabaseMessage, newContent: string) => void; onEditWithReplacement?: ( @@ -17,6 +18,7 @@ newContent: string, shouldBranch: boolean ) => void; + onEditUserMessagePreserveResponses?: (message: DatabaseMessage, newContent: string) => void; onNavigateToSibling?: (siblingId: string) => void; onRegenerateWithBranching?: (message: DatabaseMessage) => void; siblingInfo?: ChatMessageSiblingInfo | null; @@ -26,9 +28,11 @@ class: className = '', message, onCopy, + onContinueAssistantMessage, onDelete, onEditWithBranching, onEditWithReplacement, + onEditUserMessagePreserveResponses, onNavigateToSibling, onRegenerateWithBranching, siblingInfo = null @@ -133,17 +137,33 @@ onRegenerateWithBranching?.(message); } + function handleContinue() { + onContinueAssistantMessage?.(message); + } + function handleSaveEdit() { if (message.role === 'user') { + // For user messages, trim to avoid accidental whitespace onEditWithBranching?.(message, editedContent.trim()); } else { - onEditWithReplacement?.(message, editedContent.trim(), shouldBranchAfterEdit); + // For assistant messages, preserve exact content including trailing whitespace + // This is important for the Continue feature to work properly + onEditWithReplacement?.(message, editedContent, shouldBranchAfterEdit); } isEditing = false; shouldBranchAfterEdit = false; } + function handleSaveEditOnly() { + if (message.role === 'user') { + // For user messages, trim to avoid accidental whitespace + onEditUserMessagePreserveResponses?.(message, editedContent.trim()); + } + + isEditing = false; + } + function handleShowDeleteDialogChange(show: boolean) { showDeleteDialog = show; } @@ -166,6 +186,7 @@ onEditedContentChange={handleEditedContentChange} {onNavigateToSibling} onSaveEdit={handleSaveEdit} + onSaveEditOnly={handleSaveEditOnly} onShowDeleteDialogChange={handleShowDeleteDialogChange} {showDeleteDialog} {siblingInfo} @@ -181,6 +202,7 @@ messageContent={message.content} onCancelEdit={handleCancelEdit} onConfirmDelete={handleConfirmDelete} + onContinue={handleContinue} onCopy={handleCopy} onDelete={handleDelete} onEdit={handleEdit} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte index c16a3105cb..d37d806514 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte @@ -1,5 +1,5 @@