diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml index 36201281f0..c2c6ea12ae 100644 --- a/.github/workflows/build-linux-cross.yml +++ b/.github/workflows/build-linux-cross.yml @@ -291,6 +291,7 @@ jobs: -DGGML_RVV=ON \ -DGGML_RV_ZFH=ON \ -DGGML_RV_ZICBOP=ON \ + -DGGML_RV_ZIHINTPAUSE=ON \ -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 \ -DCMAKE_TOOLCHAIN_FILE=${PWD}/cmake/riscv64-spacemit-linux-gnu-gcc.cmake diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3668e4e2c9..77aec20c11 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -546,6 +546,8 @@ jobs: cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libmmd.dll" ./build/bin cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libiomp5md.dll" ./build/bin cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/sycl-ls.exe" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libsycl-fallback-bfloat16.spv" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libsycl-native-bfloat16.spv" ./build/bin cp "${{ env.ONEAPI_ROOT }}/dnnl/latest/bin/dnnl.dll" ./build/bin cp "${{ env.ONEAPI_ROOT }}/tbb/latest/bin/tbb12.dll" ./build/bin diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e4f05258db..4545ff8f9a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -15,6 +15,7 @@ The project differentiates between 3 levels of contributors: - If you modified the `ggml` source, run the `test-backend-ops` tool to check whether different backend implementations of the `ggml` operators produce consistent results (this requires access to at least two different `ggml` backends) - If you modified a `ggml` operator or added a new one, add the corresponding test cases to `test-backend-ops` - Create separate PRs for each feature or fix. Avoid combining unrelated changes in a single PR +- When adding support for a new model or feature, focus on **CPU support only** in the initial PR unless you have a good reason not to. Add support for other backends like CUDA in follow-up PRs - Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly - If your PR becomes stale, rebase it on top of latest `master` to get maintainers attention - Maintainers will rely on your insights and approval when making a final decision to approve and merge a PR diff --git a/README.md b/README.md index eac8d66cc2..7dd2bfd8a1 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ range of hardware - locally and in the cloud. - Plain C/C++ implementation without any dependencies - Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks - AVX, AVX2, AVX512 and AMX support for x86 architectures -- RVV, ZVFH, ZFH and ZICBOP support for RISC-V architectures +- RVV, ZVFH, ZFH, ZICBOP and ZIHINTPAUSE support for RISC-V architectures - 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use - Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads GPUs via MUSA) - Vulkan and SYCL backend support diff --git a/common/chat-parser-xml-toolcall.cpp b/common/chat-parser-xml-toolcall.cpp index 7349895550..a80900ff8d 100644 --- a/common/chat-parser-xml-toolcall.cpp +++ b/common/chat-parser-xml-toolcall.cpp @@ -724,16 +724,10 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons if (reasoning_unclosed) { if (auto pos = content.find(end_think); pos == std::string::npos && builder.pos() != builder.input().size()) { unclosed_reasoning_content += content; - if (form.allow_toolcall_in_think) { - builder.move_to(tc->groups[0].begin); - if (!builder.try_consume_xml_tool_calls(form)) { - unclosed_reasoning_content += tool_call_start; - builder.move_to(tc->groups[0].end); - } - } else { + if (!(form.allow_toolcall_in_think && tc)) { unclosed_reasoning_content += tool_call_start; + continue; } - continue; } else { reasoning_unclosed = false; std::string reasoning_content; @@ -781,8 +775,12 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons } } else { // This start is in thinking block, skip this tool call - auto pos = think_start + start_think.size(); - unclosed_reasoning_content = content.substr(pos) + tool_call_start; + // This start is in thinking block + if (form.allow_toolcall_in_think) { + unclosed_reasoning_content = content.substr(think_start + start_think.size()); + } else { + unclosed_reasoning_content = content.substr(think_start + start_think.size()) + tool_call_start; + } reasoning_unclosed = true; content.resize(think_start); toolcall_in_think = true; @@ -805,14 +803,35 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons } // remove potential partial suffix - if (content.size() > 0 && builder.pos() == builder.input().size() && unclosed_reasoning_content.empty()) { - rstrip(content); - trim_potential_partial_word(content); - rstrip(content); + if (builder.pos() == builder.input().size()) { + if (unclosed_reasoning_content.empty()) { + rstrip(content); + trim_potential_partial_word(content); + rstrip(content); + } else { + rstrip(unclosed_reasoning_content); + trim_potential_partial_word(unclosed_reasoning_content); + rstrip(unclosed_reasoning_content); + } + } + + // consume unclosed_reasoning_content if allow_toolcall_in_think is set + if (form.allow_toolcall_in_think && !unclosed_reasoning_content.empty()) { + if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) { + builder.add_reasoning_content(unclosed_reasoning_content); + } else { + if (content.empty()) { + content = start_think + unclosed_reasoning_content; + } else { + content += "\n\n" + start_think; + content += unclosed_reasoning_content; + } + } + unclosed_reasoning_content.clear(); } // Add content - if (content.size() != 0) { + if (!content.empty()) { // If there are multiple content blocks if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content && builder.result().content.size() != 0) { builder.add_content("\n\n"); @@ -820,7 +839,7 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons builder.add_content(content); } - // This start is in thinking block, skip this tool call + // This start is in thinking block and toolcall_in_think not set, skip this tool call if (toolcall_in_think && !form.allow_toolcall_in_think) { continue; } @@ -829,7 +848,7 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons if (!tc) { GGML_ASSERT(builder.pos() == builder.input().size()); GGML_ASSERT(unclosed_reasoning_content.empty()); - GGML_ASSERT(!reasoning_unclosed); + if (!form.allow_toolcall_in_think) GGML_ASSERT(!reasoning_unclosed); break; } @@ -854,7 +873,6 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons /** * Parse content uses reasoning and XML-Style tool call - * TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed. */ void common_chat_msg_parser::consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think, const std::string & end_think) { parse_msg_with_xml_tool_calls(*this, form, start_think, end_think); diff --git a/common/chat-parser-xml-toolcall.h b/common/chat-parser-xml-toolcall.h index 67face2b94..b309fb6670 100644 --- a/common/chat-parser-xml-toolcall.h +++ b/common/chat-parser-xml-toolcall.h @@ -31,7 +31,7 @@ struct xml_tool_call_format { std::optional last_val_end = std::nullopt; std::optional last_tool_end = std::nullopt; bool trim_raw_argval = false; - bool allow_toolcall_in_think = false; // TODO: UNTESTED!!! + bool allow_toolcall_in_think = false; }; // make a GBNF that accept any strings except those containing any of the forbidden strings. diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index fe3e80037f..d740dac065 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -917,12 +917,13 @@ static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) { form.tool_start = "<|tool_call_begin|>"; form.tool_sep = "<|tool_call_argument_begin|>{"; form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; + form.key_val_sep = "\":"; + form.val_end = ","; form.tool_end = "}<|tool_call_end|>"; form.scope_end = "<|tool_calls_section_end|>"; form.raw_argval = false; form.last_val_end = ""; + form.allow_toolcall_in_think = true; return form; })(); builder.consume_reasoning_with_xml_tool_calls(form, "", ""); diff --git a/common/chat.cpp b/common/chat.cpp index 41a5bb42d5..c371edaa5a 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1,5 +1,6 @@ #include "chat.h" #include "chat-parser.h" +#include "chat-peg-parser.h" #include "common.h" #include "json-partial.h" #include "json-schema-to-grammar.h" @@ -150,6 +151,7 @@ struct templates_params { common_chat_tool_choice tool_choice; json json_schema; bool parallel_tool_calls; + common_reasoning_format reasoning_format; bool stream; std::string grammar; bool add_generation_prompt = true; @@ -589,6 +591,16 @@ common_chat_templates_ptr common_chat_templates_init( "{%- if false %}"); } + // TODO @aldehir : this is a temporary fix, pending Minja changes + // Ref: https://github.com/ggml-org/llama.cpp/pull/17713#issuecomment-3631342664 + if (default_template_src.find("[TOOL_CALLS]") != std::string::npos + // search for the error message and patch it + && default_template_src.find("if (message['content'] is none or") != std::string::npos) { + string_replace_all(default_template_src, + "{%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %}", + "{%- if false %}"); + } + std::string token_bos = bos_token_override; std::string token_eos = eos_token_override; bool add_bos = false; @@ -987,6 +999,118 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat return data; } +static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + + // Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja + auto adjusted_messages = json::array(); + for (const auto & msg : inputs.messages) { + auto role = msg.value("role", ""); + if (role != "system" && role != "assistant") { + // Only adjust system and assistant messages. Interestingly, the system message may contain thinking. + adjusted_messages.push_back(msg); + continue; + } + + auto content = json::array(); + + // If message contains `reasoning_content`, add it as a block of type `thinking` + if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) { + content.push_back({ + {"type", "thinking"}, + {"thinking", msg.at("reasoning_content").get()}, + }); + } + + // If message contains `content`, add it as a block of type `text` + if (msg.contains("content")) { + if (msg.at("content").is_string()) { + content.push_back({ + {"type", "text"}, + {"text", msg.at("content").get()}, + }); + } else if (msg.at("content").is_array()) { + auto blocks = msg.at("content"); + content.insert(content.end(), blocks.begin(), blocks.end()); + } + } + + auto adjusted = msg; + adjusted["content"] = content; + adjusted.erase("reasoning_content"); + adjusted_messages.push_back(adjusted); + } + + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; + auto include_grammar = true; + + data.prompt = apply(tmpl, inputs, /* messages_override = */ adjusted_messages); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.preserved_tokens = { + "[THINK]", + "[/THINK]", + "[TOOL_CALLS]", + "[ARGS]", + }; + + auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { + auto reasoning = extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps(); + + // Response format parser + if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) { + // Ministral wants to emit json surrounded by code fences + return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```"; + } + + // Tool call parser + if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { + auto tool_choice = p.choice(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + const auto & schema = function.at("parameters"); + + tool_choice |= p.rule("tool-" + name, + p.tool_open(p.tool_name(p.literal(name)) + "[ARGS]") + + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)) + ); + }); + + auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; + auto max_calls = inputs.parallel_tool_calls ? -1 : 1; + auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls)); + + return reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls; + } + + // Content only parser + include_grammar = false; + return reasoning << p.content(p.rest()); + }); + + data.parser = parser.save(); + + if (include_grammar) { + data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; + + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + auto schema = function.at("parameters"); + builder.resolve_refs(schema); + }); + parser.build_grammar(builder, data.grammar_lazy); + }); + + data.grammar_triggers = { + {COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"} + }; + } + + return data; +} + static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs); @@ -2341,6 +2465,7 @@ static common_chat_params common_chat_templates_apply_jinja( params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content); params.add_generation_prompt = inputs.add_generation_prompt; params.tool_choice = inputs.tool_choice; + params.reasoning_format = inputs.reasoning_format; params.enable_thinking = inputs.enable_thinking; params.grammar = inputs.grammar; params.now = inputs.now; @@ -2504,6 +2629,13 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools); } + // Ministral/Mistral Large 3 + if (src.find("[SYSTEM_PROMPT]") != std::string::npos && + src.find("[TOOL_CALLS]") != std::string::npos && + src.find("[ARGS]") != std::string::npos) { + return common_chat_params_init_ministral_3(tmpl, params); + } + if (src.find("[THINK]") != std::string::npos && src.find("[/THINK]") != std::string::npos) { return common_chat_params_init_magistral(tmpl, params); } diff --git a/common/console.cpp b/common/console.cpp index 078a8d678d..5e9901e4a2 100644 --- a/common/console.cpp +++ b/common/console.cpp @@ -1,6 +1,11 @@ #include "console.h" #include #include +#include +#include +#include +#include +#include #if defined(_WIN32) #define WIN32_LEAN_AND_MEAN @@ -35,9 +40,26 @@ namespace console { +#if defined (_WIN32) + namespace { + // Use private-use unicode values to represent special keys that are not reported + // as characters (e.g. arrows on Windows). These values should never clash with + // real input and let the rest of the code handle navigation uniformly. + static constexpr char32_t KEY_ARROW_LEFT = 0xE000; + static constexpr char32_t KEY_ARROW_RIGHT = 0xE001; + static constexpr char32_t KEY_ARROW_UP = 0xE002; + static constexpr char32_t KEY_ARROW_DOWN = 0xE003; + static constexpr char32_t KEY_HOME = 0xE004; + static constexpr char32_t KEY_END = 0xE005; + static constexpr char32_t KEY_CTRL_ARROW_LEFT = 0xE006; + static constexpr char32_t KEY_CTRL_ARROW_RIGHT = 0xE007; + static constexpr char32_t KEY_DELETE = 0xE008; + } + // // Console state // +#endif static bool advanced_display = false; static bool simple_io = true; @@ -176,7 +198,18 @@ namespace console { if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) { wchar_t wc = record.Event.KeyEvent.uChar.UnicodeChar; if (wc == 0) { - continue; + const DWORD ctrl_mask = LEFT_CTRL_PRESSED | RIGHT_CTRL_PRESSED; + const bool ctrl_pressed = (record.Event.KeyEvent.dwControlKeyState & ctrl_mask) != 0; + switch (record.Event.KeyEvent.wVirtualKeyCode) { + case VK_LEFT: return ctrl_pressed ? KEY_CTRL_ARROW_LEFT : KEY_ARROW_LEFT; + case VK_RIGHT: return ctrl_pressed ? KEY_CTRL_ARROW_RIGHT : KEY_ARROW_RIGHT; + case VK_UP: return KEY_ARROW_UP; + case VK_DOWN: return KEY_ARROW_DOWN; + case VK_HOME: return KEY_HOME; + case VK_END: return KEY_END; + case VK_DELETE: return KEY_DELETE; + default: continue; + } } if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate @@ -315,6 +348,52 @@ namespace console { #endif } + static char32_t decode_utf8(const std::string & input, size_t pos, size_t & advance) { + unsigned char c = static_cast(input[pos]); + if ((c & 0x80u) == 0u) { + advance = 1; + return c; + } + if ((c & 0xE0u) == 0xC0u && pos + 1 < input.size()) { + unsigned char c1 = static_cast(input[pos + 1]); + if ((c1 & 0xC0u) != 0x80u) { + advance = 1; + return 0xFFFD; + } + advance = 2; + return ((c & 0x1Fu) << 6) | (static_cast(input[pos + 1]) & 0x3Fu); + } + if ((c & 0xF0u) == 0xE0u && pos + 2 < input.size()) { + unsigned char c1 = static_cast(input[pos + 1]); + unsigned char c2 = static_cast(input[pos + 2]); + if ((c1 & 0xC0u) != 0x80u || (c2 & 0xC0u) != 0x80u) { + advance = 1; + return 0xFFFD; + } + advance = 3; + return ((c & 0x0Fu) << 12) | + ((static_cast(input[pos + 1]) & 0x3Fu) << 6) | + (static_cast(input[pos + 2]) & 0x3Fu); + } + if ((c & 0xF8u) == 0xF0u && pos + 3 < input.size()) { + unsigned char c1 = static_cast(input[pos + 1]); + unsigned char c2 = static_cast(input[pos + 2]); + unsigned char c3 = static_cast(input[pos + 3]); + if ((c1 & 0xC0u) != 0x80u || (c2 & 0xC0u) != 0x80u || (c3 & 0xC0u) != 0x80u) { + advance = 1; + return 0xFFFD; + } + advance = 4; + return ((c & 0x07u) << 18) | + ((static_cast(input[pos + 1]) & 0x3Fu) << 12) | + ((static_cast(input[pos + 2]) & 0x3Fu) << 6) | + (static_cast(input[pos + 3]) & 0x3Fu); + } + + advance = 1; + return 0xFFFD; // replacement character for invalid input + } + static void append_utf8(char32_t ch, std::string & out) { if (ch <= 0x7F) { out.push_back(static_cast(ch)); @@ -336,22 +415,319 @@ namespace console { } // Helper function to remove the last UTF-8 character from a string - static void pop_back_utf8_char(std::string & line) { - if (line.empty()) { + static size_t prev_utf8_char_pos(const std::string & line, size_t pos) { + if (pos == 0) return 0; + pos--; + while (pos > 0 && (line[pos] & 0xC0) == 0x80) { + pos--; + } + return pos; + } + + static size_t next_utf8_char_pos(const std::string & line, size_t pos) { + if (pos >= line.length()) return line.length(); + pos++; + while (pos < line.length() && (line[pos] & 0xC0) == 0x80) { + pos++; + } + return pos; + } + + static void move_cursor(int delta); + static void move_word_left(size_t & char_pos, size_t & byte_pos, const std::vector & widths, const std::string & line); + static void move_word_right(size_t & char_pos, size_t & byte_pos, const std::vector & widths, const std::string & line); + static void move_to_line_start(size_t & char_pos, size_t & byte_pos, const std::vector & widths); + static void move_to_line_end(size_t & char_pos, size_t & byte_pos, const std::vector & widths, const std::string & line); + + static void delete_at_cursor(std::string & line, std::vector & widths, size_t & char_pos, size_t & byte_pos) { + if (char_pos >= widths.size()) { return; } - size_t pos = line.length() - 1; + size_t next_pos = next_utf8_char_pos(line, byte_pos); + int w = widths[char_pos]; + size_t char_len = next_pos - byte_pos; - // Find the start of the last UTF-8 character (checking up to 4 bytes back) - for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) { - if ((line[pos] & 0xC0) != 0x80) { - break; // Found the start of the character - } + line.erase(byte_pos, char_len); + widths.erase(widths.begin() + char_pos); + + size_t p = byte_pos; + int tail_width = 0; + for (size_t i = char_pos; i < widths.size(); ++i) { + size_t following = next_utf8_char_pos(line, p); + put_codepoint(line.c_str() + p, following - p, widths[i]); + tail_width += widths[i]; + p = following; } - line.erase(pos); + + for (int i = 0; i < w; ++i) { + fputc(' ', out); + } + + move_cursor(-(tail_width + w)); } + static void clear_current_line(const std::vector & widths) { + int total_width = 0; + for (int w : widths) { + total_width += (w > 0 ? w : 1); + } + + if (total_width > 0) { + std::string spaces(total_width, ' '); + fwrite(spaces.c_str(), 1, total_width, out); + move_cursor(-total_width); + } + } + + static void set_line_contents(std::string new_line, std::string & line, std::vector & widths, size_t & char_pos, + size_t & byte_pos) { + move_to_line_start(char_pos, byte_pos, widths); + clear_current_line(widths); + + line = std::move(new_line); + widths.clear(); + byte_pos = 0; + char_pos = 0; + + size_t idx = 0; + while (idx < line.size()) { + size_t advance = 0; + char32_t cp = decode_utf8(line, idx, advance); + int expected_width = estimateWidth(cp); + int real_width = put_codepoint(line.c_str() + idx, advance, expected_width); + if (real_width < 0) real_width = 0; + widths.push_back(real_width); + idx += advance; + ++char_pos; + byte_pos = idx; + } + } + + static void move_to_line_start(size_t & char_pos, size_t & byte_pos, const std::vector & widths) { + int back_width = 0; + for (size_t i = 0; i < char_pos; ++i) { + back_width += widths[i]; + } + move_cursor(-back_width); + char_pos = 0; + byte_pos = 0; + } + + static void move_to_line_end(size_t & char_pos, size_t & byte_pos, const std::vector & widths, const std::string & line) { + int forward_width = 0; + for (size_t i = char_pos; i < widths.size(); ++i) { + forward_width += widths[i]; + } + move_cursor(forward_width); + char_pos = widths.size(); + byte_pos = line.length(); + } + + static bool has_ctrl_modifier(const std::string & params) { + size_t start = 0; + while (start < params.size()) { + size_t end = params.find(';', start); + size_t len = (end == std::string::npos) ? params.size() - start : end - start; + if (len > 0) { + int value = 0; + for (size_t i = 0; i < len; ++i) { + char ch = params[start + i]; + if (!std::isdigit(static_cast(ch))) { + value = -1; + break; + } + value = value * 10 + (ch - '0'); + } + if (value == 5) { + return true; + } + } + + if (end == std::string::npos) { + break; + } + start = end + 1; + } + return false; + } + + static bool is_space_codepoint(char32_t cp) { + return std::iswspace(static_cast(cp)) != 0; + } + + static void move_word_left(size_t & char_pos, size_t & byte_pos, const std::vector & widths, const std::string & line) { + if (char_pos == 0) { + return; + } + + size_t new_char_pos = char_pos; + size_t new_byte_pos = byte_pos; + int move_width = 0; + + while (new_char_pos > 0) { + size_t prev_byte = prev_utf8_char_pos(line, new_byte_pos); + size_t advance = 0; + char32_t cp = decode_utf8(line, prev_byte, advance); + if (!is_space_codepoint(cp)) { + break; + } + move_width += widths[new_char_pos - 1]; + new_char_pos--; + new_byte_pos = prev_byte; + } + + while (new_char_pos > 0) { + size_t prev_byte = prev_utf8_char_pos(line, new_byte_pos); + size_t advance = 0; + char32_t cp = decode_utf8(line, prev_byte, advance); + if (is_space_codepoint(cp)) { + break; + } + move_width += widths[new_char_pos - 1]; + new_char_pos--; + new_byte_pos = prev_byte; + } + + move_cursor(-move_width); + char_pos = new_char_pos; + byte_pos = new_byte_pos; + } + + static void move_word_right(size_t & char_pos, size_t & byte_pos, const std::vector & widths, const std::string & line) { + if (char_pos >= widths.size()) { + return; + } + + size_t new_char_pos = char_pos; + size_t new_byte_pos = byte_pos; + int move_width = 0; + + while (new_char_pos < widths.size()) { + size_t advance = 0; + char32_t cp = decode_utf8(line, new_byte_pos, advance); + if (!is_space_codepoint(cp)) { + break; + } + move_width += widths[new_char_pos]; + new_char_pos++; + new_byte_pos += advance; + } + + while (new_char_pos < widths.size()) { + size_t advance = 0; + char32_t cp = decode_utf8(line, new_byte_pos, advance); + if (is_space_codepoint(cp)) { + break; + } + move_width += widths[new_char_pos]; + new_char_pos++; + new_byte_pos += advance; + } + + while (new_char_pos < widths.size()) { + size_t advance = 0; + char32_t cp = decode_utf8(line, new_byte_pos, advance); + if (!is_space_codepoint(cp)) { + break; + } + move_width += widths[new_char_pos]; + new_char_pos++; + new_byte_pos += advance; + } + + move_cursor(move_width); + char_pos = new_char_pos; + byte_pos = new_byte_pos; + } + + static void move_cursor(int delta) { + if (delta == 0) return; +#if defined(_WIN32) + if (hConsole != NULL) { + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + GetConsoleScreenBufferInfo(hConsole, &bufferInfo); + COORD newCursorPosition = bufferInfo.dwCursorPosition; + int width = bufferInfo.dwSize.X; + int newX = newCursorPosition.X + delta; + int newY = newCursorPosition.Y; + + while (newX >= width) { + newX -= width; + newY++; + } + while (newX < 0) { + newX += width; + newY--; + } + + newCursorPosition.X = newX; + newCursorPosition.Y = newY; + SetConsoleCursorPosition(hConsole, newCursorPosition); + } +#else + if (delta < 0) { + for (int i = 0; i < -delta; i++) fprintf(out, "\b"); + } else { + for (int i = 0; i < delta; i++) fprintf(out, "\033[C"); + } +#endif + } + + struct history_t { + std::vector entries; + size_t viewing_idx = SIZE_MAX; + std::string backup_line; // current line before viewing history + void add(const std::string & line) { + if (line.empty()) { + return; + } + // avoid duplicates with the last entry + if (entries.empty() || entries.back() != line) { + entries.push_back(line); + } + // also clear viewing state + end_viewing(); + } + bool prev(std::string & cur_line) { + if (entries.empty()) { + return false; + } + if (viewing_idx == SIZE_MAX) { + return false; + } + if (viewing_idx > 0) { + viewing_idx--; + } + cur_line = entries[viewing_idx]; + return true; + } + bool next(std::string & cur_line) { + if (entries.empty() || viewing_idx == SIZE_MAX) { + return false; + } + viewing_idx++; + if (viewing_idx >= entries.size()) { + cur_line = backup_line; + end_viewing(); + } else { + cur_line = entries[viewing_idx]; + } + return true; + } + void begin_viewing(const std::string & line) { + backup_line = line; + viewing_idx = entries.size(); + } + void end_viewing() { + viewing_idx = SIZE_MAX; + backup_line.clear(); + } + bool is_viewing() const { + return viewing_idx != SIZE_MAX; + } + } history; + static bool readline_advanced(std::string & line, bool multiline_input) { if (out != stdout) { fflush(stdout); @@ -362,8 +738,33 @@ namespace console { bool is_special_char = false; bool end_of_stream = false; + size_t byte_pos = 0; // current byte index + size_t char_pos = 0; // current character index (one char can be multiple bytes) + char32_t input_char; while (true) { + assert(char_pos <= byte_pos); + assert(char_pos <= widths.size()); + auto history_prev = [&]() { + if (!history.is_viewing()) { + history.begin_viewing(line); + } + std::string new_line; + if (!history.prev(new_line)) { + return; + } + set_line_contents(new_line, line, widths, char_pos, byte_pos); + }; + auto history_next = [&]() { + if (history.is_viewing()) { + std::string new_line; + if (!history.next(new_line)) { + return; + } + set_line_contents(new_line, line, widths, char_pos, byte_pos); + } + }; + fflush(out); // Ensure all output is displayed before waiting for input input_char = getchar32(); @@ -371,7 +772,7 @@ namespace console { break; } - if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D*/) { + if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D */) { end_of_stream = true; break; } @@ -384,7 +785,71 @@ namespace console { if (input_char == '\033') { // Escape sequence char32_t code = getchar32(); - if (code == '[' || code == 0x1B) { + if (code == '[') { + std::string params; + while (true) { + code = getchar32(); + if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~' || code == (char32_t) WEOF) { + break; + } + params.push_back(static_cast(code)); + } + + const bool ctrl_modifier = has_ctrl_modifier(params); + + if (code == 'D') { // left + if (ctrl_modifier) { + move_word_left(char_pos, byte_pos, widths, line); + } else if (char_pos > 0) { + int w = widths[char_pos - 1]; + move_cursor(-w); + char_pos--; + byte_pos = prev_utf8_char_pos(line, byte_pos); + } + } else if (code == 'C') { // right + if (ctrl_modifier) { + move_word_right(char_pos, byte_pos, widths, line); + } else if (char_pos < widths.size()) { + int w = widths[char_pos]; + move_cursor(w); + char_pos++; + byte_pos = next_utf8_char_pos(line, byte_pos); + } + } else if (code == 'H') { // home + move_to_line_start(char_pos, byte_pos, widths); + } else if (code == 'F') { // end + move_to_line_end(char_pos, byte_pos, widths, line); + } else if (code == 'A' || code == 'B') { + // up/down + if (code == 'A') { + history_prev(); + is_special_char = false; + } else if (code == 'B') { + history_next(); + is_special_char = false; + } + } else if ((code == '~' || (code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z')) && !params.empty()) { + std::string digits; + for (char ch : params) { + if (ch == ';') { + break; + } + if (std::isdigit(static_cast(ch))) { + digits.push_back(ch); + } + } + + if (code == '~') { + if (digits == "1" || digits == "7") { // home + move_to_line_start(char_pos, byte_pos, widths); + } else if (digits == "4" || digits == "8") { // end + move_to_line_end(char_pos, byte_pos, widths, line); + } else if (digits == "3") { // delete + delete_at_cursor(line, widths, char_pos, byte_pos); + } + } + } + } else if (code == 0x1B) { // Discard the rest of the escape sequence while ((code = getchar32()) != (char32_t) WEOF) { if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') { @@ -392,28 +857,107 @@ namespace console { } } } +#if defined(_WIN32) + } else if (input_char == KEY_ARROW_LEFT) { + if (char_pos > 0) { + int w = widths[char_pos - 1]; + move_cursor(-w); + char_pos--; + byte_pos = prev_utf8_char_pos(line, byte_pos); + } + } else if (input_char == KEY_ARROW_RIGHT) { + if (char_pos < widths.size()) { + int w = widths[char_pos]; + move_cursor(w); + char_pos++; + byte_pos = next_utf8_char_pos(line, byte_pos); + } + } else if (input_char == KEY_CTRL_ARROW_LEFT) { + move_word_left(char_pos, byte_pos, widths, line); + } else if (input_char == KEY_CTRL_ARROW_RIGHT) { + move_word_right(char_pos, byte_pos, widths, line); + } else if (input_char == KEY_HOME) { + move_to_line_start(char_pos, byte_pos, widths); + } else if (input_char == KEY_END) { + move_to_line_end(char_pos, byte_pos, widths, line); + } else if (input_char == KEY_DELETE) { + delete_at_cursor(line, widths, char_pos, byte_pos); + } else if (input_char == KEY_ARROW_UP || input_char == KEY_ARROW_DOWN) { + if (input_char == KEY_ARROW_UP) { + history_prev(); + is_special_char = false; + } else if (input_char == KEY_ARROW_DOWN) { + history_next(); + is_special_char = false; + } +#endif } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace - if (!widths.empty()) { - int count; - do { - count = widths.back(); - widths.pop_back(); - // Move cursor back, print space, and move cursor back again - for (int i = 0; i < count; i++) { - replace_last(' '); - pop_cursor(); - } - pop_back_utf8_char(line); - } while (count == 0 && !widths.empty()); + if (char_pos > 0) { + int w = widths[char_pos - 1]; + move_cursor(-w); + char_pos--; + size_t prev_pos = prev_utf8_char_pos(line, byte_pos); + size_t char_len = byte_pos - prev_pos; + byte_pos = prev_pos; + + // remove the character + line.erase(byte_pos, char_len); + widths.erase(widths.begin() + char_pos); + + // redraw tail + size_t p = byte_pos; + int tail_width = 0; + for (size_t i = char_pos; i < widths.size(); ++i) { + size_t next_p = next_utf8_char_pos(line, p); + put_codepoint(line.c_str() + p, next_p - p, widths[i]); + tail_width += widths[i]; + p = next_p; + } + + // clear display + for (int i = 0; i < w; ++i) { + fputc(' ', out); + } + move_cursor(-(tail_width + w)); } } else { - int offset = line.length(); - append_utf8(input_char, line); - int width = put_codepoint(line.c_str() + offset, line.length() - offset, estimateWidth(input_char)); - if (width < 0) { - width = 0; + // insert character + std::string new_char_str; + append_utf8(input_char, new_char_str); + int w = estimateWidth(input_char); + + if (char_pos == widths.size()) { + // insert at the end + line += new_char_str; + int real_w = put_codepoint(new_char_str.c_str(), new_char_str.length(), w); + if (real_w < 0) real_w = 0; + widths.push_back(real_w); + byte_pos += new_char_str.length(); + char_pos++; + } else { + // insert in middle + line.insert(byte_pos, new_char_str); + + int real_w = put_codepoint(new_char_str.c_str(), new_char_str.length(), w); + if (real_w < 0) real_w = 0; + + widths.insert(widths.begin() + char_pos, real_w); + + // print the tail + size_t p = byte_pos + new_char_str.length(); + int tail_width = 0; + for (size_t i = char_pos + 1; i < widths.size(); ++i) { + size_t next_p = next_utf8_char_pos(line, p); + put_codepoint(line.c_str() + p, next_p - p, widths[i]); + tail_width += widths[i]; + p = next_p; + } + + move_cursor(-tail_width); + + byte_pos += new_char_str.length(); + char_pos++; } - widths.push_back(width); } if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { @@ -451,6 +995,15 @@ namespace console { } } + if (!end_of_stream && !line.empty()) { + // remove the trailing newline for history storage + if (!line.empty() && line.back() == '\n') { + line.pop_back(); + } + // TODO: maybe support multiline history entries? + history.add(line); + } + fflush(out); return has_more; } diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c641989c19..867bc90531 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -383,6 +383,17 @@ class ModelBase: s = self.model_tensors[name] self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs) tensors_to_remove.append(name) + if name.endswith(".activation_scale"): # unused + tensors_to_remove.append(name) + # mistral format + if name.endswith(".qscale_weight"): + weight_name = name.removesuffix("qscale_weight") + "weight" + w = self.model_tensors[weight_name] + s = self.model_tensors[name] + self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs) + tensors_to_remove.append(name) + if name.endswith(".qscale_act"): + tensors_to_remove.append(name) elif quant_method == "gptq": for name in self.model_tensors.keys(): if name.endswith(".qweight"): @@ -2854,13 +2865,10 @@ class Mistral3Model(LlamaModel): self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): - # TODO: probably not worth supporting quantized weight, as official BF16 is also available - if name.endswith("weight_scale_inv"): - raise ValueError("This is a quantized weight, please use BF16 weight instead") - name = name.replace("language_model.", "") if "multi_modal_projector" in name or "vision_tower" in name: return [] + return super().modify_tensors(data_torch, name, bid) @@ -5825,9 +5833,11 @@ class Gemma3Model(TextModel): norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value def set_vocab(self): - self._set_vocab_sentencepiece() - - self.gguf_writer.add_add_space_prefix(False) + if (self.dir_model / "tokenizer.model").is_file(): + self._set_vocab_sentencepiece() + self.gguf_writer.add_add_space_prefix(False) + else: + self._set_vocab_gpt2() def set_gguf_parameters(self): hparams = self.hparams @@ -5845,13 +5855,24 @@ class Gemma3Model(TextModel): self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers # attn_logit_softcapping is removed in Gemma3 assert hparams.get("attn_logit_softcapping") is None - self.gguf_writer.add_sliding_window(hparams["sliding_window"]) + if (final_logit_softcap := hparams.get("final_logit_softcapping")): + self.gguf_writer.add_final_logit_softcapping(final_logit_softcap) + if hparams.get("sliding_window_pattern") != 1: + self.gguf_writer.add_sliding_window(hparams["sliding_window"]) self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4)) if hparams.get("rope_scaling") is not None: - assert hparams["rope_scaling"]["rope_type"] == "linear" - # important: this rope_scaling is only applied for global layers, and not used by 1B model - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) - self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"]) + rope_scaling = hparams["rope_scaling"] + if rope_scaling["rope_type"] == "linear": + # important: this rope_scaling is only applied for global layers, and not used by 1B model + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + elif rope_scaling["rope_type"] == "yarn": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_yarn_ext_factor(rope_scaling["extrapolation_factor"]) + self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_scaling["beta_fast"]) + self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_scaling["beta_slow"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused @@ -5865,8 +5886,10 @@ class Gemma3Model(TextModel): # remove OOV (out-of-vocabulary) rows in token_embd if "embed_tokens.weight" in name: - vocab = self._create_vocab_sentencepiece() - tokens = vocab[0] + if (self.dir_model / "tokenizer.model").is_file(): + tokens = self._create_vocab_sentencepiece()[0] + else: + tokens = self.get_vocab_base()[0] data_torch = data_torch[:len(tokens)] # ref code in Gemma3RMSNorm @@ -9883,6 +9906,18 @@ class MistralModel(LlamaModel): self.gguf_writer.add_architecture() self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + def dequant_model(self): + # transform quantization config into HF format + quant_config = self.hparams.get("quantization") + if quant_config is not None: + assert quant_config["qformat_weight"] == "fp8_e4m3" + self.hparams["quantization_config"] = { + "activation_scheme": "static", + "quant_method": "fp8", + "weight_block_size": None, + } + return super().dequant_model() + @staticmethod def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool): assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg diff --git a/docs/build-riscv64-spacemit.md b/docs/build-riscv64-spacemit.md index eaa6532546..79bd4de63a 100644 --- a/docs/build-riscv64-spacemit.md +++ b/docs/build-riscv64-spacemit.md @@ -19,6 +19,7 @@ cmake -B build \ -DGGML_RVV=ON \ -DGGML_RV_ZFH=ON \ -DGGML_RV_ZICBOP=ON \ + -DGGML_RV_ZIHINTPAUSE=ON \ -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 \ -DCMAKE_TOOLCHAIN_FILE=${PWD}/cmake/riscv64-spacemit-linux-gnu-gcc.cmake \ -DCMAKE_INSTALL_PREFIX=build/installed diff --git a/docs/ops.md b/docs/ops.md index ef1febccae..6ede4a4f64 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -36,14 +36,15 @@ Legend: | CPY | ❌ | 🟑 | 🟑 | 🟑 | 🟑 | 🟑 | 🟑 | 🟑 | 🟑 | ❌ | ❌ | | CROSS_ENTROPY_LOSS | ❌ | ❌ | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| CUMSUM | ❌ | ❌ | βœ… | ❌ | βœ… | ❌ | ❌ | βœ… | ❌ | ❌ | ❌ | +| CUMSUM | ❌ | ❌ | βœ… | βœ… | βœ… | ❌ | ❌ | βœ… | ❌ | ❌ | ❌ | +| DIAG | ❌ | ❌ | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | DIAG_MASK_INF | ❌ | βœ… | βœ… | βœ… | ❌ | 🟑 | βœ… | βœ… | ❌ | ❌ | ❌ | | DIV | ❌ | βœ… | βœ… | βœ… | 🟑 | 🟑 | βœ… | βœ… | βœ… | ❌ | ❌ | | DUP | ❌ | βœ… | βœ… | 🟑 | 🟑 | 🟑 | βœ… | βœ… | ❌ | ❌ | ❌ | | ELU | ❌ | βœ… | βœ… | 🟑 | 🟑 | ❌ | βœ… | ❌ | βœ… | ❌ | ❌ | | EXP | ❌ | βœ… | βœ… | 🟑 | 🟑 | ❌ | βœ… | 🟑 | βœ… | ❌ | ❌ | | EXPM1 | ❌ | ❌ | βœ… | 🟑 | 🟑 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| FILL | ❌ | ❌ | βœ… | ❌ | βœ… | ❌ | ❌ | βœ… | ❌ | ❌ | ❌ | +| FILL | ❌ | ❌ | βœ… | βœ… | βœ… | ❌ | ❌ | βœ… | ❌ | ❌ | ❌ | | FLASH_ATTN_EXT | ❌ | 🟑 | βœ… | 🟑 | 🟑 | ❌ | ❌ | 🟑 | ❌ | ❌ | ❌ | | FLOOR | ❌ | ❌ | βœ… | 🟑 | ❌ | ❌ | 🟑 | 🟑 | ❌ | ❌ | ❌ | | GATED_LINEAR_ATTN | ❌ | ❌ | βœ… | βœ… | ❌ | ❌ | βœ… | ❌ | ❌ | ❌ | ❌ | @@ -102,7 +103,7 @@ Legend: | SOFTPLUS | ❌ | ❌ | βœ… | 🟑 | 🟑 | ❌ | ❌ | 🟑 | ❌ | ❌ | ❌ | | SOFT_MAX | ❌ | 🟑 | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | ❌ | | SOFT_MAX_BACK | ❌ | ❌ | 🟑 | 🟑 | ❌ | ❌ | 🟑 | βœ… | ❌ | ❌ | ❌ | -| SOLVE_TRI | ❌ | ❌ | βœ… | ❌ | ❌ | ❌ | ❌ | 🟑 | ❌ | ❌ | ❌ | +| SOLVE_TRI | ❌ | ❌ | βœ… | 🟑 | ❌ | ❌ | ❌ | 🟑 | ❌ | ❌ | ❌ | | SQR | ❌ | βœ… | βœ… | βœ… | 🟑 | ❌ | 🟑 | 🟑 | ❌ | ❌ | ❌ | | SQRT | ❌ | βœ… | βœ… | βœ… | 🟑 | ❌ | 🟑 | 🟑 | ❌ | ❌ | ❌ | | SSM_CONV | ❌ | ❌ | βœ… | βœ… | βœ… | ❌ | βœ… | βœ… | ❌ | ❌ | ❌ | @@ -115,8 +116,8 @@ Legend: | SWIGLU_OAI | ❌ | ❌ | βœ… | βœ… | βœ… | ❌ | ❌ | 🟑 | βœ… | ❌ | ❌ | | TANH | ❌ | βœ… | βœ… | 🟑 | 🟑 | βœ… | βœ… | 🟑 | βœ… | ❌ | ❌ | | TIMESTEP_EMBEDDING | ❌ | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | ❌ | ❌ | -| TOP_K | ❌ | ❌ | ❌ | ❌ | βœ… | ❌ | ❌ | 🟑 | ❌ | ❌ | ❌ | -| TRI | ❌ | ❌ | βœ… | ❌ | βœ… | ❌ | ❌ | βœ… | ❌ | ❌ | ❌ | +| TOP_K | ❌ | ❌ | βœ… | ❌ | βœ… | ❌ | ❌ | 🟑 | ❌ | ❌ | ❌ | +| TRI | ❌ | ❌ | βœ… | βœ… | βœ… | ❌ | ❌ | βœ… | ❌ | ❌ | ❌ | | TRUNC | ❌ | ❌ | βœ… | 🟑 | ❌ | ❌ | 🟑 | 🟑 | ❌ | ❌ | ❌ | | UPSCALE | ❌ | 🟑 | βœ… | βœ… | 🟑 | βœ… | 🟑 | 🟑 | ❌ | ❌ | ❌ | | XIELU | ❌ | ❌ | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | βœ… | ❌ | ❌ | diff --git a/docs/ops/CPU.csv b/docs/ops/CPU.csv index 7c223e279c..fef3bbce70 100644 --- a/docs/ops/CPU.csv +++ b/docs/ops/CPU.csv @@ -4964,6 +4964,7 @@ "CPU","CONV_TRANSPOSE_1D","ne_input=[2,1,1,1],ne_kernel=[3,1,1,1],s0=1,p0=0,d0=1","support","1","yes","CPU" "CPU","CONV_TRANSPOSE_2D","ne_input=[3,2,3,1],ne_kernel=[2,2,1,3],stride=1","support","1","yes","CPU" "CPU","CONV_TRANSPOSE_2D","ne_input=[10,10,9,1],ne_kernel=[3,3,1,9],stride=2","support","1","yes","CPU" +"CPU","CONV_TRANSPOSE_2D","ne_input=[129,63,35,1],ne_kernel=[3,3,48,35],stride=1","support","1","yes","CPU" "CPU","COUNT_EQUAL","type=f32,ne=[4,500,1,1]","support","1","yes","CPU" "CPU","COUNT_EQUAL","type=f32,ne=[4,5000,1,1]","support","1","yes","CPU" "CPU","ARGMAX","type=f32,ne=[32,1,1,1]","support","1","yes","CPU" @@ -5419,17 +5420,45 @@ "CPU","CPY","type_src=f16,type_dst=f16,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","CPU" "CPU","CPY","type_src=f32,type_dst=f32,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","CPU" "CPU","CPY","type_src=bf16,type_dst=bf16,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","CPU" +"CPU","CPY","type_src=i32,type_dst=i32,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","CPU" +"CPU","CPY","type_src=i32,type_dst=i32,ne=[256,1,4,1],permute_src=[1,2,0,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","CPU" "CPU","CPY","type_src=f32,type_dst=f32,ne=[256,1,4,1],permute_src=[1,2,0,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","CPU" -"CPU","CONT","type=f32,ne=[10,10,10,1]","support","1","yes","CPU" -"CPU","CONT","type=f32,ne=[2,1,1,1]","support","1","yes","CPU" -"CPU","CONT","type=f32,ne=[2,1,3,5]","support","1","yes","CPU" -"CPU","CONT","type=f32,ne=[2,3,5,7]","support","1","yes","CPU" -"CPU","CONT","type=f16,ne=[2,1,1,1]","support","1","yes","CPU" -"CPU","CONT","type=f16,ne=[2,1,3,5]","support","1","yes","CPU" -"CPU","CONT","type=f16,ne=[2,3,5,7]","support","1","yes","CPU" -"CPU","CONT","type=bf16,ne=[2,1,1,1]","support","1","yes","CPU" -"CPU","CONT","type=bf16,ne=[2,1,3,5]","support","1","yes","CPU" -"CPU","CONT","type=bf16,ne=[2,3,5,7]","support","1","yes","CPU" +"CPU","CONT","type=f32,ne=[2,1,1,1],use_view_slice=1","support","1","yes","CPU" +"CPU","CONT","type=f32,ne=[2,1,3,5],use_view_slice=1","support","1","yes","CPU" +"CPU","CONT","type=f32,ne=[2,3,5,7],use_view_slice=1","support","1","yes","CPU" +"CPU","CONT","type=f32,ne=[1,4,4,1],use_view_slice=1","support","1","yes","CPU" +"CPU","CONT","type=f32,ne=[1,8,17,1],use_view_slice=1","support","1","yes","CPU" +"CPU","CONT","type=f32,ne=[10,10,10,1],use_view_slice=1","support","1","yes","CPU" +"CPU","CONT","type=f32,ne=[2,1,1,1],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=f32,ne=[2,1,3,5],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=f32,ne=[2,3,5,7],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=f32,ne=[1,4,4,1],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=f32,ne=[1,8,17,1],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=f32,ne=[10,10,10,1],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=i32,ne=[2,1,1,1],use_view_slice=1","support","1","yes","CPU" +"CPU","CONT","type=i32,ne=[2,1,3,5],use_view_slice=1","support","1","yes","CPU" +"CPU","CONT","type=i32,ne=[2,3,5,7],use_view_slice=1","support","1","yes","CPU" +"CPU","CONT","type=i32,ne=[1,4,4,1],use_view_slice=1","support","1","yes","CPU" +"CPU","CONT","type=i32,ne=[1,8,17,1],use_view_slice=1","support","1","yes","CPU" +"CPU","CONT","type=i32,ne=[10,10,10,1],use_view_slice=1","support","1","yes","CPU" +"CPU","CONT","type=i32,ne=[2,1,1,1],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=i32,ne=[2,1,3,5],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=i32,ne=[2,3,5,7],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=i32,ne=[1,4,4,1],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=i32,ne=[1,8,17,1],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=i32,ne=[10,10,10,1],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=f16,ne=[2,1,1,1],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=f16,ne=[2,1,3,5],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=f16,ne=[2,3,5,7],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=f16,ne=[1,4,4,1],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=f16,ne=[1,8,17,1],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=f16,ne=[10,10,10,1],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=bf16,ne=[2,1,1,1],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=bf16,ne=[2,1,3,5],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=bf16,ne=[2,3,5,7],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=bf16,ne=[1,4,4,1],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=bf16,ne=[1,8,17,1],use_view_slice=0","support","1","yes","CPU" +"CPU","CONT","type=bf16,ne=[10,10,10,1],use_view_slice=0","support","1","yes","CPU" "CPU","ADD","type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1","support","1","yes","CPU" "CPU","SUB","type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1","support","1","yes","CPU" "CPU","MUL","type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1","support","1","yes","CPU" @@ -5655,6 +5684,7 @@ "CPU","MUL","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","CPU" "CPU","DIV","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","CPU" "CPU","ADD1","type=f32,ne=[10,5,4,3]","support","1","yes","CPU" +"CPU","ADD1","type=f32,ne=[1024,1024,1,1]","support","1","yes","CPU" "CPU","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=0.000000,inplace=0","support","1","yes","CPU" "CPU","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=0","support","1","yes","CPU" "CPU","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=1","support","1","yes","CPU" @@ -8644,9 +8674,13 @@ "CPU","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","CPU" "CPU","LEAKY_RELU","type=f16,ne_a=[7,1,5,3],negative_slope=0.100000","support","1","yes","CPU" "CPU","FLOOR","type=f16,ne=[7,1,5,3]","support","1","yes","CPU" +"CPU","FLOOR","type=f16,ne=[1024,1024,1,1]","support","1","yes","CPU" "CPU","CEIL","type=f16,ne=[7,1,5,3]","support","1","yes","CPU" +"CPU","CEIL","type=f16,ne=[1024,1024,1,1]","support","1","yes","CPU" "CPU","ROUND","type=f16,ne=[7,1,5,3]","support","1","yes","CPU" +"CPU","ROUND","type=f16,ne=[1024,1024,1,1]","support","1","yes","CPU" "CPU","TRUNC","type=f16,ne=[7,1,5,3]","support","1","yes","CPU" +"CPU","TRUNC","type=f16,ne=[1024,1024,1,1]","support","1","yes","CPU" "CPU","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","CPU" "CPU","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","CPU" "CPU","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","CPU" @@ -8666,9 +8700,13 @@ "CPU","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","CPU" "CPU","LEAKY_RELU","type=f32,ne_a=[7,1,5,3],negative_slope=0.100000","support","1","yes","CPU" "CPU","FLOOR","type=f32,ne=[7,1,5,3]","support","1","yes","CPU" +"CPU","FLOOR","type=f32,ne=[1024,1024,1,1]","support","1","yes","CPU" "CPU","CEIL","type=f32,ne=[7,1,5,3]","support","1","yes","CPU" +"CPU","CEIL","type=f32,ne=[1024,1024,1,1]","support","1","yes","CPU" "CPU","ROUND","type=f32,ne=[7,1,5,3]","support","1","yes","CPU" +"CPU","ROUND","type=f32,ne=[1024,1024,1,1]","support","1","yes","CPU" "CPU","TRUNC","type=f32,ne=[7,1,5,3]","support","1","yes","CPU" +"CPU","TRUNC","type=f32,ne=[1024,1024,1,1]","support","1","yes","CPU" "CPU","DIAG_MASK_INF","type=f32,ne=[10,10,1,1],n_past=5","support","1","yes","CPU" "CPU","DIAG_MASK_INF","type=f32,ne=[10,10,3,1],n_past=5","support","1","yes","CPU" "CPU","DIAG_MASK_INF","type=f32,ne=[10,10,3,2],n_past=5","support","1","yes","CPU" @@ -9411,18 +9449,405 @@ "CPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","1","yes","CPU" "CPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","CPU" "CPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[3,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[4,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[7,1,1,1],order=0","support","1","yes","CPU" "CPU","ARGSORT","type=f32,ne=[8,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[15,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[16,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[31,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[32,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[63,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[64,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[127,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[128,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[255,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[256,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[511,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[512,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[1023,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[1024,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[2047,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[2048,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[4095,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[4096,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[8191,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[8192,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[16383,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[16384,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[32767,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[32768,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[65535,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[65536,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[131071,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[131072,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[262143,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[262144,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[524287,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[524288,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[1048575,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[1048576,1,1,1],order=0","support","1","yes","CPU" "CPU","ARGSORT","type=f32,ne=[16,10,10,10],order=0","support","1","yes","CPU" "CPU","ARGSORT","type=f32,ne=[60,10,10,10],order=0","support","1","yes","CPU" -"CPU","ARGSORT","type=f32,ne=[1024,1,1,1],order=0","support","1","yes","CPU" -"CPU","ARGSORT","type=f32,ne=[16384,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[1023,2,1,3],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[1024,2,1,3],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[1025,2,1,3],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[2047,2,1,3],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[2048,2,1,3],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[2049,2,1,3],order=0","support","1","yes","CPU" "CPU","ARGSORT","type=f32,ne=[2,8,8192,1],order=0","support","1","yes","CPU" -"CPU","ARGSORT","type=f32,ne=[8,1,1,1],order=1","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[3,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[4,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[7,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[8,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[15,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[16,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[31,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[32,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[63,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[64,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[127,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[128,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[255,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[256,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[511,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[512,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[1023,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[1024,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[2047,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[2048,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[4095,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[4096,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[8191,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[8192,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[16383,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[16384,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[32767,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[32768,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[65535,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[65536,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[131071,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[131072,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[262143,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[262144,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[524287,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[524288,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[1048575,1,1,1],order=0","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[1048576,1,1,1],order=0","support","1","yes","CPU" "CPU","ARGSORT","type=f32,ne=[16,10,10,10],order=1","support","1","yes","CPU" "CPU","ARGSORT","type=f32,ne=[60,10,10,10],order=1","support","1","yes","CPU" -"CPU","ARGSORT","type=f32,ne=[1024,1,1,1],order=1","support","1","yes","CPU" -"CPU","ARGSORT","type=f32,ne=[16384,1,1,1],order=1","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[1023,2,1,3],order=1","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[1024,2,1,3],order=1","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[1025,2,1,3],order=1","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[2047,2,1,3],order=1","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[2048,2,1,3],order=1","support","1","yes","CPU" +"CPU","ARGSORT","type=f32,ne=[2049,2,1,3],order=1","support","1","yes","CPU" "CPU","ARGSORT","type=f32,ne=[2,8,8192,1],order=1","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[12,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[13,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[13,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[15,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[15,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[15,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[19,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[19,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[19,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[19,1,2,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[27,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[27,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[27,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[27,1,2,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16,1,1,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[27,1,2,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[43,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[43,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[43,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[43,1,2,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32,1,1,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[43,1,2,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[64,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[75,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[64,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[75,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[64,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[75,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[64,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[75,1,2,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[64,1,1,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[75,1,2,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[128,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[139,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[128,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[139,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[128,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[139,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[128,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[139,1,2,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[128,1,1,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[139,1,2,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[128,1,1,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[139,1,2,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[256,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[267,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[256,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[267,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[256,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[267,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[256,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[267,1,2,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[256,1,1,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[267,1,2,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[256,1,1,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[267,1,2,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[512,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[523,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[512,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[523,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[512,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[523,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[512,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[523,1,2,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[512,1,1,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[523,1,2,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[512,1,1,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[523,1,2,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[512,1,1,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[523,1,2,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1024,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1035,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1024,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1035,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1024,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1035,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1024,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1035,1,2,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1024,1,1,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1035,1,2,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1024,1,1,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1035,1,2,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1024,1,1,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1035,1,2,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1024,1,1,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1035,1,2,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2048,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2059,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2048,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2059,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2048,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2059,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2048,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2059,1,2,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2048,1,1,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2059,1,2,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2048,1,1,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2059,1,2,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2048,1,1,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2059,1,2,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2048,1,1,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2059,1,2,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4096,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4107,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4096,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4107,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4096,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4107,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4096,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4107,1,2,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4096,1,1,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4107,1,2,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4096,1,1,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4107,1,2,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4096,1,1,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4107,1,2,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4096,1,1,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[4107,1,2,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8192,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8203,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8192,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8203,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8192,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8203,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8192,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8203,1,2,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8192,1,1,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8203,1,2,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8192,1,1,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8203,1,2,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8192,1,1,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8203,1,2,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8192,1,1,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[8203,1,2,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16384,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16395,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16384,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16395,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16384,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16395,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16384,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16395,1,2,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16384,1,1,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16395,1,2,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16384,1,1,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16395,1,2,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16384,1,1,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16395,1,2,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16384,1,1,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16395,1,2,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16384,1,1,1],k=9999,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16395,1,2,1],k=9999,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32768,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32779,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32768,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32779,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32768,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32779,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32768,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32779,1,2,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32768,1,1,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32779,1,2,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32768,1,1,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32779,1,2,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32768,1,1,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32779,1,2,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32768,1,1,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32779,1,2,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32768,1,1,1],k=9999,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[32779,1,2,1],k=9999,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65536,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65547,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65536,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65547,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65536,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65547,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65536,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65547,1,2,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65536,1,1,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65547,1,2,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65536,1,1,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65547,1,2,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65536,1,1,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65547,1,2,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65536,1,1,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65547,1,2,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65536,1,1,1],k=9999,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[65547,1,2,1],k=9999,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131072,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131083,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131072,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131083,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131072,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131083,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131072,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131083,1,2,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131072,1,1,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131083,1,2,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131072,1,1,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131083,1,2,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131072,1,1,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131083,1,2,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131072,1,1,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131083,1,2,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131072,1,1,1],k=9999,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[131083,1,2,1],k=9999,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262144,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262155,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262144,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262155,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262144,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262155,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262144,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262155,1,2,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262144,1,1,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262155,1,2,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262144,1,1,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262155,1,2,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262144,1,1,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262155,1,2,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262144,1,1,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262155,1,2,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262144,1,1,1],k=9999,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[262155,1,2,1],k=9999,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524288,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524299,1,2,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524288,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524299,1,2,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524288,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524299,1,2,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524288,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524299,1,2,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524288,1,1,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524299,1,2,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524288,1,1,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524299,1,2,1],k=100,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524288,1,1,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524299,1,2,1],k=500,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524288,1,1,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524299,1,2,1],k=1023,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524288,1,1,1],k=9999,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[524299,1,2,1],k=9999,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16,10,10,10],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[60,10,10,10],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1023,2,1,3],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1024,2,1,3],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1025,2,1,3],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16384,1,1,1],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2047,2,1,3],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2048,2,1,3],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2049,2,1,3],k=1,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16,10,10,10],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[60,10,10,10],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1023,2,1,3],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1024,2,1,3],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1025,2,1,3],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16384,1,1,1],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2047,2,1,3],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2048,2,1,3],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2049,2,1,3],k=2,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16,10,10,10],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[60,10,10,10],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1023,2,1,3],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1024,2,1,3],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1025,2,1,3],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16384,1,1,1],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2047,2,1,3],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2048,2,1,3],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2049,2,1,3],k=3,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16,10,10,10],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[60,10,10,10],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1023,2,1,3],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1024,2,1,3],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1025,2,1,3],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16384,1,1,1],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2047,2,1,3],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2048,2,1,3],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2049,2,1,3],k=7,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16,10,10,10],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[60,10,10,10],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1023,2,1,3],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1024,2,1,3],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[1025,2,1,3],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[16384,1,1,1],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2047,2,1,3],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2048,2,1,3],k=15,ties=0","support","1","yes","CPU" +"CPU","TOP_K","type=f32,ne=[2049,2,1,3],k=15,ties=0","support","1","yes","CPU" "CPU","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=0","support","1","yes","CPU" "CPU","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=1","support","1","yes","CPU" "CPU","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=nearest,flags=none","support","1","yes","CPU" @@ -9435,6 +9860,10 @@ "CPU","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bicubic,transpose=1","support","1","yes","CPU" "CPU","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bicubic,flags=none","support","1","yes","CPU" "CPU","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bicubic,flags=none","support","1","yes","CPU" +"CPU","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=513,transpose=0","support","1","yes","CPU" +"CPU","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=513,transpose=1","support","1","yes","CPU" +"CPU","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear,flags=none","support","1","yes","CPU" +"CPU","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bilinear,flags=none","support","1","yes","CPU" "CPU","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear,flags=align_corners","support","1","yes","CPU" "CPU","UPSCALE","type=f32,ne=[1,4,3,2],ne_tgt=[2,8,3,2],mode=bilinear,flags=align_corners","support","1","yes","CPU" "CPU","UPSCALE","type=f32,ne=[4,1,3,2],ne_tgt=[1,1,3,2],mode=bilinear,flags=align_corners","support","1","yes","CPU" @@ -9463,15 +9892,30 @@ "CPU","GROUP_NORM","type=f32,ne=[64,64,320,1],num_groups=32,eps=0.000001","support","1","yes","CPU" "CPU","GROUP_NORM","type=f32,ne=[9,9,1280,1],num_groups=32,eps=0.000001","support","1","yes","CPU" "CPU","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1]","support","1","yes","CPU" -"CPU","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1","support","1","yes","CPU" -"CPU","PAD","type=f32,ne_a=[512,512,3,1],lp0=1,rp0=1,lp1=1,rp1=1,lp2=1,rp2=1,lp3=1,rp3=1,v=0","support","1","yes","CPU" +"CPU","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1,circular=0","support","1","yes","CPU" +"CPU","PAD","type=f32,ne_a=[33,17,2,1],pad_0=4,pad_1=3,circular=1","support","1","yes","CPU" +"CPU","PAD","type=f32,ne_a=[512,512,3,1],lp0=1,rp0=1,lp1=1,rp1=1,lp2=1,rp2=1,lp3=1,rp3=1,v=0,circular=0","support","1","yes","CPU" "CPU","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","1","yes","CPU" "CPU","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","1","yes","CPU" "CPU","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","1","yes","CPU" "CPU","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","1","yes","CPU" +"CPU","ARANGE","type=f32,start=0.000000,stop=1048576.000000,step=1.000000","support","1","yes","CPU" "CPU","TIMESTEP_EMBEDDING","type=f32,ne_a=[2,1,1,1],dim=320,max_period=10000","support","1","yes","CPU" "CPU","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","CPU" "CPU","CUMSUM","type=f32,ne=[10,5,4,3]","support","1","yes","CPU" +"CPU","CUMSUM","type=f32,ne=[127,5,4,3]","support","1","yes","CPU" +"CPU","CUMSUM","type=f32,ne=[128,5,4,3]","support","1","yes","CPU" +"CPU","CUMSUM","type=f32,ne=[128,128,4,4]","support","1","yes","CPU" +"CPU","CUMSUM","type=f32,ne=[255,5,4,3]","support","1","yes","CPU" +"CPU","CUMSUM","type=f32,ne=[256,5,4,3]","support","1","yes","CPU" +"CPU","CUMSUM","type=f32,ne=[511,5,4,3]","support","1","yes","CPU" +"CPU","CUMSUM","type=f32,ne=[512,5,4,3]","support","1","yes","CPU" +"CPU","CUMSUM","type=f32,ne=[1023,5,4,3]","support","1","yes","CPU" +"CPU","CUMSUM","type=f32,ne=[1024,5,4,3]","support","1","yes","CPU" +"CPU","CUMSUM","type=f32,ne=[2047,5,4,3]","support","1","yes","CPU" +"CPU","CUMSUM","type=f32,ne=[2048,5,4,3]","support","1","yes","CPU" +"CPU","CUMSUM","type=f32,ne=[242004,1,1,1]","support","1","yes","CPU" +"CPU","CUMSUM","type=f32,ne=[375960,1,1,1]","support","1","yes","CPU" "CPU","XIELU","type=f32,ne=[10,5,4,3]","support","1","yes","CPU" "CPU","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","1","yes","CPU" "CPU","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","1","yes","CPU" @@ -9480,6 +9924,10 @@ "CPU","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","1","yes","CPU" "CPU","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","CPU" "CPU","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","CPU" +"CPU","FILL","type=f32,ne=[2048,512,2,2],c=3.500000","support","1","yes","CPU" +"CPU","DIAG","type=f32,ne=[10,1,4,3]","support","1","yes","CPU" +"CPU","DIAG","type=f32,ne=[79,1,19,13]","support","1","yes","CPU" +"CPU","DIAG","type=f32,ne=[256,1,8,16]","support","1","yes","CPU" "CPU","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","1","yes","CPU" "CPU","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","1","yes","CPU" "CPU","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","1","yes","CPU" @@ -9487,10 +9935,16 @@ "CPU","SOLVE_TRI","type=f32,ne_lhs=[42,42,5,2],ne_rhs=[10,42,5,2]","support","1","yes","CPU" "CPU","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]","support","1","yes","CPU" "CPU","SOLVE_TRI","type=f32,ne_lhs=[100,100,4,4],ne_rhs=[41,100,4,4]","support","1","yes","CPU" -"CPU","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0","support","1","yes","CPU" -"CPU","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0","support","1","yes","CPU" -"CPU","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1","support","1","yes","CPU" -"CPU","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1","support","1","yes","CPU" +"CPU","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,4],ne_rhs=[31,128,4,4]","support","1","yes","CPU" +"CPU","SOLVE_TRI","type=f32,ne_lhs=[64,64,4,4],ne_rhs=[300,64,4,4]","support","1","yes","CPU" +"CPU","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0,circular=0","support","1","yes","CPU" +"CPU","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0,circular=0","support","1","yes","CPU" +"CPU","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0,circular=1","support","1","yes","CPU" +"CPU","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0,circular=1","support","1","yes","CPU" +"CPU","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1,circular=0","support","1","yes","CPU" +"CPU","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1,circular=0","support","1","yes","CPU" +"CPU","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1,circular=1","support","1","yes","CPU" +"CPU","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1,circular=1","support","1","yes","CPU" "CPU","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","1","yes","CPU" "CPU","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","1","yes","CPU" "CPU","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","1","yes","CPU" diff --git a/docs/ops/CUDA.csv b/docs/ops/CUDA.csv index 56ba32273a..22c84dd143 100644 --- a/docs/ops/CUDA.csv +++ b/docs/ops/CUDA.csv @@ -4964,6 +4964,7 @@ "CUDA0","CONV_TRANSPOSE_1D","ne_input=[2,1,1,1],ne_kernel=[3,1,1,1],s0=1,p0=0,d0=1","support","1","yes","CUDA" "CUDA0","CONV_TRANSPOSE_2D","ne_input=[3,2,3,1],ne_kernel=[2,2,1,3],stride=1","support","1","yes","CUDA" "CUDA0","CONV_TRANSPOSE_2D","ne_input=[10,10,9,1],ne_kernel=[3,3,1,9],stride=2","support","1","yes","CUDA" +"CUDA0","CONV_TRANSPOSE_2D","ne_input=[129,63,35,1],ne_kernel=[3,3,48,35],stride=1","support","1","yes","CUDA" "CUDA0","COUNT_EQUAL","type=f32,ne=[4,500,1,1]","support","1","yes","CUDA" "CUDA0","COUNT_EQUAL","type=f32,ne=[4,5000,1,1]","support","1","yes","CUDA" "CUDA0","ARGMAX","type=f32,ne=[32,1,1,1]","support","1","yes","CUDA" @@ -5419,17 +5420,45 @@ "CUDA0","CPY","type_src=f16,type_dst=f16,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","CUDA" "CUDA0","CPY","type_src=f32,type_dst=f32,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","CUDA" "CUDA0","CPY","type_src=bf16,type_dst=bf16,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","CUDA" +"CUDA0","CPY","type_src=i32,type_dst=i32,ne=[256,4,1,1],permute_src=[0,0,0,0],permute_dst=[0,0,0,0],_src_transpose=1","support","1","yes","CUDA" +"CUDA0","CPY","type_src=i32,type_dst=i32,ne=[256,1,4,1],permute_src=[1,2,0,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","CUDA" "CUDA0","CPY","type_src=f32,type_dst=f32,ne=[256,1,4,1],permute_src=[1,2,0,3],permute_dst=[0,0,0,0],_src_transpose=0","support","1","yes","CUDA" -"CUDA0","CONT","type=f32,ne=[10,10,10,1]","support","1","yes","CUDA" -"CUDA0","CONT","type=f32,ne=[2,1,1,1]","support","1","yes","CUDA" -"CUDA0","CONT","type=f32,ne=[2,1,3,5]","support","1","yes","CUDA" -"CUDA0","CONT","type=f32,ne=[2,3,5,7]","support","1","yes","CUDA" -"CUDA0","CONT","type=f16,ne=[2,1,1,1]","support","1","yes","CUDA" -"CUDA0","CONT","type=f16,ne=[2,1,3,5]","support","1","yes","CUDA" -"CUDA0","CONT","type=f16,ne=[2,3,5,7]","support","1","yes","CUDA" -"CUDA0","CONT","type=bf16,ne=[2,1,1,1]","support","1","yes","CUDA" -"CUDA0","CONT","type=bf16,ne=[2,1,3,5]","support","1","yes","CUDA" -"CUDA0","CONT","type=bf16,ne=[2,3,5,7]","support","1","yes","CUDA" +"CUDA0","CONT","type=f32,ne=[2,1,1,1],use_view_slice=1","support","1","yes","CUDA" +"CUDA0","CONT","type=f32,ne=[2,1,3,5],use_view_slice=1","support","1","yes","CUDA" +"CUDA0","CONT","type=f32,ne=[2,3,5,7],use_view_slice=1","support","1","yes","CUDA" +"CUDA0","CONT","type=f32,ne=[1,4,4,1],use_view_slice=1","support","1","yes","CUDA" +"CUDA0","CONT","type=f32,ne=[1,8,17,1],use_view_slice=1","support","1","yes","CUDA" +"CUDA0","CONT","type=f32,ne=[10,10,10,1],use_view_slice=1","support","1","yes","CUDA" +"CUDA0","CONT","type=f32,ne=[2,1,1,1],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=f32,ne=[2,1,3,5],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=f32,ne=[2,3,5,7],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=f32,ne=[1,4,4,1],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=f32,ne=[1,8,17,1],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=f32,ne=[10,10,10,1],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=i32,ne=[2,1,1,1],use_view_slice=1","support","1","yes","CUDA" +"CUDA0","CONT","type=i32,ne=[2,1,3,5],use_view_slice=1","support","1","yes","CUDA" +"CUDA0","CONT","type=i32,ne=[2,3,5,7],use_view_slice=1","support","1","yes","CUDA" +"CUDA0","CONT","type=i32,ne=[1,4,4,1],use_view_slice=1","support","1","yes","CUDA" +"CUDA0","CONT","type=i32,ne=[1,8,17,1],use_view_slice=1","support","1","yes","CUDA" +"CUDA0","CONT","type=i32,ne=[10,10,10,1],use_view_slice=1","support","1","yes","CUDA" +"CUDA0","CONT","type=i32,ne=[2,1,1,1],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=i32,ne=[2,1,3,5],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=i32,ne=[2,3,5,7],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=i32,ne=[1,4,4,1],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=i32,ne=[1,8,17,1],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=i32,ne=[10,10,10,1],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=f16,ne=[2,1,1,1],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=f16,ne=[2,1,3,5],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=f16,ne=[2,3,5,7],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=f16,ne=[1,4,4,1],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=f16,ne=[1,8,17,1],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=f16,ne=[10,10,10,1],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=bf16,ne=[2,1,1,1],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=bf16,ne=[2,1,3,5],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=bf16,ne=[2,3,5,7],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=bf16,ne=[1,4,4,1],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=bf16,ne=[1,8,17,1],use_view_slice=0","support","1","yes","CUDA" +"CUDA0","CONT","type=bf16,ne=[10,10,10,1],use_view_slice=0","support","1","yes","CUDA" "CUDA0","ADD","type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1","support","1","yes","CUDA" "CUDA0","SUB","type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1","support","1","yes","CUDA" "CUDA0","MUL","type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1","support","1","yes","CUDA" @@ -5655,6 +5684,7 @@ "CUDA0","MUL","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","CUDA" "CUDA0","DIV","type=f32,ne=[64,262144,1,1],nr=[1,1,1,1],nf=1","support","1","yes","CUDA" "CUDA0","ADD1","type=f32,ne=[10,5,4,3]","support","1","yes","CUDA" +"CUDA0","ADD1","type=f32,ne=[1024,1024,1,1]","support","1","yes","CUDA" "CUDA0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=0.000000,inplace=0","support","1","yes","CUDA" "CUDA0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=0","support","1","yes","CUDA" "CUDA0","SCALE","type=f32,ne=[10,10,10,10],scale=2.000000,bias=1.000000,inplace=1","support","1","yes","CUDA" @@ -8644,9 +8674,13 @@ "CUDA0","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","CUDA" "CUDA0","LEAKY_RELU","type=f16,ne_a=[7,1,5,3],negative_slope=0.100000","support","1","yes","CUDA" "CUDA0","FLOOR","type=f16,ne=[7,1,5,3]","support","1","yes","CUDA" +"CUDA0","FLOOR","type=f16,ne=[1024,1024,1,1]","support","1","yes","CUDA" "CUDA0","CEIL","type=f16,ne=[7,1,5,3]","support","1","yes","CUDA" +"CUDA0","CEIL","type=f16,ne=[1024,1024,1,1]","support","1","yes","CUDA" "CUDA0","ROUND","type=f16,ne=[7,1,5,3]","support","1","yes","CUDA" +"CUDA0","ROUND","type=f16,ne=[1024,1024,1,1]","support","1","yes","CUDA" "CUDA0","TRUNC","type=f16,ne=[7,1,5,3]","support","1","yes","CUDA" +"CUDA0","TRUNC","type=f16,ne=[1024,1024,1,1]","support","1","yes","CUDA" "CUDA0","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","CUDA" "CUDA0","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","CUDA" "CUDA0","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","CUDA" @@ -8666,9 +8700,13 @@ "CUDA0","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","CUDA" "CUDA0","LEAKY_RELU","type=f32,ne_a=[7,1,5,3],negative_slope=0.100000","support","1","yes","CUDA" "CUDA0","FLOOR","type=f32,ne=[7,1,5,3]","support","1","yes","CUDA" +"CUDA0","FLOOR","type=f32,ne=[1024,1024,1,1]","support","1","yes","CUDA" "CUDA0","CEIL","type=f32,ne=[7,1,5,3]","support","1","yes","CUDA" +"CUDA0","CEIL","type=f32,ne=[1024,1024,1,1]","support","1","yes","CUDA" "CUDA0","ROUND","type=f32,ne=[7,1,5,3]","support","1","yes","CUDA" +"CUDA0","ROUND","type=f32,ne=[1024,1024,1,1]","support","1","yes","CUDA" "CUDA0","TRUNC","type=f32,ne=[7,1,5,3]","support","1","yes","CUDA" +"CUDA0","TRUNC","type=f32,ne=[1024,1024,1,1]","support","1","yes","CUDA" "CUDA0","DIAG_MASK_INF","type=f32,ne=[10,10,1,1],n_past=5","support","1","yes","CUDA" "CUDA0","DIAG_MASK_INF","type=f32,ne=[10,10,3,1],n_past=5","support","1","yes","CUDA" "CUDA0","DIAG_MASK_INF","type=f32,ne=[10,10,3,2],n_past=5","support","1","yes","CUDA" @@ -9411,18 +9449,405 @@ "CUDA0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","0","no","CUDA" "CUDA0","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","CUDA" "CUDA0","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","0","no","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[3,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[4,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[7,1,1,1],order=0","support","1","yes","CUDA" "CUDA0","ARGSORT","type=f32,ne=[8,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[15,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[16,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[31,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[32,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[63,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[64,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[127,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[128,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[255,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[256,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[511,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[512,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[1023,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[1024,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[2047,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[2048,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[4095,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[4096,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[8191,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[8192,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[16383,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[16384,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[32767,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[32768,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[65535,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[65536,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[131071,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[131072,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[262143,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[262144,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[524287,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[524288,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[1048575,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[1048576,1,1,1],order=0","support","1","yes","CUDA" "CUDA0","ARGSORT","type=f32,ne=[16,10,10,10],order=0","support","1","yes","CUDA" "CUDA0","ARGSORT","type=f32,ne=[60,10,10,10],order=0","support","1","yes","CUDA" -"CUDA0","ARGSORT","type=f32,ne=[1024,1,1,1],order=0","support","1","yes","CUDA" -"CUDA0","ARGSORT","type=f32,ne=[16384,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[1023,2,1,3],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[1024,2,1,3],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[1025,2,1,3],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[2047,2,1,3],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[2048,2,1,3],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[2049,2,1,3],order=0","support","1","yes","CUDA" "CUDA0","ARGSORT","type=f32,ne=[2,8,8192,1],order=0","support","1","yes","CUDA" -"CUDA0","ARGSORT","type=f32,ne=[8,1,1,1],order=1","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[3,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[4,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[7,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[8,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[15,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[16,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[31,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[32,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[63,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[64,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[127,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[128,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[255,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[256,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[511,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[512,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[1023,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[1024,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[2047,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[2048,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[4095,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[4096,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[8191,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[8192,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[16383,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[16384,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[32767,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[32768,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[65535,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[65536,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[131071,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[131072,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[262143,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[262144,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[524287,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[524288,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[1048575,1,1,1],order=0","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[1048576,1,1,1],order=0","support","1","yes","CUDA" "CUDA0","ARGSORT","type=f32,ne=[16,10,10,10],order=1","support","1","yes","CUDA" "CUDA0","ARGSORT","type=f32,ne=[60,10,10,10],order=1","support","1","yes","CUDA" -"CUDA0","ARGSORT","type=f32,ne=[1024,1,1,1],order=1","support","1","yes","CUDA" -"CUDA0","ARGSORT","type=f32,ne=[16384,1,1,1],order=1","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[1023,2,1,3],order=1","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[1024,2,1,3],order=1","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[1025,2,1,3],order=1","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[2047,2,1,3],order=1","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[2048,2,1,3],order=1","support","1","yes","CUDA" +"CUDA0","ARGSORT","type=f32,ne=[2049,2,1,3],order=1","support","1","yes","CUDA" "CUDA0","ARGSORT","type=f32,ne=[2,8,8192,1],order=1","support","1","yes","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[12,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[13,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[13,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[15,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[15,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[15,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[19,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[19,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[19,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[19,1,2,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[27,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[27,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[27,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[27,1,2,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16,1,1,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[27,1,2,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[43,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[43,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[43,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[43,1,2,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32,1,1,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[43,1,2,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[64,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[75,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[64,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[75,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[64,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[75,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[64,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[75,1,2,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[64,1,1,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[75,1,2,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[128,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[139,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[128,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[139,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[128,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[139,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[128,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[139,1,2,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[128,1,1,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[139,1,2,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[128,1,1,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[139,1,2,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[256,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[267,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[256,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[267,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[256,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[267,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[256,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[267,1,2,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[256,1,1,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[267,1,2,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[256,1,1,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[267,1,2,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[512,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[523,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[512,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[523,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[512,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[523,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[512,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[523,1,2,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[512,1,1,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[523,1,2,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[512,1,1,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[523,1,2,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[512,1,1,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[523,1,2,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1024,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1035,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1024,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1035,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1024,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1035,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1024,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1035,1,2,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1024,1,1,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1035,1,2,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1024,1,1,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1035,1,2,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1024,1,1,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1035,1,2,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1024,1,1,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1035,1,2,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2048,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2059,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2048,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2059,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2048,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2059,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2048,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2059,1,2,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2048,1,1,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2059,1,2,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2048,1,1,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2059,1,2,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2048,1,1,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2059,1,2,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2048,1,1,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2059,1,2,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4096,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4107,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4096,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4107,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4096,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4107,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4096,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4107,1,2,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4096,1,1,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4107,1,2,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4096,1,1,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4107,1,2,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4096,1,1,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4107,1,2,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4096,1,1,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[4107,1,2,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8192,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8203,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8192,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8203,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8192,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8203,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8192,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8203,1,2,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8192,1,1,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8203,1,2,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8192,1,1,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8203,1,2,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8192,1,1,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8203,1,2,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8192,1,1,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[8203,1,2,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16384,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16395,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16384,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16395,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16384,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16395,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16384,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16395,1,2,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16384,1,1,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16395,1,2,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16384,1,1,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16395,1,2,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16384,1,1,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16395,1,2,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16384,1,1,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16395,1,2,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16384,1,1,1],k=9999,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16395,1,2,1],k=9999,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32768,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32779,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32768,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32779,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32768,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32779,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32768,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32779,1,2,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32768,1,1,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32779,1,2,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32768,1,1,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32779,1,2,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32768,1,1,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32779,1,2,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32768,1,1,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32779,1,2,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32768,1,1,1],k=9999,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[32779,1,2,1],k=9999,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65536,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65547,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65536,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65547,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65536,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65547,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65536,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65547,1,2,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65536,1,1,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65547,1,2,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65536,1,1,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65547,1,2,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65536,1,1,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65547,1,2,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65536,1,1,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65547,1,2,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65536,1,1,1],k=9999,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[65547,1,2,1],k=9999,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131072,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131083,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131072,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131083,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131072,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131083,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131072,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131083,1,2,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131072,1,1,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131083,1,2,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131072,1,1,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131083,1,2,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131072,1,1,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131083,1,2,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131072,1,1,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131083,1,2,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131072,1,1,1],k=9999,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[131083,1,2,1],k=9999,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262144,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262155,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262144,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262155,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262144,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262155,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262144,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262155,1,2,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262144,1,1,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262155,1,2,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262144,1,1,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262155,1,2,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262144,1,1,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262155,1,2,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262144,1,1,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262155,1,2,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262144,1,1,1],k=9999,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[262155,1,2,1],k=9999,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524288,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524299,1,2,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524288,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524299,1,2,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524288,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524299,1,2,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524288,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524299,1,2,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524288,1,1,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524299,1,2,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524288,1,1,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524299,1,2,1],k=100,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524288,1,1,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524299,1,2,1],k=500,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524288,1,1,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524299,1,2,1],k=1023,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524288,1,1,1],k=9999,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[524299,1,2,1],k=9999,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16,10,10,10],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[60,10,10,10],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1023,2,1,3],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1024,2,1,3],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1025,2,1,3],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16384,1,1,1],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2047,2,1,3],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2048,2,1,3],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2049,2,1,3],k=1,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16,10,10,10],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[60,10,10,10],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1023,2,1,3],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1024,2,1,3],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1025,2,1,3],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16384,1,1,1],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2047,2,1,3],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2048,2,1,3],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2049,2,1,3],k=2,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16,10,10,10],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[60,10,10,10],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1023,2,1,3],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1024,2,1,3],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1025,2,1,3],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16384,1,1,1],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2047,2,1,3],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2048,2,1,3],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2049,2,1,3],k=3,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16,10,10,10],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[60,10,10,10],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1023,2,1,3],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1024,2,1,3],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1025,2,1,3],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16384,1,1,1],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2047,2,1,3],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2048,2,1,3],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2049,2,1,3],k=7,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16,10,10,10],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[60,10,10,10],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1023,2,1,3],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1024,2,1,3],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[1025,2,1,3],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[16384,1,1,1],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2047,2,1,3],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2048,2,1,3],k=15,ties=0","support","0","no","CUDA" +"CUDA0","TOP_K","type=f32,ne=[2049,2,1,3],k=15,ties=0","support","0","no","CUDA" "CUDA0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=0","support","1","yes","CUDA" "CUDA0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=1","support","1","yes","CUDA" "CUDA0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=nearest,flags=none","support","1","yes","CUDA" @@ -9435,6 +9860,10 @@ "CUDA0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bicubic,transpose=1","support","1","yes","CUDA" "CUDA0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bicubic,flags=none","support","1","yes","CUDA" "CUDA0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bicubic,flags=none","support","1","yes","CUDA" +"CUDA0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=513,transpose=0","support","1","yes","CUDA" +"CUDA0","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=513,transpose=1","support","1","yes","CUDA" +"CUDA0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear,flags=none","support","1","yes","CUDA" +"CUDA0","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bilinear,flags=none","support","1","yes","CUDA" "CUDA0","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear,flags=align_corners","support","1","yes","CUDA" "CUDA0","UPSCALE","type=f32,ne=[1,4,3,2],ne_tgt=[2,8,3,2],mode=bilinear,flags=align_corners","support","1","yes","CUDA" "CUDA0","UPSCALE","type=f32,ne=[4,1,3,2],ne_tgt=[1,1,3,2],mode=bilinear,flags=align_corners","support","1","yes","CUDA" @@ -9463,34 +9892,59 @@ "CUDA0","GROUP_NORM","type=f32,ne=[64,64,320,1],num_groups=32,eps=0.000001","support","1","yes","CUDA" "CUDA0","GROUP_NORM","type=f32,ne=[9,9,1280,1],num_groups=32,eps=0.000001","support","1","yes","CUDA" "CUDA0","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1]","support","1","yes","CUDA" -"CUDA0","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1","support","1","yes","CUDA" -"CUDA0","PAD","type=f32,ne_a=[512,512,3,1],lp0=1,rp0=1,lp1=1,rp1=1,lp2=1,rp2=1,lp3=1,rp3=1,v=0","support","1","yes","CUDA" +"CUDA0","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1,circular=0","support","1","yes","CUDA" +"CUDA0","PAD","type=f32,ne_a=[33,17,2,1],pad_0=4,pad_1=3,circular=1","support","1","yes","CUDA" +"CUDA0","PAD","type=f32,ne_a=[512,512,3,1],lp0=1,rp0=1,lp1=1,rp1=1,lp2=1,rp2=1,lp3=1,rp3=1,v=0,circular=0","support","1","yes","CUDA" "CUDA0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","1","yes","CUDA" "CUDA0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","1","yes","CUDA" "CUDA0","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","1","yes","CUDA" "CUDA0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","1","yes","CUDA" +"CUDA0","ARANGE","type=f32,start=0.000000,stop=1048576.000000,step=1.000000","support","1","yes","CUDA" "CUDA0","TIMESTEP_EMBEDDING","type=f32,ne_a=[2,1,1,1],dim=320,max_period=10000","support","1","yes","CUDA" "CUDA0","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","1","yes","CUDA" -"CUDA0","CUMSUM","type=f32,ne=[10,5,4,3]","support","0","no","CUDA" +"CUDA0","CUMSUM","type=f32,ne=[10,5,4,3]","support","1","yes","CUDA" +"CUDA0","CUMSUM","type=f32,ne=[127,5,4,3]","support","1","yes","CUDA" +"CUDA0","CUMSUM","type=f32,ne=[128,5,4,3]","support","1","yes","CUDA" +"CUDA0","CUMSUM","type=f32,ne=[128,128,4,4]","support","1","yes","CUDA" +"CUDA0","CUMSUM","type=f32,ne=[255,5,4,3]","support","1","yes","CUDA" +"CUDA0","CUMSUM","type=f32,ne=[256,5,4,3]","support","1","yes","CUDA" +"CUDA0","CUMSUM","type=f32,ne=[511,5,4,3]","support","1","yes","CUDA" +"CUDA0","CUMSUM","type=f32,ne=[512,5,4,3]","support","1","yes","CUDA" +"CUDA0","CUMSUM","type=f32,ne=[1023,5,4,3]","support","1","yes","CUDA" +"CUDA0","CUMSUM","type=f32,ne=[1024,5,4,3]","support","1","yes","CUDA" +"CUDA0","CUMSUM","type=f32,ne=[2047,5,4,3]","support","1","yes","CUDA" +"CUDA0","CUMSUM","type=f32,ne=[2048,5,4,3]","support","1","yes","CUDA" +"CUDA0","CUMSUM","type=f32,ne=[242004,1,1,1]","support","1","yes","CUDA" +"CUDA0","CUMSUM","type=f32,ne=[375960,1,1,1]","support","1","yes","CUDA" "CUDA0","XIELU","type=f32,ne=[10,5,4,3]","support","0","no","CUDA" -"CUDA0","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","0","no","CUDA" -"CUDA0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","CUDA" -"CUDA0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","CUDA" -"CUDA0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","CUDA" -"CUDA0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","0","no","CUDA" -"CUDA0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","0","no","CUDA" -"CUDA0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","0","no","CUDA" -"CUDA0","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","CUDA" -"CUDA0","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","CUDA" -"CUDA0","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","CUDA" -"CUDA0","SOLVE_TRI","type=f32,ne_lhs=[30,30,7,1],ne_rhs=[8,30,7,1]","support","0","no","CUDA" -"CUDA0","SOLVE_TRI","type=f32,ne_lhs=[42,42,5,2],ne_rhs=[10,42,5,2]","support","0","no","CUDA" -"CUDA0","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]","support","0","no","CUDA" +"CUDA0","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","1","yes","CUDA" +"CUDA0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","1","yes","CUDA" +"CUDA0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","1","yes","CUDA" +"CUDA0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","1","yes","CUDA" +"CUDA0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","1","yes","CUDA" +"CUDA0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","CUDA" +"CUDA0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","CUDA" +"CUDA0","FILL","type=f32,ne=[2048,512,2,2],c=3.500000","support","1","yes","CUDA" +"CUDA0","DIAG","type=f32,ne=[10,1,4,3]","support","1","yes","CUDA" +"CUDA0","DIAG","type=f32,ne=[79,1,19,13]","support","1","yes","CUDA" +"CUDA0","DIAG","type=f32,ne=[256,1,8,16]","support","1","yes","CUDA" +"CUDA0","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","1","yes","CUDA" +"CUDA0","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","1","yes","CUDA" +"CUDA0","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","1","yes","CUDA" +"CUDA0","SOLVE_TRI","type=f32,ne_lhs=[30,30,7,1],ne_rhs=[8,30,7,1]","support","1","yes","CUDA" +"CUDA0","SOLVE_TRI","type=f32,ne_lhs=[42,42,5,2],ne_rhs=[10,42,5,2]","support","1","yes","CUDA" +"CUDA0","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]","support","1","yes","CUDA" "CUDA0","SOLVE_TRI","type=f32,ne_lhs=[100,100,4,4],ne_rhs=[41,100,4,4]","support","0","no","CUDA" -"CUDA0","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0","support","1","yes","CUDA" -"CUDA0","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0","support","1","yes","CUDA" -"CUDA0","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1","support","0","no","CUDA" -"CUDA0","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1","support","0","no","CUDA" +"CUDA0","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,4],ne_rhs=[31,128,4,4]","support","0","no","CUDA" +"CUDA0","SOLVE_TRI","type=f32,ne_lhs=[64,64,4,4],ne_rhs=[300,64,4,4]","support","0","no","CUDA" +"CUDA0","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0,circular=0","support","1","yes","CUDA" +"CUDA0","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0,circular=0","support","1","yes","CUDA" +"CUDA0","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0,circular=1","support","1","yes","CUDA" +"CUDA0","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0,circular=1","support","1","yes","CUDA" +"CUDA0","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1,circular=0","support","0","no","CUDA" +"CUDA0","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1,circular=0","support","0","no","CUDA" +"CUDA0","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1,circular=1","support","0","no","CUDA" +"CUDA0","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1,circular=1","support","0","no","CUDA" "CUDA0","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","1","yes","CUDA" "CUDA0","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","1","yes","CUDA" "CUDA0","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","CUDA" diff --git a/examples/model-conversion/logits.cpp b/examples/model-conversion/logits.cpp index bbd095e603..5bcf063267 100644 --- a/examples/model-conversion/logits.cpp +++ b/examples/model-conversion/logits.cpp @@ -144,7 +144,7 @@ int main(int argc, char ** argv) { return 1; } std::string s(buf, n); - printf("%s", s.c_str()); + printf("%s (%d)", s.c_str(), id); } printf("\n"); diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 6b69ad8281..ab5b4760e2 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -168,6 +168,7 @@ option(GGML_RVV "ggml: enable rvv" ON) option(GGML_RV_ZFH "ggml: enable riscv zfh" ON) option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON) option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON) +option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause " ON) option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE}) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 218222ece8..a5995fdc2c 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -25,6 +25,7 @@ static bool ggml_is_view(const struct ggml_tensor * t) { // ops that return true for this function must not use restrict pointers for their backend implementations bool ggml_op_can_inplace(enum ggml_op op) { switch (op) { + case GGML_OP_FILL: case GGML_OP_SCALE: case GGML_OP_DIAG_MASK_ZERO: case GGML_OP_DIAG_MASK_INF: diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 48f4b7db69..835b53f659 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2251,12 +2251,12 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx, int sections[4], bool mrope_used, bool is_imrope, - bool indep_sects) { - ggml_tensor * src0 = dst->src[0]; // input + bool indep_sects, + int64_t rope_dims) { ggml_tensor * src1 = dst->src[1]; // position ggml_tensor * src2 = dst->src[2]; // freq_factors - int64_t theta_scale_length = src0->ne[0] / 2; + int64_t theta_scale_length = rope_dims / 2; int64_t position_length = dst->ne[2]; // TODO: check theta_scale_length and position_length. @@ -2331,18 +2331,17 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx, ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), ctx.rope_cache.theta_scale_exp_host, theta_scale_length * sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream())); - - acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), - theta_scale_ne, theta_scale_nb, 1); } + acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), + theta_scale_ne, theta_scale_nb, 1); // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor. + // TODO: acl_yarn_ramp_tensor use rope cache. bool yarn_ramp_tensor_updated = false; ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool()); acl_tensor_ptr acl_yarn_ramp_tensor; - if (ext_factor != 0 && - // TODO: check more parameter. - (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) { + if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length || + ctx.rope_cache.freq_scale != freq_scale)) { yarn_ramp_tensor_updated = true; // -rope_yarn_ramp @@ -2590,7 +2589,7 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx, aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true); } - int64_t sin_reshape_ne[4] = { src0->ne[0], 1, dst->ne[2], 1 }; + int64_t sin_reshape_ne[4] = { rope_dims, 1, dst->ne[2], 1 }; size_t sin_reshape_nb[GGML_MAX_DIMS]; sin_reshape_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { @@ -2645,7 +2644,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { // param float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - int sections[4]; + int sections[4]; // const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; @@ -2654,44 +2653,60 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { GGML_TENSOR_UNARY_OP_LOCALS - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4); + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int) * 4); - // TODO: n_dims <= ne0 - GGML_ASSERT(n_dims == ne0); GGML_ASSERT(n_dims % 2 == 0); + GGML_ASSERT(n_dims <= ne00); const float theta_scale = powf(freq_base, -2.0f / n_dims); float corr_dims[2]; ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - bool is_neox = mode & GGML_ROPE_TYPE_NEOX; - const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope - const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope - const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope + // mrope_used means the GGML_ROPE_TYPE_MROPE bit is set. + // Note: this bit is also set for imrope and some vision modes, + // so mrope_used does NOT exclusively indicate pure mrope. + const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (mrope_used) { GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); } if (is_vision) { - GGML_ASSERT(n_dims == ne0/2); + GGML_ASSERT(n_dims == ne0 / 2); } if (is_imrope || mrope_used) { is_neox = true; } - // init ctx.rope_cos/rope_sin cache - aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, mrope_used, is_imrope, is_vision); + int64_t rope_dims = n_dims; - int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 }; + //Our current RotaryPositionEmbedding does not support the VISION mode, + //but essentially it only modifies theta_base in mrope, + //then repeats it at the end in the same way as is_neox. + //In fact, RoPE is still applied across all dimensions. + if (is_vision) { + rope_dims = src0->ne[0]; + } + int64_t tail_dims = ne00 - rope_dims; + bool has_tail = tail_dims > 0; + + // init ctx.rope_cos/rope_sin cache + aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, + mrope_used, is_imrope, is_vision, rope_dims); + + // Cache is generated with ne00 dimensions, so we use ne00 for reshape + int64_t sin_reshape_ne[4] = { rope_dims, 1, ne02, 1 }; size_t sin_reshape_nb[GGML_MAX_DIMS]; sin_reshape_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; i++) { @@ -2704,7 +2719,6 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0); acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); - #ifdef ASCEND_310P // Special ROPE operation for 310P @@ -2844,46 +2858,124 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { } return; #endif - int64_t acl_mode = is_neox ? 0 : 1; - switch (src0->type) { - case GGML_TYPE_F32: - { - GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(), - acl_sin_reshape_tensor.get(), acl_mode, acl_dst.get()); - break; - } - case GGML_TYPE_F16: - { - ggml_cann_pool_alloc src_trans_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(float)); - void * src_trans_buffer = src_trans_allocator.get(); - ggml_cann_pool_alloc dst_trans_allocator(ctx.pool(), ggml_nelements(dst) * sizeof(float)); - void * dst_trans_buffer = dst_trans_allocator.get(); + // Pre-define head and tail dimensions for reuse + int64_t head_ne[GGML_MAX_DIMS] = { rope_dims, ne01, ne02, ne03 }; + int64_t tail_ne[GGML_MAX_DIMS] = { tail_dims, ne01, ne02, ne03 }; - size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = sizeof(float); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; - } + // Step 1: Prepare trans tensors for F16 type conversion to F32 if needed + bool src_dst_need_trans = false; + ggml_cann_pool_alloc src_trans_allocator(ctx.pool()); + ggml_cann_pool_alloc dst_trans_allocator(ctx.pool()); + acl_tensor_ptr acl_src_trans_tensor; + acl_tensor_ptr acl_dst_trans_tensor; + void * src_trans_buffer = nullptr; + void * dst_trans_buffer = nullptr; + size_t src_dst_trans_nb[GGML_MAX_DIMS]; + if (src0->type == GGML_TYPE_F16) { + src_dst_need_trans = true; + src_trans_buffer = src_trans_allocator.alloc(ggml_nelements(src0) * sizeof(float)); + dst_trans_buffer = dst_trans_allocator.alloc(ggml_nelements(dst) * sizeof(float)); - acl_tensor_ptr acl_src_trans_tensor = ggml_cann_create_tensor( - src_trans_buffer, ACL_FLOAT, sizeof(float), src0->ne, src_trans_nb, GGML_MAX_DIMS); - acl_tensor_ptr acl_dst_trans_tensor = ggml_cann_create_tensor( - dst_trans_buffer, ACL_FLOAT, sizeof(float), dst->ne, src_trans_nb, GGML_MAX_DIMS); + src_dst_trans_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + src_dst_trans_nb[i] = src_dst_trans_nb[i - 1] * src0->ne[i - 1]; + } + acl_src_trans_tensor = ggml_cann_create_tensor(src_trans_buffer, ACL_FLOAT, sizeof(float), src0->ne, + src_dst_trans_nb, GGML_MAX_DIMS); + acl_dst_trans_tensor = ggml_cann_create_tensor(dst_trans_buffer, ACL_FLOAT, sizeof(float), dst->ne, + src_dst_trans_nb, GGML_MAX_DIMS); + aclnn_cast(ctx, acl_src.get(), acl_src_trans_tensor.get(), ACL_FLOAT); + } - aclnn_cast(ctx, acl_src.get(), acl_src_trans_tensor.get(), ACL_FLOAT); + // Step 2: Prepare head tensors for tail splitting if needed + acl_tensor_ptr acl_src_head; + acl_tensor_ptr acl_dst_head; + if (has_tail) { + // Create head views for RotaryPositionEmbedding (only first rope_dims dimensions) + // RotaryPositionEmbedding requires contiguous dst tensor, so we use a temporary buffer + if (src_dst_need_trans) { + // Use F32 trans tensor strides + acl_src_head = ggml_cann_create_tensor((char *) src_trans_buffer, ACL_FLOAT, sizeof(float), head_ne, + src_dst_trans_nb, GGML_MAX_DIMS); + } else { + // Use original F32 tensor strides + acl_src_head = ggml_cann_create_tensor((char *) src0->data, ACL_FLOAT, sizeof(float), head_ne, src0->nb, + GGML_MAX_DIMS); + } - GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), - acl_cos_reshape_tensor.get(), acl_sin_reshape_tensor.get(), acl_mode, - acl_dst_trans_tensor.get()); + int64_t head_elements = rope_dims * ne01 * ne02 * ne03; + ggml_cann_pool_alloc dst_head_contiguous_allocator(ctx.pool(), head_elements * sizeof(float)); + void * dst_head_contiguous_buffer = dst_head_contiguous_allocator.get(); - aclnn_cast(ctx, acl_dst_trans_tensor.get(), acl_dst.get(), ACL_FLOAT16); - break; - } - default: - GGML_ABORT("Unsupported tensor type for GGML_OP_ROPE"); - break; + size_t head_contiguous_nb[GGML_MAX_DIMS]; + head_contiguous_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + head_contiguous_nb[i] = head_contiguous_nb[i - 1] * head_ne[i - 1]; + } + acl_dst_head = ggml_cann_create_tensor(dst_head_contiguous_buffer, ACL_FLOAT, sizeof(float), head_ne, + head_contiguous_nb, GGML_MAX_DIMS); + } + + // Step 3: Execute RotaryPositionEmbedding + if (has_tail) { + // Rotate only the head portion (first rope_dims dimensions) + GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_head.get(), acl_cos_reshape_tensor.get(), + acl_sin_reshape_tensor.get(), acl_mode, acl_dst_head.get()); + + // Copy head result from contiguous buffer back to destination tensor + if (src_dst_need_trans) { + acl_tensor_ptr acl_dst_head_target = ggml_cann_create_tensor( + (char *) dst_trans_buffer, ACL_FLOAT, sizeof(float), head_ne, src_dst_trans_nb, GGML_MAX_DIMS); + cann_copy(ctx, acl_dst_head.get(), acl_dst_head_target.get()); + } else { + acl_tensor_ptr acl_dst_head_target = + ggml_cann_create_tensor((char *) dst->data, ACL_FLOAT, sizeof(float), head_ne, dst->nb, GGML_MAX_DIMS); + cann_copy(ctx, acl_dst_head.get(), acl_dst_head_target.get()); + } + } else if (src_dst_need_trans) { + // Rotate full tensor (no tail), using trans tensors + GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), acl_cos_reshape_tensor.get(), + acl_sin_reshape_tensor.get(), acl_mode, acl_dst_trans_tensor.get()); + } else { + // Rotate full tensor (no tail), using original tensors + GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(), + acl_sin_reshape_tensor.get(), acl_mode, acl_dst.get()); + } + + // Step 4: Copy unrotated tail portion from source to destination + if (has_tail) { + size_t src_tail_offset; + size_t dst_tail_offset; + + auto copy_tail_device = [&](void * src_ptr, void * dst_ptr, aclDataType dtype, size_t elem_size, + size_t * nb_src_arr, size_t * nb_dst_arr) { + acl_tensor_ptr acl_src_tail = + ggml_cann_create_tensor(src_ptr, dtype, elem_size, tail_ne, nb_src_arr, GGML_MAX_DIMS); + acl_tensor_ptr acl_dst_tail = + ggml_cann_create_tensor(dst_ptr, dtype, elem_size, tail_ne, nb_dst_arr, GGML_MAX_DIMS); + cann_copy(ctx, acl_src_tail.get(), acl_dst_tail.get()); + }; + + if (src_dst_need_trans) { + // Use F32 trans tensor strides and offsets + src_tail_offset = rope_dims * src_dst_trans_nb[0]; + dst_tail_offset = rope_dims * src_dst_trans_nb[0]; + copy_tail_device((char *) src_trans_buffer + src_tail_offset, (char *) dst_trans_buffer + dst_tail_offset, + ACL_FLOAT, sizeof(float), src_dst_trans_nb, src_dst_trans_nb); + } else { + // Use original tensor strides and offsets + src_tail_offset = rope_dims * nb00; + dst_tail_offset = rope_dims * nb0; + copy_tail_device((char *) src0->data + src_tail_offset, (char *) dst->data + dst_tail_offset, + ggml_cann_type_mapping(dst->type), ggml_element_size(dst), src0->nb, dst->nb); + } + } + + // Step 5: Cast back to F16 if needed + if (src_dst_need_trans) { + aclnn_cast(ctx, acl_dst_trans_tensor.get(), acl_dst.get(), ACL_FLOAT16); } } diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index b17445bb9a..45c7294e68 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -315,7 +315,7 @@ struct ggml_cann_rope_cache { if (theta_scale_exp_host) { free(theta_scale_exp_host); } - if(position_select_index_host) { + if (position_select_index_host) { free(position_select_index_host); } } @@ -340,7 +340,7 @@ struct ggml_cann_rope_cache { void set(int64_t theta_scale_length, int64_t position_length, - float ext_factor, + float ext_factor, float theta_scale, float freq_scale, float attn_factor, diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 23dc0433c1..81288464c7 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2308,7 +2308,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, bool cann_graph_update_required = false; #ifdef USE_ACL_GRAPH - bool use_cann_graph = true; + bool use_cann_graph = true; static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or("")); if (!prefill_use_graph) { @@ -2338,7 +2338,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, } } #else - bool use_cann_graph = false; + bool use_cann_graph = false; #endif // USE_ACL_GRAPH evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required); @@ -2474,16 +2474,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten } case GGML_OP_ROPE: { - // TODO: with ops-test v == 1 - // TODO: n_dims <= ne0 - if (op->src[0]->ne[0] != op->op_params[1]) { - return false; - } - if (op->src[0]->ne[0] > 896) { return false; } #ifdef ASCEND_310P + // TODO: Support rope_dim < ne00(dim) + if (op->src[0]->ne[0] != op->op_params[1]) { + return false; + } if (!ggml_is_contiguous(op->src[0])) { return false; } diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 7e53a57b7b..fc31089f3e 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -469,6 +469,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_RV_ZICBOP) string(APPEND MARCH_STR "_zicbop") endif() + if (GGML_RV_ZIHINTPAUSE) + string(APPEND MARCH_STR "_zihintpause") + endif() list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d) else() # Begin with the lowest baseline diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 8507557267..b468b115a1 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -490,6 +490,15 @@ static inline void ggml_thread_cpu_relax(void) { static inline void ggml_thread_cpu_relax(void) { _mm_pause(); } +#elif defined(__riscv) +static inline void ggml_thread_cpu_relax(void) { + #ifdef __riscv_zihintpause + __asm__ __volatile__ ("pause"); + #else + /* Encoding of the pause instruction */ + __asm__ __volatile__ (".4byte 0x100000F"); + #endif +} #else static inline void ggml_thread_cpu_relax(void) {;} #endif diff --git a/ggml/src/ggml-cuda/diag.cu b/ggml/src/ggml-cuda/diag.cu new file mode 100644 index 0000000000..5cea210517 --- /dev/null +++ b/ggml/src/ggml-cuda/diag.cu @@ -0,0 +1,77 @@ +#include "convert.cuh" +#include "diag.cuh" +#include "ggml.h" + +template +static __global__ void diag_kernel(T * __restrict__ dst, + const T * __restrict__ src, + const int64_t ne0, + const int64_t ne1, + const int64_t ne2, + const int64_t ne3, + const int64_t total_elements) { + const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (global_idx >= total_elements) { + return; + } + + const int64_t i0 = global_idx % ne0; + const int64_t i1 = (global_idx / ne0) % ne1; + const int64_t i2 = (global_idx / (ne0 * ne1)) % ne2; + const int64_t i3 = global_idx / (ne0 * ne1 * ne2); + + const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0; + + if (i0 == i1) { + const int64_t batch_idx = i3 * ne2 + i2; + const int64_t src_idx = batch_idx * ne0 + i0; + dst[dst_idx] = src[src_idx]; + } else { + dst[dst_idx] = ggml_cuda_cast(0); + } + GGML_UNUSED_VARS(ne3); +} + +void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + void * dst_d = dst->data; + const void * src0_d = src0->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + GGML_ASSERT(ne00 == ne0); + GGML_ASSERT(ne01 == 1); + GGML_ASSERT(ne02 == ne2); + GGML_ASSERT(ne03 == ne3); + + const int64_t n_elems = ggml_nelements(dst); + const int64_t num_blocks = (n_elems + CUDA_DIAG_BLOCK_SIZE - 1) / CUDA_DIAG_BLOCK_SIZE; + + switch (dst->type) { + case GGML_TYPE_F32: + diag_kernel<<>>((float *) dst_d, (const float *) src0_d, ne0, + ne1, ne2, ne3, n_elems); + break; + case GGML_TYPE_F16: + diag_kernel<<>>((half *) dst_d, (const half *) src0_d, ne0, + ne1, ne2, ne3, n_elems); + break; + default: + GGML_ABORT("unsupported type"); + } +} diff --git a/ggml/src/ggml-cuda/diag.cuh b/ggml/src/ggml-cuda/diag.cuh new file mode 100644 index 0000000000..7d73e6a8eb --- /dev/null +++ b/ggml/src/ggml-cuda/diag.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_DIAG_BLOCK_SIZE 256 + +void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index ade0773dad..d51537f7d0 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -955,22 +955,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( (K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup); } - for (; kb0 < kb0_stop-1; ++kb0) { - constexpr bool last_iter = false; - constexpr bool oob_check = false; - constexpr int k_VKQ_sup = nbatch_fa; - flash_attn_ext_f16_iter - - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, - KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); - } // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. if constexpr (ncols2 == 1) { - if (ne11 % nbatch_fa == 0) { - constexpr bool last_iter = true; - constexpr bool oob_check = false; + constexpr bool oob_check = true; + for (; kb0 < kb0_stop-1; ++kb0) { + constexpr bool last_iter = false; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } else { + constexpr bool oob_check = false; + for (; kb0 < kb0_stop-1; ++kb0) { + constexpr bool last_iter = false; + constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter @@ -989,9 +988,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); } - } else { constexpr bool last_iter = true; - constexpr bool oob_check = false; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter ( &Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)], diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index dec01ff8ad..0155406665 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -36,12 +36,26 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; const ggml_tensor * mask = dst->src[3]; float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); - const bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + // Edge cases like no mask, ALiBi, unpadded K/V, or misaligned addresses for large data transfers + // are put into the template specialization without GQA optimizations. + bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + for (const ggml_tensor * t : {Q, K, V, mask}) { + if (t == nullptr) { + continue; + } + for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { + if (t->nb[i] % 16 != 0) { + use_gqa_opt = false; + break; + } + } + } GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; diff --git a/ggml/src/ggml-cuda/fill.cu b/ggml/src/ggml-cuda/fill.cu new file mode 100644 index 0000000000..739062c405 --- /dev/null +++ b/ggml/src/ggml-cuda/fill.cu @@ -0,0 +1,37 @@ +#include "fill.cuh" +#include "convert.cuh" + +#define CUDA_FILL_BLOCK_SIZE 256 + +template +static __global__ void fill_kernel(T * dst, const int64_t k, const T value) { + const int64_t i = (int64_t)blockDim.x * blockIdx.x + threadIdx.x; + if (i >= k) { + return; + } + dst[i] = value; +} + +void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + void * dst_d = dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(dst)); + + float value; + memcpy(&value, dst->op_params, sizeof(float)); + + const int64_t k = ggml_nelements(dst); + const int64_t num_blocks = (k + CUDA_FILL_BLOCK_SIZE - 1) / CUDA_FILL_BLOCK_SIZE; + + switch (dst->type) { + case GGML_TYPE_F32: + fill_kernel<<>>((float *)dst_d, k, value); + break; + case GGML_TYPE_F16: + fill_kernel<<>>((half *)dst_d, k, ggml_cuda_cast(value)); + break; + default: + GGML_ABORT("unsupported type"); + } +} diff --git a/ggml/src/ggml-cuda/fill.cuh b/ggml/src/ggml-cuda/fill.cuh new file mode 100644 index 0000000000..8443c83620 --- /dev/null +++ b/ggml/src/ggml-cuda/fill.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 235d94d500..279679a4ea 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -20,6 +20,7 @@ #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" #include "ggml-cuda/diagmask.cuh" +#include "ggml-cuda/diag.cuh" #include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" @@ -56,6 +57,7 @@ #include "ggml-cuda/solve_tri.cuh" #include "ggml-cuda/tri.cuh" #include "ggml-cuda/cumsum.cuh" +#include "ggml-cuda/fill.cuh" #include "ggml.h" #include @@ -2640,6 +2642,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: break; + case GGML_OP_DIAG: + ggml_cuda_op_diag(ctx, dst); + break; case GGML_OP_DIAG_MASK_INF: ggml_cuda_op_diag_mask_inf(ctx, dst); break; @@ -2730,6 +2735,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SOLVE_TRI: ggml_cuda_op_solve_tri(ctx, dst); break; + case GGML_OP_FILL: + ggml_cuda_op_fill(ctx, dst); + break; default: return false; } @@ -4617,8 +4625,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: + case GGML_OP_FILL: case GGML_OP_CUMSUM: case GGML_OP_TRI: + case GGML_OP_DIAG: return true; case GGML_OP_SOLVE_TRI: return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32; diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index 2e2b39720f..e161d4dc43 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -3,7 +3,6 @@ #include "solve_tri.cuh" #define MAX_N_FAST 64 -#define MAX_K_FAST 32 // ====================== // Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction @@ -48,65 +47,58 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A, float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); __shared__ float sA[MAX_N_FAST * MAX_N_FAST]; - __shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)]; const int offset = threadIdx.x + threadIdx.y * blockDim.x; #pragma unroll for (int i = 0; i < n * n; i += k * WARP_SIZE) { - int i0 = i + offset; + const int i0 = i + offset; if (i0 < n * n) { sA[i0] = A_batch[i0]; } } - const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE; - -#pragma unroll - for (int i = 0; i < rows_per_warp; i++) { - const int i0 = lane + i * WARP_SIZE; - if (i0 < n) { - sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx]; - } - } - __syncthreads(); + float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f; + float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f; + + const int half = WARP_SIZE; + const int nrows_low = (n < half) ? n : half; + #pragma unroll - for (int row = 0; row < n; ++row) { + for (int row = 0; row < nrows_low; ++row) { float sum = 0.0f; - - { - int j = lane; - if (j < row) { - sum += sA[row * n + j] * sXt[col_idx * n + j]; - } + if (lane < row) { + sum += sA[row * n + lane] * x_low; } - if (row >= WARP_SIZE) { - int j = WARP_SIZE + lane; - if (j < row) { - sum += sA[row * n + j] * sXt[col_idx * n + j]; - } - } - sum = warp_reduce_sum(sum); - if (lane == 0) { - const float b_val = sXt[col_idx * n + row]; - const float a_diag = sA[row * n + row]; - // no safeguards for division by zero because that indicates corrupt - // data anyway - sXt[col_idx * n + row] = (b_val - sum) / a_diag; + if (lane == row) { + x_low = (x_low - sum) / sA[row * n + row]; } } - __syncthreads(); +#pragma unroll + for (int row = half; row < n; ++row) { + float sum = sA[row * n + lane] * x_low; + const int j = half + lane; + if (j < row) { + sum += sA[row * n + j] * x_high; + } + sum = warp_reduce_sum(sum); + + if (lane == row - half) { + x_high = (x_high - sum) / sA[row * n + row]; + } + } #pragma unroll - for (int i = 0; i < rows_per_warp; i++) { - const int i0 = lane + i * WARP_SIZE; - if (i0 < n) { - X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0]; + for (int rr = 0; rr < 2; ++rr) { + const int row = rr * WARP_SIZE + lane; + if (row < n) { + const float val = (row < half) ? x_low : x_high; + X_batch[row * k + col_idx] = val; } } } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index ba3c342751..680904d132 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -411,6 +411,38 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_me return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) { + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); + + char base[256]; + char name[256]; + + const char * suffix = ""; + if (op->src[1]->ne[0] % 4 == 0) { + suffix = "_4"; + } + + snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix); + snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); @@ -427,7 +459,12 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_me res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } - res.smem = 32*sizeof(float)*nsg; + // Shared memory layout: + // - sgptg * NW floats for partial sums (nsg * 32) + // - sgptg floats for shared_x_dt (nsg) + // - sgptg floats for shared_dA (nsg) + // Total: nsg * (32 + 2) floats + res.smem = (32 + 2)*sizeof(float)*nsg; return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 77f2e98cfe..0a8b9211a7 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -117,6 +117,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_ad struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 30109f83e1..8944b07e90 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -77,6 +77,7 @@ #define FC_MUL_MV 600 #define FC_MUL_MM 700 #define FC_ROPE 800 +#define FC_SSM_CONV 900 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPTG 8 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 9efd51abba..e99c1763f6 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -221,7 +221,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { } if (ctx->debug_graph > 0) { - GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), is_concurrent ? "(concurrent)" : ""); + GGML_LOG_DEBUG("%s: node[%5d] - %-12s %-12s %s\n", __func__, idx, ggml_op_name(node->op), ggml_get_name(node), is_concurrent ? "(concurrent)" : ""); } if (ctx->debug_graph > 1) { GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne); @@ -1365,15 +1365,43 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { /*.nb2 =*/ nb2, }; - auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); + // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead + const bool use_batched = (ne1 > 1); - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + if (use_batched) { + // Determine the smallest power of 2 that's >= ne1, but <= 256 + int BATCH_SIZE; + if (ne1 > 128) BATCH_SIZE = 256; + else if (ne1 > 64 ) BATCH_SIZE = 128; + else if (ne1 > 32 ) BATCH_SIZE = 64; + else if (ne1 > 16 ) BATCH_SIZE = 32; + else if (ne1 > 8 ) BATCH_SIZE = 16; + else if (ne1 > 4 ) BATCH_SIZE = 8; + else BATCH_SIZE = 2; - ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); + auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + + // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences + // Each threadgroup has BATCH_SIZE threads, each handling one token + const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE; + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1); + } else { + auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); + } return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 4b78d5a2ba..51bcbae309 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2343,7 +2343,102 @@ kernel void kernel_ssm_conv_f32_f32_4( x[0] = sumf; } +constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]]; + +// Batched version: each threadgroup processes multiple tokens for better efficiency +// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens +kernel void kernel_ssm_conv_f32_f32_batched( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + // tgpig.x = row index (ir) + // tgpig.y = batch of tokens (i2_base / BATCH_SIZE) + // tgpig.z = sequence index (i3) + // tpitg.x = thread within batch (0..BATCH_SIZE-1) + const short BATCH_SIZE = FC_ssm_conv_bs; + + const int64_t ir = tgpig.x; + const int64_t i2_base = tgpig.y * BATCH_SIZE; + const int64_t i3 = tgpig.z; + const int64_t i2_off = tpitg.x; + const int64_t i2 = i2_base + i2_off; + + const int64_t nc = args.ne10; // conv kernel size (typically 4) + const int64_t n_t = args.ne1; // number of tokens + + // Bounds check for partial batches at the end + if (i2 >= n_t) { + return; + } + + // Load conv weights (shared across all tokens for this row) + device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11); + + // Load source for this specific token + device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + + // Output location for this token + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + for (int64_t i0 = 0; i0 < nc; ++i0) { + sumf += s[i0] * c[i0]; + } + + x[0] = sumf; +} + +kernel void kernel_ssm_conv_f32_f32_batched_4( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + // tgpig.x = row index (ir) + // tgpig.y = batch of tokens (i2_base / BATCH_SIZE) + // tgpig.z = sequence index (i3) + // tpitg.x = thread within batch (0..BATCH_SIZE-1) + const short BATCH_SIZE = FC_ssm_conv_bs; + + const int64_t ir = tgpig.x; + const int64_t i2_base = tgpig.y * BATCH_SIZE; + const int64_t i3 = tgpig.z; + const int64_t i2_off = tpitg.x; + const int64_t i2 = i2_base + i2_off; + + const int64_t nc = args.ne10; // conv kernel size (typically 4) + const int64_t n_t = args.ne1; // number of tokens + + // Bounds check for partial batches at the end + if (i2 >= n_t) { + return; + } + + // Load conv weights (shared across all tokens for this row) + device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11); + + // Load source for this specific token + device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + + // Output location for this token + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + for (int64_t i0 = 0; i0 < nc/4; ++i0) { + sumf += dot(s[i0], c[i0]); + } + + x[0] = sumf; +} + // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part +// Optimized version: reduces redundant memory loads by having one thread load shared values kernel void kernel_ssm_scan_f32( constant ggml_metal_kargs_ssm_scan & args, device const void * src0, @@ -2363,7 +2458,15 @@ kernel void kernel_ssm_scan_f32( uint3 tgpg[[threadgroups_per_grid]]) { constexpr short NW = N_SIMDWIDTH; - shared[tpitg.x] = 0.0f; + // Shared memory layout: + // [0..sgptg*NW-1]: partial sums for reduction (existing) + // [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch + // [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch + threadgroup float * shared_sums = shared; + threadgroup float * shared_x_dt = shared + sgptg * NW; + threadgroup float * shared_dA = shared + sgptg * NW + sgptg; + + shared_sums[tpitg.x] = 0.0f; const int32_t i0 = tpitg.x; const int32_t i1 = tgpig.x; @@ -2403,32 +2506,47 @@ kernel void kernel_ssm_scan_f32( for (int i2 = 0; i2 < n_t; i2 += sgptg) { threadgroup_barrier(mem_flags::mem_threadgroup); - for (int t = 0; t < sgptg && i2 + t < n_t; t++) { - const float dt0 = dt[0]; + // Pre-compute x_dt and dA for this batch of tokens + // Only first sgptg threads do the loads and expensive math + if (i0 < sgptg && i2 + i0 < n_t) { + // ns12 and ns21 are element strides (nb12/nb10, nb21/nb20) + device const float * x_t = x + i0 * args.ns12; + device const float * dt_t = dt + i0 * args.ns21; + + const float dt0 = dt_t[0]; const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0; - const float x_dt = x[0] * dtsp; - const float dA = exp(dtsp * A0); + shared_x_dt[i0] = x_t[0] * dtsp; + shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int t = 0; t < sgptg && i2 + t < n_t; t++) { + const float x_dt = shared_x_dt[t]; + const float dA = exp(shared_dA[t] * A0); s = (s0 * dA) + (B[i0] * x_dt); const float sumf = simd_sum(s * C[i0]); if (tiisg == 0) { - shared[t*NW + sgitg] = sumf; + shared_sums[t*NW + sgitg] = sumf; } // recurse s0 = s; - x += args.ns12; - dt += args.ns21; B += args.ns42; C += args.ns52; } + // Advance pointers for next batch + x += sgptg * args.ns12; + dt += sgptg * args.ns21; + threadgroup_barrier(mem_flags::mem_threadgroup); - const float sumf = simd_sum(shared[sgitg*NW + tiisg]); + const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]); if (tiisg == 0 && i2 + sgitg < n_t) { y[sgitg*nh*nr] = sumf; diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index 83b7c71b66..b41124acc1 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -277,7 +277,7 @@ static void soft_max_f32_sycl(const float *x, const T *mask, const int id = get_current_device_id(); const size_t smpbo = ggml_sycl_info().devices[id].smpbo; - if (nbytes_shared <= smpbo) { + if (nbytes_shared <= smpbo && ncols_x <= max_block_size) { launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>( x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 9f5cdc1398..530ff7b953 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -124,6 +124,13 @@ static void ggml_print_backtrace_symbols(void) { int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0])); backtrace_symbols_fd(trace, nptrs, STDERR_FILENO); } +#elif defined(__APPLE__) +#include +static void ggml_print_backtrace_symbols(void) { + void * trace[100]; + int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0])); + backtrace_symbols_fd(trace, nptrs, STDERR_FILENO); +} #else static void ggml_print_backtrace_symbols(void) { // platform not supported @@ -135,6 +142,20 @@ void ggml_print_backtrace(void) { if (GGML_NO_BACKTRACE) { return; } +#if defined(__APPLE__) + // On macOS, fork+debugger attachment is problematic due to: + // 1. libdispatch "poisons" forked child processes + // 2. lldb has issues attaching to parent from forked child + // Use simple backtrace() instead to avoid Terminal.app crashes + const char * GGML_BACKTRACE_LLDB = getenv("GGML_BACKTRACE_LLDB"); + if (!GGML_BACKTRACE_LLDB) { + fprintf(stderr, "WARNING: Using native backtrace. Set GGML_BACKTRACE_LLDB for more info.\n"); + fprintf(stderr, "WARNING: GGML_BACKTRACE_LLDB may cause native MacOS Terminal.app to crash.\n"); + fprintf(stderr, "See: https://github.com/ggml-org/llama.cpp/pull/17869\n"); + ggml_print_backtrace_symbols(); + return; + } +#endif #if defined(__linux__) FILE * f = fopen("/proc/self/status", "r"); size_t size = 0; diff --git a/grammars/README.md b/grammars/README.md index a63198b5ae..11e3b6dd90 100644 --- a/grammars/README.md +++ b/grammars/README.md @@ -67,6 +67,30 @@ Parentheses `()` can be used to group sequences, which allows for embedding alte - `{m,n}` repeats the precedent symbol or sequence at between `m` and `n` times (included) - `{0,n}` repeats the precedent symbol or sequence at most `n` times (included) +## Tokens + +Tokens allow grammars to match specific tokenizer tokens rather than character sequences. This is useful for constraining outputs based on special tokens (like `` or ``). + +Tokens can be specified in two ways: + +1. **Token ID**: Use angle brackets with the token ID in square brackets: `<[token-id]>`. For example, `<[1000]>` matches the token with ID 1000. + +2. **Token string**: Use angle brackets with the token text directly: ``. For example, `` will match the token whose text is exactly ``. This only works if the string tokenizes to exactly one token in the vocabulary, otherwise the grammar will fail to parse. + +You can negate token matches using the `!` prefix: `!<[1000]>` or `!` matches any token *except* the specified one. + +``` +# Match a thinking block: ... +# Using token strings (requires these to be single tokens in the vocab) +root ::= thinking .* +thinking ::= !* + +# Equivalent grammar using explicit token IDs +# Assumes token 1000 = , token 1001 = +root ::= <[1000]> thinking <[1001]> .* +thinking ::= !<[1001]>* +``` + ## Comments and newlines Comments can be specified with `#`: diff --git a/models/templates/Kimi-K2-Instruct.jinja b/models/templates/Kimi-K2-Instruct.jinja index a9439135ba..6204fb3960 100644 --- a/models/templates/Kimi-K2-Instruct.jinja +++ b/models/templates/Kimi-K2-Instruct.jinja @@ -14,7 +14,7 @@ {%- endmacro %} {%- set tool_response_queue = namespace(ids=[]) -%} -{%- set tool_call_counter = namespace(value=1) -%} +{%- set tool_call_counter = namespace(value=0) -%} {%- if tools -%} <|im_system|>tool_declare<|im_middle|>{{ tools | tojson }}<|im_end|> @@ -36,12 +36,8 @@ {%- if message['role'] == 'assistant' and message.get('tool_calls') -%} {{render_content(message)}}<|tool_calls_section_begin|> {%- for tool_call in message['tool_calls'] -%} - {%- if tool_call['id'] is defined -%} - {%- set formatted_id = tool_call['id'] -%} - {%- else -%} - {%- set formatted_id = 'functions.' + tool_call['function']['name'] + ':' + (tool_call_counter.value | string) -%} - {%- set tool_call_counter.value = tool_call_counter.value + 1 -%} - {%- endif -%} + {%- set formatted_id = 'functions.' + tool_call['function']['name'] + ':' + (tool_call_counter.value | string) -%} + {%- set tool_call_counter.value = tool_call_counter.value + 1 -%} {%- set _ = tool_response_queue.ids.append(formatted_id) -%} <|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|> {%- endfor -%} diff --git a/models/templates/Kimi-K2-Thinking.jinja b/models/templates/Kimi-K2-Thinking.jinja index 4c2af6a783..5641429f53 100644 --- a/models/templates/Kimi-K2-Thinking.jinja +++ b/models/templates/Kimi-K2-Thinking.jinja @@ -25,17 +25,13 @@ {%- endmacro -%} {%- set tool_response_queue = namespace(ids=[]) -%} -{%- set tool_call_counter = namespace(value=1) -%} +{%- set tool_call_counter = namespace(value=0) -%} {%- macro render_toolcalls(message) -%} <|tool_calls_section_begin|> {%- for tool_call in message['tool_calls'] -%} - {%- if tool_call['id'] is defined -%} - {%- set formatted_id = tool_call['id'] -%} - {%- else -%} - {%- set formatted_id = 'functions.' + tool_call['function']['name'] + ':' + (tool_call_counter.value | string) -%} - {%- set tool_call_counter.value = tool_call_counter.value + 1 -%} - {%- endif -%} + {%- set formatted_id = 'functions.' + tool_call['function']['name'] + ':' + (tool_call_counter.value | string) -%} + {%- set tool_call_counter.value = tool_call_counter.value + 1 -%} {%- set _ = tool_response_queue.ids.append(formatted_id) -%} <|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{% if tool_call['function']['arguments'] is string %}{{ tool_call['function']['arguments'] }}{% else %}{{ tool_call['function']['arguments'] | tojson }}{% endif %}<|tool_call_end|> {%- endfor -%} diff --git a/models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja b/models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja new file mode 100644 index 0000000000..beb4d612c7 --- /dev/null +++ b/models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja @@ -0,0 +1,126 @@ +{#- Default system message if no system prompt is passed. #} +{%- set default_system_message = '# HOW YOU SHOULD THINK AND ANSWER\n\nFirst draft your thinking process (inner monologue) until you arrive at a response. Format your response using Markdown, and use LaTeX for any mathematical equations. Write both your thoughts and the response in the same language as the input.\n\nYour thinking process must follow the template below:[THINK]Your thoughts or/and draft, like working through an exercise on scratch paper. Be as casual and as long as you want until you are confident to generate the response to the user.[/THINK]Here, provide a self-contained response.' %} + +{#- Begin of sequence token. #} +{{- bos_token }} + +{#- Handle system prompt if it exists. #} +{#- System prompt supports text content or text and thinking chunks. #} +{%- if messages[0]['role'] == 'system' %} + {{- '[SYSTEM_PROMPT]' -}} + {%- if messages[0]['content'] is string %} + {{- messages[0]['content'] -}} + {%- else %} + {%- for block in messages[0]['content'] %} + {%- if block['type'] == 'text' %} + {{- block['text'] }} + {%- elif block['type'] == 'thinking' %} + {{- '[THINK]' + block['thinking'] + '[/THINK]' }} + {%- else %} + {{- raise_exception('Only text and thinking chunks are supported in system message contents.') }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '[/SYSTEM_PROMPT]' -}} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} + {%- if default_system_message != '' %} + {{- '[SYSTEM_PROMPT]' + default_system_message + '[/SYSTEM_PROMPT]' }} + {%- endif %} +{%- endif %} + + +{#- Tools definition #} +{%- set tools_definition = '' %} +{%- set has_tools = false %} +{%- if tools is defined and tools is not none and tools|length > 0 %} + {%- set has_tools = true %} + {%- set tools_definition = '[AVAILABLE_TOOLS]' + (tools| tojson) + '[/AVAILABLE_TOOLS]' %} + {{- tools_definition }} +{%- endif %} + +{#- Checks for alternating user/assistant messages. #} +{%- set ns = namespace(index=0) %} +{%- for message in loop_messages %} + {%- if message.role == 'user' or (message.role == 'assistant' and (message.tool_calls is not defined or message.tool_calls is none or message.tool_calls | length == 0)) %} + {%- if (message['role'] == 'user') != (ns.index % 2 == 0) %} + {{- raise_exception('After the optional system message, conversation roles must alternate user and assistant roles except for tool calls and results.') }} + {%- endif %} + {%- set ns.index = ns.index + 1 %} + {%- endif %} +{%- endfor %} + +{#- Handle conversation messages. #} +{%- for message in loop_messages %} + + {#- User messages supports text content or text and image chunks. #} + {%- if message['role'] == 'user' %} + {%- if message['content'] is string %} + {{- '[INST]' + message['content'] + '[/INST]' }} + {%- elif message['content'] | length > 0 %} + {{- '[INST]' }} + {%- if message['content'] | length == 2 %} + {%- set blocks = message['content'] | sort(attribute='type') %} + {%- else %} + {%- set blocks = message['content'] %} + {%- endif %} + {%- for block in blocks %} + {%- if block['type'] == 'text' %} + {{- block['text'] }} + {%- elif block['type'] in ['image', 'image_url'] %} + {{- '[IMG]' }} + {%- else %} + {{- raise_exception('Only text, image and image_url chunks are supported in user message content.') }} + {%- endif %} + {%- endfor %} + {{- '[/INST]' }} + {%- else %} + {{- raise_exception('User message must have a string or a list of chunks in content') }} + {%- endif %} + + {#- Assistant messages supports text content or text, image and thinking chunks. #} + {%- elif message['role'] == 'assistant' %} + {%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %} + {{- raise_exception('Assistant message must have a string or a list of chunks in content or a list of tool calls.') }} + {%- endif %} + + {%- if message['content'] is string and message['content'] != '' %} + {{- message['content'] }} + {%- elif message['content'] | length > 0 %} + {%- for block in message['content'] %} + {%- if block['type'] == 'text' %} + {{- block['text'] }} + {%- elif block['type'] == 'thinking' %} + {{- '[THINK]' + block['thinking'] + '[/THINK]' }} + {%- else %} + {{- raise_exception('Only text and thinking chunks are supported in assistant message contents.') }} + {%- endif %} + {%- endfor %} + {%- endif %} + + {%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 %} + {%- for tool in message['tool_calls'] %} + {{- '[TOOL_CALLS]' }} + {%- set name = tool['function']['name'] %} + {%- set arguments = tool['function']['arguments'] %} + {%- if arguments is not string %} + {%- set arguments = arguments|tojson|safe %} + {%- elif arguments == '' %} + {%- set arguments = '{}' %} + {%- endif %} + {{- name + '[ARGS]' + arguments }} + {%- endfor %} + {%- endif %} + + {{- eos_token }} + + {#- Tool messages only supports text content. #} + {%- elif message['role'] == 'tool' %} + {{- '[TOOL_RESULTS]' + message['content']|string + '[/TOOL_RESULTS]' }} + + {#- Raise exception for unsupported roles. #} + {%- else %} + {{- raise_exception('Only user, assistant and tool roles are supported, got ' + message['role'] + '.') }} + {%- endif %} +{%- endfor %} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index fbd538109b..4192af7c0c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -67,7 +67,7 @@ add_library(llama models/gemma-embedding.cpp models/gemma.cpp models/gemma2-iswa.cpp - models/gemma3-iswa.cpp + models/gemma3.cpp models/gemma3n-iswa.cpp models/glm4-moe.cpp models/glm4.cpp @@ -139,6 +139,7 @@ add_library(llama set_target_properties(llama PROPERTIES VERSION ${LLAMA_INSTALL_VERSION} SOVERSION 0 + MACHO_CURRENT_VERSION 0 # keep macOS linker from seeing oversized version number ) target_include_directories(llama PRIVATE .) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e04f0fc4f9..4171400713 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -248,7 +248,10 @@ llama_context::llama_context( LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); - const size_t max_nodes = this->graph_max_nodes(); + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + const size_t max_nodes = this->graph_max_nodes(n_tokens); LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); @@ -300,9 +303,6 @@ llama_context::llama_context( cross.v_embd.clear(); - const uint32_t n_seqs = cparams.n_seq_max; - const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - // avoid reserving graphs with zero outputs - assume one output per sequence n_outputs = n_seqs; @@ -1386,9 +1386,9 @@ void llama_context::output_reorder() { // graph // -uint32_t llama_context::graph_max_nodes() const { +uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { if (model.arch == LLM_ARCH_QWEN3NEXT) { - return std::max(8192u, 32u*model.n_tensors()); + return std::max(n_tokens * 40, 32u * model.n_tensors()); } return std::max(1024u, 8u*model.n_tensors()); } diff --git a/src/llama-context.h b/src/llama-context.h index 20cbd78955..cd26eafe18 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -197,7 +197,7 @@ private: // public: - uint32_t graph_max_nodes() const; + uint32_t graph_max_nodes(uint32_t n_tokens) const; // can reuse the llm_graph_result instance of the context (for example to update a memory module) llm_graph_result * get_gf_res_reserve() const; diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index b3c5eb5717..75d5d750c3 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -181,6 +181,52 @@ static std::pair parse_char(const char * src) { throw std::runtime_error("unexpected end of input"); } +static std::pair parse_token(const llama_vocab * vocab, const char * src) { + const char * pos = src; + if (*pos != '<') { + throw std::runtime_error(std::string("expecting '<' at ") + pos); + } + pos++; + + // Parse <[id]> + if (*pos == '[') { + pos++; + const char * int_end = parse_int(pos); + uint32_t token_id = std::stoul(std::string(pos, int_end - pos)); + pos = int_end; + if (*pos != ']') { + throw std::runtime_error(std::string("expecting ']' at ") + pos); + } + pos++; + if (*pos != '>') { + throw std::runtime_error(std::string("expecting '>' at ") + pos); + } + pos++; + return std::make_pair(token_id, pos); + } + + if (vocab == nullptr) { + throw std::runtime_error(std::string("no vocab to parse token at ") + src); + } + + // Parse and tokenize to obtain the token id + while (*pos != 0 && *pos != '>') { + pos++; + } + if (*pos != '>') { + throw std::runtime_error(std::string("expecting '>' at ") + pos); + } + pos++; + + llama_token tokens[2]; + int32_t n_tokens = vocab->tokenize(src, static_cast(pos - src), tokens, 2, false, true); + if (n_tokens != 1) { + // must tokenize to exactly 1 token + throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'"); + } + return std::make_pair(tokens[0], pos); +} + static void print_grammar_char(FILE * file, uint32_t c) { if (0x20 <= c && c <= 0x7f) { fprintf(file, "%c", static_cast(c)); @@ -212,6 +258,8 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) { case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; + case LLAMA_GRETYPE_TOKEN: fprintf(file, "TOKEN"); break; + case LLAMA_GRETYPE_TOKEN_NOT: fprintf(file, "TOKEN_NOT"); break; } switch (elem.type) { case LLAMA_GRETYPE_END: @@ -228,6 +276,17 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) { print_grammar_char(file, elem.value); fprintf(file, "\") "); break; + case LLAMA_GRETYPE_TOKEN: + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; + case LLAMA_GRETYPE_TOKEN_NOT: + fprintf(file, "!"); + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; } } fprintf(file, "\n"); @@ -284,6 +343,17 @@ static void print_rule( case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "."); break; + case LLAMA_GRETYPE_TOKEN: + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; + case LLAMA_GRETYPE_TOKEN_NOT: + fprintf(file, "!"); + fprintf(file, "<["); + fprintf(file, "%u", elem.value); + fprintf(file, "]> "); + break; } if (is_char_element(elem)) { switch (rule[i + 1].type) { @@ -444,6 +514,17 @@ const char * llama_grammar_parser::parse_sequence( } } pos = parse_space(pos + 1, is_nested); + } else if (*pos == '<' || *pos == '!') { // token + auto type = LLAMA_GRETYPE_TOKEN; + if (*pos == '!') { // token inverse + type = LLAMA_GRETYPE_TOKEN_NOT; + pos++; + } + auto token_pair = parse_token(vocab, pos); + const char * token_end = token_pair.second; + last_sym_start = rule.size(); + rule.push_back({type, token_pair.first}); + pos = parse_space(token_end, is_nested); } else if (is_word_char(*pos)) { // rule reference const char * name_end = parse_name(pos); uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); @@ -691,6 +772,21 @@ static bool llama_grammar_match_partial_char( return !is_positive_char; } +// returns true iff token matches the rule at pos (regular or inverse) +// asserts that pos is pointing to a token element +static bool llama_grammar_match_token( + const llama_grammar_element * pos, + const llama_token token) { + GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT); + if (pos->type == LLAMA_GRETYPE_TOKEN) { + return pos->value == static_cast(token); + } + if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + return pos->value != static_cast(token); + } + return false; +} + // transforms a grammar pushdown stack into N possible stacks, all ending // at a character range (terminal element) static void llama_grammar_advance_stack( @@ -738,6 +834,8 @@ static void llama_grammar_advance_stack( case LLAMA_GRETYPE_CHAR: case LLAMA_GRETYPE_CHAR_NOT: case LLAMA_GRETYPE_CHAR_ANY: + case LLAMA_GRETYPE_TOKEN: + case LLAMA_GRETYPE_TOKEN_NOT: if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { // only add the stack if it's not a duplicate of one we already have new_stacks.emplace_back(stack); @@ -831,26 +929,38 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) return grammar->stacks; } +static void llama_grammar_accept_chr( + struct llama_grammar & grammar, + const llama_grammar_stack & stack, + uint32_t chr, + llama_grammar_stacks & new_stacks) { + if (stack.empty()) { + return; + } + + const llama_grammar_element * pos = stack.back(); + + // ignore if this turns into a token + if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + return; + } + + auto match = llama_grammar_match_char(pos, chr); + if (match.first) { + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(match.second)) { + new_stack.push_back(match.second); + } + llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks); + } +} + void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) { llama_grammar_stacks stacks_new; stacks_new.reserve(grammar->stacks.size()); for (const auto & stack : grammar->stacks) { - if (stack.empty()) { - continue; - } - - auto match = llama_grammar_match_char(stack.back(), chr); - if (match.first) { - const llama_grammar_element * pos = match.second; - - // update top of stack to next element, if any - llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); - if (!llama_grammar_is_end_of_sequence(pos)) { - new_stack.push_back(pos); - } - llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new); - } + llama_grammar_accept_chr(*grammar, stack, chr, stacks_new); } grammar->stacks = std::move(stacks_new); @@ -875,6 +985,22 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( const llama_grammar_element * stack_pos = stack.back(); + // if the top of the stack is a token rule, then we only need to check the token id + if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + for (const auto & tok : candidates) { + if (*tok.code_points == 0) { + // reached the end of a token consumed by char rules, reject iff it ended + // in a partial response + if (tok.partial_utf8.n_remain != 0) { + rejects.push_back(tok); + } + } else if (!llama_grammar_match_token(stack_pos, tok.id)) { + rejects.push_back(tok); + } + } + return rejects; + } + llama_grammar_candidates next_candidates; next_candidates.reserve(candidates.size()); @@ -887,7 +1013,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( rejects.push_back(tok); } } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) { - next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 }); + next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id }); } else { rejects.push_back(tok); } @@ -905,7 +1031,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); for (const auto & tok : next_rejects) { - rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); + rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id }); } return rejects; @@ -972,12 +1098,13 @@ struct llama_grammar * llama_grammar_init_impl( vocab, std::move(vec_rules), std::move(stacks), - /* .partial_utf8 = */ {}, - /* .lazy =*/ false, - /* .awaiting_trigger = */ false, - /* .trigger_buffer = */ "", - /* .trigger_tokens = */ {}, - /* .trigger_patterns = */ {}, + /* .partial_utf8 = */ {}, + /* .lazy = */ false, + /* .awaiting_trigger = */ false, + /* .trigger_buffer = */ "", + /* .trigger_buffer_positions = */ {}, + /* .trigger_tokens = */ {}, + /* .trigger_patterns = */ {}, }; } @@ -990,7 +1117,7 @@ struct llama_grammar * llama_grammar_init_impl( size_t num_trigger_patterns, const llama_token * trigger_tokens, size_t num_trigger_tokens) { - llama_grammar_parser parser; + llama_grammar_parser parser(vocab); // if there is a grammar, parse it // rules will be empty (default) if there are parse errors @@ -1077,10 +1204,11 @@ struct llama_grammar * llama_grammar_init_impl( vocab, std::move(vec_rules), std::move(stacks), - /* .partial_utf8 = */ {}, - /* .lazy = */ lazy, - /* .awaiting_trigger = */ lazy, - /* .trigger_buffer = */ "", + /* .partial_utf8 = */ {}, + /* .lazy = */ lazy, + /* .awaiting_trigger = */ lazy, + /* .trigger_buffer = */ "", + /* .trigger_buffer_positions = */ {}, std::move(vec_trigger_tokens), std::move(vec_trigger_patterns), }; @@ -1103,6 +1231,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra grammar.lazy, grammar.awaiting_trigger, grammar.trigger_buffer, + grammar.trigger_buffer_positions, grammar.trigger_tokens, grammar.trigger_patterns, }; @@ -1156,7 +1285,7 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ cur_p->data[i].logit = -INFINITY; } else { candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8)); - candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); + candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id }); } } @@ -1175,10 +1304,12 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { grammar.awaiting_trigger = false; grammar.trigger_buffer.clear(); - llama_grammar_accept_str(grammar, piece); + llama_grammar_accept_token(grammar, token, piece); LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); return; } else { + auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size()); + grammar.trigger_buffer_positions.push_back(std::make_pair(token, position)); grammar.trigger_buffer += piece; std::smatch match; @@ -1196,10 +1327,23 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token if (start == std::string::npos) { start = match.position(0); } + + // replay tokens that overlap with [start, end) + for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) { + auto [tok_start, tok_end] = tok_pos; + if (tok_end <= start) { + continue; + } + + size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces + size_t piece_len = tok_end - piece_start; + auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len); + llama_grammar_accept_token(grammar, tok, tok_piece); + } + auto constrained_str = grammar.trigger_buffer.substr(start); - // std::string constrained_str(match[1].first, grammar.trigger_buffer.end()); grammar.trigger_buffer.clear(); - llama_grammar_accept_str(grammar, constrained_str); + grammar.trigger_buffer_positions.clear(); LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str()); return; } @@ -1218,7 +1362,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token GGML_ABORT("fatal error"); } - llama_grammar_accept_str(grammar, piece); + llama_grammar_accept_token(grammar, token, piece); } void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) { @@ -1235,3 +1379,59 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece); } } + +void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) { + // Note terminating 0 in decoded string + const auto decoded = decode_utf8(piece, grammar.partial_utf8); + const auto & code_points = decoded.first; + + llama_grammar_stacks stacks_new; + stacks_new.reserve(grammar.stacks.size()); + + for (const auto & stack : grammar.stacks) { + if (stack.empty()) { + continue; + } + + const llama_grammar_element * pos = stack.back(); + + if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) { + if (llama_grammar_match_token(pos, token)) { + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + new_stack.push_back(pos + 1); + } + llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new); + } + } else { + llama_grammar_stacks current_stacks = {stack}; + + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + llama_grammar_stacks next_stacks; + + for (const auto & cur_stack : current_stacks) { + llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks); + } + + current_stacks = std::move(next_stacks); + if (current_stacks.empty()) { + break; + } + } + + for (auto & surviving_stack : current_stacks) { + if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) { + stacks_new.emplace_back(surviving_stack); + } + } + } + } + + grammar.stacks = std::move(stacks_new); + grammar.partial_utf8 = decoded.second; + + if (grammar.stacks.empty()) { + throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")"); + } +} + diff --git a/src/llama-grammar.h b/src/llama-grammar.h index f8c291de99..a4c978ac11 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -36,11 +36,17 @@ enum llama_gretype { // any character (.) LLAMA_GRETYPE_CHAR_ANY = 7, + + // terminal element: token (<[token-id]>) + LLAMA_GRETYPE_TOKEN = 8, + + // inverse token (!<[token-id]>) + LLAMA_GRETYPE_TOKEN_NOT = 9, }; typedef struct llama_grammar_element { enum llama_gretype type; - uint32_t value; // Unicode code point or rule ID + uint32_t value; // Unicode code point, rule ID, or token ID } llama_grammar_element; struct llama_partial_utf8 { @@ -52,6 +58,7 @@ struct llama_grammar_candidate { size_t index; const uint32_t * code_points; llama_partial_utf8 partial_utf8; + llama_token id; }; using llama_grammar_rule = std::vector< llama_grammar_element>; @@ -77,10 +84,13 @@ std::vector llama_grammar_reject_candidates_for_stack( const llama_grammar_candidates & candidates); struct llama_grammar_parser { + const llama_vocab * vocab; std::map symbol_ids; llama_grammar_rules rules; + llama_grammar_parser(const struct llama_vocab * vocab = nullptr) : vocab(vocab) {} + llama_grammar_stack c_rules() const; uint32_t get_symbol_id(const char * src, size_t len); @@ -112,6 +122,9 @@ struct llama_grammar_trigger_pattern { }; struct llama_grammar { + // maintain a list of llama_tokens and their positions in the trigger_buffer + using token_pos = std::pair>; + // note: allow null vocab for testing (not great) const llama_vocab * vocab; @@ -127,6 +140,7 @@ struct llama_grammar { bool lazy = false; bool awaiting_trigger = false; // Initialized to true for lazy grammars only std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. + std::vector trigger_buffer_positions; // Tokens buffered by lazy grammar. Used to replay when a trigger is found. std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). std::vector trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated @@ -171,3 +185,8 @@ void llama_grammar_accept_impl( void llama_grammar_accept_str( struct llama_grammar & grammar, const std::string & piece); + +void llama_grammar_accept_token( + struct llama_grammar & grammar, + llama_token token, + const std::string & piece); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 42ccb5b76a..43620df780 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -973,7 +973,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( // mask out the other groups selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens] - selection_probs = ggml_set_rows(ctx0, ggml_scale_bias(ctx0, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens] + selection_probs = ggml_set_rows(ctx0, ggml_fill(ctx0, selection_groups, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens] selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens] cb(selection_probs, "ffn_moe_probs_masked", il); } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 7d09d7abd5..04fccc9793 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1264,18 +1264,25 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GEMMA3: { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(6); + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(6); - hparams.rope_freq_base_train_swa = 10000.0f; - hparams.rope_freq_scale_train_swa = 1.0f; + hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_scale_train_swa = 1.0f; + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + hparams.f_final_logit_softcapping = 0.0f; + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 18: type = LLM_TYPE_270M; break; case 26: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_8B; break; // Rnj-1 case 34: type = LLM_TYPE_4B; break; case 48: type = LLM_TYPE_12B; break; case 62: type = LLM_TYPE_27B; break; @@ -1599,8 +1606,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); - switch (hparams.n_layer) { - case 28: type = LLM_TYPE_20B; break; + switch (hparams.n_ff_exp) { + case 1408: type = LLM_TYPE_16B; break; + case 1792: type = LLM_TYPE_20B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -7304,7 +7312,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_GEMMA3: { - llm = std::make_unique(*this, params); + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + llm = std::make_unique>(*this, params); + } else { + llm = std::make_unique>(*this, params); + } } break; case LLM_ARCH_GEMMA3N: { diff --git a/src/models/gemma3-iswa.cpp b/src/models/gemma3.cpp similarity index 78% rename from src/models/gemma3-iswa.cpp rename to src/models/gemma3.cpp index 839ff6d3d9..ae60ef4790 100644 --- a/src/models/gemma3-iswa.cpp +++ b/src/models/gemma3.cpp @@ -1,6 +1,7 @@ #include "models.h" -llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +template +llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; ggml_tensor * cur; @@ -17,13 +18,28 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll ggml_tensor * inp_pos = build_inp_pos(); // TODO: is causal == true correct? might need some changes - auto * inp_attn = build_attn_inp_kv_iswa(); + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { - const float freq_base_l = model.get_rope_freq_base (cparams, il); - const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + float freq_base_l = 0.0f; + float freq_scale_l = 0.0f; + + if constexpr (iswa) { + freq_base_l = model.get_rope_freq_base (cparams, il); + freq_scale_l = model.get_rope_freq_scale(cparams, il); + } else { + freq_base_l = freq_base; + freq_scale_l = freq_scale; + } // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); @@ -102,7 +118,7 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, -1); - cb(cur, "ffn_post_norm", -1); + cb(cur, "ffn_post_norm", il); cur = ggml_add(ctx0, cur, sa_out); @@ -124,8 +140,17 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll // lm_head cur = build_lora_mm(model.output, cur); + if (hparams.f_final_logit_softcapping) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + } + cb(cur, "result_output", -1); res->t_logits = cur; ggml_build_forward_expand(gf, cur); } + +template struct llm_build_gemma3; +template struct llm_build_gemma3; diff --git a/src/models/models.h b/src/models/models.h index d93601ad06..6494f54501 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -179,8 +179,9 @@ struct llm_build_gemma2_iswa : public llm_graph_context { llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_gemma3_iswa : public llm_graph_context { - llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params); +template +struct llm_build_gemma3 : public llm_graph_context { + llm_build_gemma3(const llama_model & model, const llm_graph_params & params); }; struct llm_build_gemma3n_iswa : public llm_graph_context { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2e94a53da2..a6f266601f 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6253,6 +6253,31 @@ struct test_solve_tri : public test_case { } }; +// GGML_OP_DIAG +struct test_diag : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { return VARS_TO_STR2(type, ne); } + + test_diag(ggml_type type = GGML_TYPE_F32, + std::array ne = { 10, 1, 4, 3 }) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + GGML_ASSERT(ne[1] == 1); + ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]); + ggml_set_param(a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_diag(ctx, a); + ggml_set_name(out, "out"); + + return out; + } +}; + + enum llm_norm_type { LLM_NORM, LLM_NORM_RMS, @@ -7826,6 +7851,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_fill(-152.0f, GGML_TYPE_F32, { 800, 600, 4, 4 })); test_cases.emplace_back(new test_fill(3.5f, GGML_TYPE_F32, { 2048, 512, 2, 2 })); + test_cases.emplace_back(new test_diag()); + test_cases.emplace_back(new test_diag(GGML_TYPE_F32, { 79, 1, 19, 13 })); + test_cases.emplace_back(new test_diag(GGML_TYPE_F32, { 256, 1, 8, 16 })); + test_cases.emplace_back(new test_solve_tri()); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 11, 11, 1, 1 }, { 5, 11, 1, 1 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 17, 17, 2, 4 }, { 9, 17, 2, 4 })); @@ -8164,6 +8193,13 @@ static std::vector> make_test_cases_perf() { } } + // Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1})); // generate + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate + + return test_cases; } diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 62dd1583fa..007929f517 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -428,10 +428,38 @@ static void test_templates(const struct common_chat_templates * tmpls, const std */ template static void test_parser_with_streaming(const common_chat_msg & expected, const std::string & raw_message, T parse_msg) { + constexpr auto utf8_truncate_safe_len = [](const std::string_view s) -> size_t { + auto len = s.size(); + if (len == 0) return 0; + auto i = len; + for (size_t back = 0; back < 4 && i > 0; ++back) { + --i; + unsigned char c = s[i]; + if ((c & 0x80) == 0) { + return len; + } else if ((c & 0xC0) == 0xC0) { + size_t expected_len = 0; + if ((c & 0xE0) == 0xC0) expected_len = 2; + else if ((c & 0xF0) == 0xE0) expected_len = 3; + else if ((c & 0xF8) == 0xF0) expected_len = 4; + else return i; + if (len - i >= expected_len) { + return len; + } else { + return i; + } + } + } + return len - std::min(len, size_t(3)); + }; + constexpr auto utf8_truncate_safe_view = [utf8_truncate_safe_len](const std::string_view s) { + return s.substr(0, utf8_truncate_safe_len(s)); + }; + auto merged = simple_assist_msg(""); auto last_msg = parse_msg(""); for (size_t i = 1; i <= raw_message.size(); ++i) { - auto curr_msg = parse_msg(raw_message.substr(0, i)); + auto curr_msg = parse_msg(std::string(utf8_truncate_safe_view(std::string_view(raw_message).substr(0, i)))); if (curr_msg == simple_assist_msg("")) continue; LOG_INF("Streaming msg: %s\n", common_chat_msgs_to_json_oaicompat({curr_msg}).dump().c_str()); for (auto diff: common_chat_msg_diff::compute_diffs(last_msg, curr_msg)) { @@ -511,6 +539,71 @@ const common_chat_msg message_assist_call_python_lines = simple_assist const common_chat_msg message_assist_call_python_lines_unclosed = simple_assist_msg("", "", "python", "{\"code\":\"# This is a program:\\nprint('hey')"); const common_chat_msg message_assist_call_code_interpreter = simple_assist_msg("", "", "code_interpreter", "{\"code\":\"print('hey')\"}"); +// Use for PEG parser implementations +struct peg_test_case { + common_chat_templates_inputs params; + std::string input; + common_chat_msg expect; +}; + +struct make_peg_parser { + common_chat_params params_; + common_peg_arena arena_; + + make_peg_parser(common_chat_templates * tmpls, const common_chat_templates_inputs & inputs) { + params_ = common_chat_templates_apply(tmpls, inputs); + arena_.load(params_.parser); + } + + common_chat_msg parse(const std::string & msg, bool is_partial) { + return common_chat_peg_parse(arena_, msg, is_partial, /* syntax = */ {params_.format}); + } +}; + +static void test_peg_parser(common_chat_templates * tmpls, const std::function & init) { + peg_test_case tc; + init(tc); + if (tc.params.messages.empty()) { + tc.params.messages = {message_user}; + } + if (tc.expect.role.empty()) { + tc.expect.role = "assistant"; + } + + auto parser = make_peg_parser(tmpls, tc.params); + + common_chat_msg msg_accum; + common_chat_msg msg_prev; + msg_accum.role = msg_prev.role = "assistant"; + + for (size_t i = 1; i <= tc.input.size(); ++i) { + auto is_partial = i < tc.input.size(); + common_chat_msg msg_current = parser.parse(tc.input.substr(0, i), is_partial); + + for (const auto & diff : common_chat_msg_diff::compute_diffs(msg_prev, msg_current)) { + if (!diff.reasoning_content_delta.empty()) { + msg_accum.reasoning_content += diff.reasoning_content_delta; + } + if (!diff.content_delta.empty()) { + msg_accum.content += diff.content_delta; + } + if (diff.tool_call_index != std::string::npos) { + if (!diff.tool_call_delta.name.empty()) { + msg_accum.tool_calls.push_back({diff.tool_call_delta.name, "", ""}); + } + if (!diff.tool_call_delta.arguments.empty()) { + msg_accum.tool_calls.back().arguments += diff.tool_call_delta.arguments; + } + } + } + assert_msg_equals(msg_current, msg_accum, true); + msg_prev = msg_current; + } + + assert_msg_equals(tc.expect, parser.parse(tc.input, false), true); + assert_msg_equals(tc.expect, msg_accum, true); +} + static void test_msgs_oaicompat_json_conversion() { printf("[%s]\n", __func__); std::vector msgs{ @@ -2659,14 +2752,14 @@ Hey there!<|im_end|> // Test parsing tool calls assert_msg_equals(message_assist_call, common_chat_parse( - "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", /* is_partial= */ false, {COMMON_CHAT_FORMAT_KIMI_K2})); // Test parsing tool calls with thinking assert_msg_equals(message_assist_call_thoughts, common_chat_parse( - "I'm\nthinking<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", + "I'm\nthinking<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", /* is_partial= */ false, { /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, @@ -2676,7 +2769,7 @@ Hey there!<|im_end|> // Test tool calls with extra content assert_msg_equals(message_assist_call_content, common_chat_parse( - "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>Hello, world!\nWhat's up?", + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>Hello, world!\nWhat's up?", /* is_partial= */ false, {COMMON_CHAT_FORMAT_KIMI_K2} )); @@ -2684,7 +2777,7 @@ Hey there!<|im_end|> // Test tool calls with extra content AND thinking assert_msg_equals(message_assist_call_thoughts_content, common_chat_parse( - "I'm\nthinking<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>Hello, world!\nWhat's up?", + "I'm\nthinking<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>Hello, world!\nWhat's up?", /* is_partial= */ false, { /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, @@ -2693,47 +2786,152 @@ Hey there!<|im_end|> // Test streaming test_parser_with_streaming(message_assist_call_thoughts_content, - "I'm\nthinking\nHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", + "I'm\nthinking\nHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK }); }); test_parser_with_streaming(message_assist_call_thoughts_unparsed, - "I'm\nthinking\n\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", + "I'm\nthinking\n\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE }); }); test_parser_with_streaming(message_assist_call_thoughts_content, - "I'm\nthinking\n\n\nHello, world!\nWhat's up?\n\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>\n", + "I'm\nthinking\n\n\nHello, world!\nWhat's up?\n\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>\n", [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK }); }); test_parser_with_streaming(message_assist_call_withopt, - "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function_with_opt:1<|tool_call_argument_begin|>{\"arg1\": 1, \"arg2\": 2}<|tool_call_end|><|tool_calls_section_end|>", + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function_with_opt:0<|tool_call_argument_begin|>{\"arg1\": 1, \"arg2\": 2}<|tool_call_end|><|tool_calls_section_end|>", [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE }); }); test_parser_with_streaming(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": \"123456\"}"), - "I'm\nthinkingHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": \"123456\"}<|tool_call_end|><|tool_calls_section_end|>", + "I'm\nthinkingHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": \"123456\"}<|tool_call_end|><|tool_calls_section_end|>", [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK }); }); test_parser_with_streaming(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": [1, 2, \"345\", 6]}"), - "I'm\nthinkingHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": [1, 2, \"345\", 6]}<|tool_call_end|><|tool_calls_section_end|>", + "I'm\nthinkingHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": [1, 2, \"345\", 6]}<|tool_call_end|><|tool_calls_section_end|>", [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK }); }); test_parser_with_streaming(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": {\"12\": 34, \"5\": [67, 8], \"9\": \"10\"}}"), - "I'm\nthinkingHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": {\"12\": 34, \"5\": [67, 8], \"9\": \"10\"}}<|tool_call_end|><|tool_calls_section_end|>", + "I'm\nthinkingHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": {\"12\": 34, \"5\": [67, 8], \"9\": \"10\"}}<|tool_call_end|><|tool_calls_section_end|>", [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK }); }); + test_parser_with_streaming( + simple_assist_msg("", "", "complex_function", "{\"name\":\"John Doe\",\"age\":30,\"active\":true,\"score\":95.5}"), + "<|tool_calls_section_begin|><|tool_call_begin|>functions.complex_function:0<|tool_call_argument_begin|>" + "{\"name\": \"John Doe\", \"age\": 30, \"active\": true, \"score\": 95.5}" + "<|tool_call_end|><|tool_calls_section_end|>", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_KIMI_K2}); }); + test_parser_with_streaming( + simple_assist_msg("", "", "web_search", "{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}"), + "<|tool_calls_section_begin|><|tool_call_begin|>functions.web_search:0<|tool_call_argument_begin|>" + "{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}" + "<|tool_call_end|><|tool_calls_section_end|>", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_KIMI_K2}); }); + test_parser_with_streaming( + simple_assist_msg("", "", "read_file", "{\"args\": [{\"path\": \"src/providers/ThemeProvider.tsx\"}, {\"path\": \"src/components/Header.tsx\"}, {\"path\": \"src/components/ThemeToggle.tsx\"}, {\"path\": \"src/app/globals.css\"}, {\"path\": \"src/app/layout.tsx\"}]}"), + "<|tool_calls_section_begin|><|tool_call_begin|>functions.read_file:0<|tool_call_argument_begin|>" + "{\"args\": [{\"path\": \"src/providers/ThemeProvider.tsx\"}, {\"path\": \"src/components/Header.tsx\"}, {\"path\": \"src/components/ThemeToggle.tsx\"}, {\"path\": \"src/app/globals.css\"}, {\"path\": \"src/app/layout.tsx\"}]}" + "<|tool_call_end|><|tool_calls_section_end|>", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_KIMI_K2}); }); + test_parser_with_streaming( + simple_assist_msg( + "Let me start by examining the relevant files to understand the current implementation.", "", + "read_file", + "{\"files\": [{\"path\": \"src/app/Partners.tsx\", \"line_ranges\": [\"1-100\"]}]}"), + "Let me start by examining the relevant files to understand the current implementation." + "<|tool_calls_section_begin|><|tool_call_begin|>functions.read_file:0<|tool_call_argument_begin|>" + "{\"files\":[{\"path\":\"src/app/Partners.tsx\",\"line_ranges\":[\"1-100\"]}]}" + "<|tool_call_end|><|tool_calls_section_end|>", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_KIMI_K2}); }); + auto multi_tool_msg = simple_assist_msg("Let me call multiple tools.", "I'm thinking."); + multi_tool_msg.tool_calls.push_back({ "read_file", "{\"files\": [{\"path\": \"src/app/Partners.tsx\", \"line_ranges\": [\"1-100\"]}]}", "" }); + multi_tool_msg.tool_calls.push_back({ "web_search", "{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}", "" }); + multi_tool_msg.tool_calls.push_back({ "complex_function", "{\"name\": \"John Doe\", \"age\": 30, \"active\": true, \"score\": 95.5}", "" }); + multi_tool_msg.tool_calls.push_back({ "emoji_function", "{\"message\":\"Hello! πŸ‘‹ 🌟 πŸš€ Testing emojis: πŸ˜€πŸ˜ƒπŸ˜„πŸ˜ and symbols: βˆ‘βˆβˆ†βˆ‡\"}", "" }); + test_parser_with_streaming(multi_tool_msg, + "I'm thinking.Let me call multiple tools." + "<|tool_calls_section_begin|>" + "<|tool_call_begin|>functions.read_file:0<|tool_call_argument_begin|>" + "{\"files\":[{\"path\":\"src/app/Partners.tsx\",\"line_ranges\":[\"1-100\"]}]}" + "<|tool_call_end|>" + "<|tool_call_begin|>functions.web_search:1<|tool_call_argument_begin|>" + "{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}" + "<|tool_call_end|>" + "<|tool_call_begin|>functions.complex_function:2<|tool_call_argument_begin|>" + "{\"name\": \"John Doe\", \"age\": 30, \"active\": true, \"score\": 95.5}" + "<|tool_call_end|>" + "<|tool_call_begin|>functions.emoji_function:3<|tool_call_argument_begin|>" + "{\"message\":\"Hello! πŸ‘‹ 🌟 πŸš€ Testing emojis: πŸ˜€πŸ˜ƒπŸ˜„πŸ˜ and symbols: βˆ‘βˆβˆ†βˆ‡\"}" + "<|tool_call_end|>" + "<|tool_calls_section_end|>", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + COMMON_CHAT_FORMAT_KIMI_K2, + COMMON_REASONING_FORMAT_DEEPSEEK + }); }); + test_parser_with_streaming( + simple_assist_msg("", "I'm thinking", "complex_function_in_think", "{\"name\":\"John Doe\",\"age\":30,\"active\":true,\"score\":95.5}"), + "I'm thinking<|tool_calls_section_begin|><|tool_call_begin|>functions.complex_function_in_think:0<|tool_call_argument_begin|>" + "{\"name\": \"John Doe\", \"age\": 30, \"active\": true, \"score\": 95.5}" + "<|tool_call_end|><|tool_calls_section_end|>", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + COMMON_CHAT_FORMAT_KIMI_K2, + COMMON_REASONING_FORMAT_DEEPSEEK + }); }); + test_parser_with_streaming( + simple_assist_msg("Hello", "I'm thinkingI'm still thinking", "complex_function_in_think", "{\"name\":\"John Doe\",\"age\":30,\"active\":true,\"score\":95.5}"), + "I'm thinking<|tool_calls_section_begin|><|tool_call_begin|>functions.complex_function_in_think:0<|tool_call_argument_begin|>" + "{\"name\": \"John Doe\", \"age\": 30, \"active\": true, \"score\": 95.5}" + "<|tool_call_end|><|tool_calls_section_end|>I'm still thinkingHello", + [&](const std::string &msg) { return common_chat_parse(msg, /* is_partial= */ true, { + COMMON_CHAT_FORMAT_KIMI_K2, + COMMON_REASONING_FORMAT_DEEPSEEK + }); }); + + // Test template rendering + common_chat_templates_inputs conversation_with_tools = inputs_tools; + conversation_with_tools.messages.push_back(simple_assist_msg("Let's do it", "Think first", "complex_function", "{\"name\":\"John Doe\",\"age\":30,\"active\":true,\"score\":95.5}")); + conversation_with_tools.messages.push_back({ + "tool", + "Tool response 1", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", + /* .tool_name = */ "complex_function", + /* .tool_call_id = */ "", + }); + conversation_with_tools.messages.push_back(simple_assist_msg("Continue", "Think next", "web_search", "{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}")); + conversation_with_tools.messages.push_back({ + "tool", + "Tool response 2", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", + /* .tool_name = */ "web_search", + /* .tool_call_id = */ "", + }); + conversation_with_tools.messages.push_back(simple_assist_msg("CC", "Think last", "read_file", "{\"args\": [{\"path\": \"src/providers/ThemeProvider.tsx\"}, {\"path\": \"src/components/Header.tsx\"}, {\"path\": \"src/components/ThemeToggle.tsx\"}, {\"path\": \"src/app/globals.css\"}, {\"path\": \"src/app/layout.tsx\"}]}")); + conversation_with_tools.messages.push_back({ + "tool", + "Tool response 3", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", + /* .tool_name = */ "read_file", + /* .tool_call_id = */ "", + }); + assert_equals(common_chat_templates_apply(tmpls.get(), conversation_with_tools).prompt, std::string("<|im_system|>tool_declare<|im_middle|>[{\"type\": \"function\", \"function\": {\"name\": \"special_function\", \"description\": \"I'm special\", \"parameters\": {\"type\": \"object\", \"properties\": {\"arg1\": {\"type\": \"integer\", \"description\": \"The arg.\"}}, \"required\": [\"arg1\"]}}}]<|im_end|><|im_system|>system<|im_middle|>You are Kimi, an AI assistant created by Moonshot AI.<|im_end|><|im_user|>user<|im_middle|>Hey there!<|im_end|><|im_assistant|>assistant<|im_middle|>Think firstLet's do it<|tool_calls_section_begin|><|tool_call_begin|>functions.complex_function:0<|tool_call_argument_begin|>{\"name\":\"John Doe\",\"age\":30,\"active\":true,\"score\":95.5}<|tool_call_end|><|tool_calls_section_end|><|im_end|><|im_system|>complex_function<|im_middle|>## Return of functions.complex_function:0\nTool response 1<|im_end|><|im_assistant|>assistant<|im_middle|>Think nextContinue<|tool_calls_section_begin|><|tool_call_begin|>functions.web_search:1<|tool_call_argument_begin|>{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}<|tool_call_end|><|tool_calls_section_end|><|im_end|><|im_system|>web_search<|im_middle|>## Return of functions.web_search:1\nTool response 2<|im_end|><|im_assistant|>assistant<|im_middle|>Think lastCC<|tool_calls_section_begin|><|tool_call_begin|>functions.read_file:2<|tool_call_argument_begin|>{\"args\": [{\"path\": \"src/providers/ThemeProvider.tsx\"}, {\"path\": \"src/components/Header.tsx\"}, {\"path\": \"src/components/ThemeToggle.tsx\"}, {\"path\": \"src/app/globals.css\"}, {\"path\": \"src/app/layout.tsx\"}]}<|tool_call_end|><|tool_calls_section_end|><|im_end|><|im_system|>read_file<|im_middle|>## Return of functions.read_file:2\nTool response 3<|im_end|><|im_assistant|>assistant<|im_middle|>")); // Test template generation for regular content test_templates(tmpls.get(), end_tokens, message_assist, tools, @@ -2742,7 +2940,7 @@ Hey there!<|im_end|> // Test template generation for tool calls test_templates(tmpls.get(), end_tokens, message_assist_call, tools, - "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", /* expect_grammar_triggered= */ true, /* test_grammar_if_triggered= */ true, /* common_reasoning_format= */ COMMON_REASONING_FORMAT_DEEPSEEK, @@ -2751,14 +2949,14 @@ Hey there!<|im_end|> // Test template generation for tools with optional parameters test_templates(tmpls.get(), end_tokens, message_assist_call_noopt, tools, - "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function_with_opt:1<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function_with_opt:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", /* expect_grammar_triggered= */ true, /* test_grammar_if_triggered= */ true, /* common_reasoning_format= */ COMMON_REASONING_FORMAT_DEEPSEEK, /* ignore_whitespace_differences= */ true ); test_templates(tmpls.get(), end_tokens, message_assist_call_withopt, tools, - "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function_with_opt:1<|tool_call_argument_begin|>{\"arg1\": 1, \"arg2\": 2}<|tool_call_end|><|tool_calls_section_end|>", + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function_with_opt:0<|tool_call_argument_begin|>{\"arg1\": 1, \"arg2\": 2}<|tool_call_end|><|tool_calls_section_end|>", /* expect_grammar_triggered= */ true, /* test_grammar_if_triggered= */ true, /* common_reasoning_format= */ COMMON_REASONING_FORMAT_DEEPSEEK, @@ -3301,7 +3499,95 @@ Hey there!<|im_end|> auto grammar = build_grammar(params.grammar); GGML_ASSERT(grammar && "Failed to build Qwen3-Coder grammar with union types"); } +} +static void test_template_output_peg_parsers() { + printf("[%s]\n", __func__); + + // JSON schemas + const char * invoice_schema = R"({ + "type": "object", + "properties": { + "amount": {"type": "number"}, + "date": {"type": "string"} + } + })"; + + { + // Ministral-3-14B-Reasoning-2512 + auto tmpls = read_templates("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja"); + + // Test basic message + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = "Hello, world!\nWhat's up?"; + t.expect = message_assist; + }); + + // Test basic message and reasoning with reasoning_format = none + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = "[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?"; + t.expect.content = "[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?"; + }); + + // Test basic message and reasoning with reasoning_format = auto + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = "[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?"; + t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; + + t.expect = message_assist_thoughts; + }); + + // Test tool call + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = R"([TOOL_CALLS]special_function[ARGS]{"arg1":1})"; + t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; + t.params.tools = {special_function_tool}; + + t.expect = message_assist_call; + }); + + // Test tool call with reasoning + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = "[THINK]I'm\nthinking[/THINK]" + R"([TOOL_CALLS]special_function[ARGS]{"arg1":1})"; + t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; + t.params.tools = {special_function_tool}; + + t.expect = message_assist_call_thoughts; + }); + + // Test parallel tool calls + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = R"([TOOL_CALLS]special_function[ARGS]{"arg1": 1})" + R"([TOOL_CALLS]special_function_with_opt[ARGS]{"arg1": 1, "arg2": 2})"; + t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; + t.params.parallel_tool_calls = true; + t.params.tools = {special_function_tool, special_function_tool_with_optional_param}; + + t.expect.tool_calls = {{ + /* .name = */ "special_function", + /* .arguments = */ R"({"arg1": 1})", + /* .id = */ {}, + }, { + /* .name = */ "special_function_with_opt", + /* .arguments = */ R"({"arg1": 1, "arg2": 2})", + /* .id = */ {}, + }}; + }); + + // Test response format + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = "[THINK]I need to output the invoice details in JSON[/THINK]" + "```json\n" + R"({"amount": 123.45, "date": "2025-12-03"})" + "\n```"; + t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; + t.params.json_schema = invoice_schema; + + t.expect.reasoning_content = "I need to output the invoice details in JSON"; + t.expect.content =R"({"amount": 123.45, "date": "2025-12-03"})"; + }); + } } static void test_msg_diffs_compute() { @@ -3427,6 +3713,7 @@ int main(int argc, char ** argv) { test_msgs_oaicompat_json_conversion(); test_tools_oaicompat_json_conversion(); test_template_output_parsers(); + test_template_output_peg_parsers(); std::cout << "\n[chat] All tests passed!" << '\n'; } return 0; diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 82fae671ed..7aa7e58a5c 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -32,13 +32,66 @@ static bool test_build_grammar_fails(const std::string & grammar_str) { return grammar_fails; } +struct token_and_piece { + llama_token token; + std::string piece; +}; + +// token() encodes a 32-bit ID as 5 bytes: a 0xff marker followed by the ID in big-endian order. +static std::string token(llama_token id) { + return std::string{ + static_cast(0xff), + static_cast((id >> 24) & 0xff), + static_cast((id >> 16) & 0xff), + static_cast((id >> 8) & 0xff), + static_cast(id & 0xff) + }; +} + +// parse_tokens() parses the token encodes above and UTF-8 text. +static std::vector parse_tokens(const std::string & input) { + std::vector result; + result.reserve(input.size()); + size_t offset = 0; + while (offset < input.size()) { + try { + if (static_cast(input[offset]) == 0xff) { + if (offset + 5 > input.size()) { + throw std::runtime_error("not enough bytes for token id"); + } + uint32_t val = + (static_cast(input[offset + 1]) << 24) | + (static_cast(input[offset + 2]) << 16) | + (static_cast(input[offset + 3]) << 8) | + (static_cast(input[offset + 4])); + auto piece = "<[" + std::to_string(val) + "]>"; + result.push_back({static_cast(val), piece}); + offset += 5; + } else { + uint32_t cpt = unicode_cpt_from_utf8(input, offset); + result.push_back({0, unicode_cpt_to_utf8(cpt)}); + } + } catch (const std::invalid_argument & /*ex*/) { + // Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize + ++offset; + result.push_back({0, unicode_cpt_to_utf8(0xFFFD)}); // replacement character + } + } + return result; +} + static bool match_string(const std::string & input, llama_grammar * grammar) { - const auto cpts = unicode_cpts_from_utf8(input); + const auto parsed = parse_tokens(input); auto & stacks_cur = llama_grammar_get_stacks(grammar); - for (const auto & cpt : cpts) { - llama_grammar_accept(grammar, cpt); + for (const auto & in : parsed) { + try { + llama_grammar_accept_token(*grammar, in.token, in.piece); + } catch (const std::runtime_error & /*e*/) { + // normally this shouldn't get hit because of llama_grammar_apply + return false; + } if (stacks_cur.empty()) { // no stacks means that the grammar failed to match at this point @@ -426,6 +479,30 @@ static void test_simple_grammar() { "12a45", } ); + + // Test case for a simple grammar with tokens + test_grammar( + "simple grammar with tokens", + R"""( + root ::= <[10]> content <[11]> + content ::= (!<[11]>)*)""", + // Passing strings + { + token(10) + "hello world" + token(11), + token(10) + "text with " + token(12) + " other tokens " + token(13) + " mixed in" + token(11), + token(10) + token(11), + token(10) + token(12) + token(13) + token(14) + token(15) + token(11), + token(10) + "a" + token(11), + }, + // Failing strings + { + token(10) + "missing end token", + token(10), + "missing start token" + token(11), + token(10) + token(11) + token(11), // double end token + token(11) + "wrong order" + token(10), + } + ); } static void test_complex_grammar() { @@ -487,6 +564,34 @@ static void test_complex_grammar() { "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/", } ); + + // Test case for a more complex grammar with tokens + test_grammar( + "complex grammar with tokens", + R"""( + root ::= reasoning+ content tool-call* + reasoning ::= <[10]> (!<[11]>)* <[11]> + content ::= <[20]> (!<[21]>)* <[21]> + tool-call ::= <[12]> name <[13]> args <[14]> + name ::= (!<[13]>)+ + args ::= (!<[14]>)*)""", + // Passing strings + { + token(10) + "I am thinking" + token(11) + token(20) + "hello world!" + token(21) + token(12) + "search" + token(13) + "query=test" + token(14), + token(10) + "reasoning 1" + token(11) + token(10) + "reasoning 2" + token(11) + token(20) + token(21) + token(12) + "tool" + token(13) + token(14), + token(10) + token(11) + token(20) + "content" + token(21), + token(10) + "think" + token(12) + " nested" + token(11) + token(20) + token(10) + "more content" + token(21) + token(12) + "fn" + token(13) + "x=1,y=2" + token(14) + token(12) + "fn2" + token(13) + token(14), + token(10) + "reasoning" + token(11) + token(10) + "more" + token(11) + token(10) + "even more" + token(11) + token(20) + "text" + token(21) + token(12) + "a" + token(13) + "b" + token(14) + token(12) + "c" + token(13) + "d" + token(14), + }, + // Failing strings + { + token(20) + "content only" + token(21), + token(10) + "no closing reasoning", + token(10) + token(11) + token(20) + "no closing content", + token(10) + token(11) + token(20) + token(21) + token(12) + "incomplete tool", + token(10) + token(11) + token(11) + token(20) + token(21), + } + ); } static void test_special_chars() { diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp index 67821a2d5c..03ae78ff73 100644 --- a/tests/test-grammar-parser.cpp +++ b/tests/test-grammar-parser.cpp @@ -515,5 +515,19 @@ int main() {LLAMA_GRETYPE_END, 0}, }); + // <[1000]> = "" + // <[1001]> = "" + verify_parsing(R"""( + root ::= <[1000]> !<[1001]> <[1001]> + )""", { + {"root", 0} + }, { + // root (index 0) + {LLAMA_GRETYPE_TOKEN, 1000}, + {LLAMA_GRETYPE_TOKEN_NOT, 1001}, + {LLAMA_GRETYPE_TOKEN, 1001}, + {LLAMA_GRETYPE_END, 0}, + }); + return 0; } diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index cc198f3e3c..fd45d5ada8 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -202,7 +202,7 @@ int main() uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point cp[0] = 37 + i; cp[1] = 0; - next_candidates[i] = {i, cp, {}}; + next_candidates[i] = {i, cp, {}, 0}; } std::vector>> expected_reject = { diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index f640ae2a6e..13ab7c78f4 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -16,6 +16,7 @@ add_library(mtmd set_target_properties(mtmd PROPERTIES VERSION ${LLAMA_INSTALL_VERSION} SOVERSION 0 + MACHO_CURRENT_VERSION 0 # keep macOS linker from seeing oversized version number ) target_link_libraries (mtmd PUBLIC ggml llama) diff --git a/tools/server/README-dev.md b/tools/server/README-dev.md new file mode 100644 index 0000000000..fbcd6bc1f9 --- /dev/null +++ b/tools/server/README-dev.md @@ -0,0 +1,177 @@ +# llama-server Development Documentation + +This document provides an in-depth technical overview of `llama-server`, intended for maintainers and contributors. + +If you are an end user consuming `llama-server` as a product, please refer to the main [README](./README.md) instead. + +## Backend + +### Overview + +The server supports two primary operating modes: + +- **Inference mode**: The default mode for performing inference with a single loaded GGUF model. +- **Router mode**: Enables management of multiple inference server instances behind a single API endpoint. Requests are automatically routed to the appropriate backend instance based on the requested model. + +The core architecture consists of the following components: + +- `server_context`: Holds the primary inference state, including the main `llama_context` and all active slots. +- `server_slot`: An abstraction over a single β€œsequence” in llama.cpp, responsible for managing individual parallel inference requests. +- `server_routes`: Middleware layer between `server_context` and the HTTP interface; handles JSON parsing/formatting and request routing logic. +- `server_http_context`: Implements the HTTP server using `cpp-httplib`. +- `server_queue`: Thread-safe queue used by HTTP workers to submit new tasks to `server_context`. +- `server_response`: Thread-safe queue used by `server_context` to return results to HTTP workers. +- `server_response_reader`: Higher-level wrapper around the two queues above for cleaner code. +- `server_task`: Unit of work pushed into `server_queue`. +- `server_task_result`: Unit of result pushed into `server_response`. +- `server_tokens`: Unified representation of token sequences (supports both text and multimodal tokens); used by `server_task` and `server_slot`. +- `server_prompt_checkpoint`: For recurrent (e.g., RWKV) and SWA models, stores snapshots of KV cache state. Enables reuse when subsequent requests share the same prompt prefix, saving redundant computation. +- `server_models`: Standalone component for managing multiple backend instances (used in router mode). It is completely independent of `server_context`. + +```mermaid +graph TD + API_User <--> server_http_context + server_http_context <-- router mode --> server_models + server_http_context <-- inference mode --> server_routes + server_routes -- server_task --> server_queue + subgraph server_context + server_queue --> server_slot + server_slot -- server_task_result --> server_response + server_slot[multiple server_slot] + end + server_response --> server_routes +``` + +### Batching + +The server context maintains a single batch shared across all slots. When `update_slots()` is invoked, the system iterates through all active slots to populate this batch. For each slot, either a generated token from the previous decoding step or available prompt tokens are added to the batch. + +Batching constraints apply: slots can only be batched together if they share compatible configurations. For instance, slots using a specific LoRA adapter can be batched with each other, but not with slots using a different LoRA adapter or no adapter at all. + +Once the batch reaches capacity or all slots have been processed, `llama_decode` is called to execute the inference. This operation represents the primary computational bottleneck in `update_slots()`. + +Following decoding, the system either retrieves embeddings or samples the next token using `common_sampler_sample`. If a slot has remaining prompt tokens to process, it yields until the next `update_slots()` iteration. + +### Thread Management + +`server_context` runs on a dedicated single thread. Because it is single-threaded, heavy post-processing (especially after token generation) should be avoided, as it directly impacts multi-sequence throughput. + +Each incoming HTTP request is handled by its own thread managed by the HTTP library. The following operations are performed in HTTP worker threads: + +- JSON request parsing +- Chat template application +- Tokenization +- Conversion of `server_task_result` into final JSON response +- Error formatting into JSON +- Tracking of partial/incremental responses (e.g., streaming tool calls or reasoning steps) + +**Best practices to follow:** + +- All JSON formatting and chat template logic must stay in the HTTP layer. +- Avoid passing raw JSON between the HTTP layer and `server_slot`. Instead, parse everything into native C++ types as early as possible. + +### Example trace of a request + +Here is an example trace of an API request for text completion: + +- A request arrives at the HTTP layer. +- The request is routed to the corresponding handler inside `server_routes`. In this case, `handle_completions_impl` is invoked. +- The handler parses the input request, constructs a new `server_task`, and passes it to `server_res_generator`. +- `server_res_generator` creates a new `task_result_state` for each task: + - `task_result_state` stays in the HTTP layer, responsible for keeping track of the current state of the response (e.g., parsing tool calls or thinking messages). + - `server_task` is moved into `server_queue` inside `server_context`. +- `server_context` launches the task by moving it into an available slot (see `launch_slot_with_task()`). +- `update_slot()` processes the task as described in the "Batching" section above. +- Results may be sent using `send_partial_response` or `send_final_response`, which creates a new `server_task_result` and pushes it to the response queue. +- At the same time, `server_res_generator` listens to the response queue and retrieves this response. +- As the response is stateless, `server_res_generator` calls `response->update()` to update the response with the current state. +- `server_res_generator` then calls `response->to_json()` and passes the response to the HTTP layer. + +### Testing + +`llama-server` includes an automated test suite based on `pytest`. + +The framework automatically starts a `llama-server` instance, sends requests, and validates responses. + +For detailed instructions, see the [test documentation](./tests/README.md). + +### Notable Related PRs + +- Initial server implementation: https://github.com/ggml-org/llama.cpp/pull/1443 +- Parallel decoding support: https://github.com/ggml-org/llama.cpp/pull/3228 +- Refactor introducing `server_queue` and `server_response`: https://github.com/ggml-org/llama.cpp/pull/5065 +- Reranking endpoint: https://github.com/ggml-org/llama.cpp/pull/9510 +- Multimodal model support (`libmtmd`): https://github.com/ggml-org/llama.cpp/pull/12898 +- Unified KV cache handling: https://github.com/ggml-org/llama.cpp/pull/16736 +- Separation of HTTP logic into dedicated files: https://github.com/ggml-org/llama.cpp/pull/17216 +- Large-scale code base split into smaller files: https://github.com/ggml-org/llama.cpp/pull/17362 +- Introduction of router mode: https://github.com/ggml-org/llama.cpp/pull/17470 +- Speculative decoding: https://github.com/ggml-org/llama.cpp/pull/17808 and rework in https://github.com/ggml-org/llama.cpp/pull/17808 + + + + +## Web UI + +The project includes a web-based user interface for interacting with `llama-server`. It supports both single-model (`MODEL` mode) and multi-model (`ROUTER` mode) operation. + +The SvelteKit-based Web UI is introduced in this PR: https://github.com/ggml-org/llama.cpp/pull/14839 + +### Features + +- **Chat interface** with streaming responses +- **Multi-model support** (ROUTER mode) - switch between models, auto-load on selection +- **Modality validation** - ensures selected model supports conversation's attachments (images, audio) +- **Conversation management** - branching, regeneration, editing with history preservation +- **Attachment support** - images, audio, PDFs (with vision/text fallback) +- **Configurable parameters** - temperature, top_p, etc. synced with server defaults +- **Dark/light theme** + +### Tech Stack + +- **SvelteKit** - frontend framework with Svelte 5 runes for reactive state +- **TailwindCSS** + **shadcn-svelte** - styling and UI components +- **Vite** - build tooling +- **IndexedDB** (Dexie) - local storage for conversations +- **LocalStorage** - user settings persistence + +### Architecture + +The WebUI follows a layered architecture: + +``` +Routes β†’ Components β†’ Hooks β†’ Stores β†’ Services β†’ Storage/API +``` + +- **Stores** - reactive state management (`chatStore`, `conversationsStore`, `modelsStore`, `serverStore`, `settingsStore`) +- **Services** - stateless API/database communication (`ChatService`, `ModelsService`, `PropsService`, `DatabaseService`) +- **Hooks** - reusable logic (`useModelChangeValidation`, `useProcessingState`) + +For detailed architecture diagrams, see [`tools/server/webui/docs/`](webui/docs/): + +- `high-level-architecture.mmd` - full architecture with all modules +- `high-level-architecture-simplified.mmd` - simplified overview +- `data-flow-simplified-model-mode.mmd` - data flow for single-model mode +- `data-flow-simplified-router-mode.mmd` - data flow for multi-model mode +- `flows/*.mmd` - detailed per-domain flows (chat, conversations, models, etc.) + +### Development + +```sh +# make sure you have Node.js installed +cd tools/server/webui +npm i + +# run dev server (with hot reload) +npm run dev + +# run tests +npm run test + +# build production bundle +npm run build +``` + +After `public/index.html.gz` has been generated, rebuild `llama-server` as described in the [build](#build) section to include the updated UI. + +**Note:** The Vite dev server automatically proxies API requests to `http://localhost:8080`. Make sure `llama-server` is running on that port during development. diff --git a/tools/server/README.md b/tools/server/README.md index bf274db79d..f98fb44c7b 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -2,7 +2,7 @@ Fast, lightweight, pure C/C++ HTTP server based on [httplib](https://github.com/yhirose/cpp-httplib), [nlohmann::json](https://github.com/nlohmann/json) and **llama.cpp**. -Set of LLM REST APIs and a simple web front end to interact with llama.cpp. +Set of LLM REST APIs and a web UI to interact with llama.cpp. **Features:** * LLM inference of F16 and quantized models on GPU and CPU @@ -19,7 +19,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp. * Speculative decoding * Easy-to-use web UI -The project is under active development, and we are [looking for feedback and contributors](https://github.com/ggml-org/llama.cpp/issues/4216). +For the ful list of features, please refer to [server's changelog](https://github.com/ggml-org/llama.cpp/issues/9291) ## Usage @@ -289,69 +289,6 @@ For more details, please refer to [multimodal documentation](../../docs/multimod cmake --build build --config Release -t llama-server ``` -## Web UI - -The project includes a web-based user interface for interacting with `llama-server`. It supports both single-model (`MODEL` mode) and multi-model (`ROUTER` mode) operation. - -### Features - -- **Chat interface** with streaming responses -- **Multi-model support** (ROUTER mode) - switch between models, auto-load on selection -- **Modality validation** - ensures selected model supports conversation's attachments (images, audio) -- **Conversation management** - branching, regeneration, editing with history preservation -- **Attachment support** - images, audio, PDFs (with vision/text fallback) -- **Configurable parameters** - temperature, top_p, etc. synced with server defaults -- **Dark/light theme** - -### Tech Stack - -- **SvelteKit** - frontend framework with Svelte 5 runes for reactive state -- **TailwindCSS** + **shadcn-svelte** - styling and UI components -- **Vite** - build tooling -- **IndexedDB** (Dexie) - local storage for conversations -- **LocalStorage** - user settings persistence - -### Architecture - -The WebUI follows a layered architecture: - -``` -Routes β†’ Components β†’ Hooks β†’ Stores β†’ Services β†’ Storage/API -``` - -- **Stores** - reactive state management (`chatStore`, `conversationsStore`, `modelsStore`, `serverStore`, `settingsStore`) -- **Services** - stateless API/database communication (`ChatService`, `ModelsService`, `PropsService`, `DatabaseService`) -- **Hooks** - reusable logic (`useModelChangeValidation`, `useProcessingState`) - -For detailed architecture diagrams, see [`tools/server/webui/docs/`](webui/docs/): - -- `high-level-architecture.mmd` - full architecture with all modules -- `high-level-architecture-simplified.mmd` - simplified overview -- `data-flow-simplified-model-mode.mmd` - data flow for single-model mode -- `data-flow-simplified-router-mode.mmd` - data flow for multi-model mode -- `flows/*.mmd` - detailed per-domain flows (chat, conversations, models, etc.) - -### Development - -```sh -# make sure you have Node.js installed -cd tools/server/webui -npm i - -# run dev server (with hot reload) -npm run dev - -# run tests -npm run test - -# build production bundle -npm run build -``` - -After `public/index.html.gz` has been generated, rebuild `llama-server` as described in the [build](#build) section to include the updated UI. - -**Note:** The Vite dev server automatically proxies API requests to `http://localhost:8080`. Make sure `llama-server` is running on that port during development. - ## Quick Start To get started right away, run the following command, making sure to use the correct path for the model you have: @@ -380,7 +317,7 @@ docker run -p 8080:8080 -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:se docker run -p 8080:8080 -v /path/to/models:/models --gpus all ghcr.io/ggml-org/llama.cpp:server-cuda -m models/7B/ggml-model.gguf -c 512 --host 0.0.0.0 --port 8080 --n-gpu-layers 99 ``` -## Testing with CURL +## Using with CURL Using [curl](https://curl.se/). On Windows, `curl.exe` should be available in the base OS. @@ -391,46 +328,6 @@ curl --request POST \ --data '{"prompt": "Building a website can be done in 10 simple steps:","n_predict": 128}' ``` -## Advanced testing - -We implemented a [server test framework](./tests/README.md) using human-readable scenario. - -*Before submitting an issue, please try to reproduce it with this format.* - -## Node JS Test - -You need to have [Node.js](https://nodejs.org/en) installed. - -```bash -mkdir llama-client -cd llama-client -``` - -Create an index.js file and put this inside: - -```javascript -const prompt = "Building a website can be done in 10 simple steps:" - -async function test() { - let response = await fetch("http://127.0.0.1:8080/completion", { - method: "POST", - body: JSON.stringify({ - prompt, - n_predict: 64, - }) - }) - console.log((await response.json()).content) -} - -test() -``` - -And run it: - -```bash -node index.js -``` - ## API Endpoints ### GET `/health`: Returns health check result @@ -495,6 +392,8 @@ By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to re `n_cmpl`: Number of completions to generate from the current prompt. If input has multiple prompts, the output will have N prompts times `n_cmpl` entries. +`n_cache_reuse`: Min chunk size to attempt reusing from the cache via KV shifting. For more info, see `--cache-reuse` arg. Default: `0`, which is disabled. + `stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`. `stop`: Specify a JSON array of stopping strings. @@ -1636,6 +1535,22 @@ Response: } ``` +## API errors + +`llama-server` returns errors in the same format as OAI: https://github.com/openai/openai-openapi + +Example of an error: + +```json +{ + "error": { + "code": 401, + "message": "Invalid API Key", + "type": "authentication_error" + } +} +``` + ## More examples ### Interactive mode @@ -1655,26 +1570,6 @@ Run with bash: bash chat.sh ``` -### OAI-like API - -The HTTP `llama-server` supports an OAI-like API: https://github.com/openai/openai-openapi - -### API errors - -`llama-server` returns errors in the same format as OAI: https://github.com/openai/openai-openapi - -Example of an error: - -```json -{ - "error": { - "code": 401, - "message": "Invalid API Key", - "type": "authentication_error" - } -} -``` - Apart from error types supported by OAI, we also have custom types that are specific to functionalities of llama.cpp: **When /metrics or /slots endpoint is disabled** diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index b403864e0e..ab6b3aa7ce 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -972,6 +972,9 @@ json oaicompat_chat_params_parse( inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); inputs.reasoning_format = opt.reasoning_format; + if (body.contains("reasoning_format")) { + inputs.reasoning_format = common_reasoning_format_from_name(body.at("reasoning_format").get()); + } inputs.enable_thinking = opt.enable_thinking; if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { if (body.contains("grammar")) { diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 0c4d84ffa0..0629bb5edd 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -18,11 +18,13 @@ const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + " using json = nlohmann::ordered_json; #define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_CNT(slot, fmt, ...) LOG_CNT("" fmt, __VA_ARGS__) #define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) #define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) #define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) #define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__) #define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 12a4e94e5d..4578f8d7a9 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -102,6 +102,11 @@ struct server_slot { std::string generated_text; llama_tokens generated_tokens; + // idx of draft tokens in the main batch + // non-empty if we went to evaluate draft tokens + // ref: https://github.com/ggml-org/llama.cpp/pull/17808 + std::vector i_batch_dft; + std::vector generated_token_probs; bool has_next_token = true; @@ -150,7 +155,8 @@ struct server_slot { struct common_sampler * smpl = nullptr; - llama_token sampled; + llama_token sampled; // in speculative mode, this is the last accepted token + llama_tokens drafted; // stats size_t n_sent_text = 0; // number of sent text character @@ -180,6 +186,8 @@ struct server_slot { stopping_word = ""; n_sent_text = 0; + drafted.clear(); + i_batch_dft.clear(); generated_tokens.clear(); generated_token_probs.clear(); json_schema = json(); @@ -255,6 +263,31 @@ struct server_slot { generated_token_probs.push_back(token); } + int get_n_draft_max() const { + if (!can_speculate()) { + return 0; + } + + // determine the max draft that fits the current slot state + int n_draft_max = task->params.speculative.n_max; + + // note: slot.prompt is not yet expanded with the `id` token sampled above + // also, need to leave space for 1 extra token to allow context shifts + n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2); + + if (n_remaining > 0) { + n_draft_max = std::min(n_draft_max, n_remaining - 1); + } + + SLT_DBG(*this, "max possible draft: %d\n", n_draft_max); + + if (n_draft_max < task->params.speculative.n_min) { + SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min); + n_draft_max = 0; + } + return n_draft_max; + } + // note: a slot can also be either a parent or a child bool is_parent() const { return is_processing() && task->n_children > 0; @@ -353,8 +386,7 @@ struct server_slot { if (n_draft_total > 0) { const float draft_ratio = (float) n_draft_accepted / n_draft_total; - SLT_INF(*this, - "\n" + SLT_CNT(*this, "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", draft_ratio, n_draft_accepted, n_draft_total ); @@ -1774,14 +1806,57 @@ struct server_context_impl { continue; } - slot.i_batch = batch.n_tokens; + // generate draft tokens in speculative decoding mode + // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] + // perform the speculative drafting for all sequences at the same time in a single batch + int n_draft_max = slot.get_n_draft_max(); + if (n_draft_max > 0) { + if (mctx) { + // we should never reach this, as speculative is automatically disabled if mmproj is loaded + GGML_ABORT("not supported by multimodal"); + } - common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; + params_spec.p_min = slot.task->params.speculative.p_min; + const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled); - slot.prompt.tokens.push_back(slot.sampled); + // add the sampled token to the batch + slot.i_batch_dft.push_back(batch.n_tokens); + common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); + slot.prompt.tokens.push_back(slot.sampled); - SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); + if (slot.task->params.speculative.n_min > (int) draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); + // fallback to normal decoding + slot.i_batch = slot.i_batch_dft[0]; + slot.drafted.clear(); + slot.i_batch_dft.clear(); + } else { + // keep track of total number of drafted tokens tested + slot.n_draft_total += draft.size(); + + // add all drafted tokens to the batch + for (size_t i = 0; i < draft.size(); i++) { + slot.i_batch_dft.push_back(batch.n_tokens); + common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true); + slot.prompt.tokens.push_back(draft[i]); + } + slot.drafted = std::move(draft); + } + } else { + // no speculative decoding + slot.i_batch = batch.n_tokens; + + common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); + + slot.prompt.tokens.push_back(slot.sampled); + + SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", + slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); + } } // process in chunks of params.n_batch @@ -1880,8 +1955,18 @@ struct server_context_impl { n_past = std::min(n_past, slot.alora_invocation_start - 1); } + const auto n_cache_reuse = slot.task->params.n_cache_reuse; + + const bool can_cache_reuse = + llama_memory_can_shift(llama_get_memory(ctx)) && + !slot.prompt.tokens.has_mtmd; + + if (!can_cache_reuse && n_cache_reuse > 0) { + SLT_WRN(slot, "cache reuse is not supported - ignoring n_cache_reuse = %d\n", n_cache_reuse); + } + // reuse chunks from the cached prompt by shifting their KV cache in the new position - if (params_base.n_cache_reuse > 0) { + if (can_cache_reuse && n_cache_reuse > 0) { GGML_ASSERT(!slot.prompt.tokens.has_mtmd); size_t head_c = n_past; // cache @@ -1892,7 +1977,7 @@ struct server_context_impl { GGML_ABORT("not supported by multimodal"); } - SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past); + SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", n_cache_reuse, n_past); while (head_c < slot.prompt.tokens.size() && head_p < input_tokens.size()) { @@ -1901,11 +1986,10 @@ struct server_context_impl { while (head_c + n_match < slot.prompt.tokens.size() && head_p + n_match < input_tokens.size() && slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) { - n_match++; } - if (n_match >= (size_t) params_base.n_cache_reuse) { + if (n_match >= (size_t) n_cache_reuse) { SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); //for (size_t i = head_p; i < head_p + n_match; i++) { // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); @@ -2336,6 +2420,10 @@ struct server_context_impl { // on successful decode, restore the original batch size n_batch = llama_n_batch(ctx); + // technically, measuring the time here excludes the sampling time for the last batch + // but on the other hand, we don't want to do too many system calls to measure the time, so it's ok + const int64_t t_current = ggml_time_us(); + for (auto & slot : slots) { // may need to copy state to other slots if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) { @@ -2390,6 +2478,10 @@ struct server_context_impl { continue; // continue loop of slots } + if (slot.i_batch_dft.size() > 0) { + continue; // sample using speculative decoding + } + const int tok_idx = slot.i_batch - i; llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); @@ -2400,8 +2492,6 @@ struct server_context_impl { slot.n_decoded += 1; - const int64_t t_current = ggml_time_us(); - if (slot.n_decoded == 1) { slot.t_start_generation = t_current; slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; @@ -2430,84 +2520,32 @@ struct server_context_impl { } } - // do speculative decoding - // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] - // perform the speculative drafting for all sequences at the same time in a single batch + // speculative decoding - main model sample and accept for (auto & slot : slots) { - if (!slot.is_processing() || !slot.can_speculate()) { + if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) { continue; } - if (slot.state != SLOT_STATE_GENERATING) { - continue; - } - - if (mctx) { - // we should never reach this, as speculative is automatically disabled if mmproj is loaded - GGML_ABORT("not supported by multimodal"); - } - - // determine the max draft that fits the current slot state - int n_draft_max = slot.task->params.speculative.n_max; - - // note: slot.prompt is not yet expanded with the `id` token sampled above - // also, need to leave space for 1 extra token to allow context shifts - n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2); - - if (slot.n_remaining > 0) { - n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); - } - - SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); - - if (n_draft_max < slot.task->params.speculative.n_min) { - SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min); - - continue; - } - - llama_token id = slot.sampled; - - struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; - params_spec.p_min = slot.task->params.speculative.p_min; - - const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); - - // ignore small drafts - if (slot.task->params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); - - continue; - } - - // keep track of total number of drafted tokens tested - slot.n_draft_total += draft.size(); - - // construct the speculation batch - common_batch_clear(slot.batch_spec); - common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true); - - for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true); - } - - SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); - - llama_decode(ctx, slot.batch_spec); + size_t n_draft = slot.drafted.size(); // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, slot.i_batch_dft, slot.drafted); + slot.i_batch_dft.clear(); + slot.drafted.clear(); slot.n_decoded += ids.size(); + slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; + // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; - slot.prompt.tokens.push_back(id); + // rollback to the state before sampling the draft tokens + slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); + + // add accepted tokens to the prompt slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); + slot.sampled = ids.back(); // last accepted token llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); @@ -2530,7 +2568,7 @@ struct server_context_impl { } } - SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens()); + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) slot.drafted.size(), slot.prompt.n_tokens()); } } @@ -2551,6 +2589,10 @@ struct server_context_impl { int get_slot_n_ctx() { return slots.back().n_ctx; } + + server_response_reader get_response_reader() { + return server_response_reader(queue_tasks, queue_results, HTTP_POLLING_SECONDS); + } }; // @@ -2580,8 +2622,8 @@ llama_context * server_context::get_llama_context() const { return impl->ctx; } -std::pair server_context::get_queues() { - return { impl->queue_tasks, impl->queue_results }; +server_response_reader server_context::get_response_reader() { + return impl->get_response_reader(); } @@ -2590,7 +2632,7 @@ std::pair server_context::get_queues() { struct server_res_generator : server_http_res { server_response_reader rd; server_res_generator(server_context_impl & ctx_server) - : rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {} + : rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {} void ok(const json & response_data) { status = 200; data = safe_json_to_str(response_data); @@ -2623,9 +2665,6 @@ static std::unique_ptr handle_completions_impl( try { std::vector tasks; - // tracking generation state and partial tool calls - std::vector states; - const auto & prompt = data.at("prompt"); // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); @@ -2641,7 +2680,6 @@ static std::unique_ptr handle_completions_impl( inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); } tasks.reserve(inputs.size()); - states.reserve(inputs.size()); int idx = 0; for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); @@ -2660,7 +2698,6 @@ static std::unique_ptr handle_completions_impl( task.params.res_type = res_type; task.params.oaicompat_cmpl_id = completion_id; task.params.oaicompat_model = ctx_server.model_name; - states.push_back(task.params.oaicompat_chat_syntax); if (task.params.n_cmpl > 1) { task.n_children = task.params.n_cmpl - 1; @@ -2669,7 +2706,6 @@ static std::unique_ptr handle_completions_impl( task.id, ctx_server.queue_tasks.get_new_id(), idx++); - states.push_back(child.params.oaicompat_chat_syntax); tasks.push_back(std::move(child)); } } @@ -2677,7 +2713,6 @@ static std::unique_ptr handle_completions_impl( tasks.push_back(std::move(task)); } - rd.set_states(std::move(states)); rd.post_tasks(std::move(tasks)); } catch (const std::exception & e) { res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); @@ -3407,7 +3442,7 @@ void server_routes::init_routes() { // create and queue the task json responses = json::array(); - server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS); + server_response_reader rd = ctx_server.get_response_reader(); { std::vector tasks; tasks.reserve(documents.size()); @@ -3667,7 +3702,7 @@ std::unique_ptr server_routes::handle_embeddings_impl(cons // create and queue the task json responses = json::array(); - server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS); + server_response_reader rd = ctx_server.get_response_reader(); { std::vector tasks; for (size_t i = 0; i < tokenized_prompts.size(); i++) { diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 05b4afaeeb..eaa1380877 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -31,9 +31,8 @@ struct server_context { // get the underlaying llama_context llama_context * get_llama_context() const; - // get the underlaying queue_tasks and queue_results - // used by CLI application - std::pair get_queues(); + // get a new response reader, used by CLI application + server_response_reader get_response_reader(); }; diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp index 10196128db..3cceb2bbe2 100644 --- a/tools/server/server-queue.cpp +++ b/tools/server/server-queue.cpp @@ -271,12 +271,21 @@ void server_response::terminate() { // server_response_reader // -void server_response_reader::set_states(std::vector && states) { - this->states = std::move(states); +void server_response_reader::post_task(server_task && task) { + GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader"); + id_tasks.insert(task.id); + states.push_back(task.create_state()); + queue_results.add_waiting_task_id(task.id); + queue_tasks.post(std::move(task)); } void server_response_reader::post_tasks(std::vector && tasks) { + GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader"); id_tasks = server_task::get_list_id(tasks); + states.reserve(tasks.size()); + for (size_t i = 0; i < tasks.size(); i++) { + states.push_back(tasks[i].create_state()); + } queue_results.add_waiting_tasks(tasks); queue_tasks.post(std::move(tasks)); } diff --git a/tools/server/server-queue.h b/tools/server/server-queue.h index a5c3179d8c..726eadf4ef 100644 --- a/tools/server/server-queue.h +++ b/tools/server/server-queue.h @@ -129,13 +129,13 @@ struct server_response_reader { std::vector states; // should_stop function will be called each polling_interval_seconds - server_response_reader(std::pair server_queues, int polling_interval_seconds) - : queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {} + server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds) + : queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {} ~server_response_reader() { stop(); } - void set_states(std::vector && states); + void post_task(server_task && tasks); void post_tasks(std::vector && tasks); bool has_next() const; diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index c401f47a78..360826062b 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -155,11 +155,12 @@ task_params server_task::params_from_json_cmpl( // Sampling parameter defaults are loaded from the global server context (but individual requests can still them) task_params defaults; - defaults.sampling = params_base.sampling; - defaults.speculative = params_base.speculative; - defaults.n_keep = params_base.n_keep; - defaults.n_predict = params_base.n_predict; - defaults.antiprompt = params_base.antiprompt; + defaults.sampling = params_base.sampling; + defaults.speculative = params_base.speculative; + defaults.n_keep = params_base.n_keep; + defaults.n_predict = params_base.n_predict; + defaults.n_cache_reuse = params_base.n_cache_reuse; + defaults.antiprompt = params_base.antiprompt; // enabling this will output extra debug information in the HTTP responses from the server params.verbose = params_base.verbosity > 9; @@ -176,6 +177,7 @@ task_params server_task::params_from_json_cmpl( params.n_keep = json_value(data, "n_keep", defaults.n_keep); params.n_discard = json_value(data, "n_discard", defaults.n_discard); params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1)); + params.n_cache_reuse = json_value(data, "n_cache_reuse", defaults.n_cache_reuse); //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); params.response_fields = json_value(data, "response_fields", std::vector()); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 4e4840fc83..9011ff944b 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -55,6 +55,8 @@ struct task_params { int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters int32_t n_cmpl = 1; // number of completions to generate from this prompt + int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled) + int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit @@ -62,18 +64,19 @@ struct task_params { std::vector antiprompt; std::vector response_fields; - bool timings_per_token = false; + + bool timings_per_token = false; bool post_sampling_probs = false; struct common_params_sampling sampling; struct common_params_speculative speculative; // response formatting - bool verbose = false; - task_response_type res_type = TASK_RESPONSE_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_syntax oaicompat_chat_syntax; + bool verbose = false; + task_response_type res_type = TASK_RESPONSE_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_syntax oaicompat_chat_syntax; // Embeddings int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) @@ -82,6 +85,25 @@ struct task_params { json to_json(bool only_metrics = false) const; }; +// struct for tracking the state of a task (e.g., for streaming) +struct task_result_state { + // tracking diffs for partial tool calls + std::vector diffs; + common_chat_syntax oaicompat_chat_syntax; + common_chat_msg chat_msg; + std::string generated_text; // append new chunks of generated text here + std::vector generated_tool_call_ids; + + task_result_state(const common_chat_syntax & oaicompat_chat_syntax) + : oaicompat_chat_syntax(oaicompat_chat_syntax) {} + + // parse partial tool calls and update the internal state + common_chat_msg update_chat_msg( + const std::string & text_added, + bool is_partial, + std::vector & diffs); +}; + struct server_task { int id = -1; // to be filled by server_queue int index = -1; // used when there are multiple prompts (batch request) @@ -146,6 +168,12 @@ struct server_task { copy.tokens = tokens.clone(); return copy; } + + // the task will be moved into queue, then onto slots + // however, the state must be kept by caller (e.g., HTTP thread) + task_result_state create_state() const { + return task_result_state(params.oaicompat_chat_syntax); + } }; struct result_timings { @@ -177,25 +205,6 @@ struct result_prompt_progress { json to_json() const; }; -// struct for tracking the state of a task (e.g., for streaming) -struct task_result_state { - // tracking diffs for partial tool calls - std::vector diffs; - common_chat_syntax oaicompat_chat_syntax; - common_chat_msg chat_msg; - std::string generated_text; // append new chunks of generated text here - std::vector generated_tool_call_ids; - - task_result_state(const common_chat_syntax & oaicompat_chat_syntax) - : oaicompat_chat_syntax(oaicompat_chat_syntax) {} - - // parse partial tool calls and update the internal state - common_chat_msg update_chat_msg( - const std::string & text_added, - bool is_partial, - std::vector & diffs); -}; - struct server_task_result { int id = -1; int id_slot = -1;