diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 43553ac13b..1f226d6ea8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1063,7 +1063,17 @@ jobs: run: | git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1 - - name: Install + - name: Cache ROCm Installation + id: cache-rocm + uses: actions/cache@v4 + with: + path: C:\Program Files\AMD\ROCm + key: rocm-6.1-${{ runner.os }}-v1 + restore-keys: | + rocm-6.1-${{ runner.os }}- + + - name: Install ROCm + if: steps.cache-rocm.outputs.cache-hit != 'true' id: depends run: | $ErrorActionPreference = "Stop" @@ -1071,13 +1081,28 @@ jobs: Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" write-host "Installing AMD HIP SDK" $proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru - $proc.WaitForExit(600000) + $completed = $proc.WaitForExit(600000) + if (-not $completed) { + Write-Error "ROCm installation timed out after 10 minutes. Killing the process" + $proc.Kill() + exit 1 + } + if ($proc.ExitCode -ne 0) { + Write-Error "ROCm installation failed with exit code $($proc.ExitCode)" + exit 1 + } write-host "Completed AMD HIP SDK installation" - name: Verify ROCm id: verify run: | - & 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version + # Find and test ROCm installation + $clangPath = Get-ChildItem 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | Select-Object -First 1 + if (-not $clangPath) { + Write-Error "ROCm installation not found" + exit 1 + } + & $clangPath.FullName --version - name: Install ccache uses: ggml-org/ccache-action@v1.2.16 diff --git a/.github/workflows/close-issue.yml b/.github/workflows/close-issue.yml index 19e7854745..cbfc4990db 100644 --- a/.github/workflows/close-issue.yml +++ b/.github/workflows/close-issue.yml @@ -17,7 +17,7 @@ jobs: steps: - uses: actions/stale@v5 with: - exempt-issue-labels: "refactoring,help wanted,good first issue,research,bug,roadmap" + exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap" days-before-issue-stale: 30 days-before-issue-close: 14 stale-issue-label: "stale" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5367637e42..701811eeb2 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -544,13 +544,23 @@ jobs: run: | git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1 + - name: Cache ROCm Installation + id: cache-rocm + uses: actions/cache@v4 + with: + path: C:\Program Files\AMD\ROCm + key: rocm-6.1-${{ runner.os }}-v1 + restore-keys: | + rocm-6.1-${{ runner.os }}- + - name: ccache uses: ggml-org/ccache-action@v1.2.16 with: key: windows-latest-cmake-hip-${{ matrix.name }}-x64 evict-old-files: 1d - - name: Install + - name: Install ROCm + if: steps.cache-rocm.outputs.cache-hit != 'true' id: depends run: | $ErrorActionPreference = "Stop" @@ -558,13 +568,28 @@ jobs: Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" write-host "Installing AMD HIP SDK" $proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru - $proc.WaitForExit(600000) + $completed = $proc.WaitForExit(600000) + if (-not $completed) { + Write-Error "ROCm installation timed out after 10 minutes. Killing the process" + $proc.Kill() + exit 1 + } + if ($proc.ExitCode -ne 0) { + Write-Error "ROCm installation failed with exit code $($proc.ExitCode)" + exit 1 + } write-host "Completed AMD HIP SDK installation" - name: Verify ROCm id: verify run: | - & 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version + # Find and test ROCm installation + $clangPath = Get-ChildItem 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | Select-Object -First 1 + if (-not $clangPath) { + Write-Error "ROCm installation not found" + exit 1 + } + & $clangPath.FullName --version - name: Build id: cmake_build diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e68ff92445..6db74a0cf2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -16,6 +16,9 @@ - Use the following format for the squashed commit title: ` : (#)`. For example: `utils : fix typo in utils.py (#1234)` - Optionally pick a `` from here: https://github.com/ggml-org/llama.cpp/wiki/Modules - Consider adding yourself to [CODEOWNERS](CODEOWNERS) +- Let authors, who are also collaborators, merge their own PRs +- When merging a PR by a contributor, make sure you have a good understanding of the changes +- Be mindful of maintenance: most of the work going into a feature happens after the PR is merged. If the PR author is not committed to contribute long-term, someone else needs to take responsibility (you) # Coding guidelines diff --git a/common/arg.cpp b/common/arg.cpp index 7507c81155..406fbc2f06 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1263,6 +1263,18 @@ static std::string list_builtin_chat_templates() { return msg.str(); } +static bool is_truthy(const std::string & value) { + return value == "on" || value == "enabled" || value == "1"; +} + +static bool is_falsey(const std::string & value) { + return value == "off" || value == "disabled" || value == "0"; +} + +static bool is_autoy(const std::string & value) { + return value == "auto" || value == "-1"; +} + common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) { // load dynamic backends ggml_backend_load_all(); @@ -1544,21 +1556,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.n_chunks = value; } ).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL})); - add_opt(common_arg( - {"-fa", "--flash-attn"}, "FA", - string_format("set Flash Attention use ('on', 'off', or 'auto', default: '%s')", llama_flash_attn_type_name(params.flash_attn_type)), - [](common_params & params, const std::string & value) { - if (value == "on" || value == "enabled" || value == "1") { - params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; - } else if (value == "off" || value == "disabled" || value == "0") { - params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; - } else if (value == "auto" || value == "-1") { - params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; - } else { - throw std::runtime_error(string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str())); - } - } - ).set_env("LLAMA_ARG_FLASH_ATTN")); + add_opt(common_arg({ "-fa", "--flash-attn" }, "[on|off|auto]", + string_format("set Flash Attention use ('on', 'off', or 'auto', default: '%s')", + llama_flash_attn_type_name(params.flash_attn_type)), + [](common_params & params, const std::string & value) { + if (is_truthy(value)) { + params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; + } else if (is_falsey(value)) { + params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; + } else if (is_autoy(value)) { + params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; + } else { + throw std::runtime_error( + string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str())); + } + }).set_env("LLAMA_ARG_FLASH_ATTN")); add_opt(common_arg( {"-p", "--prompt"}, "PROMPT", "prompt to start generation with; for system message, use -sys", @@ -3134,13 +3146,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex common_log_set_file(common_log_main(), value.c_str()); } )); - add_opt(common_arg( - {"--log-colors"}, - "Enable colored logging", - [](common_params &) { - common_log_set_colors(common_log_main(), true); - } - ).set_env("LLAMA_LOG_COLORS")); + add_opt(common_arg({ "--log-colors" }, "[on|off|auto]", + "Set colored logging ('on', 'off', or 'auto', default: 'auto')\n" + "'auto' enables colors when output is to a terminal", + [](common_params &, const std::string & value) { + if (is_truthy(value)) { + common_log_set_colors(common_log_main(), LOG_COLORS_ENABLED); + } else if (is_falsey(value)) { + common_log_set_colors(common_log_main(), LOG_COLORS_DISABLED); + } else if (is_autoy(value)) { + common_log_set_colors(common_log_main(), LOG_COLORS_AUTO); + } else { + throw std::invalid_argument( + string_format("error: unkown value for --log-colors: '%s'\n", value.c_str())); + } + }).set_env("LLAMA_LOG_COLORS")); add_opt(common_arg( {"-v", "--verbose", "--log-verbose"}, "Set verbosity level to infinity (i.e. log all messages, useful for debugging)", diff --git a/common/chat.cpp b/common/chat.cpp index 955c42852a..4707c4fef4 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -163,6 +163,19 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin throw std::runtime_error("Invalid tool_choice: " + tool_choice); } +bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates) { + common_chat_templates_inputs dummy_inputs; + common_chat_msg msg; + msg.role = "user"; + msg.content = "test"; + dummy_inputs.messages = {msg}; + dummy_inputs.enable_thinking = false; + const auto rendered_no_thinking = common_chat_templates_apply(chat_templates, dummy_inputs); + dummy_inputs.enable_thinking = true; + const auto rendered_with_thinking = common_chat_templates_apply(chat_templates, dummy_inputs); + return rendered_no_thinking.prompt != rendered_with_thinking.prompt; +} + template <> std::vector common_chat_msgs_parse_oaicompat(const json & messages) { std::vector msgs; @@ -618,11 +631,13 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; + case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: return "DeepSeek V3.1"; case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; case COMMON_CHAT_FORMAT_GRANITE: return "Granite"; case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS"; case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS"; + case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2"; default: throw std::runtime_error("Unknown chat format"); } @@ -684,11 +699,13 @@ static void parse_json_tool_calls( size_t from = std::string::npos; auto first = true; while (true) { + auto start_pos = builder.pos(); auto res = function_regex_start_only && first ? builder.try_consume_regex(*function_regex_start_only) : function_regex ? builder.try_find_regex(*function_regex, from) : std::nullopt; + if (res) { std::string name; if (get_function_name) { @@ -723,6 +740,8 @@ static void parse_json_tool_calls( return; } throw common_chat_msg_partial_exception("incomplete tool call"); + } else { + builder.move_to(start_pos); } break; } @@ -1184,6 +1203,67 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te }); return data; } + +static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + + // Generate the prompt using the apply() function with the template + data.prompt = apply(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_NEMOTRON_V2; + + // Handle thinking tags appropriately based on inputs.enable_thinking + if (string_ends_with(data.prompt, "\n")) { + if (!inputs.enable_thinking) { + data.prompt += ""; + } else { + data.thinking_forced_open = true; + } + } + + // When tools are present, build grammar for the format, similar to CommandR, but without tool call ID + if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) { + data.grammar_lazy = true; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + schemas.push_back({ + { "type", "object" }, + { "properties", + { + { "name", + { + { "type", "string" }, + { "const", function.at("name") }, + } }, + { "arguments", function.at("parameters") }, + } }, + { "required", json::array({ "name", "arguments" }) }, + }); + }); + auto schema = json{ + { "type", "array" }, + { "items", schemas.size() == 1 ? schemas[0] : json{ { "anyOf", schemas } } }, + { "minItems", 1 }, + }; + if (!inputs.parallel_tool_calls) { + schema["maxItems"] = 1; + } + builder.add_rule("root", + std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + + "\"\" " + builder.add_schema("tool_calls", schema) + + " \"\""); + }); + data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + // If thinking_forced_open, then we capture the tag in the grammar, + // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) + std::string(data.thinking_forced_open ? + "[\\s\\S]*?(\\s*)" : + "(?:[\\s\\S]*?\\s*)?") + + "()[\\s\\S]*" }); + } + return data; +} static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { if (!builder.syntax().parse_tool_calls) { builder.add_content(builder.consume_rest()); @@ -1313,6 +1393,71 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ } return data; } + +static common_chat_params common_chat_params_init_deepseek_v3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + + // Pass thinking context for DeepSeek V3.1 template + json additional_context = { + {"thinking", inputs.enable_thinking}, + }; + + auto prompt = apply(tmpl, inputs, + /* messages_override= */ inputs.messages, + /* tools_override= */ std::nullopt, + additional_context); + data.prompt = prompt; + data.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; + if (string_ends_with(data.prompt, "")) { + if (!inputs.enable_thinking) { + data.prompt += ""; + } else { + data.thinking_forced_open = true; + } + } + if (inputs.tools.is_array() && !inputs.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null(); + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + tool_rules.push_back(builder.add_rule(name + "-call", + "( \"<|tool▁call▁begin|>\" )? \"" + name + "<|tool▁sep|>" + "\" " + builder.add_schema(name + "-args", parameters) + " " + "\"<|tool▁call▁end|>\"")); + }); + // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, + // so we accept common variants (then it's all constrained) + builder.add_rule("root", + std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + + "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" | \"<|tool▁calls|>\" ) " + "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " " + "\"<|tool▁calls▁end|>\"" + " space"); + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + // If thinking_forced_open, then we capture the tag in the grammar, + // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) + std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + + "(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)[\\s\\S]*" + }); + data.preserved_tokens = { + "", + "", + "<|tool▁calls▁begin|>", + "<|tool▁call▁begin|>", + "<|tool▁sep|>", + "<|tool▁call▁end|>", + "<|tool▁calls▁end|>", + }; + }); + } + return data; +} + static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { builder.try_parse_reasoning("", ""); if (!builder.syntax().parse_tool_calls) { @@ -1334,6 +1479,66 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { tool_calls_end); } +static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) { + static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)"); + + static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>"); + static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); + static const common_regex tool_calls_end("<|tool▁calls▁end|>"); + + if (!builder.syntax().parse_tool_calls) { + LOG_DBG("%s: not parse_tool_calls\n", __func__); + builder.add_content(builder.consume_rest()); + return; + } + + LOG_DBG("%s: parse_tool_calls\n", __func__); + + parse_json_tool_calls( + builder, + /* block_open= */ tool_calls_begin, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + tool_calls_end); +} + +static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) { + // DeepSeek V3.1 outputs reasoning content between "" and "" tags, followed by regular content + // First try to parse using the standard reasoning parsing method + LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str()); + + auto start_pos = builder.pos(); + auto found_end_think = builder.try_find_literal(""); + builder.move_to(start_pos); + + if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) { + LOG_DBG("%s: no end_think, not partial, adding content\n", __func__); + common_chat_parse_deepseek_v3_1_content(builder); + } else if (builder.try_parse_reasoning("", "")) { + // If reasoning was parsed successfully, the remaining content is regular content + LOG_DBG("%s: parsed reasoning, adding content\n", __func__); + // <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|> + common_chat_parse_deepseek_v3_1_content(builder); + } else { + if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) { + LOG_DBG("%s: reasoning_format none, adding content\n", __func__); + common_chat_parse_deepseek_v3_1_content(builder); + return; + } + // If no reasoning tags found, check if we should treat everything as reasoning + if (builder.syntax().thinking_forced_open) { + // If thinking is forced open but no tags found, treat everything as reasoning + LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__); + builder.add_reasoning_content(builder.consume_rest()); + } else { + LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__); + // <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|> + common_chat_parse_deepseek_v3_1_content(builder); + } + } +} + static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; auto prompt = apply(tmpl, inputs); @@ -1830,7 +2035,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat // If thinking_forced_open, then we capture the tag in the grammar, // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + ( - "(\\s*" + "\\s*(" "(?:" "||||)?" @@ -2060,6 +2265,33 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) { } } +static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) { + // Parse thinking tags + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // Look for tool calls + static const common_regex tool_call_regex(regex_escape("")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + // Expect JSON array of tool calls + auto tool_calls_data = builder.consume_json(); + if (tool_calls_data.json.is_array()) { + if (!builder.try_consume_literal("")) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + builder.add_tool_calls(tool_calls_data.json); + } else { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + builder.add_content(builder.consume_rest()); +} + static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { // Parse thinking tags first - this handles the main reasoning content builder.try_parse_reasoning("", ""); @@ -2263,6 +2495,12 @@ static common_chat_params common_chat_templates_apply_jinja( } } + // DeepSeek V3.1: detect based on specific patterns in the template + if (src.find("message['prefix'] is defined and message['prefix'] and thinking") != std::string::npos && + params.json_schema.is_null()) { + return common_chat_params_init_deepseek_v3_1(tmpl, params); + } + // DeepSeek R1: use handler in all cases except json schema (thinking / tools). if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) { return common_chat_params_init_deepseek_r1(tmpl, params); @@ -2293,6 +2531,11 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_seed_oss(tmpl, params, inputs); } + // Nemotron v2 + if (src.find("") != std::string::npos) { + return common_chat_params_init_nemotron_v2(tmpl, params); + } + // Use generic handler when mixing tools + JSON schema. // TODO: support that mix in handlers below. if ((params.tools.is_array() && params.json_schema.is_object())) { @@ -2430,6 +2673,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_DEEPSEEK_R1: common_chat_parse_deepseek_r1(builder); break; + case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: + common_chat_parse_deepseek_v3_1(builder); + break; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: common_chat_parse_functionary_v3_2(builder); break; @@ -2454,6 +2700,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_SEED_OSS: common_chat_parse_seed_oss(builder); break; + case COMMON_CHAT_FORMAT_NEMOTRON_V2: + common_chat_parse_nemotron_v2(builder); + break; default: throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } diff --git a/common/chat.h b/common/chat.h index b09ff3b126..5170fc14f4 100644 --- a/common/chat.h +++ b/common/chat.h @@ -107,11 +107,13 @@ enum common_chat_format { COMMON_CHAT_FORMAT_FIREFUNCTION_V2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, + COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, COMMON_CHAT_FORMAT_HERMES_2_PRO, COMMON_CHAT_FORMAT_COMMAND_R7B, COMMON_CHAT_FORMAT_GRANITE, COMMON_CHAT_FORMAT_GPT_OSS, COMMON_CHAT_FORMAT_SEED_OSS, + COMMON_CHAT_FORMAT_NEMOTRON_V2, COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; @@ -198,6 +200,8 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_p common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice); +bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates); + // Parses a JSON array of messages in OpenAI's chat completion API format. // T can be std::string containing JSON or nlohmann::ordered_json template std::vector common_chat_msgs_parse_oaicompat(const T & messages); diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 637891f506..182c787544 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -843,9 +843,10 @@ public: _build_object_rule( properties, required, name, schema.contains("additionalProperties") ? schema["additionalProperties"] : json())); - } else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) { + } else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) { std::unordered_set required; std::vector> properties; + std::map enum_values; std::string hybrid_name = name; std::function add_component = [&](const json & comp_schema, bool is_required) { if (comp_schema.contains("$ref")) { @@ -857,6 +858,14 @@ public: required.insert(prop.key()); } } + } else if (comp_schema.contains("enum")) { + for (const auto & v : comp_schema["enum"]) { + const auto rule = _generate_constant_rule(v); + if (enum_values.find(rule) == enum_values.end()) { + enum_values[rule] = 0; + } + enum_values[rule] += 1; + } } else { // todo warning } @@ -870,6 +879,17 @@ public: add_component(t, true); } } + if (!enum_values.empty()) { + std::vector enum_intersection; + for (const auto & p : enum_values) { + if (p.second == schema["allOf"].size()) { + enum_intersection.push_back(p.first); + } + } + if (!enum_intersection.empty()) { + return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space"); + } + } return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json())); } else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) { json items = schema.contains("items") ? schema["items"] : schema["prefixItems"]; diff --git a/common/log.cpp b/common/log.cpp index 52b31470c4..4ccdbd17cd 100644 --- a/common/log.cpp +++ b/common/log.cpp @@ -4,17 +4,52 @@ #include #include #include +#include +#include #include #include #include #include +#if defined(_WIN32) +# include +# include +# define isatty _isatty +# define fileno _fileno +#else +# include +#endif // defined(_WIN32) + int common_log_verbosity_thold = LOG_DEFAULT_LLAMA; void common_log_set_verbosity_thold(int verbosity) { common_log_verbosity_thold = verbosity; } +// Auto-detect if colors should be enabled based on terminal and environment +static bool common_log_should_use_colors_auto() { + // Check NO_COLOR environment variable (https://no-color.org/) + if (const char * no_color = std::getenv("NO_COLOR")) { + if (no_color[0] != '\0') { + return false; + } + } + + // Check TERM environment variable + if (const char * term = std::getenv("TERM")) { + if (std::strcmp(term, "dumb") == 0) { + return false; + } + } + + // Check if stdout and stderr are connected to a terminal + // We check both because log messages can go to either + bool stdout_is_tty = isatty(fileno(stdout)); + bool stderr_is_tty = isatty(fileno(stderr)); + + return stdout_is_tty || stderr_is_tty; +} + static int64_t t_us() { return std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); } @@ -353,6 +388,11 @@ struct common_log * common_log_init() { struct common_log * common_log_main() { static struct common_log log; + static std::once_flag init_flag; + std::call_once(init_flag, [&]() { + // Set default to auto-detect colors + log.set_colors(common_log_should_use_colors_auto()); + }); return &log; } @@ -380,8 +420,19 @@ void common_log_set_file(struct common_log * log, const char * file) { log->set_file(file); } -void common_log_set_colors(struct common_log * log, bool colors) { - log->set_colors(colors); +void common_log_set_colors(struct common_log * log, log_colors colors) { + if (colors == LOG_COLORS_AUTO) { + log->set_colors(common_log_should_use_colors_auto()); + return; + } + + if (colors == LOG_COLORS_DISABLED) { + log->set_colors(false); + return; + } + + GGML_ASSERT(colors == LOG_COLORS_ENABLED); + log->set_colors(true); } void common_log_set_prefix(struct common_log * log, bool prefix) { diff --git a/common/log.h b/common/log.h index c56bb50d95..f329b434c9 100644 --- a/common/log.h +++ b/common/log.h @@ -24,6 +24,12 @@ #define LOG_DEFAULT_DEBUG 1 #define LOG_DEFAULT_LLAMA 0 +enum log_colors { + LOG_COLORS_AUTO = -1, + LOG_COLORS_DISABLED = 0, + LOG_COLORS_ENABLED = 1, +}; + // needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower // set via common_log_set_verbosity() extern int common_log_verbosity_thold; @@ -65,10 +71,10 @@ void common_log_add(struct common_log * log, enum ggml_log_level level, const ch // D - debug (stderr, V = LOG_DEFAULT_DEBUG) // -void common_log_set_file (struct common_log * log, const char * file); // not thread-safe -void common_log_set_colors (struct common_log * log, bool colors); // not thread-safe -void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log -void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix +void common_log_set_file (struct common_log * log, const char * file); // not thread-safe +void common_log_set_colors (struct common_log * log, log_colors colors); // not thread-safe +void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log +void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix // helper macros for logging // use these to avoid computing log arguments if the verbosity of the log is higher than the threshold diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 717b8a6595..bbc21813f8 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -5128,6 +5128,20 @@ class EmbeddingGemma(Gemma3Model): def set_gguf_parameters(self): super().set_gguf_parameters() + + # Override the sliding window size as it gets adjusted by the Gemma3TextConfig + # constructor. We want to use the value from the original model's config.json. + # ref: https://github.com/huggingface/transformers/pull/40700 + with open(self.dir_model / "config.json", "r", encoding="utf-8") as f: + config = json.load(f) + orig_sliding_window = config.get("sliding_window") + if orig_sliding_window is None: + raise ValueError("sliding_window not found in model config - this is required for the model") + + logger.info(f"Using original sliding_window from config: {orig_sliding_window} " + f"instead of {self.hparams['sliding_window']}") + self.gguf_writer.add_sliding_window(orig_sliding_window) + self._try_set_pooling_type() @@ -6687,6 +6701,8 @@ class T5Model(TextModel): self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) self.gguf_writer.add_block_count(self.hparams["num_layers"]) + if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None: + self.gguf_writer.add_decoder_block_count(dec_n_layer) self.gguf_writer.add_head_count(self.hparams["num_heads"]) self.gguf_writer.add_key_length(self.hparams["d_kv"]) self.gguf_writer.add_value_length(self.hparams["d_kv"]) diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index a67c0536a4..befe8ab9cc 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -12,7 +12,7 @@ import json from math import prod from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast -from transformers import AutoConfig +from transformers import AutoConfig, AutoTokenizer import torch @@ -26,6 +26,8 @@ import gguf # reuse model definitions from convert_hf_to_gguf.py from convert_hf_to_gguf import LazyTorchTensor, ModelBase +from gguf.constants import GGUFValueType + logger = logging.getLogger("lora-to-gguf") @@ -369,7 +371,31 @@ if __name__ == '__main__': self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora") def set_gguf_parameters(self): + logger.debug("GGUF KV: %s = %d", gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha) self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha) + alora_invocation_tokens = lparams.get("alora_invocation_tokens") + invocation_string = lparams.get("invocation_string") + if invocation_string and not alora_invocation_tokens: + logger.debug("Tokenizing invocation_string -> alora_invocation_tokens") + base_model_path_or_id = hparams.get("_name_or_path") + try: + tokenizer = AutoTokenizer.from_pretrained(base_model_path_or_id) + except ValueError: + logger.error("Unable to load tokenizer from %s", base_model_path_or_id) + raise + # NOTE: There's an off-by-one with the older aLoRAs where + # the invocation string includes the "<|start_of_turn|>" + # token, but the adapters themselves were trained to + # activate _after_ that first token, so we drop it here. + alora_invocation_tokens = tokenizer(invocation_string)["input_ids"][1:] + if alora_invocation_tokens: + logger.debug("GGUF KV: %s = %s", gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS, alora_invocation_tokens) + self.gguf_writer.add_key_value( + gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS, + alora_invocation_tokens, + GGUFValueType.ARRAY, + GGUFValueType.UINT32, + ) def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: # Never add extra tensors (e.g. rope_freqs) for LoRA adapters diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md index 357253f43a..e45fc7dd28 100755 --- a/docs/backend/CANN.md +++ b/docs/backend/CANN.md @@ -314,3 +314,11 @@ Converting the matmul weight format from ND to NZ to improve performance. Enable ### GGML_CANN_ACL_GRAPH Operators are executed using ACL graph execution, rather than in op-by-op (eager) mode. Enabled by default. + +### GGML_CANN_GRAPH_CACHE_CAPACITY + +Maximum number of compiled CANN graphs kept in the LRU cache, default is 12. When the number of cached graphs exceeds this capacity, the least recently used graph will be evicted. + +### GGML_CANN_PREFILL_USE_GRAPH + +Enable ACL graph execution during the prefill stage, default is false. This option is only effective when FA is enabled. diff --git a/docs/build-s390x.md b/docs/build-s390x.md index f3cdd63be3..94f8ffdb74 100644 --- a/docs/build-s390x.md +++ b/docs/build-s390x.md @@ -42,18 +42,6 @@ cmake --build build --config Release -j $(nproc) cmake --build build --config Release -j $(nproc) ``` -- By default, NNPA is disabled by default. To enable it: - - ```bash - cmake -S . -B build \ - -DCMAKE_BUILD_TYPE=Release \ - -DGGML_BLAS=ON \ - -DGGML_BLAS_VENDOR=OpenBLAS \ - -DGGML_NNPA=ON - - cmake --build build --config Release -j $(nproc) - ``` - - For debug builds: ```bash @@ -164,15 +152,11 @@ All models need to be converted to Big-Endian. You can achieve this in three cas Only available in IBM z15/LinuxONE 3 or later system with the `-DGGML_VXE=ON` (turned on by default) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z14/arch12. In such systems, the APIs can still run but will use a scalar implementation. -### 2. NNPA Vector Intrinsics Acceleration - -Only available in IBM z16/LinuxONE 4 or later system with the `-DGGML_NNPA=ON` (turned off by default) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z15/arch13. In such systems, the APIs can still run but will use a scalar implementation. - -### 3. zDNN Accelerator (WIP) +### 2. zDNN Accelerator (WIP) Only available in IBM z17/LinuxONE 5 or later system with the `-DGGML_ZDNN=ON` compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z15/arch13. In such systems, the APIs will default back to CPU routines. -### 4. Spyre Accelerator +### 3. Spyre Accelerator _Only available with IBM z17 / LinuxONE 5 or later system. No support currently available._ @@ -230,10 +214,6 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl CXXFLAGS="-include cstdint" pip3 install -r requirements.txt ``` -5. `-DGGML_NNPA=ON` generates gibberish output - - Answer: We are aware of this as detailed in [this issue](https://github.com/ggml-org/llama.cpp/issues/14877). Please either try reducing the number of threads, or disable the compile option using `-DGGML_NNPA=OFF`. - ## Getting Help on IBM Z & LinuxONE 1. **Bugs, Feature Requests** @@ -258,38 +238,38 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl ## Appendix B: SIMD Support Matrix -| | VX/VXE/VXE2 | NNPA | zDNN | Spyre | -| ---------- | ----------- | ---- | ---- | ----- | -| FP32 | ✅ | ✅ | ✅ | ❓ | -| FP16 | ✅ | ✅ | ❓ | ❓ | -| BF16 | 🚫 | 🚫 | ❓ | ❓ | -| Q4_0 | ✅ | ✅ | ❓ | ❓ | -| Q4_1 | ✅ | ✅ | ❓ | ❓ | -| MXFP4 | 🚫 | 🚫 | ❓ | ❓ | -| Q5_0 | ✅ | ✅ | ❓ | ❓ | -| Q5_1 | ✅ | ✅ | ❓ | ❓ | -| Q8_0 | ✅ | ✅ | ❓ | ❓ | -| Q2_K | 🚫 | 🚫 | ❓ | ❓ | -| Q3_K | ✅ | ✅ | ❓ | ❓ | -| Q4_K | ✅ | ✅ | ❓ | ❓ | -| Q5_K | ✅ | ✅ | ❓ | ❓ | -| Q6_K | ✅ | ✅ | ❓ | ❓ | -| TQ1_0 | 🚫 | 🚫 | ❓ | ❓ | -| TQ2_0 | 🚫 | 🚫 | ❓ | ❓ | -| IQ2_XXS | 🚫 | 🚫 | ❓ | ❓ | -| IQ2_XS | 🚫 | 🚫 | ❓ | ❓ | -| IQ2_S | 🚫 | 🚫 | ❓ | ❓ | -| IQ3_XXS | 🚫 | 🚫 | ❓ | ❓ | -| IQ3_S | 🚫 | 🚫 | ❓ | ❓ | -| IQ1_S | 🚫 | 🚫 | ❓ | ❓ | -| IQ1_M | 🚫 | 🚫 | ❓ | ❓ | -| IQ4_NL | ✅ | ✅ | ❓ | ❓ | -| IQ4_XS | ✅ | ✅ | ❓ | ❓ | -| FP32->FP16 | 🚫 | ✅ | ❓ | ❓ | -| FP16->FP32 | 🚫 | ✅ | ❓ | ❓ | +| | VX/VXE/VXE2 | zDNN | Spyre | +|------------|-------------|------|-------| +| FP32 | ✅ | ✅ | ❓ | +| FP16 | ✅ | ❓ | ❓ | +| BF16 | 🚫 | ❓ | ❓ | +| Q4_0 | ✅ | ❓ | ❓ | +| Q4_1 | ✅ | ❓ | ❓ | +| MXFP4 | 🚫 | ❓ | ❓ | +| Q5_0 | ✅ | ❓ | ❓ | +| Q5_1 | ✅ | ❓ | ❓ | +| Q8_0 | ✅ | ❓ | ❓ | +| Q2_K | 🚫 | ❓ | ❓ | +| Q3_K | ✅ | ❓ | ❓ | +| Q4_K | ✅ | ❓ | ❓ | +| Q5_K | ✅ | ❓ | ❓ | +| Q6_K | ✅ | ❓ | ❓ | +| TQ1_0 | 🚫 | ❓ | ❓ | +| TQ2_0 | 🚫 | ❓ | ❓ | +| IQ2_XXS | 🚫 | ❓ | ❓ | +| IQ2_XS | 🚫 | ❓ | ❓ | +| IQ2_S | 🚫 | ❓ | ❓ | +| IQ3_XXS | 🚫 | ❓ | ❓ | +| IQ3_S | 🚫 | ❓ | ❓ | +| IQ1_S | 🚫 | ❓ | ❓ | +| IQ1_M | 🚫 | ❓ | ❓ | +| IQ4_NL | ✅ | ❓ | ❓ | +| IQ4_XS | ✅ | ❓ | ❓ | +| FP32->FP16 | 🚫 | ❓ | ❓ | +| FP16->FP32 | 🚫 | ❓ | ❓ | - ✅ - acceleration available - 🚫 - acceleration unavailable, will still run using scalar implementation - ❓ - acceleration unknown, please contribute if you can test it yourself -Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Aug 22, 2025. +Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Sep 6, 2025. diff --git a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp index bdf0eed2a9..767198aafa 100644 --- a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +++ b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp @@ -333,17 +333,17 @@ static void print_params(struct my_llama_hparams * params) { } static void print_tensor_info(const struct ggml_context * ctx) { - for (auto t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + for (auto * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { LOG_INF("%s: Allocating ", __func__); int64_t total = 1; int i = 0; for (; i < ggml_n_dims(t); ++i) { - if (i > 0) LOG("x "); - LOG("[%" PRId64 "] ", t->ne[i]); + if (i > 0) { LOG_INF("x "); } + LOG_INF("[%" PRId64 "] ", t->ne[i]); total *= t->ne[i]; } - if (i > 1) LOG("= [%" PRId64 "] ", total); - LOG("float space for %s\n", ggml_get_name(t)); + if (i > 1) { LOG_INF("= [%" PRId64 "] ", total); } + LOG_INF("float space for %s\n", ggml_get_name(t)); } } diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index d4ef751fbb..cefa39a57c 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -28,6 +28,15 @@ static std::string ggml_ne_string(const ggml_tensor * t) { return str; } +static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) { + union { + float f; + uint32_t i; + } u; + u.i = (uint32_t)h.bits << 16; + return u.f; +} + static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; float v; @@ -43,6 +52,8 @@ static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * v = (float) *(int16_t *) &data[i]; } else if (type == GGML_TYPE_I8) { v = (float) *(int8_t *) &data[i]; + } else if (type == GGML_TYPE_BF16) { + v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]); } else { GGML_ABORT("fatal error"); } diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index ed37958554..2d57549046 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -586,9 +586,10 @@ class SchemaConverter: properties = list(schema.get('properties', {}).items()) return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties'))) - elif schema_type in (None, 'object') and 'allOf' in schema: + elif schema_type in (None, 'object', 'string') and 'allOf' in schema: required = set() properties = [] + enum_sets = [] hybrid_name = name def add_component(comp_schema, is_required): if (ref := comp_schema.get('$ref')) is not None: @@ -600,6 +601,9 @@ class SchemaConverter: if is_required: required.add(prop_name) + if 'enum' in comp_schema: + enum_sets.append(set(comp_schema['enum'])) + for t in schema['allOf']: if 'anyOf' in t: for tt in t['anyOf']: @@ -607,6 +611,15 @@ class SchemaConverter: else: add_component(t, is_required=True) + if enum_sets: + enum_intersection = enum_sets[0] + for s in enum_sets[1:]: + enum_intersection &= s + + if enum_intersection: + rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space' + return self._add_rule(rule_name, rule) + return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None)) elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): diff --git a/examples/model-conversion/requirements.txt b/examples/model-conversion/requirements.txt index b8148b269a..ac9f69e10b 100644 --- a/examples/model-conversion/requirements.txt +++ b/examples/model-conversion/requirements.txt @@ -1,5 +1,6 @@ --extra-index-url https://download.pytorch.org/whl/cpu -torch~=2.6.0 -torchvision~=0.21.0 -transformers~=4.55.0 -huggingface-hub~=0.34.0 +torch +torchvision +transformers +huggingface-hub +accelerate diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index f6188ea6f3..78a54abf13 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -9,15 +9,134 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig import torch import numpy as np -unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') +### If you want to dump RoPE activations, apply this monkey patch to the model +### class from Transformers that you are running (replace apertus.modeling_apertus +### with the proper package and class for your model +### === START ROPE DEBUG === +# from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb -parser = argparse.ArgumentParser(description='Process model with specified path') -parser.add_argument('--model-path', '-m', help='Path to the model') +# orig_rope = apply_rotary_pos_emb +# torch.set_printoptions(threshold=float('inf')) +# torch.set_printoptions(precision=6, sci_mode=False) + +# def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +# # log inputs +# summarize(q, "RoPE.q_in") +# summarize(k, "RoPE.k_in") + +# # call original +# q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim) + +# # log outputs +# summarize(q_out, "RoPE.q_out") +# summarize(k_out, "RoPE.k_out") + +# return q_out, k_out + +# # Patch it +# import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402 +# apertus_mod.apply_rotary_pos_emb = debug_rope +### == END ROPE DEBUG === + + +def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3): + """ + Print a tensor in llama.cpp debug style. + + Supports: + - 2D tensors (seq, hidden) + - 3D tensors (batch, seq, hidden) + - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head + + Shows first and last max_vals of each vector per sequence position. + """ + t = tensor.detach().to(torch.float32).cpu() + + # Determine dimensions + if t.ndim == 3: + _, s, _ = t.shape + elif t.ndim == 2: + _, s = 1, t.shape[0] + t = t.unsqueeze(0) + elif t.ndim == 4: + _, s, _, _ = t.shape + else: + print(f"Skipping tensor due to unsupported dimensions: {t.ndim}") + return + + ten_shape = t.shape + + print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}") + print(" [") + print(" [") + + # Determine indices for first and last sequences + first_indices = list(range(min(s, max_seq))) + last_indices = list(range(max(0, s - max_seq), s)) + + # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq + has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s) + + # Combine indices + if has_overlap: + # If there's overlap, just use the combined unique indices + indices = sorted(list(set(first_indices + last_indices))) + separator_index = None + else: + # If no overlap, we'll add a separator between first and last sequences + indices = first_indices + last_indices + separator_index = len(first_indices) + + for i, si in enumerate(indices): + # Add separator if needed + if separator_index is not None and i == separator_index: + print(" ...") + + # Extract appropriate slice + vec = t[0, si] + if vec.ndim == 2: # 4D case: flatten heads × dim_per_head + flat = vec.flatten().tolist() + else: # 2D or 3D case + flat = vec.tolist() + + # First and last slices + first = flat[:max_vals] + last = flat[-max_vals:] if len(flat) >= max_vals else flat + first_str = ", ".join(f"{v:12.4f}" for v in first) + last_str = ", ".join(f"{v:12.4f}" for v in last) + + print(f" [{first_str}, ..., {last_str}]") + + print(" ],") + print(" ]") + print(f" sum = {t.sum().item():.6f}\n") + + +def debug_hook(name): + def fn(_m, input, output): + if isinstance(input, torch.Tensor): + summarize(input, name + "_in") + elif isinstance(input, (tuple, list)) and isinstance(input[0], torch.Tensor): + summarize(input[0], name + "_in") + if isinstance(output, torch.Tensor): + summarize(output, name + "_out") + elif isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor): + summarize(output[0], name + "_out") + + return fn + + +unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME") + +parser = argparse.ArgumentParser(description="Process model with specified path") +parser.add_argument("--model-path", "-m", help="Path to the model") args = parser.parse_args() -model_path = os.environ.get('MODEL_PATH', args.model_path) +model_path = os.environ.get("MODEL_PATH", args.model_path) if model_path is None: - parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable") + parser.error( + "Model path must be specified either via --model-path argument or MODEL_PATH environment variable" + ) config = AutoConfig.from_pretrained(model_path) @@ -34,18 +153,30 @@ config = AutoConfig.from_pretrained(model_path) if unreleased_model_name: model_name_lower = unreleased_model_name.lower() - unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" + unreleased_module_path = ( + f"transformers.models.{model_name_lower}.modular_{model_name_lower}" + ) class_name = f"{unreleased_model_name}ForCausalLM" print(f"Importing unreleased model module: {unreleased_module_path}") try: - model_class = getattr(importlib.import_module(unreleased_module_path), class_name) - model = model_class.from_pretrained(model_path) # Note: from_pretrained, not fromPretrained + model_class = getattr( + importlib.import_module(unreleased_module_path), class_name + ) + model = model_class.from_pretrained( + model_path + ) # Note: from_pretrained, not fromPretrained except (ImportError, AttributeError) as e: print(f"Failed to import or load model: {e}") exit(1) else: - model = AutoModelForCausalLM.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, device_map="auto", offload_folder="offload" + ) + +for name, module in model.named_modules(): + if len(list(module.children())) == 0: # only leaf modules + module.register_forward_hook(debug_hook(name)) model_name = os.path.basename(model_path) # Printing the Model class to allow for easier debugging. This can be useful diff --git a/examples/model-conversion/scripts/embedding/modelcard.template b/examples/model-conversion/scripts/embedding/modelcard.template index 75c580524f..9e63042b7b 100644 --- a/examples/model-conversion/scripts/embedding/modelcard.template +++ b/examples/model-conversion/scripts/embedding/modelcard.template @@ -7,7 +7,7 @@ base_model: Recommended way to run this model: ```sh -llama-server -hf {namespace}/{model_name}-GGUF +llama-server -hf {namespace}/{model_name}-GGUF --embeddings ``` Then the endpoint can be accessed at http://localhost:8080/embedding, for diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 9ef88c6fd0..d06464f5eb 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -134,7 +134,6 @@ option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON) option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON) option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) option(GGML_VXE "ggml: enable vxe" ON) -option(GGML_NNPA "ggml: enable nnpa" OFF) # temp disabled by default, see: https://github.com/ggml-org/llama.cpp/issues/14877 option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 4f246f6ccd..ab297e0c6f 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -132,6 +132,8 @@ extern "C" { GGML_BACKEND_DEVICE_TYPE_CPU, // GPU device using dedicated memory GGML_BACKEND_DEVICE_TYPE_GPU, + // integrated GPU device using host memory + GGML_BACKEND_DEVICE_TYPE_IGPU, // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX) GGML_BACKEND_DEVICE_TYPE_ACCEL }; @@ -150,11 +152,21 @@ extern "C" { // all the device properties struct ggml_backend_dev_props { + // device name const char * name; + // device description const char * description; + // device free memory in bytes size_t memory_free; + // device total memory in bytes size_t memory_total; + // device type enum ggml_backend_dev_type type; + // device id + // for PCI devices, this should be the PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:01:00.0") + // if the id is unknown, this should be NULL + const char * device_id; + // device capabilities struct ggml_backend_dev_caps caps; }; diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index be40b10097..9edd485136 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -101,7 +101,6 @@ extern "C" { GGML_BACKEND_API int ggml_cpu_has_riscv_v (void); GGML_BACKEND_API int ggml_cpu_has_vsx (void); GGML_BACKEND_API int ggml_cpu_has_vxe (void); - GGML_BACKEND_API int ggml_cpu_has_nnpa (void); GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void); GGML_BACKEND_API int ggml_cpu_has_llamafile (void); @@ -135,6 +134,7 @@ extern "C" { GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void); GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t); + GGML_BACKEND_API void ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t); GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t); GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t); GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t); diff --git a/ggml/include/ggml-metal.h b/ggml/include/ggml-metal.h index a610694423..1163438bc2 100644 --- a/ggml/include/ggml-metal.h +++ b/ggml/include/ggml-metal.h @@ -43,14 +43,8 @@ GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void); GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend); -GGML_DEPRECATED( - GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size), - "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713"); - GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data); -GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); - // helper to check if the device supports a specific family // ideally, the user code should be doing these checks // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index c01b98ac78..b7b472c56e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1404,6 +1404,7 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // note: casting from f32 to i32 will discard the fractional part GGML_API struct ggml_tensor * ggml_cast( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1528,7 +1529,11 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - // supports 3D: a->ne[2] == b->ne[1] + // supports 4D a: + // a [n_embd, ne1, ne2, ne3] + // b I32 [n_rows, ne2, ne3, 1] + // + // return [n_embd, n_rows, ne2, ne3] GGML_API struct ggml_tensor * ggml_get_rows( struct ggml_context * ctx, struct ggml_tensor * a, // data diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index c36c12d657..89d80db6e6 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -8,7 +8,7 @@ extern "C" { #endif - #define GGML_BACKEND_API_VERSION 1 + #define GGML_BACKEND_API_VERSION 2 // // Backend buffer type @@ -114,6 +114,9 @@ extern "C" { void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event); // wait for an event on on a different stream void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event); + + // (optional) sort/optimize the nodes in the graph + void (*optimize_graph) (ggml_backend_t backend, struct ggml_cgraph * cgraph); }; struct ggml_backend { diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 5f02a710a1..7002cb07e0 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -400,9 +400,8 @@ ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const ggml_backend_t ggml_backend_init_best(void) { ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU); - if (!dev) { - dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - } + dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU); + dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); if (!dev) { return nullptr; } diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index f615ab4bee..7646f3f134 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -463,6 +463,13 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) backend->iface.event_wait(backend, event); } +static void ggml_backend_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + GGML_ASSERT(backend); + if (backend->iface.optimize_graph != NULL) { + backend->iface.optimize_graph(backend, cgraph); + } +} + // Backend device const char * ggml_backend_dev_name(ggml_backend_dev_t device) { @@ -1298,6 +1305,10 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra struct ggml_backend_sched_split * split = &sched->splits[i]; split->graph = ggml_graph_view(graph, split->i_start, split->i_end); + // Optimize this split of the graph. This needs to happen before we make graph_copy, + // so they are in sync. + ggml_backend_optimize_graph(sched->backends[split->backend_id], &split->graph); + // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split for (int j = 0; j < split->n_inputs; j++) { assert(graph_copy->size > (graph_copy->n_nodes + 1)); diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index aeac2e5744..cdfc5a9bc2 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -270,6 +270,7 @@ static struct ggml_backend_i blas_backend_i = { /* .graph_compute = */ ggml_backend_blas_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_blas_guid(void) { diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index ac2e2e1adf..434023dd22 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2268,8 +2268,6 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx, * stream, and persistent buffers for rope init/cache. * @param dst The destination ggml_tensor whose computation * depends on the RoPE values (usually Qcur/Kcur). - * @param sin_tensor_buffer Pre-allocated buffer for storing repeated sin values. - * @param cos_tensor_buffer Pre-allocated buffer for storing repeated cos values. * @param theta_scale Scalar exponent base for computing theta scale values. * @param freq_scale Frequency scaling factor, applied to theta scale. * @param attn_factor Attention scaling factor, applied to sin/cos. @@ -2277,17 +2275,23 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx, * (dim expansion vs repeat_interleave). */ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, - void* sin_tensor_buffer, void* cos_tensor_buffer, float* corr_dims, float ext_factor, float theta_scale, float freq_scale, float attn_factor, bool is_neox) { - // int sin/cos cache, cache has different repeat method depond on - // @param.is_neox - ggml_tensor* src0 = dst->src[0]; // input ggml_tensor* src1 = dst->src[1]; // position ggml_tensor* src2 = dst->src[2]; // freq_factors + if(src2 == nullptr && ctx.rope_cache.cached + && ctx.rope_cache.ext_factor == ext_factor + && ctx.rope_cache.theta_scale == theta_scale + && ctx.rope_cache.freq_scale == freq_scale + && ctx.rope_cache.attn_factor == attn_factor + && ctx.rope_cache.is_neox == is_neox) { + // use cache. + return; + } + int64_t theta_scale_length = src0->ne[0] / 2; int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1}; size_t theta_scale_nb[] = {sizeof(float), sizeof(float), sizeof(float), @@ -2316,8 +2320,6 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, ctx.rope_cache.freq_scale != freq_scale) { ctx.rope_cache.theta_scale_length = theta_scale_length; - ctx.rope_cache.theta_scale = theta_scale; - ctx.rope_cache.freq_scale = freq_scale; if (ctx.rope_cache.theta_scale_cache != nullptr) { ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache)); @@ -2342,7 +2344,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, // return MIN(1, MAX(0, y)) - 1; yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); void* yarn_ramp_buffer = yarn_ramp_allocator.get(); - acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float_t), + acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); float zero_value = 0, one_value = 1; float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); @@ -2411,6 +2413,20 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor); } + // init sin_repeat && cos_repeat, only to accelerate first layer on each device + if (position_length > ctx.rope_cache.position_length) { + ctx.rope_cache.position_length = position_length; + if (ctx.rope_cache.sin_cache != nullptr) { + ACL_CHECK(aclrtFree(ctx.rope_cache.sin_cache)); + } + if (ctx.rope_cache.cos_cache != nullptr) { + ACL_CHECK(aclrtFree(ctx.rope_cache.cos_cache)); + } + int64_t repeat_theta_length = theta_scale_length * position_length * 2; + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.sin_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); + } + // position aclTensor* acl_position_tensor = ggml_cann_create_tensor( src1->data, ggml_cann_type_mapping(src1->type), @@ -2462,10 +2478,10 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; } aclTensor* acl_sin_repeat_tensor = - ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float), + ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclTensor* acl_cos_repeat_tensor = - ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float), + ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); // repeat @@ -2483,6 +2499,14 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, num_repeats, output_size); } + // Other layers use cache except first layer. + ctx.rope_cache.cached = true; + ctx.rope_cache.ext_factor = ext_factor; + ctx.rope_cache.theta_scale = theta_scale; + ctx.rope_cache.freq_scale = freq_scale; + ctx.rope_cache.attn_factor = attn_factor; + ctx.rope_cache.is_neox = is_neox; + ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor, acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor, acl_cos_repeat_tensor); @@ -2504,10 +2528,7 @@ aclnnStatus aclnnRotaryPositionEmbedding(void* workspace, #endif void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { - // TODO: use ascendc - // Only test with LLAMA model. ggml_tensor* src0 = dst->src[0]; // input - ggml_tensor* src1 = dst->src[1]; // param float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; @@ -2538,15 +2559,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; - // sin/cos tensor length. - int64_t repeat_theta_length = src0->ne[0] * src1->ne[0]; - ggml_cann_pool_alloc sin_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float)); - ggml_cann_pool_alloc cos_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float)); - void *sin_tensor_buffer = sin_tensor_allocator.get(); - void *cos_tensor_buffer = cos_tensor_allocator.get(); - // init ctx.rope_cos/rope_sin cache - aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer, corr_dims, ext_factor, + aclnn_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox); int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; @@ -2556,10 +2570,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; } aclTensor* acl_sin_reshape_tensor = - ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float), + ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclTensor* acl_cos_reshape_tensor = - ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float), + ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclTensor* acl_src = ggml_cann_create_tensor(src0); diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index e295f4ab47..c5fce8dc91 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -38,6 +38,7 @@ #include #include #include +#include #include "../include/ggml-cann.h" #include "../include/ggml.h" @@ -106,6 +107,7 @@ int32_t ggml_cann_get_device(); std::optional get_env(const std::string& name); bool parse_bool(const std::string& value); +int parse_integer(const std::string& value); /** * @brief Abstract base class for memory pools used by CANN. @@ -350,7 +352,7 @@ struct ggml_graph_node_properties { struct ggml_cann_graph { ~ggml_cann_graph() { if (graph != nullptr) { - aclmdlRIDestroy(graph); + ACL_CHECK(aclmdlRIDestroy(graph)); } } @@ -358,6 +360,64 @@ struct ggml_cann_graph { std::vector ggml_graph_properties; }; + +/** + * @brief LRU cache for managing ggml_cann_graph objects. + * + * This class maintains a list of shared_ptr to ggml_cann_graph objects + * and enforces a maximum capacity. It provides methods to push new graphs, + * move existing graphs to the front (most recently used), and clear the cache. + */ +struct ggml_cann_graph_lru_cache { + size_t capacity; /**< Maximum number of graphs in the cache. */ + + std::list cache_list; /**< List storing cached graphs as raw pointers. */ + + ggml_cann_graph_lru_cache() { + capacity = parse_integer(get_env("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12")); + } + + /** + * @brief Push a new graph to the front of the cache. + * If the cache exceeds capacity, the least recently used graph is deleted. + * @param new_node Pointer to the new ggml_cann_graph to cache. + * Ownership is transferred to the cache (cache will delete it). + */ + void push(ggml_cann_graph* new_node) { + if (cache_list.size() >= capacity) { + ggml_cann_graph* old = cache_list.back(); + cache_list.pop_back(); + delete old; // free the old graph + } + cache_list.push_front(new_node); + } + + /** + * @brief Move an existing graph to the front of the cache. + * @param node Pointer to the ggml_cann_graph to move. + */ + void move_to_front(ggml_cann_graph* node) { + cache_list.remove(node); + cache_list.push_front(node); + } + + /** + * @brief Clear all graphs from the cache (also frees memory). + */ + void clear() { + for (auto ptr : cache_list) { + delete ptr; + } + cache_list.clear(); + } + + /** + * @brief Destructor that clears the cache and frees all cached graphs. + */ + ~ggml_cann_graph_lru_cache() { + clear(); + } +}; #endif // USE_ACL_GRAPH struct ggml_cann_rope_cache { @@ -365,12 +425,27 @@ struct ggml_cann_rope_cache { if(theta_scale_cache != nullptr) { ACL_CHECK(aclrtFree(theta_scale_cache)); } + if(sin_cache != nullptr) { + ACL_CHECK(aclrtFree(sin_cache)); + } + if(cos_cache != nullptr) { + ACL_CHECK(aclrtFree(cos_cache)); + } } void* theta_scale_cache = nullptr; int64_t theta_scale_length = 0; + // sin/cos cache, used only to accelerate first layer on each device + void* sin_cache = nullptr; + void* cos_cache = nullptr; + int64_t position_length = 0; + // Properties to check before reusing the sincos cache + bool cached = false; + float ext_factor = 0.0f; float theta_scale = 0.0f; float freq_scale = 0.0f; + float attn_factor = 0.0f; + bool is_neox = false; }; struct ggml_cann_tensor_cache { @@ -394,7 +469,7 @@ struct ggml_backend_cann_context { aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */ #ifdef USE_ACL_GRAPH /// Cached CANN ACL graph used for executing the current ggml computation graph. - std::unique_ptr cann_graph; + ggml_cann_graph_lru_cache graph_lru_cache; bool acl_graph_mode = true; #endif cann_task_queue task_queue; diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 756ad8dfad..19a18a281d 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -116,6 +116,24 @@ bool parse_bool(const std::string& value) { return valid_values.find(value) != valid_values.end(); } +/** + * @brief Parse a string as an integer, returning 0 if invalid. + * + * This function attempts to convert the input string `value` to an `int`. + * If the string is not a valid integer or is out of the `int` range, + * it returns 0. + * + * @param value The string to parse. + * @return The parsed integer, or 0 if conversion fails. + */ +int parse_integer(const std::string& value) { + try { + return std::stoi(value); + } catch (...) { + return 0; + } +} + /** * @brief Initialize the CANN device information. * @@ -2092,16 +2110,17 @@ static bool ggml_backend_cann_cpy_tensor_async( ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, cann_ctx_src->stream())); - // record event on src stream after the copy - if (!cann_ctx_src->copy_event) { - ACL_CHECK(aclrtCreateEventWithFlag(&cann_ctx_src->copy_event, ACL_EVENT_SYNC)); - } - ACL_CHECK(aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream())); + // TODO: this event is not effective with acl graph mode, change to use aclrtSynchronizeStream + // if (!cann_ctx_src->copy_event) { + // ACL_CHECK(aclrtCreateEventWithFlag(&cann_ctx_src->copy_event, ACL_EVENT_SYNC)); + // } + // ACL_CHECK(aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream())); - // wait on dst stream for the copy to complete - ggml_cann_set_device(cann_ctx_dst->device); - ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(), cann_ctx_src->copy_event)); + // // wait on dst stream for the copy to complete + // ggml_cann_set_device(cann_ctx_dst->device); + // ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(), cann_ctx_src->copy_event)); + ACL_CHECK(aclrtSynchronizeStream(cann_ctx_src->stream())); } else { // src and dst are on the same backend ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, @@ -2130,30 +2149,52 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) { #ifdef USE_ACL_GRAPH /** - * @brief Populate the internal CANN graph node properties from the ggml computation graph. + * @brief Add a new CANN graph to the LRU cache by populating node properties from the ggml graph. * - * This function copies all node attributes (operation type, dimensions, strides, input sources, - * and operation parameters) into the cached CANN graph structure for later reuse or comparison. + * This function creates a new ggml_cann_graph object and fills its node properties + * (operation type, dimensions, strides, input sources, and operation parameters) + * based on the current ggml computation graph. * - * @param cann_ctx The CANN backend context. - * @param cgraph The ggml computational graph. + * Each node in the ggml graph is mapped to a property entry in the new CANN graph: + * - node address + * - operation type + * - shape (ne) and strides (nb) + * - source tensor addresses + * - operation parameters + * + * After initialization, the new graph is pushed into the LRU cache owned by the + * CANN backend context. The cache takes ownership of the graph and manages its + * lifetime (including deletion upon eviction). + * + * @param cann_ctx The CANN backend context containing the graph cache. + * @param cgraph The current ggml computation graph. */ -static void set_ggml_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) { - for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) { - ggml_tensor * node = cgraph->nodes[node_idx]; - cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_address = node->data; - cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_op = node->op; +static void add_lru_matched_graph_node_properties( + ggml_backend_cann_context * cann_ctx, + ggml_cgraph * cgraph) { + // Create a new ggml_cann_graph object on the heap (its lifetime is managed by the cache). + ggml_cann_graph * new_graph = new ggml_cann_graph(); + new_graph->ggml_graph_properties.resize(cgraph->n_nodes); - for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { - cann_ctx->cann_graph->ggml_graph_properties[node_idx].ne[dim] = node->ne[dim]; - cann_ctx->cann_graph->ggml_graph_properties[node_idx].nb[dim] = node->nb[dim]; + for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) { + ggml_tensor * node = cgraph->nodes[node_idx]; + auto & prop = new_graph->ggml_graph_properties[node_idx]; + + prop.node_address = node->data; + prop.node_op = node->op; + + std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne); + std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb); + + for (int src = 0; src < GGML_MAX_SRC; ++src) { + prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr; } - for (int src = 0; src < GGML_MAX_SRC; src++) { - cann_ctx->cann_graph->ggml_graph_properties[node_idx].src_address[src] = - node->src[src] ? node->src[src]->data : nullptr; - } - memcpy(cann_ctx->cann_graph->ggml_graph_properties[node_idx].op_params, node->op_params, GGML_MAX_OP_PARAMS); + + memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS); } + + // Insert into the LRU cache (cache takes ownership and will delete it when evicted). + cann_ctx->graph_lru_cache.push(new_graph); } /** @@ -2198,30 +2239,45 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } /** - * @brief Determine if the CANN graph needs to be rebuilt due to graph changes. + * @brief Check whether there is a cached CANN graph that matches the current ggml graph. * - * This checks whether the number or properties of ggml graph nodes have changed - * compared to the last captured CANN graph. If so, the CANN graph must be re-captured. + * This function iterates through the cached CANN graphs stored in the LRU cache and + * compares them against the given ggml computation graph. A match requires that the + * number of nodes is the same and that each node’s properties (operation type, + * dimensions, strides, inputs, and operation parameters) are identical. * - * @param cann_ctx The CANN backend context. + * If a matching graph is found, it is promoted to the front of the LRU cache and the + * function returns true. Otherwise, the function returns false, indicating that a new + * CANN graph needs to be captured. + * + * @param cann_ctx The CANN backend context containing the graph cache. * @param cgraph The current ggml computation graph. - * @return true if an update is required; false otherwise. + * @return true if a matching cached graph exists; false otherwise. */ -static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) { - // The number of nodes is different, so the graph needs to be reconstructed. - if (cann_ctx->cann_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { - cann_ctx->cann_graph->ggml_graph_properties.resize(cgraph->n_nodes); - return true; - } +static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) { + ggml_cann_graph_lru_cache &lru_cache = cann_ctx->graph_lru_cache; + for (auto &graph_ptr : lru_cache.cache_list) { + // Skip graphs with a different number of nodes. + if (graph_ptr->ggml_graph_properties.size() != static_cast(cgraph->n_nodes)) { + continue; + } - // The number of nodes is the same; iterate over each node to check whether they match. - for (int i = 0; i < cgraph->n_nodes; i++) { - bool has_matching_properties = ggml_graph_node_has_matching_properties( - cgraph->nodes[i], &cann_ctx->cann_graph->ggml_graph_properties[i]); - if(!has_matching_properties) { + // Check if all nodes match. + bool all_match = true; + for (int i = 0; i < cgraph->n_nodes; ++i) { + if (!ggml_graph_node_has_matching_properties(cgraph->nodes[i], &graph_ptr->ggml_graph_properties[i])) { + all_match = false; + break; + } + } + + if (all_match) { + // update cache_list && renturn graph_ptr + lru_cache.move_to_front(graph_ptr); return true; } } + return false; } #endif // USE_ACL_GRAPH @@ -2240,17 +2296,13 @@ static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx, * @param cann_graph_update_required Whether graph capture is needed due to graph changes. */ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph, - bool & use_cann_graph, bool & cann_graph_update_required) { + bool & use_cann_graph, bool & cann_graph_update_required) { #ifdef USE_ACL_GRAPH + ggml_cann_graph* matched_graph = cann_ctx->graph_lru_cache.cache_list.front(); if (use_cann_graph && cann_graph_update_required) { - if (cann_ctx->cann_graph->graph != nullptr) { - ACL_CHECK(aclmdlRIDestroy(cann_ctx->cann_graph->graph)); - cann_ctx->cann_graph->graph = nullptr; - } ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL)); } #endif // USE_ACL_GRAPH - // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph. // With the use of CANN graphs, the execution will be performed by the graph launch. if (!use_cann_graph || cann_graph_update_required) { @@ -2271,12 +2323,12 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx #ifdef USE_ACL_GRAPH if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture - ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &cann_ctx->cann_graph->graph)); + ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph)); } if (use_cann_graph) { // Execute graph - ACL_CHECK(aclmdlRIExecuteAsync(cann_ctx->cann_graph->graph, cann_ctx->stream())); + ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream())); } #endif // USE_ACL_GRAPH } @@ -2301,28 +2353,44 @@ static enum ggml_status ggml_backend_cann_graph_compute( ggml_cann_set_device(cann_ctx->device); g_nz_workspaces[cann_ctx->device].clear(); + // calculate rope cache for fist layer in current device. + cann_ctx->rope_cache.cached = false; + #ifdef USE_ACL_GRAPH bool use_cann_graph = true; bool cann_graph_update_required = false; + static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or("")); + if (!prefill_use_graph) { + // Do not use acl_graph for prefill. + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + // TODO: Optimize here. Currently, we can only + // get seq_len by FA's input. + if (node->op == GGML_OP_FLASH_ATTN_EXT) { + // Q -> src[0], shape: [B, S, N, D] + use_cann_graph = (node->src[0]->ne[1] == 1); + break; + } + } + } + if (!cann_ctx->acl_graph_mode) { use_cann_graph = false; } if (use_cann_graph) { - if (cann_ctx->cann_graph == nullptr) { - cann_ctx->cann_graph.reset(new ggml_cann_graph()); - cann_graph_update_required = true; + // If no matching graph is found, the graph needs to be recaptured. + cann_graph_update_required = !is_matched_graph(cann_ctx, cgraph); + if (cann_graph_update_required) { + // If no matching graph is found, add a new ACL graph. + add_lru_matched_graph_node_properties(cann_ctx, cgraph); } - - cann_graph_update_required = is_cann_graph_update_required(cann_ctx, cgraph); - set_ggml_graph_node_properties(cann_ctx, cgraph); } #else bool use_cann_graph = false; bool cann_graph_update_required = false; #endif // USE_ACL_GRAPH - evaluate_and_capture_cann_graph( cann_ctx, cgraph, @@ -2689,6 +2757,7 @@ static const ggml_backend_i ggml_backend_cann_interface = { /* .graph_compute = */ ggml_backend_cann_graph_compute, /* .event_record = */ ggml_backend_cann_event_record, /* .event_wait = */ ggml_backend_cann_event_wait, + /* .optimize_graph = */ NULL, }; /** diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index dd8c1cf678..3699057507 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -224,7 +224,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name) foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME) string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos) if (NOT ${feature_pos} EQUAL -1) - message(STATUS "ARM feature ${feature} enabled") + # Special handling for MATMUL_INT8 when machine doesn't support i8mm + if ("${feature}" STREQUAL "MATMUL_INT8" AND GGML_MACHINE_SUPPORTS_noi8mm) + message(STATUS "ARM feature ${feature} detected but unsetting due to machine not supporting i8mm") + list(APPEND ARCH_FLAGS -U__ARM_FEATURE_MATMUL_INT8) + else() + message(STATUS "ARM feature ${feature} enabled") + endif() endif() endforeach() endif() @@ -457,7 +463,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # TODO: Separation to determine activation of VX/VXE/VXE2 if (${S390X_M} MATCHES "8561|8562") - set(GGML_NNPA OFF) message(STATUS "z15 target") list(APPEND ARCH_FLAGS -march=z15) elseif (${S390X_M} MATCHES "3931") @@ -479,11 +484,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list(APPEND ARCH_FLAGS -mvx -mzvector) list(APPEND ARCH_DEFINITIONS GGML_VXE) endif() - - if (GGML_NNPA) - message(STATUS "NNPA enabled") - list(APPEND ARCH_DEFINITIONS GGML_NNPA) - endif() elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm") message(STATUS "Wasm detected") list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c) diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c index 1c8176fb4d..dc1bba3a3e 100644 --- a/ggml/src/ggml-cpu/arch/s390/quants.c +++ b/ggml/src/ggml-cpu/arch/s390/quants.c @@ -53,9 +53,9 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i #if defined(__VXE__) || defined(__VXE2__) for (int i = 0; i < nb; i++) { - __vector float srcv [8]; - __vector float asrcv[8]; - __vector float amaxv[8]; + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); @@ -74,8 +74,8 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i y[i].d = GGML_CPU_FP32_TO_FP16(d); for (int j = 0; j < 8; j++) { - const __vector float v = vec_mul(srcv[j], vec_splats(id)); - const __vector int32_t vi = vec_signed(v); + const float32x4_t v = vec_mul(srcv[j], vec_splats(id)); + const int32x4_t vi = vec_signed(v); y[i].qs[4*j + 0] = vec_extract(vi, 0); y[i].qs[4*j + 1] = vec_extract(vi, 1); @@ -98,9 +98,9 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i #if defined(__VXE__) || defined(__VXE2__) for (int i = 0; i < nb; i++) { - __vector float srcv [8]; - __vector float asrcv[8]; - __vector float amaxv[8]; + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); @@ -118,11 +118,11 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i y[i].d = GGML_CPU_FP32_TO_FP16(d); - __vector int32_t acc = vec_splats(0); + int32x4_t acc = vec_splats(0); for (int j = 0; j < 8; j++) { - const __vector float v = vec_mul(srcv[j], vec_splats(id)); - const __vector int32_t vi = vec_signed(v); + const float32x4_t v = vec_mul(srcv[j], vec_splats(id)); + const int32x4_t vi = vec_signed(v); y[i].qs[4*j + 0] = vec_extract(vi, 0); y[i].qs[4*j + 1] = vec_extract(vi, 1); @@ -162,37 +162,36 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi float sumf = 0; #if defined(__VXE__) || defined(__VXE2__) - __vector float acc = vec_splats(0.0f); + float32x4_t acc = vec_splats(0.0f); - const __vector uint8_t v_m = vec_splats((const uint8_t)0x0F); - const __vector int8_t v_s = vec_splats( (const int8_t)0x08); + const uint8x16_t v_m = vec_splats((const uint8_t)0x0F); + const int8x16_t v_s = vec_splats( (const int8_t)0x08); for (; ib < nb; ++ib) { - const __vector uint8_t v_x = vec_xl(0, x[ib].qs); - const __vector int8_t v_xl = (const __vector int8_t)(v_x & v_m); - const __vector int8_t v_xh = (const __vector int8_t)(v_x >> 4); + const uint8x16_t v_x = vec_xl(0, x[ib].qs); + const int8x16_t v_xl = (const int8x16_t)(v_x & v_m); + const int8x16_t v_xh = (const int8x16_t)(v_x >> 4); - const __vector int8_t v_xls = vec_sub(v_xl, v_s); - const __vector int8_t v_xhs = vec_sub(v_xh, v_s); + const int8x16_t v_xls = vec_sub(v_xl, v_s); + const int8x16_t v_xhs = vec_sub(v_xh, v_s); - const __vector int8_t v_yl = vec_xl(0 , y[ib].qs); - const __vector int8_t v_yh = vec_xl(QK8_0/2, y[ib].qs); + const int8x16_t v_yl = vec_xl(0 , y[ib].qs); + const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs); - const __vector int16_t v_xylso = vec_mulo(v_xls, v_yl); - const __vector int16_t v_xylse = vec_mule(v_xls, v_yl); - const __vector int16_t v_xyhso = vec_mulo(v_xhs, v_yh); - const __vector int16_t v_xyhse = vec_mule(v_xhs, v_yh); + const int16x8_t v_xylso = vec_mulo(v_xls, v_yl); + const int16x8_t v_xylse = vec_mule(v_xls, v_yl); + const int16x8_t v_xyhso = vec_mulo(v_xhs, v_yh); + const int16x8_t v_xyhse = vec_mule(v_xhs, v_yh); - __vector int16_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_); + int16x8_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_); - const __vector float v_xy = vec_float(vec_unpackh(v_xy_)); - const __vector float v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)); + const float32x4_t v_xy = vec_float(vec_unpackh(v_xy_)); + const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)); acc = vec_madd(v_xy, v_d, acc); } - sumf = acc[0] + acc[1] + acc[2] + acc[3]; - + sumf = vec_hsum_f32x4(acc); *s = sumf; #else UNUSED(nb); @@ -249,8 +248,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi acc = vec_madd(v_xy, v_d, acc); } - sumf = acc[0] + acc[1] + acc[2] + acc[3] + summs; - + sumf = vec_hsum_f32x4(acc) + summs; *s = sumf; #else UNUSED(nb); @@ -351,7 +349,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1); } - sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1); + sumf += vec_hsum_f32x4(v_sum0) + vec_hsum_f32x4(v_sum1); #pragma GCC unroll 4 for (; ib < nb; ++ib) { @@ -390,7 +388,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)); const float32x4_t v_acc = vec_madd(v_xyf, v_d, vec_splats(0.0f)); - sumf += vec_hsum(v_acc); + sumf += vec_hsum_f32x4(v_acc); } *s = sumf; @@ -502,7 +500,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1); } - sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1) + summs0 + summs1; + sumf += vec_hsum_f32x4(v_sum0) + vec_hsum_f32x4(v_sum1) + summs0 + summs1; #pragma GCC unroll 4 for (; ib < nb; ++ib) { @@ -543,7 +541,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)); const float32x4_t v_acc = vec_madd(v_xyf, v_d, v_acc); - sumf += vec_hsum(v_acc) + summs; + sumf += vec_hsum_f32x4(v_acc) + summs; } *s = sumf; @@ -575,7 +573,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi float sumf = 0; #if defined(__VXE__) || defined(__VXE2__) - __vector float acc = vec_splats(0.0f); + float32x4_t acc = vec_splats(0.0f); #pragma GCC unroll 8 for (; ib < nb; ++ib) { @@ -594,7 +592,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi acc = vec_madd(v_xy, v_d, acc); } - sumf = acc[0] + acc[1] + acc[2] + acc[3]; + sumf = vec_hsum_f32x4(acc); *s = sumf; #else @@ -718,10 +716,10 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]); isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]); - isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0]; - isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1]; - isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2]; - isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3]; + isum += vec_hsum_i32x4(isum0) * scale[0]; + isum += vec_hsum_i32x4(isum1) * scale[1]; + isum += vec_hsum_i32x4(isum2) * scale[2]; + isum += vec_hsum_i32x4(isum3) * scale[3]; scale += 4; @@ -819,7 +817,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm); const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]); - sumi1 += (p1[0] + p1[1] + p1[2] + p1[3]) * scales[2*j+0]; + sumi1 += vec_hsum_i32x4(p1) * scales[2*j+0]; v_y[0] = vec_xl(0 , y0); v_y[1] = vec_xl(16, y0); @@ -829,7 +827,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4); const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]); - sumi2 += (p2[0] + p2[1] + p2[2] + p2[3]) * scales[2*j+1]; + sumi2 += vec_hsum_i32x4(p2) * scales[2*j+1]; } sumf += d * (sumi1 + sumi2); @@ -911,7 +909,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh); const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh); const int32x4_t v_mins = vec_add(v_minsho, v_minshe); - const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]; + const int32_t mins = vec_hsum_i32x4(v_mins); const uint8_t * scales = (const uint8_t *)utmp; const uint8_t * GGML_RESTRICT x0l = x[i].qs; @@ -948,8 +946,8 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi int32x4_t sumi0 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]); int32x4_t sumi1 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]); - sumi += (sumi0[0] + sumi0[1] + sumi0[2] + sumi0[3]) * *scales++; - sumi += (sumi1[0] + sumi1[1] + sumi1[2] + sumi1[3]) * *scales++; + sumi += vec_hsum_i32x4(sumi0) * *scales++; + sumi += vec_hsum_i32x4(sumi1) * *scales++; } sumf += d * sumi - dmin * mins; @@ -1020,7 +1018,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh); const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe; - const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]; + const int32_t mins = vec_hsum_i32x4(v_mins); int32_t isum = 0; for (int j = 0; j < QK_K/128; ++j) { @@ -1060,10 +1058,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi int32x4_t summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]); int32x4_t summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]); - isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] + - (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] + - (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] + - (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3]; + isum += vec_hsum_i32x4(summs0) * scale[0] + + vec_hsum_i32x4(summs1) * scale[1] + + vec_hsum_i32x4(summs2) * scale[2] + + vec_hsum_i32x4(summs3) * scale[3]; scale += 4; @@ -1094,10 +1092,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]); summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]); - isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] + - (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] + - (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] + - (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3]; + isum += vec_hsum_i32x4(summs0) * scale[0] + + vec_hsum_i32x4(summs1) * scale[1] + + vec_hsum_i32x4(summs2) * scale[2] + + vec_hsum_i32x4(summs3) * scale[3]; scale += 4; } @@ -1285,7 +1283,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs); const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh); - sumf += GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d) * (v_xy[0] + v_xy[1] + v_xy[2] + v_xy[3]); + sumf += GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d) * vec_hsum_i32x4(v_xy); } *s = sumf; @@ -1354,8 +1352,8 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v h >>= 4; - sumi1 += (vsumi0[0] + vsumi0[1] + vsumi0[2] + vsumi0[3]) * ls1; - sumi2 += (vsumi1[0] + vsumi1[1] + vsumi1[2] + vsumi1[3]) * ls2; + sumi1 += vec_hsum_i32x4(vsumi0) * ls1; + sumi2 += vec_hsum_i32x4(vsumi1) * ls2; } sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index e08c30a348..799e2b1187 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -68,12 +68,6 @@ struct ggml_compute_params { #endif // __VXE2__ #endif // __s390x__ && __VEC__ -#if defined(__s390x__) && defined(GGML_NNPA) -#ifndef __NNPA__ -#define __NNPA__ -#endif // __NNPA__ -#endif // __s390x__ && GGML_NNPA - #if defined(__ARM_FEATURE_SVE) #include #endif @@ -489,11 +483,16 @@ inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) { /** * @see https://github.com/ggml-org/llama.cpp/pull/14037 */ -inline static float vec_hsum(float32x4_t v) { +inline static float vec_hsum_f32x4(float32x4_t v) { float32x4_t v_temp = v + vec_reve(v); return v_temp[0] + v_temp[1]; } +inline static int32_t vec_hsum_i32x4(int32x4_t v) { + int32x4_t v_temp = v + vec_reve(v); + return v_temp[0] + v_temp[1]; +} + inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) { const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b); return acc + (vec_unpackh(p) + vec_unpackl(p)); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 0d35d9333e..c131290849 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -373,6 +373,9 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_I32] = { + .from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32, + }, }; const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) { @@ -2696,7 +2699,10 @@ struct ggml_cplan ggml_graph_plan( if (ggml_is_quantized(node->type) || // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32 (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) || - (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) { + (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16) || + // conversion between F32 and I32 + (node->src[0]->type == GGML_TYPE_F32 && node->src[1] && node->src[1]->type == GGML_TYPE_I32) || + (node->src[0]->type == GGML_TYPE_I32 && node->src[1] && node->src[1]->type == GGML_TYPE_F32)) { cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; } } break; @@ -3211,21 +3217,6 @@ void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) { __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); _mm_storel_epi64((__m128i *)(y + i), y_vec); } -#elif defined(__NNPA__) - for (; i + 7 < n; i += 8) { - float32x4_t v_xh = vec_xl(0, (const float *)(x + i + 0)); - float32x4_t v_xl = vec_xl(0, (const float *)(x + i + 4)); - uint16x8_t v_yd = vec_round_from_fp32(v_xh, v_xl, 0); - uint16x8_t v_y = vec_convert_to_fp16(v_yd, 0); - vec_xst(v_y, 0, (ggml_fp16_t *)(y + i)); - } - for (; i + 3 < n; i += 4) { - float32x4_t v_x = vec_xl(0, (const float *)(x + i)); - float32x4_t v_zero = vec_splats(0.0f); - uint16x8_t v_yd = vec_round_from_fp32(v_x, v_zero, 0); - uint16x8_t v_y = vec_convert_to_fp16(v_yd, 0); - vec_xst(v_y, 0, (ggml_fp16_t *)(y + i)); - } #elif defined(__riscv_zvfh) for (int vl; i < n; i += vl) { vl = __riscv_vsetvl_e32m2(n - i); @@ -3259,21 +3250,6 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) { __m128 y_vec = _mm_cvtph_ps(x_vec); _mm_storeu_ps(y + i, y_vec); } -#elif defined(__NNPA__) - for (; i + 7 < n; i += 8) { - uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)(x + i)); - uint16x8_t v_yd = vec_convert_from_fp16(v_x, 0); - float32x4_t v_yh = vec_extend_to_fp32_hi(v_yd, 0); - float32x4_t v_yl = vec_extend_to_fp32_lo(v_yd, 0); - vec_xst(v_yh, 0, (float *)(y + i + 0)); - vec_xst(v_yl, 0, (float *)(y + i + 4)); - } - for (; i + 3 < n; i += 4) { - uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)(x + i)); - uint16x8_t v_yd = vec_convert_from_fp16(v_x, 0); - float32x4_t v_yh = vec_extend_to_fp32_hi(v_yd, 0); - vec_xst(v_yh, 0, (float *)(y + i)); - } #endif for (; i < n; ++i) { @@ -3288,6 +3264,13 @@ void ggml_cpu_fp32_to_bf16(const float * x, ggml_bf16_t * y, int64_t n) { } } +void ggml_cpu_fp32_to_i32(const float * x, int32_t * y, int64_t n) { + int64_t i = 0; + for (; i < n; ++i) { + y[i] = x[i]; + } +} + void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) { int64_t i = 0; #if defined(__AVX2__) @@ -3477,14 +3460,6 @@ int ggml_cpu_has_vxe(void) { #endif } -int ggml_cpu_has_nnpa(void) { -#if defined(GGML_NNPA) - return 1; -#else - return 0; -#endif -} - int ggml_cpu_has_neon(void) { #if defined(__ARM_ARCH) && defined(__ARM_NEON) return 1; diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 8dacd36714..2b81f8b9af 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -190,6 +190,7 @@ static const struct ggml_backend_i ggml_backend_cpu_i = { /* .graph_compute = */ ggml_backend_cpu_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_cpu_guid(void) { @@ -348,8 +349,10 @@ static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t * long pages = sysconf(_SC_PHYS_PAGES); long page_size = sysconf(_SC_PAGE_SIZE); *total = pages * page_size; + + // "free" system memory is ill-defined, for practical purposes assume that all of it is free: *free = *total; -#endif +#endif // _WIN32 GGML_UNUSED(dev); } @@ -576,9 +579,6 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r if (ggml_cpu_has_vxe()) { features.push_back({ "VXE", "1" }); } - if (ggml_cpu_has_nnpa()) { - features.push_back({ "NNPA", "1" }); - } if (ggml_cpu_has_wasm_simd()) { features.push_back({ "WASM_SIMD", "1" }); } diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 7a830448eb..8694ee15d3 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -154,7 +154,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { if (dst->src[0]->type == GGML_TYPE_Q4_0) { return compute_forward_q4_0(params, dst); } else if (dst->src[0]->type == GGML_TYPE_F16) { - return compute_forward_kv_cache(params, dst); + return compute_forward_fp16(params, dst); } } else if (dst->op == GGML_OP_GET_ROWS) { if (dst->src[0]->type == GGML_TYPE_Q4_0) { @@ -164,7 +164,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { return false; } - bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) { + bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) { static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT; const ggml_tensor * src0 = dst->src[0]; @@ -515,9 +515,6 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) { - if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) { - return false; - } if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { return false; } @@ -534,13 +531,8 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } - else if (ggml_kleidiai_select_kernels(ctx.features, op) && - op->src[0]->op == GGML_OP_VIEW && - (op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) && - op->src[1]->ne[1] > 1) { - if ((op->src[0]->nb[0] != 2) || - (op->src[1]->nb[0] != 4) || - (op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || + else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) { + if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) { return nullptr; } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 0bb767e01a..212e52ef6a 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -776,6 +776,24 @@ static void ggml_compute_forward_dup_f32( id += ne00 * (ne01 - ir1); } } + } else if (dst->type == GGML_TYPE_I32) { + size_t id = 0; + int32_t * dst_ptr = (int32_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } } else { GGML_ABORT("fatal error"); // TODO: implement } @@ -947,6 +965,144 @@ static void ggml_compute_forward_dup_f32( } } } + } else if (dst->type == GGML_TYPE_I32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(int32_t *) dst_ptr = *(const float *) src0_ptr; + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else { + GGML_ABORT("fatal error"); // TODO: implement + } +} + +static void ggml_compute_forward_dup_i32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + + GGML_TENSOR_UNARY_OP_LOCALS + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + // dst counters + + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + // TODO: not optimal, but works + if (dst->type == GGML_TYPE_F32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(float *) dst_ptr = *(const int32_t *) src0_ptr; + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } } else { GGML_ABORT("fatal error"); // TODO: implement } @@ -1177,6 +1333,10 @@ void ggml_compute_forward_dup( { ggml_compute_forward_dup_f32(params, dst); } break; + case GGML_TYPE_I32: + { + ggml_compute_forward_dup_i32(params, dst); + } break; default: { if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) { @@ -8438,6 +8598,7 @@ static void ggml_compute_forward_timestep_embedding_f32( embed_data[j + half] = sinf(arg); } if (dim % 2 != 0 && ith == 0) { + embed_data[2 * half] = 0.f; embed_data[dim] = 0.f; } } diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 8bd56bdac1..a84ba75c20 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -114,26 +114,6 @@ extern "C" { #define GGML_CPU_COMPUTE_FP32_TO_FP16(x) riscv_compute_fp32_to_fp16(x) #define GGML_CPU_FP16_TO_FP32(x) GGML_CPU_COMPUTE_FP16_TO_FP32(x) #define GGML_CPU_FP32_TO_FP16(x) GGML_CPU_COMPUTE_FP32_TO_FP16(x) -#elif defined(__NNPA__) - #define GGML_CPU_COMPUTE_FP16_TO_FP32(x) nnpa_compute_fp16_to_fp32(x) - #define GGML_CPU_COMPUTE_FP32_TO_FP16(x) nnpa_compute_fp32_to_fp16(x) - - #define GGML_CPU_FP16_TO_FP32(x) GGML_CPU_COMPUTE_FP16_TO_FP32(x) - #define GGML_CPU_FP32_TO_FP16(x) GGML_CPU_COMPUTE_FP32_TO_FP16(x) - - static inline float nnpa_compute_fp16_to_fp32(ggml_fp16_t h) { - uint16x8_t v_h = vec_splats(h); - uint16x8_t v_hd = vec_convert_from_fp16(v_h, 0); - return vec_extend_to_fp32_hi(v_hd, 0)[0]; - } - - static inline ggml_fp16_t nnpa_compute_fp32_to_fp16(float f) { - float32x4_t v_f = vec_splats(f); - float32x4_t v_zero = vec_splats(0.0f); - uint16x8_t v_hd = vec_round_from_fp32(v_f, v_zero, 0); - uint16x8_t v_h = vec_convert_to_fp16(v_hd, 0); - return vec_extract(v_h, 0); - } #endif // precomputed f32 table for f16 (256 KB) @@ -1156,11 +1136,6 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { #define GGML_F16_EPR GGML_F32_EPR static inline float32x4_t __lzs_f16cx4_load(const ggml_fp16_t * x) { -#if defined(__NNPA__) - uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)x); - uint16x8_t v_xd = vec_convert_from_fp16(v_x, 0); - return vec_extend_to_fp32_hi(v_xd, 0); -#else float tmp[4]; for (int i = 0; i < 4; i++) { @@ -1170,20 +1145,9 @@ static inline float32x4_t __lzs_f16cx4_load(const ggml_fp16_t * x) { // note: keep type-cast here to prevent compiler bugs // see: https://github.com/ggml-org/llama.cpp/issues/12846 return vec_xl(0, (const float *)(tmp)); -#endif } static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) { -#if defined(__NNPA__) - float32x4_t v_zero = vec_splats(0.0f); - uint16x8_t v_xd = vec_round_from_fp32(v_y, v_zero, 0); - uint16x8_t v_x = vec_convert_to_fp16(v_xd, 0); - - x[0] = vec_extract(v_x, 0); - x[1] = vec_extract(v_x, 1); - x[2] = vec_extract(v_x, 2); - x[3] = vec_extract(v_x, 3); -#else float arr[4]; // note: keep type-cast here to prevent compiler bugs @@ -1193,7 +1157,6 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) { for (int i = 0; i < 4; i++) { x[i] = GGML_CPU_FP32_TO_FP16(arr[i]); } -#endif } #define GGML_F16_VEC GGML_F32x4 diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index d3dfc7807d..0d8c5af473 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/mmq*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/mmf*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) if (GGML_CUDA_FA_ALL_QUANTS) file(GLOB SRCS "template-instances/fattn-vec*.cu") diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 1c76566344..725e1a81a1 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -23,28 +23,44 @@ static __device__ __forceinline__ float op_div(const float a, const float b) { return a / b; } +template +static __global__ void k_bin_bcast(const src0_t * src0, + const src1_t * src1, + dst_t * dst, + const int ne0, + const int ne1, + const int ne2, + const uint3 ne3, + const uint3 ne10, + const uint3 ne11, + const uint3 ne12, + const uint3 ne13, + /*int s0, */ const int s1, + const int s2, + const int s3, + /*int s00,*/ const int s01, + const int s02, + const int s03, + /*int s10,*/ const int s11, + const int s12, + const int s13, + src1_ptrs... src1s) { + const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x; + const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y); + const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3); + const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z); - -template -static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, - const int ne0, const int ne1, const int ne2, const int ne3, - const int ne10, const int ne11, const int ne12, const int ne13, - /*int s0, */ const int s1, const int s2, const int s3, - /*int s00,*/ const int s01, const int s02, const int s03, - /*int s10,*/ const int s11, const int s12, const int s13, - src1_ptrs... src1s) { - const int i0s = blockDim.x*blockIdx.x + threadIdx.x; - const int i1 = (blockDim.y*blockIdx.y + threadIdx.y); - const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3; - const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3; - - if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) { return; } - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; + const uint32_t i11 = fastmodulo(i1, ne11); + const uint32_t i12 = fastmodulo(i2, ne12); + const uint32_t i13 = fastmodulo(i3, ne13); const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; @@ -53,8 +69,8 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; - for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) { - const int i10 = i0 % ne10; + for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { + const uint32_t i10 = fastmodulo(i0, ne10); float result = src0_row ? (float) src0_row[i0] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { @@ -67,28 +83,48 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst } } -template -static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, - const int ne0, const int ne1, const int ne2,const int ne3, - const int ne10, const int ne11, const int ne12, const int ne13, - /*int s0, */ const int s1, const int s2, const int s3, - /*int s00,*/ const int s01, const int s02, const int s03, - /*int s10,*/ const int s11, const int s12, const int s13, - src1_ptrs ... src1s) { +template +static __global__ void k_bin_bcast_unravel(const src0_t * src0, + const src1_t * src1, + dst_t * dst, + const uint3 ne0, + const uint3 ne1, + const uint3 ne2, + const uint32_t ne3, + const uint3 prod_012, + const uint3 prod_01, + const uint3 ne10, + const uint3 ne11, + const uint3 ne12, + const uint3 ne13, + /*int s0, */ const int s1, + const int s2, + const int s3, + /*int s00,*/ const int s01, + const int s02, + const int s03, + /*int s10,*/ const int s11, + const int s12, + const int s13, + src1_ptrs... src1s) { const int i = blockDim.x*blockIdx.x + threadIdx.x; - const int i3 = i/(ne2*ne1*ne0); - const int i2 = (i/(ne1*ne0)) % ne2; - const int i1 = (i/ne0) % ne1; - const int i0 = i % ne0; + const uint32_t i3 = fastdiv(i, prod_012); + const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01); + const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0); + const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z; - if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) { return; } - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; + const int i11 = fastmodulo(i1, ne11); + const int i12 = fastmodulo(i2, ne12); + const int i13 = fastmodulo(i3, ne13); const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; @@ -97,7 +133,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; - const int i10 = i0 % ne10; + const int i10 = fastmodulo(i0, ne10); float result = src0_row ? (float) src0_row[i0] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { @@ -170,11 +206,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02); //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03); - int64_t ne10 = cne1[0]; - int64_t ne11 = cne1[1]; - int64_t ne12 = cne1[2]; - int64_t ne13 = cne1[3]; - size_t nb0 = cnb[0]; size_t nb1 = cnb[1]; size_t nb2 = cnb[2]; @@ -233,48 +264,51 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * block_dims.y = std::min(ne1, block_size / block_dims.x); block_dims.z = std::min(std::min(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U); - dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, - (ne1 + block_dims.y - 1) / block_dims.y, + dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y, (ne2 * ne3 + block_dims.z - 1) / block_dims.z); + const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]); + const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]); + const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]); + const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]); + if (block_nums.z > 65535) { - int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; + int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; + const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2)); + const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1)); + const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0); + const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1); + const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2); + if constexpr (sizeof...(I) > 0) { - k_bin_bcast_unravel - <<>>(src0_dd, src1_dd, dst_dd, - ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12,s13, - (const src1_t *) dst->src[I + 1]->data...); + k_bin_bcast_unravel<<>>( + src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, + ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } else { k_bin_bcast_unravel - <<>>(src0_dd, src1_dd, dst_dd, - ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12,s13); + <<>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, + ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12, s13); } } else { + const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); if constexpr (sizeof...(I) > 0) { - k_bin_bcast - <<>>(src0_dd, src1_dd, dst_dd, - ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12,s13, - (const src1_t *) dst->src[I + 1]->data...); + k_bin_bcast<<>>( + src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } else { - k_bin_bcast - <<>>(src0_dd, src1_dd, dst_dd, - ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12,s13); + k_bin_bcast<<>>( + src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s00,*/ s01, s02, s03, + /* s10,*/ s11, s12, s13); } } } diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a2dc26eab7..b0feea3623 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -545,6 +545,45 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i #endif // defined(GGML_USE_HIP) } +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) { + acc += v*u; +} + +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) { + acc += v.x*u.x; + acc += v.y*u.y; +} + +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) { +#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA)) + asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u)); +#else +#ifdef FAST_FP16_AVAILABLE + const float2 tmp = __half22float2(v*u); + acc += tmp.x + tmp.y; +#else + const float2 tmpv = __half22float2(v); + const float2 tmpu = __half22float2(u); + acc += tmpv.x * tmpu.x; + acc += tmpv.y * tmpu.y; +#endif // FAST_FP16_AVAILABLE +#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA)) +} + +// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD. +template +static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) { + if constexpr (nbytes == 4) { + *(int *) dst = *(const int *) src; + } else if constexpr (nbytes == 8) { + *(int2 *) dst = *(const int2 *) src; + } else if constexpr (nbytes == 16) { + *(int4 *) dst = *(const int4 *) src; + } else { + static_assert(nbytes == 0 && nbytes == -1, "bad nbytes"); + } +} + static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { #if CUDART_VERSION >= 12080 const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x); @@ -570,6 +609,8 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { // // n/d = (mulhi(n, mp) + n) >> L; static const uint3 init_fastdiv_values(uint32_t d) { + GGML_ASSERT(d != 0); + // compute L = ceil(log2(d)); uint32_t L = 0; while (L < 32 && (uint32_t{ 1 } << L) < d) { diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index c62e8a1b10..ef9e129950 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -38,6 +38,8 @@ template return __float2bfloat16(float(x)); } else if constexpr(std::is_same_v) { return __bfloat162float(x); + } else if constexpr(std::is_same_v) { + return int32_t(x); } else { return float(x); } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index c40db08ced..8567c3d5a1 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -374,6 +374,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu deleted file mode 100644 index a900799a99..0000000000 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ /dev/null @@ -1,371 +0,0 @@ -#include "common.cuh" -#include "fattn-common.cuh" -#include "fattn-tile-f16.cuh" - -#define FATTN_KQ_STRIDE_TILE_F16 64 - -template // D == head size -#if !defined(GGML_USE_HIP) -__launch_bounds__(nwarps*WARP_SIZE, 2) -#endif // !defined(GGML_USE_HIP) -static __global__ void flash_attn_tile_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) - - // Skip unused kernel variants for faster compilation: -#ifdef FP16_MMA_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FP16_MMA_AVAILABLE - if (use_logit_softcap && !(D == 128 || D == 256)) { - NO_DEVICE_CODE; - return; - } - - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float * sinksf = (const float *) (sinks); - - const int stride_KV2 = nb11 / sizeof(half2); - - const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - const half slopeh = __float2half(slopef); - - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - - __shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16]; - half2 * KQ2 = (half2 *) KQ; - - __shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts. - - half kqmax[ncols/nwarps]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - kqmax[j0/nwarps] = -HALF_MAX_HALF; - } - half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}}; - - half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; - - // Convert Q to half2 and store in registers: - __shared__ half2 Q_h2[ncols][D/2]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); - Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); - } - } - - __syncthreads(); - - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) { - // Calculate KQ tile and keep track of new maximum KQ values: - - half kqmax_new[ncols/nwarps]; -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - kqmax_new[j] = kqmax[j]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; - } - } - - __syncthreads(); - - half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}}; - -#pragma unroll - for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) { - half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE]; - half2 Q_k[ncols/nwarps]; - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - - K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ]; - } -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) { -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps]; - } - } - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - half sum; - if (use_logit_softcap) { - const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - sum = logit_softcap * tanhf(tmp.x + tmp.y); - } else { - sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - } - sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); - - kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum); - - KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum; - } - } - - __syncthreads(); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]); - const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps])); - kqmax[j0/nwarps] = kqmax_new[j0/nwarps]; - -#pragma unroll - for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]); - const half2 val = h2exp(diff); - kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val; - KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale; - } - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) { - const int k = k0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i]; - } - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) { - half2 V_k[(D/2)/WARP_SIZE][2]; - half2 KQ_k[ncols/nwarps]; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i]; - V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i]; - } -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2]; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]); - VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]); - } - } - } - - __syncthreads(); - } - - //Attention sink: adjust running max and sum once per head - if (sinksf && blockIdx.y == 0) { - const half sink = __float2half(sinksf[head]); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink); - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j)); - kqmax[j0/nwarps] = kqmax_new_j; - - const half val = hexp(sink - kqmax[j0/nwarps]); - kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale; - if (threadIdx.x == 0) { - kqsum[j0/nwarps].x = __hadd(__low2half(kqsum[j0/nwarps]), val); - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale; - } - } - } - - float2 * dst2 = (float2 *) dst; - -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { - const int j_VKQ = j_VKQ_0 + threadIdx.y; - - if (ic0 + j_VKQ >= ne01) { - return; - } - - half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); - kqsum_j = warp_reduce_sum((float)kqsum_j); - - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; - -#pragma unroll - for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { - const int i0 = i00 + threadIdx.x; - - half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; - if (gridDim.y == 1) { - dst_val /= __half2half2(kqsum_j); - } - dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val); - } - - if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); - } - } -#else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) -} - -template -void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { - case 64: { - constexpr int D = 64; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); - } break; - case 128: { - constexpr int D = 128; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); - } break; - default: { - GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); - } break; - } -} - -void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - - const int32_t precision = KQV->op_params[3]; - GGML_ASSERT(precision == GGML_PREC_DEFAULT); - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (Q->ne[1] <= 16) { - constexpr int cols_per_block = 16; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f16_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f16_64_128(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 32; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f16_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f16_64_128(ctx, dst); - } -} diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cuh b/ggml/src/ggml-cuda/fattn-tile-f16.cuh deleted file mode 100644 index ffc5878427..0000000000 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cuh +++ /dev/null @@ -1,3 +0,0 @@ -#include "common.cuh" - -void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu deleted file mode 100644 index b96a9ef971..0000000000 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ /dev/null @@ -1,379 +0,0 @@ -#include "common.cuh" -#include "fattn-common.cuh" -#include "fattn-tile-f32.cuh" - -#define FATTN_KQ_STRIDE_TILE_F32 32 - -template // D == head size -#if !defined(GGML_USE_HIP) -__launch_bounds__(nwarps*WARP_SIZE, 2) -#endif // !defined(GGML_USE_HIP) -static __global__ void flash_attn_tile_ext_f32( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#ifdef FLASH_ATTN_AVAILABLE - - // Skip unused kernel variants for faster compilation: -#ifdef FP16_MMA_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FP16_MMA_AVAILABLE - if (use_logit_softcap && !(D == 128 || D == 256)) { - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; - return; - } - - // In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float * sinksf = (const float *) (sinks); - - const int stride_KV2 = nb11 / sizeof(half2); - - const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - - __shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32]; - - __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts. - float2 * KV_tmp2 = (float2 *) KV_tmp; - - float kqmax[ncols/nwarps]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - kqmax[j0/nwarps] = -FLT_MAX/2.0f; - } - float kqsum[ncols/nwarps] = {0.0f}; - - float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; - - // Convert Q to half2 and store in registers: - __shared__ float Q_f[ncols][D]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) { - float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f); - Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale; - Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale; - } - } - - __syncthreads(); - - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) { - // Calculate KQ tile and keep track of new maximum KQ values: - - float kqmax_new[ncols/nwarps]; -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - kqmax_new[j] = kqmax[j]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) { - const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x]; - KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp); - KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp); - } - } - - __syncthreads(); - - float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}}; - -#pragma unroll - for (int k_KQ = 0; k_KQ < D; ++k_KQ) { - float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE]; - float Q_k[ncols/nwarps]; - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - - K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ]; - } -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - Q_k[j_KQ_0/nwarps] = Q_f[j_KQ][k_KQ]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) { -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE] * Q_k[j_KQ_0/nwarps]; - } - } - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - if (use_logit_softcap) { - sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - } - - sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; - - kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - - KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F32 + i_KQ] = sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]; - } - } - - __syncthreads(); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]); - const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]); - kqmax[j0/nwarps] = kqmax_new[j0/nwarps]; - - float kqsum_add = 0.0f; -#pragma unroll - for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F32; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float diff = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] - kqmax[j0/nwarps]; - const float val = expf(diff); - kqsum_add += val; - KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] = val; - } - kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale; - VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale; - } - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F32; k0 += nwarps) { - const int k = k0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i]; - KV_tmp2[k*(D/2) + i].x = __low2float(tmp); - KV_tmp2[k*(D/2) + i].y = __high2float(tmp); - } - } - - __syncthreads(); - -#pragma unroll - for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) { - float2 V_k[(D/2)/WARP_SIZE]; - float KQ_k[ncols/nwarps]; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i]; - } -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - KQ_k[j0/nwarps] = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + k]; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps]; - VKQ[j0/nwarps][i0/WARP_SIZE].y += V_k[i0/WARP_SIZE].y*KQ_k[j0/nwarps]; - } - } - } - - __syncthreads(); - } - - - //Attention sink: adjust running max and sum once per head - if (sinksf && blockIdx.y == 0) { - const float sink = sinksf[head]; - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink); - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j); - kqmax[j0/nwarps] = kqmax_new_j; - - const float val = expf(sink - kqmax[j0/nwarps]); - kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale; - if (threadIdx.x == 0) { - kqsum[j0/nwarps] += val; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale; - VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale; - } - } - } - - float2 * dst2 = (float2 *) dst; - -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { - const int j_VKQ = j_VKQ_0 + threadIdx.y; - - if (ic0 + j_VKQ >= ne01) { - return; - } - - float kqsum_j = kqsum[j_VKQ_0/nwarps]; - kqsum_j = warp_reduce_sum(kqsum_j); - - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; - -#pragma unroll - for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { - const int i0 = i00 + threadIdx.x; - - float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; - if (gridDim.y == 1) { - dst_val.x /= kqsum_j; - dst_val.y /= kqsum_j; - } - dst2[j_dst_unrolled*(D/2) + i0] = dst_val; - } - - if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); - } - } -#else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; -#endif // FLASH_ATTN_AVAILABLE -} - -template -void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { - case 64: { - constexpr int D = 64; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); - } break; - case 128: { - constexpr int D = 128; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); - } break; - default: { - GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); - } break; - } -} - -void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (Q->ne[1] <= 16) { - constexpr int cols_per_block = 16; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f32_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f32_64_128(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 32; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f32_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f32_64_128(ctx, dst); - } -} diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cuh b/ggml/src/ggml-cuda/fattn-tile-f32.cuh deleted file mode 100644 index b1c546c805..0000000000 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cuh +++ /dev/null @@ -1,3 +0,0 @@ -#include "common.cuh" - -void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu new file mode 100644 index 0000000000..c6a399ce5d --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -0,0 +1,660 @@ +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-tile.cuh" + +#define FATTN_TILE_NTHREADS 256 + +static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) { + if (GGML_CUDA_CC_IS_AMD(cc)) { + switch (D) { + case 64: + return 64; + case 128: + case 256: + if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { + return ncols <= 16 ? 64 : 32; + } else { + return 64; + } + default: + GGML_ABORT("fatal error"); + return -1; + } + } + if (fast_fp16_available(cc)) { + switch (D) { + case 64: + case 128: + return 128; + case 256: + return ncols <= 16 ? 128 : 64; + default: + GGML_ABORT("fatal error"); + return -1; + } + } + switch (D) { + case 64: + return ncols <= 16 ? 128 : 64; + case 128: + return ncols <= 16 ? 64 : 32; + case 256: + return 32; + default: + GGML_ABORT("fatal error"); + return -1; + } + GGML_UNUSED(warp_size); +} + +static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) { +#ifdef GGML_USE_HIP + switch (D) { + case 64: + return 64; + case 128: +#if defined(GCN) || defined(CDNA) + return ncols <= 16 ? 64 : 32; +#else + return 64; +#endif // defined(GCN) || defined(CDNA) + case 256: +#if defined(GCN) || defined(CDNA) + return ncols <= 16 ? 64 : 32; +#else + return 64; +#endif // defined(GCN) || defined(CDNA) + default: + return -1; + } +#else +#ifdef FAST_FP16_AVAILABLE + switch (D) { + case 64: + case 128: + return 128; + case 256: + return ncols <= 16 ? 128 : 64; + default: + return -1; + } +#else + switch (D) { + case 64: + return ncols <= 16 ? 128 : 64; + case 128: + return ncols <= 16 ? 64 : 32; + case 256: + return 32; + default: + return -1; + } +#endif // FAST_FP16_AVAILABLE +#endif // GGML_USE_HIP + GGML_UNUSED_VARS(ncols, warp_size); +} + +static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) { +#ifdef GGML_USE_HIP + switch (D) { + case 64: + return 64; + case 128: +#if defined(GCN) || defined(CDNA) + return ncols <= 16 ? 64 : 128; +#else + return 64; +#endif // defined(GCN) || defined(CDNA) + case 256: +#if defined(GCN) || defined(CDNA) + return ncols <= 16 ? 64 : 128; +#else + return ncols <= 16 ? 64 : 256; +#endif // defined(GCN) || defined(CDNA) + default: + return -1; + } +#else +#ifdef FAST_FP16_AVAILABLE + switch (D) { + case 64: + return 64; + case 128: + return ncols <= 16 ? 128 : 64; + case 256: + return ncols <= 16 ? 64 : 128; + default: + return -1; + } +#else + switch (D) { + case 64: + return 64; + case 128: + return 128; + case 256: + return ncols <= 16 ? 128 : 64; + default: + return -1; + } +#endif // FAST_FP16_AVAILABLE +#endif // GGML_USE_HIP + GGML_UNUSED_VARS(ncols, warp_size); +} + +template // D == head size +#ifdef GGML_USE_HIP +__launch_bounds__(FATTN_TILE_NTHREADS, 1) +#else +__launch_bounds__(FATTN_TILE_NTHREADS, 2) +#endif // GGML_USE_HIP +static __global__ void flash_attn_tile( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { +#ifdef FLASH_ATTN_AVAILABLE + + // Skip unused kernel variants for faster compilation: +#ifdef FP16_MMA_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FP16_MMA_AVAILABLE + + if (use_logit_softcap && !(D == 128 || D == 256)) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; + return; + } + + constexpr int warp_size = 32; + constexpr int nwarps = FATTN_TILE_NTHREADS / warp_size; + constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size); + static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size."); + constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size); + static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch"); + + // In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + const float * sinksf = (const float *) (sinks); + + const int stride_KV2 = nb11 / sizeof(half2); + + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); + +#if defined(GGML_USE_HIP) + constexpr int cpy_nb = 16; +#else + constexpr int cpy_nb = 8; +#endif // defined(GGML_USE_HIP) && defined(GCN) + constexpr int cpy_ne = cpy_nb / 4; + + __shared__ float KQ[ncols][kq_stride]; +#ifdef FAST_FP16_AVAILABLE + __shared__ half2 Q_tmp[ncols][D/2]; + __shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts. + half2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; +#else + __shared__ float Q_tmp[ncols][D]; + __shared__ float KV_tmp_f[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts. + float2 * KV_tmp_f2 = (float2 *) KV_tmp_f; + float2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; +#endif // FAST_FP16_AVAILABLE + + + float kqmax[ncols/nwarps]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + kqmax[j0/nwarps] = -FLT_MAX/2.0f; + } + float kqsum[ncols/nwarps] = {0.0f}; + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0 + threadIdx.x] : make_float2(0.0f, 0.0f); +#ifdef FAST_FP16_AVAILABLE + Q_tmp[j][i0 + threadIdx.x] = make_half2(tmp.x * scale, tmp.y * scale); +#else + Q_tmp[j][2*i0 + threadIdx.x] = tmp.x * scale; + Q_tmp[j][2*i0 + warp_size + threadIdx.x] = tmp.y * scale; +#endif // FAST_FP16_AVAILABLE + } + } + + __syncthreads(); + + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) { + // Calculate KQ tile and keep track of new maximum KQ values: + + float kqmax_new[ncols/nwarps]; +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + kqmax_new[j] = kqmax[j]; + } + + float sum[kq_stride/warp_size][ncols/nwarps] = {{0.0f}}; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) { +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; + +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size) { + const half2 tmp_h2 = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x]; +#ifdef FAST_FP16_AVAILABLE + KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x] = tmp_h2; +#else + const float2 tmp_f2 = __half22float2(tmp_h2); + KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + threadIdx.x] = tmp_f2.x; + KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y; +#endif // FAST_FP16_AVAILABLE + } + } + + __syncthreads(); + +#ifdef FAST_FP16_AVAILABLE +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) { + half2 K_k[kq_stride/warp_size][cpy_ne]; + half2 Q_k[ncols/nwarps][cpy_ne]; +#else +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) { + float K_k[kq_stride/warp_size][cpy_ne]; + float Q_k[ncols/nwarps][cpy_ne]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { + const int i_KQ = i_KQ_0 + threadIdx.x; + +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_memcpy_1(&K_k[i_KQ_0/warp_size], &KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]); +#else + ggml_cuda_memcpy_1(&K_k[i_KQ_0/warp_size], &KV_tmp_f [i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]); +#endif // FAST_FP16_AVAILABLE + } +#pragma unroll + for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { + const int j_KQ = j_KQ_0 + threadIdx.y; + +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_memcpy_1(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]); +#else + ggml_cuda_memcpy_1(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]); +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { +#pragma unroll + for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { +#pragma unroll + for (int k = 0; k < cpy_ne; ++k) { + ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0/nwarps][k]); + } + } + } + } + + if (k_KQ_0 + kq_nbatch < D) { + __syncthreads(); // Sync not needed on last iteration. + } + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { + const int i_KQ = i_KQ_0 + threadIdx.x; + +#pragma unroll + for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { + const int j_KQ = j_KQ_0 + threadIdx.y; + + if (use_logit_softcap) { + sum[i_KQ_0/warp_size][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/warp_size][j_KQ_0/nwarps]); + } + + sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; + + kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/warp_size][j_KQ_0/nwarps]); + + KQ[j_KQ][i_KQ] = sum[i_KQ_0/warp_size][j_KQ_0/nwarps]; + } + } + + __syncthreads(); + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]); + const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]); + kqmax[j0/nwarps] = kqmax_new[j0/nwarps]; + + float kqsum_add = 0.0f; + if (kq_stride % (4*warp_size) == 0 && cpy_ne % 4 == 0) { +#pragma unroll + for (int i0 = 0; i0 < kq_stride; i0 += 4*warp_size) { + const int i = i0 + 4*threadIdx.x; + + float4 val = *(const float4 *) &KQ[j][i]; + val.x = expf(val.x - kqmax[j0/nwarps]); + val.y = expf(val.y - kqmax[j0/nwarps]); + val.z = expf(val.z - kqmax[j0/nwarps]); + val.w = expf(val.w - kqmax[j0/nwarps]); + kqsum_add += val.x + val.y + val.z + val.w; + +#ifdef FAST_FP16_AVAILABLE + const half2 tmp[2] = {make_half2(val.x, val.y), make_half2(val.z, val.w)}; + ggml_cuda_memcpy_1(&KQ[j][i/2], &tmp); +#else + ggml_cuda_memcpy_1(&KQ[j][i], &val); +#endif // FAST_FP16_AVAILABLE + } + } else if (kq_stride % (2*warp_size) == 0 && cpy_ne % 2 == 0) { +#pragma unroll + for (int i0 = 0; i0 < kq_stride; i0 += 2*warp_size) { + const int i = i0 + 2*threadIdx.x; + + float2 val = *(const float2 *) &KQ[j][i]; + val.x = expf(val.x - kqmax[j0/nwarps]); + val.y = expf(val.y - kqmax[j0/nwarps]); + kqsum_add += val.x + val.y; +#ifdef FAST_FP16_AVAILABLE + const half2 tmp = make_half2(val.x, val.y); + ggml_cuda_memcpy_1(&KQ[j][i/2], &tmp); +#else + ggml_cuda_memcpy_1(&KQ[j][i], &val); +#endif // FAST_FP16_AVAILABLE + } + } else { + for (int i0 = 0; i0 < kq_stride; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const float diff = KQ[j][i] - kqmax[j0/nwarps]; + const float val = expf(diff); + kqsum_add += val; +#ifdef FAST_FP16_AVAILABLE + ((half *) KQ[j])[i] = val; +#else + KQ[j][i] = val; +#endif // FAST_FP16_AVAILABLE + } + } + kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale; + VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + + constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; + static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter"); +#pragma unroll + for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) { +#pragma unroll + for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) { + const int k_tile = k1 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const half2 tmp = V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i]; +#ifdef FAST_FP16_AVAILABLE + KV_tmp_h2[k_tile*(D/2) + i] = tmp; +#else + KV_tmp_f2[k_tile*(D/2) + i] = __half22float2(tmp); +#endif // FAST_FP16_AVAILABLE + } + } + + __syncthreads(); + +#pragma unroll + for (int k1 = 0; k1 < V_cols_per_iter; ++k1) { +#ifdef FAST_FP16_AVAILABLE + half2 V_k[(D/2)/warp_size]; + half2 KQ_k[ncols/nwarps]; +#else + float2 V_k[(D/2)/warp_size]; + float KQ_k[ncols/nwarps]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + const int i = i0 + threadIdx.x; + +#ifdef FAST_FP16_AVAILABLE + V_k[i0/warp_size] = KV_tmp_h2[k1*(D/2) + i]; +#else + V_k[i0/warp_size] = KV_tmp_f2[k1*(D/2) + i]; +#endif // FAST_FP16_AVAILABLE + } +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#ifdef FAST_FP16_AVAILABLE + KQ_k[j0/nwarps] = __half2half2(((const half *)KQ[j])[k0 + k1]); +#else + KQ_k[j0/nwarps] = KQ[j][k0 + k1]; +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { +#ifdef FAST_FP16_AVAILABLE + VKQ[j0/nwarps][i0/warp_size] += V_k[i0/warp_size] *KQ_k[j0/nwarps]; +#else + VKQ[j0/nwarps][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0/nwarps]; + VKQ[j0/nwarps][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0/nwarps]; +#endif // FAST_FP16_AVAILABLE + } + } + } + + __syncthreads(); + } + } + + + // Attention sink: adjust running max and sum once per head + if (sinksf && blockIdx.y == 0) { + const float sink = sinksf[head]; + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink); + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j); + kqmax[j0/nwarps] = kqmax_new_j; + + const float val = expf(sink - kqmax[j0/nwarps]); + kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale; + if (threadIdx.x == 0) { + kqsum[j0/nwarps] += val; + } + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale; + VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + } + + float2 * dst2 = (float2 *) dst; + +#pragma unroll + for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { + const int j_VKQ = j_VKQ_0 + threadIdx.y; + + if (ic0 + j_VKQ >= ne01) { + return; + } + + float kqsum_j = kqsum[j_VKQ_0/nwarps]; + kqsum_j = warp_reduce_sum(kqsum_j); + + const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + +#pragma unroll + for (int i00 = 0; i00 < D/2; i00 += warp_size) { + const int i0 = i00 + threadIdx.x; + +#ifdef FAST_FP16_AVAILABLE + float2 dst_val = __half22float2(VKQ[j_VKQ_0/nwarps][i0/warp_size]); +#else + float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/warp_size]; +#endif // FAST_FP16_AVAILABLE + + if (gridDim.y == 1) { + dst_val.x /= kqsum_j; + dst_val.y /= kqsum_j; + } + dst2[j_dst_unrolled*(D/2) + i0] = dst_val; + } + + if (gridDim.y != 1 && threadIdx.x == 0) { + dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); + } + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; +#endif // FLASH_ATTN_AVAILABLE +} + +template +static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int warp_size = 32; + const int nwarps = FATTN_TILE_NTHREADS / warp_size; + + constexpr size_t nbytes_shared = 0; + + if (Q->ne[1] > 16) { + constexpr int cols_per_block = 32; + fattn_kernel_t fattn_kernel = flash_attn_tile; + const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); + return; + } + + constexpr int cols_per_block = 16; + fattn_kernel_t fattn_kernel = flash_attn_tile; + const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); +} + +template +static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + switch (Q->ne[0]) { + case 64: { + launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst); + } break; + case 128: { + launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst); + } break; + case 256: { + launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst); + } break; + default: { + GGML_ABORT("Unsupported head size"); + } break; + } +} + +void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_switch_head_size(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_switch_head_size(ctx, dst); + } +} diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh new file mode 100644 index 0000000000..10dc22d1bf --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 4883427266..7626d89ca0 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -1,8 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" #include "fattn-mma-f16.cuh" -#include "fattn-tile-f16.cuh" -#include "fattn-tile-f32.cuh" +#include "fattn-tile.cuh" #include "fattn-vec-f16.cuh" #include "fattn-vec-f32.cuh" #include "fattn-wmma-f16.cuh" @@ -271,8 +270,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg // Best FlashAttention kernel for a specific GPU: enum best_fattn_kernel { BEST_FATTN_KERNEL_NONE = 0, - BEST_FATTN_KERNEL_TILE_F32 = 200, - BEST_FATTN_KERNEL_TILE_F16 = 210, + BEST_FATTN_KERNEL_TILE = 200, BEST_FATTN_KERNEL_VEC_F32 = 100, BEST_FATTN_KERNEL_VEC_F16 = 110, BEST_FATTN_KERNEL_WMMA_F16 = 300, @@ -411,10 +409,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes: - if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { - return BEST_FATTN_KERNEL_TILE_F16; - } - return BEST_FATTN_KERNEL_TILE_F32; + return BEST_FATTN_KERNEL_TILE; } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -422,11 +417,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) { case BEST_FATTN_KERNEL_NONE: GGML_ABORT("fatal error"); - case BEST_FATTN_KERNEL_TILE_F32: - ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); - break; - case BEST_FATTN_KERNEL_TILE_F16: - ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + case BEST_FATTN_KERNEL_TILE: + ggml_cuda_flash_attn_ext_tile(ctx, dst); break; case BEST_FATTN_KERNEL_VEC_F32: ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 83d02474f5..2fab33243d 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -2,39 +2,39 @@ #include "dequantize.cuh" #include "convert.cuh" -#define MAX_GRIDDIM_Y 65535 - template static __global__ void k_get_rows( const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ - /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ + /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) { - // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. - const int i10 = blockIdx.x; - const int i11 = blockIdx.z / ne12; - const int i12 = blockIdx.z % ne12; + for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) { + for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) { + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i10 = blockIdx.x; + const int i11 = z / ne12; // TODO fastdiv + const int i12 = z % ne12; - const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; - const int ib = i00/qk; // block index - const int iqs = (i00%qk)/qr; // quant index - const int iybs = i00 - i00%qk; // dst block start index - const int y_offset = qr == 1 ? 1 : qk/2; + const int ib = i00/qk; // block index + const int iqs = (i00%qk)/qr; // quant index + const int iybs = i00 - i00%qk; // dst block start index + const int y_offset = qr == 1 ? 1 : qk/2; - // dequantize - float2 v; - dequantize_kernel(src0_row, ib, iqs, v); + // dequantize + float2 v; + dequantize_kernel(src0_row, ib, iqs, v); - dst_row[iybs + iqs + 0] = ggml_cuda_cast(v.x); - dst_row[iybs + iqs + y_offset] = ggml_cuda_cast(v.y); + dst_row[iybs + iqs + 0] = ggml_cuda_cast(v.x); + dst_row[iybs + iqs + y_offset] = ggml_cuda_cast(v.y); + } } } @@ -42,27 +42,29 @@ template static __global__ void k_get_rows_float( const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ - /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ + /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { - // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. - const int i10 = blockIdx.x; - const int i11 = blockIdx.z / ne12; - const int i12 = blockIdx.z % ne12; + for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) { + for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i10 = blockIdx.x; + const int i11 = z / ne12; // TODO fastdiv + const int i12 = z % ne12; - if (i00 >= ne00) { - return; + if (i00 >= ne00) { + return; + } + + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); + + dst_row[i00] = ggml_cuda_cast(src0_row[i00]); } - - const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; - - dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); - - dst_row[i00] = ggml_cuda_cast(src0_row[i00]); } } @@ -98,7 +100,7 @@ static void get_rows_cuda_q( cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); - const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12); + const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX)); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); @@ -116,7 +118,7 @@ static void get_rows_cuda_q( k_get_rows<<>>( src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ - /*ne10, ne11,*/ ne12, /*ne13,*/ + /*ne10,*/ ne11, ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); @@ -131,7 +133,7 @@ static void get_rows_cuda_float( cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; - const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12); + const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX)); // strides in elements // const size_t s0 = nb0 / sizeof(dst_t); @@ -147,7 +149,7 @@ static void get_rows_cuda_float( k_get_rows_float<<>>( src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ - /*ne10, ne11,*/ ne12, /*ne13,*/ + /*ne10,*/ ne11, ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 0c01eb6fa8..9ea8f4589d 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2109,6 +2109,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst); return; } + + if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) { + ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst); + return; + } } cudaStream_t stream = ctx.stream(); @@ -3135,6 +3140,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .graph_compute = */ ggml_backend_cuda_graph_compute, /* .event_record = */ ggml_backend_cuda_event_record, /* .event_wait = */ ggml_backend_cuda_event_wait, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_cuda_guid() { @@ -3204,6 +3210,7 @@ struct ggml_backend_cuda_device_context { int device; std::string name; std::string description; + std::string pci_bus_id; }; static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { @@ -3228,9 +3235,12 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend } static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + props->name = ggml_backend_cuda_device_get_name(dev); props->description = ggml_backend_cuda_device_get_description(dev); props->type = ggml_backend_cuda_device_get_type(dev); + props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; @@ -3461,6 +3471,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) { return true; } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) { + return true; + } + if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) { + return true; + } if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) { return true; } @@ -3574,9 +3590,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_GROUP_NORM: + case GGML_OP_PAD: return ggml_is_contiguous(op->src[0]); case GGML_OP_UPSCALE: - case GGML_OP_PAD: case GGML_OP_PAD_REFLECT_1D: case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: @@ -3792,6 +3808,10 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); dev_ctx->description = prop.name; + char pci_bus_id[16] = {}; + snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); + dev_ctx->pci_bus_id = pci_bus_id; + ggml_backend_dev_t dev = new ggml_backend_device { /* .iface = */ ggml_backend_cuda_device_interface, /* .reg = */ ®, diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 667deb9c65..c1f24243fe 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -1,3 +1,4 @@ +#pragma once // This file contains primitives that expose the tensor core PTX instructions for CUDA code. // The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout. // The documentation for the PTX instructions can be found under: diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index cfa5c5cce2..16331e9ecf 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -1,343 +1,12 @@ #include "ggml.h" -#include "common.cuh" -#include "mma.cuh" #include "mmf.cuh" -using namespace ggml_cuda_mma; - -#define MMF_ROWS_PER_BLOCK 32 - -template -__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) -static __global__ void mul_mat_f( - const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, - const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - typedef tile<16, 8, T> tile_A; - typedef tile< 8, 8, T> tile_B; - typedef tile<16, 8, float> tile_C; - - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - constexpr int tile_k_padded = warp_size + 4; - constexpr int ntA = rows_per_block / tile_A::I; - constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; - - const int row0 = blockIdx.x * rows_per_block; - const int channel_dst = blockIdx.y; - const int channel_x = channel_dst / channel_ratio; - const int channel_y = channel_dst; - const int sample_dst = blockIdx.z; - const int sample_x = sample_dst / sample_ratio; - const int sample_y = sample_dst; - - x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ; - y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; - dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; - - const float2 * y2 = (const float2 *) y; - - extern __shared__ char data_mmv[]; - - tile_C C[ntA][ntB]; - - T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded); - - for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { - tile_A A[ntA][warp_size / tile_A::J]; -#pragma unroll - for (int itA = 0; itA < ntA; ++itA) { -#pragma unroll - for (int i = 0; i < tile_A::I; ++i) { - tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; - } -#pragma unroll - for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { - load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); - } - } - -#pragma unroll - for (int itB = 0; itB < ntB; ++itB) { - if constexpr (std::is_same_v) { -#pragma unroll - for (int j0 = 0; j0 < tile_B::I; ++j0) { - const int j = j0 + itB*tile_B::I; - - tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; - } - } else if constexpr (std::is_same_v || std::is_same_v) { -#pragma unroll - for (int j0 = 0; j0 < tile_B::I; ++j0) { - const int j = j0 + itB*tile_B::I; - - const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); - tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; - } - } else { - static_assert(std::is_same_v, "unsupported type"); - } -#pragma unroll - for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { - tile_B B; - load_ldmatrix(B, tile_xy + k0, tile_k_padded); -#pragma unroll - for (int itA = 0; itA < ntA; ++itA) { - mma(C[itA][itB], A[itA][k0/tile_B::J], B); - } - } - } - } - - float * buf_iw = (float *) data_mmv; - constexpr int kiw = nwarps*rows_per_block + 4; - - if (nwarps > 1) { - __syncthreads(); - } -#pragma unroll - for (int itB = 0; itB < ntB; ++itB) { -#pragma unroll - for (int itA = 0; itA < ntA; ++itA) { -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); - const int j = itB*tile_C::J + tile_C::get_j(l); - buf_iw[j*kiw + i] = C[itA][itB].x[l]; - } - } - } - - if (nwarps > 1) { - __syncthreads(); - } - -#pragma unroll - for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (j0 + nwarps > cols_per_block && j >= cols_per_block) { - return; - } - - float sum = 0.0f; - static_assert(rows_per_block == warp_size, "need loop/check"); -#pragma unroll - for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { - const int i = i0 + threadIdx.x; - - sum += buf_iw[j*kiw + i]; - } - dst[j*stride_col_dst + row0 + threadIdx.x] = sum; - } -#else - GGML_UNUSED_VARS(x, y, ids, dst, - ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - NO_DEVICE_CODE; -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) -} - -template -static void mul_mat_f_cuda( - const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols_x, const int64_t nrows_x, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { - typedef tile<16, 8, T> tile_A; - typedef tile< 8, 8, T> tile_B; - - GGML_ASSERT(!ids && "mul_mat_id not implemented"); - - GGML_ASSERT(ncols_x % 2 == 0); - GGML_ASSERT(stride_row % 2 == 0); - GGML_ASSERT(stride_col_y % 2 == 0); - GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); - GGML_ASSERT( nsamples_dst % nsamples_x == 0); - const int64_t channel_ratio = nchannels_dst / nchannels_x; - const int64_t sample_ratio = nsamples_dst / nsamples_x; - - const int device = ggml_cuda_get_device(); - const int warp_size = ggml_cuda_info().devices[device].warp_size; - - int64_t nwarps_best = 1; - int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); - int64_t max_block_size = 256; - for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { - const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); - if (niter < niter_best) { - niter_best = niter; - nwarps_best = nwarps; - } - } - - constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; - const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4; - const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; - const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); - const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst); - const dim3 block_dims(warp_size, nwarps_best, 1); - switch (nwarps_best) { - case 1: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 2: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 3: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 4: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 5: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 6: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 7: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 8: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - default: { - GGML_ABORT("fatal error"); - } break; - } -} - -template -static void mul_mat_f_switch_cols_per_block( - const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { - switch (ncols_dst) { - case 1: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 2: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 3: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 4: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 5: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 6: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 7: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 8: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 9: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 10: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 11: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 12: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 13: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 14: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 15: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 16: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - default: { - GGML_ABORT("fatal error"); - } break; - } -} - void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_TENSOR_BINARY_OP_LOCALS; const size_t ts_src0 = ggml_type_size(src0->type); @@ -365,55 +34,72 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr const int64_t s13 = src1->nb[3] / ts_src1; const int64_t s3 = dst->nb[3] / ts_dst; + const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0; + const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: const int64_t ncols_dst = ids ? ne2 : ne1; - const int64_t nchannels_y = ids ? ne11 : ne12; - const int64_t nchannels_dst = ids ? ne1 : ne2; - const int64_t stride_channel_dst = ids ? s1 : s2; - const int64_t stride_channel_y = ids ? s11 : s12; + const int64_t nchannels_dst = ids ? ne1 : ne2; - GGML_ASSERT(!ids || ncols_dst == 1); + const int64_t stride_col_dst = ids ? s2 : s1; + const int64_t stride_col_y = ids ? s12 : s11; + const int64_t stride_channel_dst = ids ? s1 : s2; + + int64_t stride_channel_y = ids ? s11 : s12; + int64_t nchannels_y = ids ? ne11 : ne12; + + //mul_mat_id: handle broadcast + if (ids && nchannels_y == 1) { + stride_channel_y = 0; + nchannels_y = ids->ne[0]; + } switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; constexpr int vals_per_T = 1; mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, - ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); } break; case GGML_TYPE_F16: { const half2 * src0_d = (const half2 *) src0->data; constexpr int vals_per_T = 2; mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, - ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; constexpr int vals_per_T = 2; mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, - ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); } } -bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) { +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) { + + if (ggml_is_quantized(type)) { + return false; + } + if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) { return false; } if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { return false; } - if (ne11 > 16) { + if (src1_ncols > 16) { return false; } + switch (type) { case GGML_TYPE_F32: return ampere_mma_available(cc); diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 785f9f211c..bf724bc57b 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -1,5 +1,473 @@ +#pragma once + +#include "mma.cuh" #include "common.cuh" +using namespace ggml_cuda_mma; + +#define MMF_ROWS_PER_BLOCK 32 + void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); -bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11); +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols); + +template +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f( + const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, + const int ncols, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int stride_col_id, const int stride_row_id, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int tile_k_padded = warp_size + 4; + constexpr int ntA = rows_per_block / tile_A::I; + constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; + + const int row0 = blockIdx.x * rows_per_block; + + const int expert_idx = has_ids ? blockIdx.y : 0; + const int channel_dst = has_ids ? 0 : blockIdx.y; + + const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio); + const int channel_y = channel_dst; + const int sample_dst = blockIdx.z; + const int sample_x = sample_dst / sample_ratio; + const int sample_y = sample_dst; + + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ; + y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y); + dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst); + + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; + + char * shmem_base = data_mmv; + int * slot_map = (int *) shmem_base; + char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base; + + tile_C C[ntA][ntB]; + + T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); + + if constexpr (has_ids) { + __shared__ int has_any; + if (threadIdx.y == 0) { + int local_has_any = 0; + for (int j = threadIdx.x; j < cols_per_block; j += warp_size) { + int slot = -1; + for (int k = 0; k < nchannels_dst; ++k) { + const int idv = ids[j*stride_row_id + k*stride_col_id]; + if (idv == expert_idx) { + slot = k; + break; + } + } + if (j < cols_per_block) { + local_has_any |= (slot >= 0); + slot_map[j] = slot; + } + } + has_any = warp_reduce_any(local_has_any); + } + __syncthreads(); + if (has_any == 0) { + return; + } + } + + for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { + tile_A A[ntA][warp_size / tile_A::J]; +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int i = 0; i < tile_A::I; ++i) { + tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { + load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); + } + } + +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + if constexpr (!has_ids) { + tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; + } else { + float val = 0.0f; + if (j < cols_per_block) { + const int slot = slot_map[j]; + if (slot >= 0) { + val = y[slot*stride_channel_y + j*stride_col_y + col]; + } + } + tile_xy[j0*tile_k_padded + threadIdx.x] = val; + } + } + } else if constexpr (std::is_same_v || std::is_same_v) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + if constexpr (!has_ids) { + const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); + tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + } else { + float2 tmp = make_float2(0.0f, 0.0f); + if (j < cols_per_block) { + const int slot = slot_map[j]; + if (slot >= 0) { + const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y); + tmp = y2_slot[j*stride_col_y + col]; + } + } + tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + } + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + } + } + + float * buf_iw = (float *) compute_base; + constexpr int kiw = nwarps*rows_per_block + 4; + + if (nwarps > 1) { + __syncthreads(); + } +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); + const int j = itB*tile_C::J + tile_C::get_j(l); + buf_iw[j*kiw + i] = C[itA][itB].x[l]; + } + } + } + + if (nwarps > 1) { + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > cols_per_block && j >= cols_per_block) { + return; + } + + float sum = 0.0f; + static_assert(rows_per_block == warp_size, "need loop/check"); +#pragma unroll + for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { + const int i = i0 + threadIdx.x; + + sum += buf_iw[j*kiw + i]; + } + + if constexpr (!has_ids) { + dst[j*stride_col_dst + row0 + threadIdx.x] = sum; + } else { + const int slot = (j < cols_per_block) ? slot_map[j] : -1; + if (slot >= 0) { + dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum; + } + } + } +#else + GGML_UNUSED_VARS(x, y, ids, dst, + ncols, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + NO_DEVICE_CODE; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +} + +template +static inline void mul_mat_f_switch_ids( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nchannels_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int64_t stride_row_id, + const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { + if (ids) { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } else { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } +} + +template +void mul_mat_f_cuda( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int64_t stride_row_id, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + + GGML_ASSERT(ncols_x % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(stride_col_y % 2 == 0); + GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); + GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const int64_t channel_ratio = nchannels_dst / nchannels_x; + const int64_t sample_ratio = nsamples_dst / nsamples_x; + + const int device = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[device].warp_size; + + int64_t nwarps_best = 1; + int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); + int64_t max_block_size = 256; + for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { + const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); + if (niter < niter_best) { + niter_best = niter; + nwarps_best = nwarps; + } + } + + constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; + const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4; + const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; + const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); + const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; + const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; + const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present + + const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst); + const dim3 block_dims(warp_size, nwarps_best, 1); + + switch (nwarps_best) { + case 1: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 2: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 3: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 4: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 5: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 6: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 7: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 8: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } + + GGML_UNUSED_VARS(nchannels_y); +} + +template +static void mul_mat_f_switch_cols_per_block( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int stride_row_id, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + switch (ncols_dst) { + case 1: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 2: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 3: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 4: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 5: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 6: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 7: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 8: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 9: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 10: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 11: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 12: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 13: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 14: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 15: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 16: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +#define DECL_MMF_CASE_HELPER(T, ncols_dst) \ + template void mul_mat_f_cuda( \ + const T * x, const float * y, const int32_t * ids, float * dst, \ + const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \ + const int64_t stride_col_id, const int64_t stride_row_id, \ + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \ + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\ + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ + cudaStream_t stream); + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#define DECL_MMF_CASE_EXTERN(ncols_dst) \ + extern DECL_MMF_CASE_HELPER(float, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) + +#define DECL_MMF_CASE(ncols_dst) \ + DECL_MMF_CASE_HELPER(float, ncols_dst) \ + DECL_MMF_CASE_HELPER(half2, ncols_dst) \ + DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) + +DECL_MMF_CASE_EXTERN(1); +DECL_MMF_CASE_EXTERN(2); +DECL_MMF_CASE_EXTERN(3); +DECL_MMF_CASE_EXTERN(4); +DECL_MMF_CASE_EXTERN(5); +DECL_MMF_CASE_EXTERN(6); +DECL_MMF_CASE_EXTERN(7); +DECL_MMF_CASE_EXTERN(8); +DECL_MMF_CASE_EXTERN(9); +DECL_MMF_CASE_EXTERN(10); +DECL_MMF_CASE_EXTERN(11); +DECL_MMF_CASE_EXTERN(12); +DECL_MMF_CASE_EXTERN(13); +DECL_MMF_CASE_EXTERN(14); +DECL_MMF_CASE_EXTERN(15); +DECL_MMF_CASE_EXTERN(16); +#else +#define DECL_MMF_CASE(ncols_dst) +#endif diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index b7c3079308..52de4e78d1 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -141,9 +141,10 @@ template __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst, - const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, + const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, + const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) { constexpr int qk = ggml_cuda_type_traits::qk; constexpr int qi = ggml_cuda_type_traits::qi; @@ -161,12 +162,12 @@ static __global__ void mul_mat_vec_q( constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1. - const int channel_dst = blockIdx.y; - const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio; - const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst; - const int sample_dst = blockIdx.z; - const int sample_x = sample_dst / sample_ratio; - const int sample_y = sample_dst; + const uint32_t channel_dst = blockIdx.y; + const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); + const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; + const uint32_t sample_dst = blockIdx.z; + const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); + const uint32_t sample_y = sample_dst; // partial sum for each thread float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; @@ -247,8 +248,9 @@ static void mul_mat_vec_q_switch_ncols_dst( GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); - const int channel_ratio = nchannels_dst / nchannels_x; - const int sample_ratio = nsamples_dst / nsamples_x; + const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0); + const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); + const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); const int device = ggml_cuda_get_device(); const int warp_size = ggml_cuda_info().devices[device].warp_size; @@ -256,86 +258,70 @@ static void mul_mat_vec_q_switch_ncols_dst( GGML_ASSERT(!ids || ncols_dst == 1); switch (ncols_dst) { - case 1: - { + case 1: { constexpr int c_ncols_dst = 1; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 2: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 2: { constexpr int c_ncols_dst = 2; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 3: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 3: { constexpr int c_ncols_dst = 3; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 4: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 4: { constexpr int c_ncols_dst = 4; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 5: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 5: { constexpr int c_ncols_dst = 5; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 6: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 6: { constexpr int c_ncols_dst = 6; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 7: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 7: { constexpr int c_ncols_dst = 7; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } - case 8: - { + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 8: { constexpr int c_ncols_dst = 8; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - break; - } + (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; default: GGML_ABORT("fatal error"); break; diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index a0b03a740d..5117f9ffc0 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -1,26 +1,27 @@ #include "quantize.cuh" #include +__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1) static __global__ void quantize_q8_1( const float * __restrict__ x, void * __restrict__ vy, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, - const int64_t ne0, const int ne1, const int ne2) { + const int64_t ne0, const uint32_t ne1, const uint3 ne2) { const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= ne0) { return; } + const int64_t i3 = fastdiv(blockIdx.z, ne2); + const int64_t i2 = blockIdx.z - i3*ne2.z; const int64_t i1 = blockIdx.y; - const int64_t i2 = blockIdx.z % ne2; - const int64_t i3 = blockIdx.z / ne2; const int64_t & i00 = i0; const int64_t & i01 = i1; const int64_t & i02 = i2; const int64_t & i03 = i3; - const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0; + const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0; block_q8_1 * y = (block_q8_1 *) vy; @@ -31,10 +32,10 @@ static __global__ void quantize_q8_1( float amax = fabsf(xi); float sum = xi; - amax = warp_reduce_max(amax); - sum = warp_reduce_sum(sum); + amax = warp_reduce_max(amax); + sum = warp_reduce_sum(sum); - const float d = amax / 127; + const float d = amax / 127.0f; const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); y[ib].qs[iqs] = q; @@ -43,8 +44,7 @@ static __global__ void quantize_q8_1( return; } - reinterpret_cast(y[ib].ds.x) = d; - reinterpret_cast(y[ib].ds.y) = sum; + y[ib].ds = make_half2(d, sum); } template @@ -152,10 +152,12 @@ void quantize_row_q8_1_cuda( GGML_ASSERT(!ids); GGML_ASSERT(ne0 % QK8_1 == 0); + const uint3 ne2_fastdiv = init_fastdiv_values(ne2); + const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const dim3 num_blocks(block_num_x, ne1, ne2*ne3); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1<<>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + quantize_q8_1<<>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv); GGML_UNUSED(type_src0); } diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 3428113dc8..da2d7b7c3b 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -24,7 +24,7 @@ TYPES_MMQ = [ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", - "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS" + "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4" ] SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. @@ -34,6 +34,13 @@ SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do DECL_MMQ_CASE({type}); """ +SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE({type}); +""" + def get_short_name(long_quant_name): return long_quant_name.replace("GGML_TYPE_", "").lower() @@ -76,3 +83,7 @@ for ncols in [8, 16, 32, 64]: for type in TYPES_MMQ: with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: f.write(SOURCE_MMQ.format(type=type)) + +for type in range(1, 17): + with open(f"mmf-instance-ncols_{type}.cu", "w") as f: + f.write(SOURCE_MMF.format(type=type)) diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu new file mode 100644 index 0000000000..f594d5d51d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(1); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu new file mode 100644 index 0000000000..9cc6772542 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(10); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu new file mode 100644 index 0000000000..317f487d7a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(11); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu new file mode 100644 index 0000000000..dc0033227c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(12); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu new file mode 100644 index 0000000000..0782101753 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(13); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu new file mode 100644 index 0000000000..a23ad6ae26 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(14); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu new file mode 100644 index 0000000000..0fe3f7821e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(15); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu new file mode 100644 index 0000000000..544086375e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(16); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu new file mode 100644 index 0000000000..3b901797cf --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(2); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu new file mode 100644 index 0000000000..56e940bba0 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(3); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu new file mode 100644 index 0000000000..a7665d49d0 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(4); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu new file mode 100644 index 0000000000..3a1dff2587 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(5); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu new file mode 100644 index 0000000000..400fb7c663 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(6); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu new file mode 100644 index 0000000000..954a1c7e03 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(7); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu new file mode 100644 index 0000000000..f1bd09c945 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(8); diff --git a/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu new file mode 100644 index 0000000000..1255ac2af6 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(9); diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index c6a33d5de3..12bbee4556 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -162,6 +162,14 @@ #define GCN #endif +#if defined(__gfx900__) || defined(__gfx906__) +#define GCN5 +#endif + +#if defined(__gfx803__) +#define GCN4 +#endif + #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) #define CDNA // For the entire family #endif diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index b9d3639448..651943fa92 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -20,8 +20,8 @@ #define N_R0_Q5_1 4 #define N_SG_Q5_1 2 -#define N_R0_Q8_0 4 -#define N_SG_Q8_0 2 +#define N_R0_Q8_0 2 +#define N_SG_Q8_0 4 #define N_R0_MXFP4 2 #define N_SG_MXFP4 2 @@ -68,6 +68,11 @@ #define N_R0_IQ4_XS 2 #define N_SG_IQ4_XS 2 +// function constants offsets +#define FC_FLASH_ATTN_EXT 100 +#define FC_FLASH_ATTN_EXT_VEC 200 +#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300 + // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage @@ -236,9 +241,11 @@ typedef struct { int32_t ne11; int32_t ne_12_2; // assume K and V are same shape int32_t ne_12_3; + int32_t ns10; uint64_t nb11; uint64_t nb12; uint64_t nb13; + int32_t ns20; uint64_t nb21; uint64_t nb22; uint64_t nb23; @@ -258,10 +265,43 @@ typedef struct { float logit_softcap; } ggml_metal_kargs_flash_attn_ext; +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + int32_t ns10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ns20; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; + int32_t ne1; + int32_t ne2; + int32_t ne3; + float scale; + float max_bias; + float m0; + float m1; + int32_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext_vec; + typedef struct { int32_t nrows; - int32_t ne20; -} ggml_metal_kargs_flash_attn_ext_reduce; +} ggml_metal_kargs_flash_attn_ext_vec_reduce; typedef struct { int32_t ne00; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index c1a0a2bef1..07b96dbdd5 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -48,6 +48,11 @@ static struct ggml_backend_metal_device_context { int mtl_device_ref_count; id mtl_library; + // a single global queue shared by all Metal backends + // technically not needed for devices with unified memory, but enables discrete GPUs support + // ref: https://github.com/ggml-org/llama.cpp/pull/15906 + id mtl_queue; + NSLock * mtl_lock; bool has_simdgroup_reduction; @@ -56,6 +61,7 @@ static struct ggml_backend_metal_device_context { bool has_bfloat; bool use_bfloat; bool use_fusion; + bool use_shared_buffers; int debug_fusion; @@ -69,6 +75,7 @@ static struct ggml_backend_metal_device_context { /*.mtl_device =*/ nil, /*.mtl_device_ref_count =*/ 0, /*.mtl_library =*/ nil, + /*.mtl_queue =*/ nil, /*.mtl_lock =*/ nil, /*.has_simdgroup_reduction =*/ false, /*.has_simdgroup_mm =*/ false, @@ -76,6 +83,7 @@ static struct ggml_backend_metal_device_context { /*.has_bfloat =*/ false, /*.use_bfloat =*/ false, /*.use_fusion =*/ true, + /*.use_shared_buffers =*/ true, /*.debug_fusion =*/ 0, /*.fuse_cnt =*/ { 0 }, /*.max_size =*/ 0, @@ -94,6 +102,11 @@ static id ggml_backend_metal_device_acq(struct ggml_backend_metal_dev ctx->mtl_device = MTLCreateSystemDefaultDevice(); if (ctx->mtl_device) { + ctx->mtl_queue = [ctx->mtl_device newCommandQueue]; + if (ctx->mtl_queue == nil) { + GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); + } + ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; @@ -118,6 +131,12 @@ static id ggml_backend_metal_device_acq(struct ggml_backend_metal_dev ctx->debug_fusion = val ? atoi(val) : 0; } + ctx->use_shared_buffers = ctx->mtl_device.hasUnifiedMemory; + + if (getenv("GGML_METAL_SHARED_BUFFERS_DISABLE") != NULL) { + ctx->use_shared_buffers = false; + } + memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt)); ctx->max_size = ctx->mtl_device.maxBufferLength; @@ -161,6 +180,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte ctx->mtl_library = nil; } + if (ctx->mtl_queue) { + [ctx->mtl_queue release]; + ctx->mtl_queue = nil; + } + if (ctx->mtl_device) { [ctx->mtl_device release]; ctx->mtl_device = nil; @@ -174,6 +198,19 @@ struct ggml_metal_kernel { id pipeline; }; +@interface ggml_metal_kernel_wrapper : NSObject + +@property (nonatomic, assign) struct ggml_metal_kernel kernel; + +@end + +@implementation ggml_metal_kernel_wrapper +- (void) dealloc { + [_kernel.pipeline release]; + [super dealloc]; +} +@end + enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ADD, GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, @@ -454,128 +491,6 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, - GGML_METAL_KERNEL_TYPE_SET_I32, - GGML_METAL_KERNEL_TYPE_SET_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, @@ -583,6 +498,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_CPY_F16_F32, GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, + GGML_METAL_KERNEL_TYPE_CPY_F32_I32, + GGML_METAL_KERNEL_TYPE_CPY_I32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, @@ -878,12 +795,16 @@ struct ggml_metal_command_buffer { struct ggml_backend_metal_context { id device; - id queue; + id queue; // currently a pointer to the device queue, but might become separate queue [TAG_QUEUE_PER_BACKEND] dispatch_queue_t d_queue; + // the set of pre-compiled kernels for this context struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT]; + // additional, inference-time compiled kernels + NSMutableDictionary * kernels_ext; + // capture state bool capture_next_compute; bool capture_started; @@ -904,6 +825,12 @@ struct ggml_backend_metal_context { // n_cb command buffers + 1 used by the main thread struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; + // extra command buffers for things like getting, setting and copying tensors + NSMutableArray * cmd_bufs_ext; + + // the last command buffer queued into the Metal queue with operations relevant to the current Metal backend + id cmd_buf_last; + // abort ggml_metal_graph_compute if callback returns true ggml_abort_callback abort_callback; void * abort_callback_data; @@ -949,6 +876,8 @@ static void * ggml_metal_host_malloc(size_t n) { // - if not found, load the source and compile it // - if that fails, return NULL static id ggml_metal_load_library(id device, bool use_bfloat) { + const int64_t t_start = ggml_time_us(); + id metal_library = nil; NSError * error = nil; NSString * src = nil; @@ -1072,6 +1001,8 @@ static id ggml_metal_load_library(id device, bool use_bfl [src release]; #endif // GGML_METAL_EMBED_LIBRARY + GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6); + return metal_library; } @@ -1096,7 +1027,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); ctx->device = device; - ctx->queue = [device newCommandQueue]; + + // TODO: question - would it be better to have one queue for the backend and one queue for the device? + // the graph encoders and async ops would use the backend queue while the sync ops would use the device queue? + //ctx->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND] + ctx->queue = ctx_dev->mtl_queue; if (ctx->queue == nil) { GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); return NULL; @@ -1155,6 +1090,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_LOG_INFO("%s: has residency sets = %s\n", __func__, ctx_dev->has_residency_sets ? "true" : "false"); GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false"); GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false"); + GGML_LOG_INFO("%s: use fusion = %s\n", __func__, ctx_dev->use_fusion ? "true" : "false"); + GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, ctx_dev->use_shared_buffers ? "true" : "false"); GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false"); ctx->capture_next_compute = false; @@ -1170,6 +1107,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de ctx->cmd_bufs[i].mem_pool->device = device; } + ctx->cmd_bufs_ext = [[NSMutableArray alloc] init]; + + ctx->cmd_buf_last = nil; + #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6); @@ -1269,7 +1210,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); @@ -1487,128 +1428,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40, flash_attn_ext_f16_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40, flash_attn_ext_bf16_h40, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40, flash_attn_ext_q4_0_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40, flash_attn_ext_q4_1_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40, flash_attn_ext_q5_0_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40, flash_attn_ext_q5_1_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40, flash_attn_ext_q8_0_h40, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, flash_attn_ext_vec_f16_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, flash_attn_ext_vec_bf16_h192, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, flash_attn_ext_vec_q4_0_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, flash_attn_ext_vec_q4_1_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, flash_attn_ext_vec_q5_0_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, flash_attn_ext_vec_q5_1_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, flash_attn_ext_vec_q8_0_h192, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, flash_attn_ext_vec_f16_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, flash_attn_ext_vec_bf16_hk192_hv128, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, flash_attn_ext_vec_q4_0_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, flash_attn_ext_vec_q4_1_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, flash_attn_ext_vec_q5_0_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, flash_attn_ext_vec_q5_1_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, flash_attn_ext_vec_q8_0_hk192_hv128, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat); @@ -1616,6 +1435,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_I32, cpy_f32_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_I32_F32, cpy_i32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); @@ -1651,9 +1472,219 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); } + ctx->kernels_ext = [[NSMutableDictionary alloc] init]; + return ctx; } +static id ggml_metal_get_kernel(struct ggml_backend_metal_context * ctx, const char * name) { + NSString * key = [NSString stringWithUTF8String:name]; + + ggml_metal_kernel_wrapper * obj = [ctx->kernels_ext objectForKey:key]; + if (obj) { + return obj.kernel.pipeline; + } + + return nil; +} + +static id ggml_metal_compile_kernel(ggml_backend_t backend, const char * base, const char * name, MTLFunctionConstantValues * cv) { + struct ggml_backend_metal_context * ctx = backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + id res = nil; + + @autoreleasepool { + NSError * error = nil; + + NSString * base_func = [NSString stringWithUTF8String:base]; + + GGML_LOG_DEBUG("%s: compiling kernel: base = '%s', name = '%s'\n", __func__, base, name); + + // TODO: make sure it is thread-safe to compile kernels in parallel + id metal_function = [ctx_dev->mtl_library newFunctionWithName:base_func constantValues:cv error:&error]; + if (!metal_function) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + + return nil; + } + + struct ggml_metal_kernel kernel = { + /*.pipeline =*/ [ctx_dev->mtl_device newComputePipelineStateWithFunction:metal_function error:&error], + }; + + ggml_metal_kernel_wrapper * obj = [[ggml_metal_kernel_wrapper alloc] init]; + obj.kernel = kernel; + + res = obj.kernel.pipeline; + + NSString * key = [NSString stringWithUTF8String:name]; + [ctx->kernels_ext setObject:obj forKey:key]; + + GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) kernel.pipeline, + (int) kernel.pipeline.maxTotalThreadsPerThreadgroup, + (int) kernel.pipeline.threadExecutionWidth); + } + + return res; +} + +static id ggml_metal_get_pipeline_flash_attn_ext( + ggml_backend_t backend, struct ggml_tensor * op, + bool has_mask, + bool has_sinks, + bool has_bias, + bool has_scap, + int32_t nsg) { + struct ggml_backend_metal_context * ctx = backend->context; + + char base[256]; + char name[256]; + + @autoreleasepool { + MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init]; + + const int32_t dk = (int32_t) op->src[1]->ne[0]; + const int32_t dv = (int32_t) op->src[2]->ne[0]; + + const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; + const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; + + snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", + "flash_attn_ext", + ggml_type_name(op->src[1]->type), + dk, + dv); + + snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d", + "flash_attn_ext", + ggml_type_name(op->src[1]->type), + dk, + dv, + has_mask, + has_sinks, + has_bias, + has_scap, + ns10, + ns20, + nsg); + + id res = ggml_metal_get_kernel(ctx, name); + if (res) { + // kernel found + return res; + } + + cv = [[MTLFunctionConstantValues alloc] init]; + + [cv setConstantValue:&has_mask type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 0]; + [cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 1]; + [cv setConstantValue:&has_bias type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 2]; + [cv setConstantValue:&has_scap type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 3]; + + [cv setConstantValue:&ns10 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 20]; + [cv setConstantValue:&ns20 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 21]; + [cv setConstantValue:&nsg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 22]; + + return ggml_metal_compile_kernel(backend, base, name, cv); + } +} + +static id ggml_metal_get_pipeline_flash_attn_ext_vec( + ggml_backend_t backend, struct ggml_tensor * op, + bool has_mask, + bool has_sinks, + bool has_bias, + bool has_scap, + int32_t nsg, + int32_t nwg) { + struct ggml_backend_metal_context * ctx = backend->context; + + char base[256]; + char name[256]; + + @autoreleasepool { + MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init]; + + const int32_t dk = (int32_t) op->src[1]->ne[0]; + const int32_t dv = (int32_t) op->src[2]->ne[0]; + + const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; + const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; + + snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", + "flash_attn_ext_vec", + ggml_type_name(op->src[1]->type), + dk, + dv); + + snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d", + "flash_attn_ext_vec", + ggml_type_name(op->src[1]->type), + dk, + dv, + has_mask, + has_sinks, + has_bias, + has_scap, + ns10, + ns20, + nsg, nwg); + + id res = ggml_metal_get_kernel(ctx, name); + if (res) { + // kernel found + return res; + } + + cv = [[MTLFunctionConstantValues alloc] init]; + + [cv setConstantValue:&has_mask type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 0]; + [cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 1]; + [cv setConstantValue:&has_bias type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 2]; + [cv setConstantValue:&has_scap type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 3]; + + [cv setConstantValue:&ns10 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 20]; + [cv setConstantValue:&ns20 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 21]; + [cv setConstantValue:&nsg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 22]; + [cv setConstantValue:&nwg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 23]; + + return ggml_metal_compile_kernel(backend, base, name, cv); + } +} + +static id ggml_metal_get_pipeline_flash_attn_ext_vec_reduce( + ggml_backend_t backend, struct ggml_tensor * op, + int32_t dv, + int32_t nwg) { + struct ggml_backend_metal_context * ctx = backend->context; + + char base[256]; + char name[256]; + + @autoreleasepool { + MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init]; + + snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce"); + snprintf(name, 256, "kernel_flash_attn_ext_vec_reduce_dv=%d_nwg=%d", dv, nwg); + + id res = ggml_metal_get_kernel(ctx, name); + if (res) { + // kernel found + return res; + } + + cv = [[MTLFunctionConstantValues alloc] init]; + + [cv setConstantValue:&dv type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 0]; + [cv setConstantValue:&nwg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 1]; + + return ggml_metal_compile_kernel(backend, base, name, cv); + } + + GGML_UNUSED(op); +} + static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { GGML_LOG_INFO("%s: deallocating\n", __func__); @@ -1661,16 +1692,26 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { [ctx->kernels[i].pipeline release]; } + if (ctx->kernels_ext) { + [ctx->kernels_ext release]; + ctx->kernels_ext = nil; + } + Block_release(ctx->encode_async); - [ctx->queue release]; + //[ctx->queue release]; // [TAG_QUEUE_PER_BACKEND] for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { - // ctx->cmd_bufs[i].obj is auto released + if (ctx->cmd_bufs[i].obj) { + [ctx->cmd_bufs[i].obj release]; + } ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool); } + [ctx->cmd_bufs_ext removeAllObjects]; + [ctx->cmd_bufs_ext release]; + dispatch_release(ctx->d_queue); free(ctx); @@ -1688,14 +1729,21 @@ struct ggml_backend_metal_buffer { struct ggml_backend_metal_buffer_context { void * all_data; size_t all_size; - bool owned; + + // if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host + bool is_shared; // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap int n_buffers; struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; // optional MTLResidencySet + // note: cannot use explicity "id" here because it is not available on certain OSes id rset; + + // pointers to global device objects + id device; + id queue; }; // rset init @@ -1761,7 +1809,7 @@ static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the // Metal buffer based on the host memory pointer // -static id ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) { +static id ggml_metal_get_buffer(const struct ggml_tensor * t, size_t * offs) { //GGML_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach); const int64_t tsize = ggml_nbytes(t); @@ -1945,6 +1993,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_I32: return true; default: return false; @@ -1977,16 +2026,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex default: return false; } - default: - return false; - }; - } - case GGML_OP_SET: - { - switch (op->src[0]->type) { - case GGML_TYPE_F32: case GGML_TYPE_I32: - return true; + return op->type == GGML_TYPE_F32; default: return false; }; @@ -3765,6 +3806,7 @@ static int ggml_metal_encode_node( { nsg = N_SG_Q8_0; nr0 = N_R0_Q8_0; + smem = 32*sizeof(float)*N_R0_Q8_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; } break; case GGML_TYPE_MXFP4: @@ -3901,7 +3943,12 @@ static int ggml_metal_encode_node( if (smem > 0) { [encoder setThreadgroupMemoryLength:smem atIndex:0]; } - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + + if (src0t == GGML_TYPE_Q8_0) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0 - 1)/(nr0), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } } } break; case GGML_OP_MUL_MAT_ID: @@ -4122,6 +4169,7 @@ static int ggml_metal_encode_node( { nsg = N_SG_Q8_0; nr0 = N_R0_Q8_0; + smem = 32*sizeof(float)*N_R0_Q8_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; } break; case GGML_TYPE_MXFP4: @@ -4267,7 +4315,12 @@ static int ggml_metal_encode_node( if (smem > 0) { [encoder setThreadgroupMemoryLength:smem atIndex:0]; } - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + + if (src0t == GGML_TYPE_Q8_0) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } } } break; case GGML_OP_GET_ROWS: @@ -5118,6 +5171,7 @@ static int ggml_metal_encode_node( float scale; float max_bias; float logit_softcap; + memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap)); @@ -5126,398 +5180,24 @@ static int ggml_metal_encode_node( scale /= logit_softcap; } + const bool has_mask = src3 != NULL; + const bool has_sinks = src4 != NULL; + const bool has_bias = max_bias != 0.0f; + const bool has_scap = logit_softcap != 0.0f; + const uint32_t n_head = src0->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - id pipeline = nil; - - bool use_vec_kernel = false; + GGML_ASSERT(ne01 < 65536); // use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size - if (ne01 >= 20 || (ne00 == 40 || ne00 == 80 || ne00 == 112)) { - switch (src1->type) { - case GGML_TYPE_F16: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_BF16: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q4_0: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q4_1: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q5_0: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q5_1: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - case GGML_TYPE_Q8_0: - { - if (ne00 == 192 && ne20 == 128) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline; - } else if (ne00 == 576 && ne20 == 512) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline; - } else { - switch (ne00) { - case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40 ].pipeline; break; - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break; - case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - } break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } else { - use_vec_kernel = true; - - switch (ne00) { - case 64: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 96: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 128: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 192: - { - if (ne20 == 128) { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } else { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } - } break; - case 256: - { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } break; - case 576: - { - if (ne20 == 512) { - switch (src1->type) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break; - default: - { - GGML_LOG_ERROR("unsupported type: %d\n", src1->type); - GGML_LOG_ERROR("add template specialization for this type\n"); - GGML_ABORT("add template specialization for this type"); - } - } - } else { - GGML_LOG_ERROR("unsupported size: %lld\n", ne20); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } break; - default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - - ggml_metal_kargs_flash_attn_ext args = { - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne11 =*/ ne11, - /*.ne_12_2 =*/ ne12, - /*.ne_12_3 =*/ ne13, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb21 =*/ nb21, - /*.nb22 =*/ nb22, - /*.nb23 =*/ nb23, - /*.ne32 =*/ ne32, - /*.ne33 =*/ ne33, - /*.nb31 =*/ nb31, - /*.nb32 =*/ nb32, - /*.nb33 =*/ nb33, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.scale =*/ scale, - /*.max_bias =*/ max_bias, - /*.m0 =*/ m0, - /*.m1 =*/ m1, - /*.n_head_log2 =*/ n_head_log2, - /*.logit_softcap =*/ logit_softcap, - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; - } - if (id_src4) { - [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5]; - } - - if (!use_vec_kernel) { + if (ne01 >= 20 || (ne00 % 32 != 0)) { // half8x8 kernel const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !! GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 8 == 0); @@ -5525,34 +5205,90 @@ static int ggml_metal_encode_node( const int is_q = ggml_is_quantized(src1->type) ? 1 : 0; - // 2*(2*ncpsg + nqptg)*(nsg) - // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float) + // 2*(2*ncpsg) + // ncpsg soft_max values + ncpsg mask values // // 16*32*(nsg) // the shared memory needed for the simdgroups to load the KV cache // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*GGML_PAD(ne20, 64) + 2*(2*ncpsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16)) - int64_t nsgmax = 2; - - while (true) { - const size_t smem = FATTN_SMEM(nsgmax); - if (smem > device.maxThreadgroupMemoryLength/2) { - break; - } - nsgmax *= 2; - } - nsgmax /= 2; + //int64_t nsgmax = 4; + // + //if (is_q) { + // nsgmax = 2; + // while (true) { + // const size_t smem = FATTN_SMEM(nsgmax); + // if (smem > device.maxThreadgroupMemoryLength/2) { + // break; + // } + // nsgmax *= 2; + // } + // nsgmax /= 2; + //} // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + int32_t nsg = 4; const size_t smem = FATTN_SMEM(nsg); + ggml_metal_kargs_flash_attn_ext args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.ns10 =*/ nb11/nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ns20 =*/ nb21/nb20, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, + }; + + id pipeline = ggml_metal_get_pipeline_flash_attn_ext(backend, node, has_mask, has_sinks, has_bias, has_scap, nsg); + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + if (id_src3) { + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; + } + if (id_src4) { + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); + //printf("smem: %zu, max: %zu, nsg = %d, ne02 = %d, ne12 = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, ne02, ne12); GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; @@ -5561,7 +5297,7 @@ static int ggml_metal_encode_node( // half4x4 kernel const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable + const int64_t nkpsg = 1*ncpsg; GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 1 == 0); @@ -5574,8 +5310,7 @@ static int ggml_metal_encode_node( // ne20*(nsg) // each simdgroup has a full f32 head vector in shared mem to accumulate results // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16)) -//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; while (true) { @@ -5589,7 +5324,8 @@ static int ggml_metal_encode_node( nsgmax /= 2; // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); + //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); + const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32))); int64_t nsg = 1; while (nsg <= nsgt) { @@ -5599,28 +5335,86 @@ static int ggml_metal_encode_node( // workgroups // each workgroup handles nsg*nkpsg cache values - uint16_t nwg = 1; - if (4*nsg*nkpsg >= ne11) { - const size_t smem = FATTN_SMEM(nsg); + int32_t nwg = 1; + if (false) { + // for small KV caches, we could launch a single workgroup and write the results directly to dst/ + // however, this does not lead to significant improvement, so disabled + nwg = 1; + nsg = 4; + } else { + nwg = 32; + nsg = 1; + while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) { + nsg *= 2; + } + } - //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax); - GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); + ggml_metal_kargs_flash_attn_ext_vec args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.ns10 =*/ nb11/nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ns20 =*/ nb21/nb20, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, + }; + id pipeline = ggml_metal_get_pipeline_flash_attn_ext_vec(backend, node, has_mask, has_sinks, has_bias, has_scap, nsg, nwg); + + GGML_ASSERT(nsg*32 <= (int) pipeline.maxTotalThreadsPerThreadgroup); + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + if (id_src3) { + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; + } + if (id_src4) { + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5]; + } + + const size_t smem = FATTN_SMEM(nsg); + + //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax); + GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); + + if (nwg == 1) { // using 1 workgroup -> write the result directly into dst - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; [encoder setThreadgroupMemoryLength:smem atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } else { - nwg = 32; - nsg = MIN(4, nsg); - - const size_t smem = FATTN_SMEM(nsg); - - //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax); - GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); - // sanity checks GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3); GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31)); @@ -5640,20 +5434,18 @@ static int ggml_metal_encode_node( //printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20); //printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f); - [encoder setBuffer:h_tmp offset:0 atIndex:6]; - [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7]; + [encoder setBuffer:h_tmp offset:0 atIndex:6]; [encoder setThreadgroupMemoryLength:smem atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; // reduce the results from the workgroups { - ggml_metal_kargs_flash_attn_ext_reduce args0 = { + ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = { nrows, - ne20, }; - id pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline; + id pipeline0 = ggml_metal_get_pipeline_flash_attn_ext_vec_reduce(backend, node, ne20, nwg); [encoder setComputePipelineState:pipeline0]; [encoder setBytes:&args0 length:sizeof(args0) atIndex:0]; @@ -5661,7 +5453,7 @@ static int ggml_metal_encode_node( [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; //printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20); - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*nwg, 1, 1)]; } } #undef FATTN_SMEM @@ -5680,6 +5472,7 @@ static int ggml_metal_encode_node( switch (dstt) { case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_I32].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; @@ -5691,6 +5484,13 @@ static int ggml_metal_encode_node( default: GGML_ABORT("not implemented"); }; } break; + case GGML_TYPE_I32: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_I32_F32].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; case GGML_TYPE_F16: { switch (dstt) { @@ -5807,68 +5607,6 @@ static int ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)]; } break; - case GGML_OP_SET: - { - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - - // src0 and dst as viewed during set - const size_t dst_nb0 = ggml_element_size(src0); - - const size_t dst_nb1 = ((int32_t *) dst->op_params)[0]; - const size_t dst_nb2 = ((int32_t *) dst->op_params)[1]; - const size_t dst_nb3 = ((int32_t *) dst->op_params)[2]; - const size_t offset = ((int32_t *) dst->op_params)[3]; - const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - - if (!inplace) { - memcpy(((char *) dst->data), ((char *) src0->data), ggml_nbytes(dst)); - } - - const int im0 = (ne10 == 0 ? 0 : ne10-1); - const int im1 = (ne11 == 0 ? 0 : ne11-1); - const int im2 = (ne12 == 0 ? 0 : ne12-1); - const int im3 = (ne13 == 0 ? 0 : ne13-1); - - GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= ggml_nbytes(dst)); - - id pipeline = nil; - - switch (src0t) { - case GGML_TYPE_F32: - GGML_ASSERT(nb10 == sizeof(float)); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break; - case GGML_TYPE_I32: - GGML_ASSERT(nb10 == sizeof(int32_t)); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - ggml_metal_kargs_set args = { - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb1 =*/ dst_nb1, - /*.nb2 =*/ dst_nb2, - /*.nb3 =*/ dst_nb3, - /*.offs =*/ offset, - /*.inplace =*/ inplace, - }; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10); - - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; case GGML_OP_POOL_2D: { GGML_ASSERT(ggml_is_contiguous(src0)); @@ -5997,6 +5735,12 @@ static enum ggml_status ggml_metal_graph_compute( if (should_capture) { ctx->capture_next_compute = false; + // make sure all previous computations have finished before starting the capture + if (ctx->cmd_buf_last) { + [ctx->cmd_buf_last waitUntilCompleted]; + ctx->cmd_buf_last = nil; + } + if (!ctx->capture_started) { // create capture scope ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device]; @@ -6019,78 +5763,103 @@ static enum ggml_status ggml_metal_graph_compute( // the main thread commits the first few commands immediately // cmd_buf[n_cb] { - id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + // cannot use commandBufferWithUnretainedReferences because the buffers from the memory pool can get destroyed + // TODO: when the memory pools are removed, we can again use commandBufferWithUnretainedReferences + // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009 + //id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + id cmd_buf = [ctx->queue commandBuffer]; + [cmd_buf retain]; + ctx->cmd_bufs[n_cb].obj = cmd_buf; [cmd_buf enqueue]; + ctx->encode_async(n_cb); } - // prepare the rest of the command buffers asynchronously + // remember the command buffer for the next iteration + ctx->cmd_buf_last = ctx->cmd_bufs[n_cb].obj; + + // prepare the rest of the command buffers asynchronously (optional) // cmd_buf[0.. n_cb) for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + //id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + id cmd_buf = [ctx->queue commandBuffer]; + [cmd_buf retain]; + + if (ctx->cmd_bufs[cb_idx].obj) { + [ctx->cmd_bufs[cb_idx].obj release]; + } ctx->cmd_bufs[cb_idx].obj = cmd_buf; // always enqueue the first two command buffers // enqueue all of the command buffers if we don't need to abort if (cb_idx < 2 || ctx->abort_callback == NULL) { [cmd_buf enqueue]; + + // update the pointer to the last queued command buffer + // this is needed to implement synchronize() + ctx->cmd_buf_last = cmd_buf; } } dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); - // wait for completion and check status of each command buffer - // needed to detect if the device ran out-of-memory for example (#1881) - { - id cmd_buf = ctx->cmd_bufs[n_cb].obj; - [cmd_buf waitUntilCompleted]; - - MTLCommandBufferStatus status = [cmd_buf status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); - if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); - } - - return GGML_STATUS_FAILED; - } - } - - for (int i = 0; i < n_cb; ++i) { - id cmd_buf = ctx->cmd_bufs[i].obj; - [cmd_buf waitUntilCompleted]; - - MTLCommandBufferStatus status = [cmd_buf status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); - if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); - } - - return GGML_STATUS_FAILED; - } - - id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); - if (!next_buffer) { - continue; - } - - const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); - if (next_queued) { - continue; - } - - if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { - GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i); - return GGML_STATUS_ABORTED; - } - - [next_buffer commit]; - } + // for debugging: block until graph is computed + //[ctx->cmd_buf_last waitUntilCompleted]; + // enter here only when capturing in order to wait for all computation to finish + // otherwise, we leave the graph to compute asynchronously if (!should_capture && ctx->capture_started) { + // wait for completion and check status of each command buffer + // needed to detect if the device ran out-of-memory for example (#1881) + { + id cmd_buf = ctx->cmd_bufs[n_cb].obj; + [cmd_buf waitUntilCompleted]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + } + + for (int i = 0; i < n_cb; ++i) { + id cmd_buf = ctx->cmd_bufs[i].obj; + [cmd_buf waitUntilCompleted]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + + id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); + if (!next_buffer) { + continue; + } + + const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); + if (next_queued) { + continue; + } + + if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { + GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i); + return GGML_STATUS_ABORTED; + } + + [next_buffer commit]; + } + [ctx->capture_scope endScope]; [[MTLCaptureManager sharedCaptureManager] stopCapture]; } @@ -6100,10 +5869,12 @@ static enum ggml_status ggml_metal_graph_compute( } //////////////////////////////////////////////////////////////////////////////// - // backend interface +//////////////////////////////////////////////////////////////////////////////// -static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) { +// shared buffer + +static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t buffer) { struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; for (int i = 0; i < ctx->n_buffers; i++) { @@ -6112,7 +5883,9 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) ggml_backend_metal_buffer_rset_free(ctx); - if (ctx->owned) { + GGML_ASSERT(ctx->is_shared); + + { #if TARGET_OS_OSX vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size); #else @@ -6123,66 +5896,254 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) free(ctx); } -static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) { +static void * ggml_backend_metal_buffer_shared_get_base(ggml_backend_buffer_t buffer) { struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; return ctx->all_data; } -static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { - memset((char *)tensor->data + offset, value, size); - - GGML_UNUSED(buffer); -} - -static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - memcpy((char *)tensor->data + offset, data, size); - - GGML_UNUSED(buffer); -} - -static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - memcpy(data, (const char *)tensor->data + offset, size); - - GGML_UNUSED(buffer); -} - -static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { - if (ggml_backend_buffer_is_host(src->buffer)) { - memcpy(dst->data, src->data, ggml_nbytes(src)); - return true; - } - return false; - - GGML_UNUSED(buffer); -} - -static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { +static void ggml_backend_metal_buffer_shared_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + GGML_ASSERT(ctx->is_shared); + + memset((char *)tensor->data + offset, value, size); +} + +static void ggml_backend_metal_buffer_shared_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + GGML_ASSERT(ctx->is_shared); + + memcpy((char *)tensor->data + offset, data, size); +} + +static void ggml_backend_metal_buffer_shared_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + GGML_ASSERT(ctx->is_shared); + + memcpy(data, (const char *)tensor->data + offset, size); +} + +static bool ggml_backend_metal_buffer_shared_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { + GGML_UNUSED(buffer); + GGML_UNUSED(src); + GGML_UNUSED(dst); + + return false; +} + +static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, uint8_t value) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + GGML_ASSERT(ctx->is_shared); + memset(ctx->all_data, value, ctx->all_size); } -static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = { - /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer, - /* .get_base = */ ggml_backend_metal_buffer_get_base, +static struct ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = { + /* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_shared_get_base, /* .init_tensor = */ NULL, - /* .memset_tensor = */ ggml_backend_metal_buffer_memset_tensor, - /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor, - /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor, - /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor, - /* .clear = */ ggml_backend_metal_buffer_clear, + /* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor, + /* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_shared_clear, /* .reset = */ NULL, }; -// default buffer type +// private buffer -static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) { - return "Metal"; +static void ggml_backend_metal_buffer_private_free_buffer(ggml_backend_buffer_t buffer) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; - GGML_UNUSED(buft); + for (int i = 0; i < ctx->n_buffers; i++) { + [ctx->buffers[i].metal release]; + } + + ggml_backend_metal_buffer_rset_free(ctx); + + GGML_ASSERT(!ctx->is_shared); + + free(ctx); } +static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + return ctx->all_data; +} + +static void ggml_backend_metal_buffer_private_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + GGML_ASSERT(!ctx->is_shared); + + @autoreleasepool { + // dst + size_t buf_dst_offset = 0; + id buf_dst = ggml_metal_get_buffer(tensor, &buf_dst_offset); + + buf_dst_offset += offset; + + id queue = ctx->queue; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder fillBuffer:buf_dst + range:NSMakeRange(buf_dst_offset, buf_dst_offset + size) + value:value]; + + [encoder endEncoding]; + } + + [cmd_buf commit]; + [cmd_buf waitUntilCompleted]; + } +} + +static void ggml_backend_metal_buffer_private_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + GGML_ASSERT(!ctx->is_shared); + + @autoreleasepool { + // src + void * data_ptr = (void *)(uintptr_t) data; // "const cast" the src data + id buf_src = [ctx->device newBufferWithBytesNoCopy:data_ptr + length:size + options:MTLResourceStorageModeShared + deallocator:nil]; + + // dst + size_t buf_dst_offset = 0; + id buf_dst = ggml_metal_get_buffer(tensor, &buf_dst_offset); + + buf_dst_offset += offset; + + // note: for experimentation purposes, here we use a semaphore to wait for the copy to complete + // this is alternative to waitUntilCompleted, which should be faster, but don't seem to make much difference + dispatch_semaphore_t completion_semaphore = dispatch_semaphore_create(0); + + id queue = ctx->queue; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:buf_src + sourceOffset:0 + toBuffer:buf_dst + destinationOffset:buf_dst_offset + size:size]; + + [encoder endEncoding]; + } + + [cmd_buf addCompletedHandler:^(id cb) { + // TODO: can check for errors here + GGML_UNUSED(cb); + + dispatch_semaphore_signal(completion_semaphore); + }]; + + [cmd_buf commit]; + + dispatch_semaphore_wait(completion_semaphore, DISPATCH_TIME_FOREVER); + //[cmd_buf waitUntilCompleted]; + } +} + +static void ggml_backend_metal_buffer_private_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + GGML_ASSERT(!ctx->is_shared); + + @autoreleasepool { + // src + size_t buf_src_offset = 0; + id buf_src = ggml_metal_get_buffer(tensor, &buf_src_offset); + + buf_src_offset += offset; + + // dst + id buf_dst = [ctx->device newBufferWithBytesNoCopy:data + length:size + options:MTLResourceStorageModeShared + deallocator:nil]; + + id queue = ctx->queue; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:buf_src + sourceOffset:buf_src_offset + toBuffer:buf_dst + destinationOffset:0 + size:size]; + + [encoder endEncoding]; + } + + [cmd_buf commit]; + [cmd_buf waitUntilCompleted]; + } +} + +static bool ggml_backend_metal_buffer_private_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { + GGML_UNUSED(buffer); + GGML_UNUSED(src); + GGML_UNUSED(dst); + + return false; +} + +static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer, uint8_t value) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + GGML_ASSERT(!ctx->is_shared); + + @autoreleasepool { + id queue = ctx->queue; + id cmd_buf = [queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder fillBuffer:ctx->buffers[0].metal + range:NSMakeRange(0, ctx->buffers[0].size) + value:value]; + + [encoder endEncoding]; + } + + [cmd_buf commit]; + [cmd_buf waitUntilCompleted]; + } +} + +static struct ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = { + /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_private_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, + /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_private_clear, + /* .reset = */ NULL, +}; + +// +// buffer types +// + static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { #ifndef GGML_METAL_NDEBUG #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) @@ -6208,7 +6169,8 @@ static void ggml_backend_metal_log_allocated_size(id device, size_t s GGML_UNUSED(size_aligned); } -static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { +// common method for allocating shread or private Metal buffers +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size, bool shared) { struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); const size_t size_page = sysconf(_SC_PAGESIZE); @@ -6224,22 +6186,40 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba id device = ctx_dev->mtl_device; - ctx->all_data = ggml_metal_host_malloc(size_aligned); + // allocate shared buffer if the device supports it and it is required by the buffer type + if (ctx_dev->use_shared_buffers && shared) { + ctx->all_data = ggml_metal_host_malloc(size_aligned); + ctx->is_shared = true; + } else { + // dummy, non-NULL value - we'll populate this after creating the Metal buffer below + ctx->all_data = (void *) 0x000000400ULL; + ctx->is_shared = false; + } ctx->all_size = size_aligned; - ctx->owned = true; + + ctx->device = device; + ctx->queue = ctx_dev->mtl_queue; + ctx->n_buffers = 1; if (ctx->all_data != NULL) { - ctx->buffers[0].data = ctx->all_data; ctx->buffers[0].size = size; ctx->buffers[0].metal = nil; if (size_aligned > 0) { - ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data - length:size_aligned - options:MTLResourceStorageModeShared - deallocator:nil]; + if (ctx_dev->use_shared_buffers) { + ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data + length:size_aligned + options:MTLResourceStorageModeShared + deallocator:nil]; + } else { + ctx->buffers[0].metal = [device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate]; + + ctx->all_data = (void *) (ctx->buffers[0].metal.gpuAddress); + } } + + ctx->buffers[0].data = ctx->all_data; } if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) { @@ -6256,36 +6236,50 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba //ggml_backend_metal_log_allocated_size(device, size_aligned); - return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); + struct ggml_backend_buffer_i buf_i = ctx->is_shared ? ggml_backend_metal_buffer_shared_i : ggml_backend_metal_buffer_private_i; + + return ggml_backend_buffer_init(buft, buf_i, ctx, size); } -static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { +// default (shared) buffer type + +static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) { + return "Metal"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_shared_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true); +} + +static size_t ggml_backend_metal_buffer_type_shared_get_alignment(ggml_backend_buffer_type_t buft) { return 32; GGML_UNUSED(buft); } -static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { +static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_buffer_type_t buft) { const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size; return max_size; } -static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) { - return true; +static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) { + return false; GGML_UNUSED(buft); } -ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) { static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = { /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size, + /* .get_name = */ ggml_backend_metal_buffer_type_shared_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size, /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ ggml_backend_metal_buffer_type_is_host, + /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host, }, /* .device = */ &g_ggml_backend_metal_device, /* .context = */ NULL, @@ -6294,116 +6288,101 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { return &ggml_backend_buffer_type_metal; } -static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) { - return "Metal_Mapped"; +// default (private) buffer type + +static const char * ggml_backend_metal_buffer_type_private_get_name(ggml_backend_buffer_type_t buft) { + return "Metal_Private"; GGML_UNUSED(buft); } -static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) { - static struct ggml_backend_buffer_type ggml_backend_buffer_from_ptr_type_metal = { +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_private_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, false); +} + +static size_t ggml_backend_metal_buffer_type_private_get_alignment(ggml_backend_buffer_type_t buft) { + return 32; + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_metal_buffer_type_private_get_max_size(ggml_backend_buffer_type_t buft) { + const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size; + + return max_size; +} + +static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_type_t buft) { + return false; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) { + static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = { /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_from_ptr_type_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size, + /* .get_name = */ ggml_backend_metal_buffer_type_private_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size, /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ ggml_backend_metal_buffer_type_is_host, + /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host, }, /* .device = */ &g_ggml_backend_metal_device, /* .context = */ NULL, }; - return &ggml_backend_buffer_from_ptr_type_metal; + return &ggml_backend_buffer_type_metal; } -// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr -ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) { - struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); +// mapped buffer type - ctx->all_data = data; - ctx->all_size = size; - ctx->owned = false; - ctx->n_buffers = 0; +static const char * ggml_backend_metal_buffer_type_mapped_get_name(ggml_backend_buffer_type_t buft) { + return "Metal_Mapped"; - const size_t size_page = sysconf(_SC_PAGESIZE); + GGML_UNUSED(buft); +} - // page-align the data ptr - { - const uintptr_t offs = (uintptr_t) data % size_page; - data = (void *) ((char *) data - offs); - size += offs; - } +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_mapped_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + // for mapped buffers, prefer shared memory + return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true); +} - size_t size_aligned = size; - if ((size_aligned % size_page) != 0) { - size_aligned += (size_page - (size_aligned % size_page)); - } +static size_t ggml_backend_metal_buffer_type_mapped_get_alignment(ggml_backend_buffer_type_t buft) { + return 32; - struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main; + GGML_UNUSED(buft); +} - GGML_ASSERT(ctx_dev->mtl_device != nil); +static size_t ggml_backend_metal_buffer_type_mapped_get_max_size(ggml_backend_buffer_type_t buft) { + const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size; - id device = ctx_dev->mtl_device; + return max_size; +} - // the buffer fits into the max buffer size allowed by the device - if (size_aligned <= device.maxBufferLength) { - ctx->buffers[ctx->n_buffers].data = data; - ctx->buffers[ctx->n_buffers].size = size; - ctx->buffers[ctx->n_buffers].metal = nil; +static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_type_t buft) { + return false; - if (size_aligned > 0) { - ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; + GGML_UNUSED(buft); +} - if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); - return false; - } - } +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) { + // note: not obvious, but this buffer type still needs to implement .alloc_buffer: + // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099 + static struct ggml_backend_buffer_type ggml_backend_buffer_type_mapped_metal = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_mapped_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size, + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host, + }, + /* .device = */ &g_ggml_backend_metal_device, + /* .context = */ NULL, + }; - ggml_backend_metal_log_allocated_size(device, size_aligned); - - ++ctx->n_buffers; - } else { - // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into - // one of the views - const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case - const size_t size_step = device.maxBufferLength - size_ovlp; - const size_t size_view = device.maxBufferLength; - - for (size_t i = 0; i < size; i += size_step) { - const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); - - ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i); - ctx->buffers[ctx->n_buffers].size = size_step_aligned; - ctx->buffers[ctx->n_buffers].metal = nil; - - if (size_step_aligned > 0) { - ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; - - if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); - return false; - } - } - - ggml_backend_metal_log_allocated_size(device, size_step_aligned); - - if (i + size_step < size) { - GGML_LOG_INFO("\n"); - } - - ++ctx->n_buffers; - } - } - - if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { - GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); - free(ctx); - return NULL; - } - - return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size); + return &ggml_backend_buffer_type_mapped_metal; } // backend @@ -6422,6 +6401,137 @@ static void ggml_backend_metal_free(ggml_backend_t backend) { free(backend); } +static void ggml_backend_metal_synchronize(ggml_backend_t backend) { + struct ggml_backend_metal_context * ctx = backend->context; + + // wait for any backend operations to finish + if (ctx->cmd_buf_last) { + [ctx->cmd_buf_last waitUntilCompleted]; + ctx->cmd_buf_last = nil; + } + + // release any completed command buffers + if (ctx->cmd_bufs_ext.count > 0) { + for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) { + id cmd_buf = ctx->cmd_bufs_ext[i]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, (int) i, (int) status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + GGML_ABORT("fatal error"); + } + + [cmd_buf release]; + } + + [ctx->cmd_bufs_ext removeAllObjects]; + } +} + +static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + struct ggml_backend_metal_context * ctx = backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + @autoreleasepool { + id device = ctx_dev->mtl_device; + + // wrap the source data into a Metal buffer + id buf_src = [device newBufferWithBytes:data + length:size + options:MTLResourceStorageModeShared]; + + size_t buf_dst_offset = 0; + id buf_dst = ggml_metal_get_buffer(tensor, &buf_dst_offset); + + if (buf_dst == nil) { + GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name); + } + + buf_dst_offset += offset; + + // queue the copy operation into the queue of the Metal context + // this will be queued at the end, after any currently ongoing GPU operations + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:buf_src + sourceOffset:0 + toBuffer:buf_dst + destinationOffset:buf_dst_offset + size:size]; + + [encoder endEncoding]; + [cmd_buf commit]; + + // do not wait here for completion + //[cmd_buf waitUntilCompleted]; + + // instead, remember a reference to the command buffer and wait for it later if needed + [ctx->cmd_bufs_ext addObject:cmd_buf]; + ctx->cmd_buf_last = cmd_buf; + + [cmd_buf retain]; + } +} + +static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + struct ggml_backend_metal_context * ctx = backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + @autoreleasepool { + id device = ctx_dev->mtl_device; + + id buf_dst = [device newBufferWithBytesNoCopy:data + length:size + options:MTLResourceStorageModeShared + deallocator:nil]; + + size_t buf_src_offset = 0; + id buf_src = ggml_metal_get_buffer(tensor, &buf_src_offset); + + if (buf_src == nil) { + GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name); + } + + buf_src_offset += offset; + + // queue the copy operation into the queue of the Metal context + // this will be queued at the end, after any currently ongoing GPU operations + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:buf_src + sourceOffset:buf_src_offset + toBuffer:buf_dst + destinationOffset:0 + size:size]; + + [encoder endEncoding]; + [cmd_buf commit]; + + // do not wait here for completion + //[cmd_buf waitUntilCompleted]; + + // instead, remember a reference to the command buffer and wait for it later if needed + [ctx->cmd_bufs_ext addObject:cmd_buf]; + ctx->cmd_buf_last = cmd_buf; + + [cmd_buf retain]; + } +} + +static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) { + return false; + + GGML_UNUSED(backend_src); + GGML_UNUSED(backend_dst); + GGML_UNUSED(src); + GGML_UNUSED(dst); +} + static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { return ggml_metal_graph_compute(backend, cgraph); } @@ -6452,7 +6562,10 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { const int n_nodes_per_cb = ctx->n_nodes_per_cb; - id cmd_buf = ctx->cmd_bufs[cb_idx].obj; + id cmd_buf = ctx->cmd_bufs[cb_idx].obj; + struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool; + + ggml_metal_mem_pool_reset(mem_pool); id encoder = [cmd_buf computeCommandEncoder]; @@ -6466,9 +6579,6 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { const bool should_capture = ctx->capture_next_compute; - struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool; - ggml_metal_mem_pool_reset(mem_pool); - for (int idx = node_start; idx < node_end;) { if (should_capture) { [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; @@ -6502,17 +6612,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { static struct ggml_backend_i ggml_backend_metal_i = { /* .get_name = */ ggml_backend_metal_name, /* .free = */ ggml_backend_metal_free, - /* .set_tensor_async = */ NULL, - /* .get_tensor_async = */ NULL, - /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, + /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async, + /* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups + /* .synchronize = */ ggml_backend_metal_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_metal_graph_compute, + + // the events API is needed only for multi-GPU setups, so likely no need to implement it for Metal + // in any case, these docs seem relevant if we ever decide to implement it: + // https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_metal_guid(void) { @@ -6613,7 +6728,7 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct g props->type = ggml_backend_metal_device_get_type(dev); ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); props->caps = (struct ggml_backend_dev_caps) { - /* .async = */ false, + /* .async = */ true, /* .host_buffer = */ false, /* .buffer_from_host_ptr = */ true, /* .events = */ false, @@ -6644,17 +6759,19 @@ static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, con } static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) { - return ggml_backend_metal_buffer_type(); + struct ggml_backend_metal_device_context * ctx_dev = dev->context; - GGML_UNUSED(dev); + return ctx_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared() : ggml_backend_metal_buffer_type_private(); } -static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { +static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); ctx->all_data = ptr; ctx->all_size = size; - ctx->owned = false; + + ctx->is_shared = true; + ctx->n_buffers = 0; const size_t size_page = sysconf(_SC_PAGESIZE); @@ -6677,6 +6794,9 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back id device = ctx_dev->mtl_device; + ctx->device = device; + ctx->queue = ctx_dev->mtl_queue; + // the buffer fits into the max buffer size allowed by the device if (size_aligned <= device.maxBufferLength) { ctx->buffers[ctx->n_buffers].data = ptr; @@ -6734,7 +6854,7 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back return NULL; } - return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size); + return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(), ggml_backend_metal_buffer_shared_i, ctx, size); } static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { @@ -6745,14 +6865,30 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { return - buft->iface.get_name == ggml_backend_metal_buffer_type_get_name || - buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name; + buft->iface.get_name == ggml_backend_metal_buffer_type_shared_get_name || + buft->iface.get_name == ggml_backend_metal_buffer_type_private_get_name || + buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name; GGML_UNUSED(dev); } +static int64_t get_op_batch_size(const struct ggml_tensor * op) { + switch (op->op) { + case GGML_OP_MUL_MAT: + return op->ne[1]; + case GGML_OP_MUL_MAT_ID: + return op->ne[2]; + default: + return ggml_nrows(op); + } +} + static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { - return false; + const int min_batch_size = 32; + + return (op->op == GGML_OP_MUL_MAT || + op->op == GGML_OP_MUL_MAT_ID) && + get_op_batch_size(op) >= min_batch_size; GGML_UNUSED(dev); GGML_UNUSED(op); @@ -6767,7 +6903,7 @@ static struct ggml_backend_device_i ggml_backend_metal_device_i = { /* .init_backend = */ ggml_backend_metal_device_init, /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type, /* .get_host_buffer_type = */ NULL, - /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_from_ptr, + /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_mapped, /* .supports_op = */ ggml_backend_metal_device_supports_op, /* .supports_buft = */ ggml_backend_metal_device_supports_buft, /* .offload_op = */ ggml_backend_metal_device_offload_op, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2d56c62674..157d0cc6d0 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -15,6 +15,10 @@ using namespace metal; #define MIN(x, y) ((x) < (y) ? (x) : (y)) #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } +#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1)) + +#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x) + #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf @@ -2755,7 +2759,47 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; } -template +template +static inline void helper_mv_reduce_and_write( + device float * dst_f32, + float sumf[NR0], + const int r0, + const int ne01, + ushort tiisg, + ushort sgitg, + threadgroup char * shmem) { + threadgroup float * shmem_f32[NR0]; + + for (short row = 0; row < NR0; ++row) { + shmem_f32[row] = (threadgroup float *) shmem + NW*row; + + if (sgitg == 0) { + shmem_f32[row][tiisg] = 0.0f; + } + + sumf[row] = simd_sum(sumf[row]); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short row = 0; row < NR0; ++row) { + if (tiisg == 0) { + shmem_f32[row][sgitg] = sumf[row]; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short row = 0; row < NR0 && r0 + row < ne01; ++row) { + float tot = simd_sum(shmem_f32[row][tiisg]); + + if (tiisg == 0 && sgitg == 0) { + dst_f32[r0 + row] = tot; + } + } +} + +template void mul_vec_q_n_f32_impl( args_t args, device const char * src0, @@ -2765,45 +2809,51 @@ void mul_vec_q_n_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + constexpr short NQ = 16; + const int nb = args.ne00/QK4_0; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * nsg + sgitg) * nr0; + const int r0 = (tgpig.x*NSG + sgitg)*NR0; + //const int r0 = tgpig.x*NR0; + const int r1 = tgpig.y; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q_type * ax[nr0]; - for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + device const block_q_type * ax[NR0]; + FOR_UNROLL (int row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } + float sumf[NR0] = {0.f}; + + const short ix = (tiisg/(NW/NQ)); + const short il = (tiisg%(NW/NQ))*8; + + //const int ib0 = sgitg*NQ + ix; + const int ib0 = ix; + float yl[16]; // src1 vector cache - float sumf[nr0] = {0.f}; - const short ix = (tiisg/2); - const short il = (tiisg%2)*8; - - device const float * yb = y + ix*QK4_0 + il; + //device const float * yb = y + ix*QK4_0 + il; + device const float * yb = y + ib0*QK4_0 + il; // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += nw/2) { + //for (int ib = ib0; ib < nb; ib += NSG*NQ) { + for (int ib = ib0; ib < nb; ib += NQ) { float sumy[2] = { 0.f, 0.f }; -#pragma unroll - for (short i = 0; i < 8; i += 2) { + FOR_UNROLL (short i = 0; i < 8; i += 2) { sumy[0] += yb[i + 0] + yb[i + 1]; yl[i + 0] = yb[i + 0]; yl[i + 1] = yb[i + 1]/256.f; @@ -2813,21 +2863,23 @@ void mul_vec_q_n_f32_impl( yl[i + 9] = yb[i + 17]/4096.f; } -#pragma unroll - for (short row = 0; row < nr0; row++) { + FOR_UNROLL (short row = 0; row < NR0; row++) { sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); } yb += QK4_0 * 16; + //yb += NSG*NQ*QK4_0; } device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; - for (int row = 0; row < nr0; ++row) { + //helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); + + for (int row = 0; row < NR0; ++row) { const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < args.ne01) { - dst_f32[first_row + row] = tot; + if (tiisg == 0 && r0 + row < args.ne01) { + dst_f32[r0 + row] = tot; } } } @@ -2837,10 +2889,11 @@ kernel void kernel_mul_mv_q4_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -2848,10 +2901,11 @@ kernel void kernel_mul_mv_q4_1_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -2859,10 +2913,11 @@ kernel void kernel_mul_mv_q5_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -2870,15 +2925,14 @@ kernel void kernel_mul_mv_q5_1_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -#define NB_Q8_0 8 - -template +template void kernel_mul_mv_q8_0_f32_impl( args_t args, device const char * src0, @@ -2888,66 +2942,65 @@ void kernel_mul_mv_q8_0_f32_impl( uint3 tgpig, ushort tiisg, ushort sgitg) { + constexpr short NQ = 8; + const int nb = args.ne00/QK8_0; - const int r0 = tgpig.x; + const int r0 = tgpig.x*NR0; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; - const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows - device const block_q8_0 * ax[nr0]; - for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + device const block_q8_0 * ax[NR0]; + FOR_UNROLL (short row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } - float yl[NB_Q8_0]; - float sumf[nr0] = { 0.f }; + float sumf[NR0] = { 0.f }; - const short ix = tiisg/4; - const short il = tiisg%4; + const short ix = tiisg/(NW/NQ); + const short il = tiisg%(NW/NQ); - device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; + const int ib0 = sgitg*NQ + ix; - // each thread in a SIMD group deals with NB_Q8_0 quants at a time - for (int ib = ix; ib < nb; ib += nw/4) { - for (short i = 0; i < NB_Q8_0; ++i) { + float yl[NQ]; + + device const float * yb = y + ib0*QK8_0 + il*NQ; + + // each thread in a SIMD group deals with NQ quants at a time + for (int ib = ib0; ib < nb; ib += NSG*NQ) { + for (short i = 0; i < NQ; ++i) { yl[i] = yb[i]; } - for (short row = 0; row < nr0; row++) { - device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; + for (short row = 0; row < NR0; row++) { + device const int8_t * qs = ax[row][ib].qs + il*NQ; + float sumq = 0.f; - for (short iq = 0; iq < NB_Q8_0; ++iq) { - sumq += qs[iq] * yl[iq]; + FOR_UNROLL (short i = 0; i < NQ; ++i) { + sumq += qs[i] * yl[i]; } + sumf[row] += sumq*ax[row][ib].d; } - yb += nw*NB_Q8_0; + yb += NSG*NQ*QK8_0; } device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0; ++row) { - const float tot = simd_sum(sumf[row]); - - if (tiisg == 0 && first_row + row < args.ne01) { - dst_f32[first_row + row] = tot; - } - } + helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); } [[host_name("kernel_mul_mv_q8_0_f32")]] @@ -2956,10 +3009,11 @@ kernel void kernel_mul_mv_q8_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } // mat-vec kernel processing in chunks of float4 @@ -4197,6 +4251,19 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope; } +constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]]; +constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]]; +constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]]; +constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]]; + +//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; +//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]]; +//constant float FC_flash_attn_ext_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT + 12)]]; + +constant int32_t FC_flash_attn_ext_ns10 [[function_constant(FC_FLASH_ATTN_EXT + 20)]]; +constant int32_t FC_flash_attn_ext_ns20 [[function_constant(FC_FLASH_ATTN_EXT + 21)]]; +constant int32_t FC_flash_attn_ext_nsg [[function_constant(FC_FLASH_ATTN_EXT + 22)]]; + // ref: https://arxiv.org/pdf/2307.08691.pdf template< typename q_t, // query types in shared memory @@ -4211,6 +4278,7 @@ template< typename qk_t, // Q*K types typename qk8x8_t, typename s_t, // soft-max types + typename s2_t, typename s8x8_t, typename o_t, // attention accumulation types typename o4_t, @@ -4221,12 +4289,12 @@ template< typename vd4x4_t, // value type in device memory short nl_v, void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), - short DK, // K head size - short DV, // V head size - short Q = 8, // queries per threadgroup - short KV = 8, // key/value processed per each simdgroup - short C = 32> // cache items per threadgroup -kernel void kernel_flash_attn_ext( + short DK, // K head size + short DV, // V head size + short Q, // queries per threadgroup + short C, // cache items per threadgroup + short NSG> // number of simd groups +void kernel_flash_attn_ext_impl( constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, device const char * k, @@ -4234,46 +4302,85 @@ kernel void kernel_flash_attn_ext( device const char * mask, device const char * sinks, device char * dst, - threadgroup half * shmem_f16 [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short nsg = ntg.y; // number of simdgroups + threadgroup half * shmem_f16, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const ushort iq3 = tgpig[2]; + const ushort iq2 = tgpig[1]; + const ushort iq1 = tgpig[0]*Q; - const int iq3 = tgpig[2]; - const int iq2 = tgpig[1]; - const int iq1 = tgpig[0]*Q; +#define NS10 (FC_flash_attn_ext_ns10) +#define NS20 (FC_flash_attn_ext_ns20) + + // note: I had some concerns that using this instead of the ugly macros above was affecting performance + // need to re-check carefully and if no regressions are observerd - remove the macros + // the concerns is that maybe using const variables requires extra registers? but not sure if the compiler + // is clever enough to avoid this. unfortunately, using constexpr is not possible with FC + //const short NS10 = FC_flash_attn_ext_ns10; + //const short NS20 = FC_flash_attn_ext_ns20; + + constexpr short KV = 8; constexpr short DK4 = DK/4; constexpr short DK8 = DK/8; constexpr short DK16 = DK/16; constexpr short DV4 = DV/4; - constexpr short DV8 = DV/8; + //constexpr short DV8 = DV/8; constexpr short DV16 = DV/16; + constexpr short PV = PAD2(DV, 64); + constexpr short PV4 = PV/4; + constexpr short PV8 = PV/8; + //constexpr short PV16 = PV/16; + constexpr short NW = N_SIMDWIDTH; - constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) + constexpr short NQ = Q/NSG; + constexpr short SH = 2*C; // shared memory per simdgroup (s_t == float) - const short TS = nsg*SH; // shared memory size per query in (s_t == float) - const short T = 2*DK + 2*TS; // shared memory size per query in (half) + constexpr short TS = 2*SH; + constexpr short T = DK + 2*PV; // shared memory size per query in (half) - threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*T); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*T); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*T + Q*DK); // the result for all queries in 8x8 matrices (the O matrix from the paper) + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*T + Q*DK); + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + Q*T); // scratch buffer for attention, mask and diagonal matrix + threadgroup s2_t * ss2 = (threadgroup s2_t *) (shmem_f16 + Q*T); // same as above but in s2_t - threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory - threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t + threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load K in shared memory + threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in k4x4_t - threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory - threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t + threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load V in shared memory + threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in v4x4_t - // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - o8x8_t lo[DV8]; + // mask storage in shared mem + threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C); + + // per-query mask pointers + device const half2 * pm2[NQ]; + + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); + } + + { + q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += ikv2*args.nb12 + ikv3*args.nb13; + v += ikv2*args.nb22 + ikv3*args.nb23; + } // load heads from Q to shared memory - for (short j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + device const float4 * q4 = (device const float4 *) ((device const char *) q + j*args.nb01); for (short i = tiisg; i < DK4; i += NW) { if (iq1 + j < args.ne01) { @@ -4284,43 +4391,30 @@ kernel void kernel_flash_attn_ext( } } - // zero out lo - for (short i = 0; i < DV8; ++i) { - lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f); - } + // zero out + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] = 0; + } - // zero out shared memory SH - for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < SH; i += NW) { - ss[j*TS + i] = 0.0f; + ss[j*SH + i] = 0.0f; } } threadgroup_barrier(mem_flags::mem_threadgroup); + float S[NQ] = { [0 ... NQ-1] = 0.0f }; + { - float S[Q] = { [0 ... Q-1] = 0.0f }; - float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 }; - - // thread indices inside the simdgroup - // TODO: see if we can utilize quad-group functions for better performance - // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3) - const short tx = tiisg%4; - const short ty = tiisg/4; - - // broadcast kv - //const short rk2 = args.ne02/args.ne12; - //const short rk3 = args.ne03/args.ne13; - - const short ikv2 = iq2/(args.ne02/args.ne_12_2); - const short ikv3 = iq3/(args.ne03/args.ne_12_3); - - const bool has_mask = mask != q; + float M[NQ] = { [0 ... NQ-1] = -FLT_MAX/2 }; float slope = 1.0f; // ALiBi - if (args.max_bias > 0.0f) { + if (FC_flash_attn_ext_has_bias) { const short h = iq2; const float base = h < args.n_head_log2 ? args.m0 : args.m1; @@ -4331,177 +4425,277 @@ kernel void kernel_flash_attn_ext( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { - const int ic = ic0 + C*sgitg; - if (ic >= args.ne11) { - break; - } + for (int ic = 0; ic < args.ne11; ic += C) { + // read the mask into shared mem + if (FC_flash_attn_ext_has_mask) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; - if (has_mask) { - // used to detect blocks full of -INF - float smax = -INFINITY; - - // load the mask in shared memory - #pragma unroll(Q) - for (short j = 0; j < Q; ++j) { - device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); - - const float m = pm[ic + tiisg]; - - ss[j*TS + C + tiisg] = m; - smax = max(smax, m); + sm2[j*SH + tiisg] = pm2[jj][tiisg]; + pm2[jj] += NW; } - smax = simd_max(smax); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // used to detect blocks full of -INF + // skip only when the entire threadgroup is masked + half2 smax2(-MAXHALF/2, -MAXHALF/2); + + FOR_UNROLL (short j = 0; j < Q; ++j) { + smax2 = max(smax2, sm2[j*SH + tiisg]); + } + + smax2 = simd_max(smax2); + + if (max(smax2[0], smax2[1]) <= -MAXHALF/2) { + // this barrier is important + threadgroup_barrier(mem_flags::mem_threadgroup); - if (smax == -INFINITY) { continue; } } // Q*K^T - { - for (short cc = 0; cc < C/8; ++cc) { + // this is compile-time check, so it does not have runtime overhead + if (is_same::value) { + // we can read directly from global memory + device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11); + threadgroup const q_t * pq = sq; + threadgroup s_t * ps = ss; + + pk += sgitg*(8*NS10); + ps += sgitg*(8*1); + + static_assert((C/8) % NSG == 0, ""); + + constexpr short NC = (C/8)/NSG; + + // TODO: not good to unroll for large contexts - not sure why? + for (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); - // this is compile-time check, so it does not have runtime overhead - if (is_same::value) { - // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + if (DK8 % 16 != 0) { + k8x8_t mk; + q8x8_t mq; - #pragma unroll(DK8) - for (short i = 0; i < DK8; ++i) { - k8x8_t mk; - simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10 + FOR_UNROLL (short i = 0; i < DK8; ++i) { + simdgroup_barrier(mem_flags::mem_none); + + simdgroup_load(mk, pk, NS10, 0, true); + simdgroup_load(mq, pq, DK); + + simdgroup_barrier(mem_flags::mem_none); - q8x8_t mq; - simdgroup_load(mq, sq + i*8, DK); simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + pk += 8; + pq += 8; } } else { - for (short ii = 0; ii < DK16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + k8x8_t mk[2]; + q8x8_t mq[2]; - if (DK16%4 == 0) { - // the head is evenly divisible by 4*16 = 64, so no need for bound checks - { - k4x4_t tmp; - deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); - sk4x4[4*ty + tx] = tmp; - } + FOR_UNROLL (short i = 0; i < DK8/2; ++i) { + simdgroup_barrier(mem_flags::mem_none); - simdgroup_barrier(mem_flags::mem_threadgroup); + simdgroup_load(mk[0], pk + 0*8, NS10, 0, true); + simdgroup_load(mk[1], pk + 1*8, NS10, 0, true); - #pragma unroll(4) - for (short k = 0; k < 4; ++k) { - k8x8_t mk; - q8x8_t mq; + simdgroup_load(mq[0], pq + 0*8, DK); + simdgroup_load(mq[1], pq + 1*8, DK); - simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - } - } else { - if (ii + tx < DK16) { - k4x4_t tmp; - deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); - sk4x4[4*ty + tx] = tmp; - } + simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk); + simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk); - simdgroup_barrier(mem_flags::mem_threadgroup); + pk += 16; + pq += 16; + } + } - for (short k = 0; k < 4 && ii + k < DK16; ++k) { - k8x8_t mk; - q8x8_t mq; + simdgroup_store(mqk, ps, SH, 0, false); - simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + pk += 8*(NSG*NS10 - DK8); + pq += 8*(NSG*0 - DK8); + ps += 8*(NSG); + } + } else { + // TODO: this is the quantized K cache branch - not optimized yet + for (short ccc = 0; ccc < (C/8)/NSG; ++ccc) { + const short cc = ccc*NSG + sgitg; - simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose - simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - } + const short tx = tiisg%4; + const short ty = tiisg/4; + + qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); + + for (short ii = 0; ii < DK16; ii += 4) { + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11)); + + if (DK16%4 == 0) { + // the head is evenly divisible by 4*16 = 64, so no need for bound checks + { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short k = 0; k < 4; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } else { + if (ii + tx < DK16) { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (short k = 0; k < 4 && ii + k < DK16; ++k) { + k8x8_t mk; + q8x8_t mq; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); } } } - // cast qk_t -> s_t - //s8x8_t mqks(1.0f); - //simdgroup_multiply(mqks, mqk, mqks); - //simdgroup_store(mqks, ss + 8*cc, TS, 0, false); - - simdgroup_store(mqk, ss + 8*cc, TS, 0, false); + simdgroup_store(mqk, ss + 8*cc, SH, 0, false); } } + threadgroup_barrier(mem_flags::mem_threadgroup); + // online softmax - { - for (ushort j = 0; j < Q; ++j) { - const float m = M[j]; + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; - // scale and apply the logitcap / mask - float s = ss[j*TS + tiisg]*args.scale; + const float m = M[jj]; - if (args.logit_softcap != 0.0f) { - s = args.logit_softcap*precise::tanh(s); + // scale and apply the logitcap / mask + float2 s2 = ss2[j*SH/2 + tiisg]*args.scale; + + if (FC_flash_attn_ext_has_scap) { + s2 = args.logit_softcap*precise::tanh(s2); + } + + // mqk = mqk + slope*mask + if (FC_flash_attn_ext_has_bias) { + s2 += s2_t(sm2[j*SH + tiisg])*slope; + } else { + s2 += s2_t(sm2[j*SH + tiisg]); + } + + M[jj] = simd_max(max(M[jj], max(s2[0], s2[1]))); + + const float ms = exp(m - M[jj]); + const float2 vs2 = exp(s2 - M[jj]); + + S[jj] = S[jj]*ms + simd_sum(vs2[0] + vs2[1]); + + // the P matrix from the paper (Q rows, C columns) + ss2[j*SH/2 + tiisg] = vs2; + + if (DV4 % NW == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { + const short i = ii*NW + tiisg; + + so4[j*PV4 + i] *= ms; } - - // mqk = mqk + mask*slope - s += slope*ss[j*TS + C + tiisg]; - - M[j] = simd_max(max(M[j], s)); - - const float ms = exp(m - M[j]); - const float vs = exp(s - M[j]); - - S[j] = S[j]*ms + simd_sum(vs); - - // the P matrix from the paper (Q rows, C columns) - ss[j*TS + tiisg] = vs; - - // create a QxQ diagonal matrix for rescaling the output - if (tiisg == j) { - ss[j*TS + 2*C + j] = ms; + } else { + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] *= ms; } } } - // O = diag(ms)*O - { - s8x8_t ms; - simdgroup_load(ms, ss + 2*C, TS, 0, false); - - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_multiply(lo[i], ms, lo[i]); - } - } + threadgroup_barrier(mem_flags::mem_threadgroup); // O = O + (Q*K^T)*V { - for (short cc = 0; cc < C/8; ++cc) { - s8x8_t vs; - simdgroup_load(vs, ss + 8*cc, TS, 0, false); + // we can read directly from global memory + if (is_same::value) { + static_assert(PV8 % NSG == 0, ""); - if (is_same::value) { - // we can read directly from global memory - device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + constexpr short NO = PV8/NSG; - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - v8x8_t mv; - simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20 + o8x8_t lo[NO]; - simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]); + { + auto sot = so + 8*sgitg; + + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + simdgroup_load(lo[ii], sot, PV, 0, false); + + sot += 8*NSG; } - } else { - for (short ii = 0; ii < DV16; ii += 4) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + } + + { + auto sst = ss; + + device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21); + + pv += 8*sgitg; + + FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, sst, SH, 0, false); + + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + v8x8_t mv; + + simdgroup_load(mv, pv, NS20, 0, false); + simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]); + + pv += 8*NSG; + } + + pv += 8*(NS20 - NO*NSG); + sst += 8; + } + } + + { + auto sot = so + 8*sgitg; + + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + simdgroup_store(lo[ii], sot, PV, 0, false); + + sot += 8*NSG; + } + } + } else { + // TODO: this is the quantized V cache branch - not optimized yet + + const short tx = tiisg%4; + const short ty = tiisg/4; + + for (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, ss + 8*cc, SH, 0, false); + + for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) { + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21)); if (DV16%4 == 0) { // no need for bound checks @@ -4513,15 +4707,20 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); - #pragma unroll(4) - for (short k = 0; k < 4; ++k) { - v8x8_t mv; + FOR_UNROLL (short k = 0; k < 4; ++k) { + v8x8_t mv[2]; + o8x8_t lo[2]; - simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); + simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); - simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); + simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]); + simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]); + + simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); } } else { if (ii + tx < DV16) { @@ -4533,243 +4732,249 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); for (short k = 0; k < 4 && ii + k < DV16; ++k) { - v8x8_t mv; + v8x8_t mv[2]; + o8x8_t lo[2]; - simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); + simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); - simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); + simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]); + simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]); + + simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); + simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); } } } } } } + + threadgroup_barrier(mem_flags::mem_threadgroup); } - if (sinks != q && sgitg == 0) { - for (ushort j = 0; j < Q; ++j) { - const float m = M[j]; + if (FC_flash_attn_ext_has_sinks) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + const float m = M[jj]; const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; - M[j] = simd_max(max(M[j], s)); + M[jj] = simd_max(max(M[jj], s)); - const float ms = exp(m - M[j]); - const float vs = exp(s - M[j]); + const float ms = exp(m - M[jj]); + const float vs = exp(s - M[jj]); - S[j] = S[j]*ms + simd_sum(vs); + S[jj] = S[jj]*ms + simd_sum(vs); - if (tiisg == j) { - ss[j*TS + 2*C + j] = ms; - } - } - - // O = diag(ms)*O - { - s8x8_t ms; - simdgroup_load(ms, ss + 2*C, TS, 0, false); - - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_multiply(lo[i], ms, lo[i]); + for (short i = tiisg; i < DV4; i += NW) { + so4[j*PV4 + i] *= ms; } } } - - // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (short j = tiisg; j < Q; j += NW) { - ss[j*TS + 0] = S[j]; - ss[j*TS + 1] = M[j]; - } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation - threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK); - - // store result to shared memory in F32 - if (sgitg == 0) { - for (short i = 0; i < DV8; ++i) { - //simdgroup_store(lo[i], so + i*8, DV, 0, false); - simdgroup_float8x8 t(1.0f); - simdgroup_multiply(t, lo[i], t); - simdgroup_store(t, so + i*8, DV, 0, false); + // store to global memory + for (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + if (iq1 + j >= args.ne01) { + break; } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // reduce the warps sequentially - for (ushort sg = 1; sg < nsg; ++sg) { - if (sgitg == sg) { - for (short j = tiisg; j < Q; j += NW) { - const float S0 = ss[j*TS - 1*SH + 0]; - const float S1 = ss[j*TS + 0]; - - const float M0 = ss[j*TS - 1*SH + 1]; - const float M1 = ss[j*TS + 1]; - - const float M = max(M0, M1); - - float ms0 = exp(M0 - M); - float ms1 = exp(M1 - M); - - const float S = S0*ms0 + S1*ms1; - - ss[j*TS + 0] = S; - ss[j*TS + 1] = M; - - ss[j*TS + 2*C + j - 1*SH] = ms0; - ss[j*TS + 2*C + j ] = ms1; - } - - //simdgroup_barrier(mem_flags::mem_threadgroup); - - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - { - s8x8_t ms0; - s8x8_t ms1; - - simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false); - simdgroup_load(ms1, ss + 2*C, TS, 0, false); - - #pragma unroll(DV8) - for (short i = 0; i < DV8; ++i) { - simdgroup_float8x8 t; - - simdgroup_load (t, so + i*8, DV, 0, false); - simdgroup_multiply(t, ms0, t); - - simdgroup_multiply_accumulate(t, ms1, lo[i], t); - simdgroup_store(t, so + i*8, DV, 0, false); - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK); - - // final rescale with 1/S and store to global memory - for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) { - const float S = 1.0f/sf[j*TS + 0]; device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; - for (short i = tiisg; i < DV4; i += NW) { - dst4[i] = (float4) so4[j*DV4 + i]*S; + const float scale = 1.0f/S[jj]; + + if (DV4 % NW == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { + const short i = ii*NW + tiisg; + + dst4[i] = (float4) so4[j*PV4 + i]*scale; + } + } else { + for (short i = tiisg; i < DV4; i += NW) { + dst4[i] = (float4) so4[j*PV4 + i]*scale; + } } } + +#undef NS10 +#undef NS20 +} + +template< + typename q_t, // query types in shared memory + typename q4_t, + typename q8x8_t, + typename k_t, // key types in shared memory + typename k4x4_t, + typename k8x8_t, + typename v_t, // value types in shared memory + typename v4x4_t, + typename v8x8_t, + typename qk_t, // Q*K types + typename qk8x8_t, + typename s_t, // soft-max types + typename s2_t, + typename s8x8_t, + typename o_t, // attention accumulation types + typename o4_t, + typename o8x8_t, + typename kd4x4_t, // key type in device memory + short nl_k, + void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), + typename vd4x4_t, // value type in device memory + short nl_v, + void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), + short DK, // K head size + short DV, // V head size + short Q = 8, // queries per threadgroup + short C = 64> // cache items per threadgroup +kernel void kernel_flash_attn_ext( + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device const char * sinks, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C +#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg + switch (FC_flash_attn_ext_nsg) { + // note: disabled cases to reduce library load time + //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; + //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break; + case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; + } +#undef FWD_TMPL +#undef FWD_ARGS } // TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as // template to be able to explore different combinations // #define FA_TYPES \ - float, float4, simdgroup_float8x8, \ + half, half4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ float, simdgroup_float8x8, \ - float, simdgroup_float8x8, \ - half, half4, simdgroup_half8x8 - //float, float4, simdgroup_float8x8 + float, float2, simdgroup_float8x8, \ + float, float4, simdgroup_float8x8 + //half, half4, simdgroup_half8x8 #define FA_TYPES_BF \ bfloat, bfloat4, simdgroup_bfloat8x8, \ bfloat, bfloat4x4, simdgroup_bfloat8x8, \ bfloat, bfloat4x4, simdgroup_bfloat8x8, \ float, simdgroup_float8x8, \ - float, simdgroup_float8x8, \ + float, float2, simdgroup_float8x8, \ half, half4, simdgroup_half8x8 //float, float4, simdgroup_float8x8 typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; -template [[host_name("kernel_flash_attn_ext_f16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_bf16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif -template [[host_name("kernel_flash_attn_ext_q4_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES #undef FA_TYPES_BF +constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]]; +constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]]; +constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]]; +constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]]; + +//constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]]; +//constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]]; +//constant float FC_flash_attn_ext_vec_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 12)]]; + +constant int32_t FC_flash_attn_ext_vec_ns10 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 20)]]; +constant int32_t FC_flash_attn_ext_vec_ns20 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 21)]]; +constant int32_t FC_flash_attn_ext_vec_nsg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 22)]]; +constant int32_t FC_flash_attn_ext_vec_nwg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 23)]]; + template< typename q4_t, // query types in shared memory typename k4_t, // key types in shared memory @@ -4788,63 +4993,86 @@ template< short DV, // V head size short NE = 4, // head elements per thread short Q = 1, // queries per threadgroup - short C = 32> // cache items per threadgroup -kernel void kernel_flash_attn_ext_vec( - constant ggml_metal_kargs_flash_attn_ext & args, + short C = 32, // cache items per threadgroup + short NSG> // number of simd groups +void kernel_flash_attn_ext_vec_impl( + constant ggml_metal_kargs_flash_attn_ext_vec & args, device const char * q, device const char * k, device const char * v, device const char * mask, device const char * sinks, device char * dst, - constant uint16_t & nwg, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { static_assert(DK % 32 == 0, "DK must be divisible by 32"); static_assert(DV % 32 == 0, "DV must be divisible by 32"); - const short nsg = ntg.y; // number of simdgroups - const short iwg = tgpig[2]%nwg; +#define NWG (FC_flash_attn_ext_vec_nwg) - const int iq3 = tgpig[2]/nwg; - const int iq2 = tgpig[1]; - const int iq1 = tgpig[0]; +#define NS10 (FC_flash_attn_ext_vec_ns10) +#define NS20 (FC_flash_attn_ext_vec_ns20) + + const short iwg = tgpig[2]%NWG; + + const ushort iq3 = tgpig[2]/NWG; + const ushort iq2 = tgpig[1]; + const ushort iq1 = tgpig[0]; constexpr short DK4 = DK/4; constexpr short DV4 = DV/4; + + constexpr short PK = PAD2(DK, 128); + constexpr short PK4 = PK/4; + + constexpr short PV = PAD2(DV, 128); + constexpr short PV4 = PV/4; + constexpr short NW = N_SIMDWIDTH; constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads constexpr short SH = 4*C; // shared memory per simdgroup - const short T = DK + nsg*SH; // shared memory size per query in (half) + static_assert(DK4 % NL == 0, "DK4 must be divisible by NL"); + static_assert(DV4 % NL == 0, "DV4 must be divisible by NL"); - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t - threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask - threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results + const short T = PK + NSG*SH; // shared memory size per query in (half) - // store the result for all queries in local memory (the O matrix from the paper) - o4_t lo[DV4/NL]; + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results + + // store the result for all queries in shared memory (the O matrix from the paper) + so4 += tiisg; + + { + q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += ikv2*args.nb12 + ikv3*args.nb13; + v += ikv2*args.nb22 + ikv3*args.nb23; + } // load heads from Q to shared memory - device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q); - for (short i = tiisg; i < DK4; i += NW) { - if (iq1 < args.ne01) { + for (short i = tiisg; i < PK4; i += NW) { + if (iq1 < args.ne01 && i < DK4) { sq4[i] = (q4_t) q4[i]; } else { sq4[i] = (q4_t) 0.0f; } } - // zero out lo + // zero out so for (short i = 0; i < DV4/NL; ++i) { - lo[i] = (o4_t) 0.0f; + so4[i*NL] = (o4_t) 0.0f; } // zero out shared memory SH @@ -4856,28 +5084,19 @@ kernel void kernel_flash_attn_ext_vec( { float S = 0.0f; - float M = -__FLT_MAX__/2; + float M = -FLT_MAX/2; // thread indices inside the simdgroup const short tx = tiisg%NL; const short ty = tiisg/NL; - // broadcast kv - //const short rk2 = args.ne02/args.ne12; - //const short rk3 = args.ne03/args.ne13; - - const short ikv2 = iq2/(args.ne02/args.ne_12_2); - const short ikv3 = iq3/(args.ne03/args.ne_12_3); - - const bool has_mask = mask != q; - // pointer to the mask device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); float slope = 1.0f; // ALiBi - if (args.max_bias > 0.0f) { + if (FC_flash_attn_ext_vec_has_bias) { const short h = iq2; const float base = h < args.n_head_log2 ? args.m0 : args.m1; @@ -4888,13 +5107,13 @@ kernel void kernel_flash_attn_ext_vec( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = (int) iwg*C*nsg; ic0 < args.ne11; ic0 += (int) nwg*C*nsg) { + for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) { const int ic = ic0 + C*sgitg; if (ic >= args.ne11) { break; } - if (has_mask) { + if (FC_flash_attn_ext_vec_has_mask) { sm[tiisg] = pm[ic + tiisg]; } @@ -4905,70 +5124,82 @@ kernel void kernel_flash_attn_ext_vec( // Q*K^T { - // each simdgroup processes 1 query and NE (NW/NL) head elements - for (short cc = 0; cc < C/NE; ++cc) { - qk_t mqk = 0.0f; + device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11); + threadgroup const q4_t * pq4 = sq4; - device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13)); + pk4 += ty*NS10/4 + tx; + pq4 += tx; - #pragma unroll(DK4/NL) - for (short ii = 0; ii < DK4; ii += NL) { - const short i = ii + tx; + qk_t mqk[C/NE] = { [ 0 ... C/NE - 1] = 0.0f }; + + // each simdgroup processes 1 query and NE (NW/NL) cache elements + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + if (is_same::value) { + FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { + mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]); + } + } else { + device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11)); k4_t mk; - deq_k_t4(pk + i/nl_k, i%nl_k, mk); - // note: this is less precise than the version below - //mqka[0] += dot(mq[0], mk[0]); - //mqka[1] += dot(mq[1], mk[1]); - //mqka[2] += dot(mq[2], mk[2]); - //mqka[3] += dot(mq[3], mk[3]); + FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { + const short i = ii*NL + tx; - //q4x4_t mq = sq4x4[i]; - //mqka[0] += dot((float4) mq[0], (float4) mk[0]); - //mqka[1] += dot((float4) mq[1], (float4) mk[1]); - //mqka[2] += dot((float4) mq[2], (float4) mk[2]); - //mqka[3] += dot((float4) mq[3], (float4) mk[3]); + deq_k_t4(pk + i/nl_k, i%nl_k, mk); - mqk += dot((float4) mk, (float4) sq4[i]); + mqk[cc] += dot((float4) mk, (float4) sq4[i]); + } } - static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails - - // simdgroup reduce (NE = 4) - // [ 0 .. 7] -> [ 0] - // [ 8 .. 15] -> [ 8] - // [16 .. 23] -> [16] - // [24 .. 31] -> [24] - if (NE <= 1) { - mqk += simd_shuffle_down(mqk, 16); - } - if (NE <= 2) { - mqk += simd_shuffle_down(mqk, 8); - } - if (NE <= 4) { - mqk += simd_shuffle_down(mqk, 4); - } - if (NE <= 8) { - mqk += simd_shuffle_down(mqk, 2); - } - if (NE <= 16) { - mqk += simd_shuffle_down(mqk, 1); - } - - // mqk = mqk*scale + mask*slope - if (tx == 0) { - mqk *= args.scale; - - if (args.logit_softcap != 0.0f) { - mqk = args.logit_softcap*precise::tanh(mqk); + if (NE == 1) { + mqk[cc] = simd_sum(mqk[cc]); + } else { + // simdgroup reduce (NE = 4) + // [ 0 .. 7] -> [ 0] + // [ 8 .. 15] -> [ 8] + // [16 .. 23] -> [16] + // [24 .. 31] -> [24] + if (NE <= 1) { + mqk[cc] += simd_shuffle_down(mqk[cc], 16); + } + if (NE <= 2) { + mqk[cc] += simd_shuffle_down(mqk[cc], 8); + } + if (NE <= 4) { + mqk[cc] += simd_shuffle_down(mqk[cc], 4); + } + if (NE <= 8) { + mqk[cc] += simd_shuffle_down(mqk[cc], 2); + } + if (NE <= 16) { + mqk[cc] += simd_shuffle_down(mqk[cc], 1); } - mqk += sm[NE*cc + ty]*slope; - - ss[NE*cc + ty] = mqk; + // broadcast + mqk[cc] = simd_shuffle(mqk[cc], NL*ty); } } + + if (FC_flash_attn_ext_vec_has_mask && + !FC_flash_attn_ext_vec_has_scap && + !FC_flash_attn_ext_vec_has_bias) { + ss[NE*tx + ty] = fma(mqk[tx], args.scale, (qk_t) sm[NE*tx + ty]); + } else { + mqk[tx] *= args.scale; + + if (FC_flash_attn_ext_vec_has_scap) { + mqk[tx] = args.logit_softcap*precise::tanh(mqk[tx]); + } + + if (FC_flash_attn_ext_vec_has_bias) { + mqk[tx] += (qk_t) sm[NE*tx + ty]*slope; + } else { + mqk[tx] += (qk_t) sm[NE*tx + ty]; + } + + ss[NE*tx + ty] = mqk[tx]; + } } simdgroup_barrier(mem_flags::mem_threadgroup); @@ -4989,9 +5220,10 @@ kernel void kernel_flash_attn_ext_vec( ss[tiisg] = vs; // O = diag(ms)*O - #pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - lo[ii/NL] *= ms; + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] *= ms; + } } } @@ -4999,26 +5231,84 @@ kernel void kernel_flash_attn_ext_vec( // O = O + (Q*K^T)*V { - //#pragma unroll(C/NE) - for (short cc = 0; cc < C/NE; ++cc) { - device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23)); + o4_t lo[DV4/NL]; + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + lo[ii] = 0.0f; + } - const s4_t ms(ss[NE*cc + ty]); + if (is_same::value) { + device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21); - #pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - const short i = ii + tx; + pv4 += ty*NS20/4 + tx; - v4_t mv; - deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); + const auto sst = ss + ty; - lo[ii/NL] += o4_t(float4(mv)*float4(ms)); + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + lo[ii] += o4_t(float4(pv4[cc*NE*NS20/4 + ii*NL])*float4(sst[cc*NE])); + } + } + } else { + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21)); + + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + const short i = ii*NL + tx; + + v4_t mv; + deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); + + lo[ii] += o4_t(float4(mv)*float4(ss[NE*cc + ty])); + } + } + } + + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + if (NE > 1) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 16); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 16); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 16); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 16); + } + + if (NE > 2) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 8); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 8); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 8); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 8); + } + + if (NE > 4) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 4); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 4); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 4); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 4); + } + + if (NE > 8) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 2); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 2); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 2); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 2); + } + + if (NE > 16) { + lo[ii][0] += simd_shuffle_down(lo[ii][0], 1); + lo[ii][1] += simd_shuffle_down(lo[ii][1], 1); + lo[ii][2] += simd_shuffle_down(lo[ii][2], 1); + lo[ii][3] += simd_shuffle_down(lo[ii][3], 1); + } + } + + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] += lo[ii]; } } } } - if (sinks != q && sgitg == 0 && iwg == 0) { + if (FC_flash_attn_ext_vec_has_sinks && sgitg == 0 && iwg == 0) { const float m = M; const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; @@ -5029,9 +5319,10 @@ kernel void kernel_flash_attn_ext_vec( S = S*ms + simd_sum(vs); -#pragma unroll(DV4/NL) - for (short ii = 0; ii < DV4; ii += NL) { - lo[ii/NL] *= ms; + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[ii*NL] *= ms; + } } } @@ -5042,63 +5333,12 @@ kernel void kernel_flash_attn_ext_vec( } } - // simdgroup reduce (NE = 4) - // [ 0, 8, 16, 24] -> [ 0] - // [ 1, 9, 17, 25] -> [ 1] - // [ 2, 10, 18, 26] -> [ 2] - // [ 3, 11, 19, 27] -> [ 3] - // [ 4, 12, 20, 28] -> [ 4] - // [ 5, 13, 21, 29] -> [ 5] - // [ 6, 14, 22, 30] -> [ 6] - // [ 7, 15, 23, 31] -> [ 7] - for (short ii = 0; ii < DV4; ii += NL) { - if (NE > 1) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); - } - - if (NE > 2) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8); - } - - if (NE > 4) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4); - } - - if (NE > 8) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2); - } - - if (NE > 16) { - lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1); - lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1); - lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1); - lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1); - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // store results to shared memory - for (short i = tiisg; i < DV4; i += NL) { - sr4[i] = lo[i/NL]; - } + so4 -= tiisg; threadgroup_barrier(mem_flags::mem_threadgroup); // parallel reduce - for (short r = nsg/2; r > 0; r >>= 1) { + for (short r = NSG/2; r > 0; r >>= 1) { if (sgitg < r) { const float S0 = ss[ 0]; const float S1 = ss[r*(SH/2) + 0]; @@ -5120,7 +5360,7 @@ kernel void kernel_flash_attn_ext_vec( // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 for (short i = tiisg; i < DV4; i += NW) { - sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1; + so4[i] = so4[i]*ms0 + so4[i + r*PV4]*ms1; } } @@ -5133,21 +5373,73 @@ kernel void kernel_flash_attn_ext_vec( const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1; device float4 * dst4 = (device float4 *) dst; - device float * dst1 = (device float *) dst + nrows*DV*nwg; // the S and M are stored after the results + device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results - const float S = nwg == 1 ? 1.0f/ss[0] : 1.0f; + const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f; // interleave the workgroup data for (short i = tiisg; i < DV4; i += NW) { - dst4[rid*DV4*nwg + nwg*i + iwg] = (float4) sr4[i]*S; + dst4[rid*DV4*NWG + NWG*i + iwg] = (float4) so4[i]*S; } // store S and M - if (nwg > 1 && tiisg == 0) { - dst1[rid*(2*nwg) + 2*iwg + 0] = ss[0]; - dst1[rid*(2*nwg) + 2*iwg + 1] = ss[1]; + if (NWG > 1) { + if (tiisg == 0) { + dst1[rid*(2*NWG) + 2*iwg + 0] = ss[0]; + dst1[rid*(2*NWG) + 2*iwg + 1] = ss[1]; + } } } + +#undef NWG +#undef NS10 +#undef NS20 +} + +template< + typename q4_t, // query types in shared memory + typename k4_t, // key types in shared memory + typename v4_t, // value types in shared memory + typename qk_t, // Q*K types + typename s_t, // soft-max types + typename s4_t, + typename o4_t, // attention accumulation types + typename kd4_t, // key type in device memory + short nl_k, + void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), + typename vd4_t, // value type in device memory + short nl_v, + void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), + short DK, // K head size + short DV, // V head size + short NE = 4, // head elements per thread + short Q = 1, // queries per threadgroup + short C = 32> // cache items per threadgroup +kernel void kernel_flash_attn_ext_vec( + constant ggml_metal_kargs_flash_attn_ext_vec & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device const char * sinks, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C +#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg + switch (FC_flash_attn_ext_vec_nsg) { + // note: disabled cases to reduce library load time + case 1: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + case 2: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + case 4: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 8: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 16: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + //case 32: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; + } +#undef FWD_TMPL +#undef FWD_ARGS } // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem @@ -5163,145 +5455,122 @@ kernel void kernel_flash_attn_ext_vec( typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; -template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #endif -template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; -template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #undef FA_TYPES -kernel void kernel_flash_attn_ext_reduce( - constant ggml_metal_kargs_flash_attn_ext_reduce & args, +constant int32_t FC_flash_attn_ext_vec_reduce_DV [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]]; +constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]]; + +kernel void kernel_flash_attn_ext_vec_reduce( + constant ggml_metal_kargs_flash_attn_ext_vec_reduce & args, device const char * htmp, device char * dst, uint tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { +#define NWG (FC_flash_attn_ext_vec_reduce_NWG) +#define DV (FC_flash_attn_ext_vec_reduce_DV) + const uint64_t rid = tgpig; - const short nwg = 32; const short iwg = tiisg; - const short DV = args.ne20; - const short DV4 = DV/4; - device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*nwg; - device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*nwg; - device float4 * dst4 = (device float4 *) dst + rid*DV4; + device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*NWG; - float S = ss[rid*(2*nwg) + 2*iwg + 0]; - float M = ss[rid*(2*nwg) + 2*iwg + 1]; + float S = ss[rid*(2*NWG) + 2*iwg + 0]; + float M = ss[rid*(2*NWG) + 2*iwg + 1]; const float m = simd_max(M); const float ms = exp(M - m); S = 1.0f/simd_sum(S*ms); - for (int i = sgitg; i < DV4; i += nwg) { - const float4 v = simd_sum(htmp4[i*nwg + iwg]*ms); + const short DV4 = DV/4; + + device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*NWG; + device float4 * dst4 = (device float4 *) dst + rid*DV4; + + for (short i = sgitg; i < DV4; i += NWG) { + const float4 v = simd_sum(htmp4[i*NWG + iwg]*ms); if (iwg == 0) { dst4[i] = v*S; } } + +#undef NWG +#undef DV } -template -kernel void kernel_set( - constant ggml_metal_kargs_set & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i13 = tgpig[2]; - const int i12 = tgpig[1]; - const int i11 = tgpig[0]; - - const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10; - - const int64_t i3 = n / (args.ne12*args.ne11*args.ne10); - const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10); - const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10; - - device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs); - - for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) { - device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10); - dst_data[i10] = (T) src[0]; - } -} - -typedef decltype(kernel_set) kernel_set_t; - -template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set; -template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set; - template kernel void kernel_cpy( constant ggml_metal_kargs_cpy & args, @@ -5338,6 +5607,8 @@ typedef decltype(kernel_cpy) kernel_cpy_t; template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy; #if defined(GGML_METAL_USE_BF16) template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy; #endif @@ -7395,7 +7666,7 @@ kernel void kernel_set_rows_f( const int32_t i10 = i01; const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; - device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); + device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) { @@ -7494,18 +7765,20 @@ kernel void kernel_mul_mm( #pragma unroll(4) for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(4) for (short i = 0; i < 4; i++) { simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } - simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(2) for (short i = 0; i < 2; i++) { simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(8) for (short i = 0; i < 8; i++){ simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 727163b7fd..b188c5af34 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2838,6 +2838,7 @@ static ggml_backend_i ggml_backend_opencl_i = { /* .graph_compute = */ ggml_backend_opencl_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; ggml_backend_t ggml_backend_opencl_init(void) { diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index e84ff93efc..d4833068d0 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -795,6 +795,7 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .graph_compute = */ ggml_backend_rpc_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 877fbf7e86..619ccaefc0 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4063,6 +4063,7 @@ static ggml_backend_i ggml_backend_sycl_interface = { /* .graph_compute = */ ggml_backend_sycl_graph_compute, /* .event_record = */ ggml_backend_sycl_event_record, /* .event_wait = */ ggml_backend_sycl_event_wait, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_sycl_guid() { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index cd1c66ba7b..178d8eb3de 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -506,8 +506,8 @@ struct vk_device_struct { vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; - vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16; - vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT]; @@ -554,6 +554,7 @@ struct vk_device_struct { vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; + vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_conv_transpose_1d_f32; vk_pipeline pipeline_pool2d_f32; @@ -582,6 +583,7 @@ struct vk_device_struct { bool disable_fusion; bool disable_host_visible_vidmem; bool allow_sysmem_fallback; + bool disable_optimize_graph; #ifdef GGML_VULKAN_MEMORY_DEBUG std::unique_ptr memory_logger; @@ -803,6 +805,57 @@ static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_ten p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize); p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize); + return p; // offsets are initialized later in ggml_vk_op +} + +struct vk_op_pad_push_constants { + uint32_t ne; + uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t misalign_offsets; + + uint32_t lp0; uint32_t rp0; + uint32_t lp1; uint32_t rp1; + uint32_t lp2; uint32_t rp2; + uint32_t lp3; uint32_t rp3; +}; + +static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst) { + int64_t ne = ggml_nelements(dst); + GGML_ASSERT(ne <= (int64_t)std::numeric_limits::max()); + + vk_op_pad_push_constants p{}; + p.ne = (uint32_t)ne; + + size_t src0_tsize = ggml_type_size(src0->type); + p.ne00 = (uint32_t)src0->ne[0]; + p.ne01 = (uint32_t)src0->ne[1]; + p.ne02 = (uint32_t)src0->ne[2]; + p.ne03 = (uint32_t)src0->ne[3]; + p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize); + p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize); + p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize); + p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize); + + size_t dst_tsize = ggml_type_size(dst->type); + p.ne10 = (uint32_t)dst->ne[0]; + p.ne11 = (uint32_t)dst->ne[1]; + p.ne12 = (uint32_t)dst->ne[2]; + p.ne13 = (uint32_t)dst->ne[3]; + p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize); + p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize); + p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize); + p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize); + + p.lp0 = dst->op_params[0]; + p.rp0 = dst->op_params[1]; + p.lp1 = dst->op_params[2]; + p.rp1 = dst->op_params[3]; + p.lp2 = dst->op_params[4]; + p.rp2 = dst->op_params[5]; + p.lp3 = dst->op_params[6]; + p.rp3 = dst->op_params[7]; + return p; // fastdiv values and offsets are initialized later in ggml_vk_op } @@ -931,6 +984,37 @@ struct vk_op_im2col_push_constants { int32_t d0; int32_t d1; }; +struct vk_op_im2col_3d_push_constants { + uint32_t nb10; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t s0; + uint32_t s1; + uint32_t s2; + uint32_t p0; + uint32_t p1; + uint32_t p2; + uint32_t d0; + uint32_t d1; + uint32_t d2; + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t IC; + uint32_t KW; + uint32_t OH; + uint32_t KD_KH_KW; + uint32_t KH_KW; + uint32_t IC_KD_KH_KW; + uint32_t N_OD_OH; + uint32_t OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW; + uint32_t misalign_offsets; +}; + struct vk_op_timestep_embedding_push_constants { uint32_t nb1; uint32_t dim; @@ -1853,7 +1937,9 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); - for (auto &req_flags : req_flags_list) { + for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { + const auto & req_flags = *it; + uint32_t memory_type_index = find_properties(&mem_props, &mem_req, req_flags); if (memory_type_index == UINT32_MAX) { @@ -1866,6 +1952,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std break; } catch (const vk::SystemError& e) { // loop and retry + // during last attempt throw the exception + if (it + 1 == req_flags_list.end()) { + device->device.destroyBuffer(buf->buffer); + throw e; + } } } @@ -3143,12 +3234,16 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); if (device->float_controls_rte_fp16) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); @@ -3250,7 +3345,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -3299,7 +3394,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); - ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true); ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); @@ -3329,10 +3424,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); if (device->float_controls_rte_fp16) { ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); } else { ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); @@ -3502,6 +3600,9 @@ static vk_device ggml_vk_get_device(size_t idx) { const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK"); device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr; + const char* GGML_VK_DISABLE_OPTIMIZE_GRAPH = getenv("GGML_VK_DISABLE_OPTIMIZE_GRAPH"); + device->disable_optimize_graph = GGML_VK_DISABLE_OPTIMIZE_GRAPH != nullptr; + bool fp16_storage = false; bool fp16_compute = false; bool maintenance4_support = false; @@ -3642,6 +3743,12 @@ static vk_device ggml_vk_get_device(size_t idx) { device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); +#ifdef __APPLE__ + // Workaround for subgroup arithmetic failing on MoltenVK with AMD GPUs (issue 15846) + if (device->vendor_id == VK_VENDOR_ID_AMD) { + device->subgroup_arithmetic = false; + } +#endif device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && @@ -5607,6 +5714,20 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cpy_f32_bf16; } } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_i32; + } else { + return ctx->device->pipeline_cpy_f32_i32; + } + } + if (src->type == GGML_TYPE_I32 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_i32_f32; + } else { + return ctx->device->pipeline_cpy_i32_f32; + } + } if (src->type == GGML_TYPE_F32) { switch (to) { case GGML_TYPE_Q4_0: @@ -7666,6 +7787,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_im2col_f32_f16; } return nullptr; + case GGML_OP_IM2COL_3D: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_im2col_3d_f32; + } + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_im2col_3d_f32_f16; + } + return nullptr; case GGML_OP_TIMESTEP_EMBEDDING: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_timestep_embedding_f32; @@ -7781,6 +7910,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_RMS_NORM: case GGML_OP_CONV_2D_DW: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_SET_ROWS: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: @@ -7829,6 +7959,26 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src2); } +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src0); + GGML_UNUSED(src2); +} + template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); @@ -8069,6 +8219,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { OW * KW * KH, OH, batch * IC }; } break; + case GGML_OP_IM2COL_3D: + { + const uint32_t IC = ((const uint32_t *)(dst->op_params))[9]; + + const uint32_t N = ne13 / IC; + + const uint32_t KD = ne02; + const uint32_t KH = ne01; + const uint32_t KW = ne00; + + const uint32_t OD = ned3 / N; + const uint32_t OH = ned2; + const uint32_t OW = ned1; + + const uint32_t IC_KD_KH_KW = IC*KD*KH*KW; + const uint32_t N_OD_OH = N*OD*OH; + + elements = { IC_KD_KH_KW, OW, N_OD_OH }; + elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + } break; case GGML_OP_TIMESTEP_EMBEDDING: { const uint32_t dim = dst->op_params[0]; @@ -8225,7 +8395,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); - } else if (op == GGML_OP_IM2COL) { + } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) { // im2col uses only src1 and dst buffers ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_COUNT_EQUAL) { @@ -8771,7 +8941,7 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con } static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); + vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst); ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun); } @@ -8982,7 +9152,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun); } static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) { @@ -9086,6 +9256,66 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co }, dryrun); } +static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; + + vk_op_im2col_3d_push_constants pc {}; + + pc.nb10 = nb10 / ggml_type_size(src1->type); + pc.nb11 = nb11 / ggml_type_size(src1->type); + pc.nb12 = nb12 / ggml_type_size(src1->type); + pc.nb13 = nb13 / ggml_type_size(src1->type); + pc.s0 = s0; + pc.s1 = s1; + pc.s2 = s2; + pc.p0 = p0; + pc.p1 = p1; + pc.p2 = p2; + pc.d0 = d0; + pc.d1 = d1; + pc.d2 = d2; + pc.IW = IW; + pc.IH = IH; + pc.ID = ID; + pc.IC = IC; + pc.KW = KW; + pc.OH = OH; + pc.KD_KH_KW = KD*KH*KW; + pc.KH_KW = KH*KW; + pc.IC_KD_KH_KW = IC*KD*KH*KW; + pc.N_OD_OH = N*OD*OH; + pc.OD_OH = OD*OH; + pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW; + pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; + pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun); +} + static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const uint32_t dim = dst->op_params[0]; const uint32_t max_period = dst->op_params[1]; @@ -10291,6 +10521,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: @@ -10361,6 +10592,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: @@ -10656,6 +10888,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_IM2COL: ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_IM2COL_3D: + ggml_vk_im2col_3d(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_TIMESTEP_EMBEDDING: ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); @@ -10807,6 +11043,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: @@ -11633,6 +11870,131 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg UNUSED(backend); } +// Sort the graph for improved parallelism. +static void ggml_vk_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * graph) +{ + VK_LOG_DEBUG("ggml_vk_optimize_graph(" << graph->n_nodes << " nodes)"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + if (ctx->device->disable_optimize_graph) { + return; + } + + auto const &is_empty = [](ggml_tensor * node) -> bool { + return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; + }; + + auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool { + for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) { + if (dst->src[s] == src) { + return true; + } + } + // implicit dependency if they view the same tensor + const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst; + const ggml_tensor *src2 = src->view_src ? src->view_src : src; + if (dst2 == src2) { + return true; + } + return false; + }; + + // This function tries to reorder the graph to allow nodes to run in parallel. + // This helps with small batches, but for large batches its a slowdown, probably + // due to cache contention. So only reorder if the majority of nodes have few rows. + int num_small_nodes = 0; + int num_counted_nodes = 0; + for (int i = 0; i < graph->n_nodes; ++i) { + if (!is_empty(graph->nodes[i]) && + graph->nodes[i]->op != GGML_OP_SET_ROWS) { + if (ggml_nrows(graph->nodes[i]) <= 8) { + num_small_nodes++; + } + num_counted_nodes++; + } + } + if (num_small_nodes < num_counted_nodes / 2) { + return; + } + + std::vector new_order; + std::vector used(graph->n_nodes, false); + int first_unused = 0; + while (first_unused < graph->n_nodes) { + std::vector current_set; + + // First, grab the next unused node. + current_set.push_back(first_unused); + + // Loop through the next N nodes. Grab any that don't depend on other nodes that + // haven't already been run. Nodes that have already been run have used[i] set + // to true. Allow nodes that depend on the previous node if it's a fusion pattern + // that we support (e.g. RMS_NORM + MUL). + // This first pass only grabs "real" (non-view nodes). Second pass grabs view nodes. + // The goal is to not interleave real and view nodes in a way that breaks fusion. + const int NUM_TO_CHECK = 20; + for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) { + if (used[j]) { + continue; + } + if (is_empty(graph->nodes[j])) { + continue; + } + bool ok = true; + for (int c = first_unused; c < j; ++c) { + if (!used[c] && + is_src_of(graph->nodes[j], graph->nodes[c]) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL)) { + ok = false; + break; + } + } + if (ok) { + current_set.push_back(j); + } + } + // Second pass grabs view nodes. + // Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add). + if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) { + for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) { + if (used[j]) { + continue; + } + if (!is_empty(graph->nodes[j])) { + continue; + } + bool ok = true; + for (int c = first_unused; c < j; ++c) { + bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end(); + // skip views whose srcs haven't been processed. + if (!used[c] && + is_src_of(graph->nodes[j], graph->nodes[c]) && + !c_in_current_set) { + ok = false; + break; + } + } + if (ok) { + current_set.push_back(j); + } + } + } + + // Push the current set into new_order + for (auto c : current_set) { + new_order.push_back(graph->nodes[c]); + used[c] = true; + } + while (first_unused < graph->n_nodes && used[first_unused]) { + first_unused++; + } + } + // Replace the graph with the new order. + for (int i = 0; i < graph->n_nodes; ++i) { + graph->nodes[i] = new_order[i]; + } +} + // TODO: enable async and synchronize static ggml_backend_i ggml_backend_vk_interface = { /* .get_name = */ ggml_backend_vk_name, @@ -11648,6 +12010,7 @@ static ggml_backend_i ggml_backend_vk_interface = { /* .graph_compute = */ ggml_backend_vk_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ ggml_vk_optimize_graph, }; static ggml_guid_t ggml_backend_vk_guid() { @@ -11750,6 +12113,7 @@ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(gg static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) { UNUSED(dev); + // TODO: return GGML_BACKEND_DEVICE_TYPE_IGPU for integrated GPUs return GGML_BACKEND_DEVICE_TYPE_GPU; } @@ -11757,6 +12121,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml props->name = ggml_backend_vk_device_get_name(dev); props->description = ggml_backend_vk_device_get_description(dev); props->type = ggml_backend_vk_device_get_type(dev); + // TODO: set props->device_id to PCI bus id ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); props->caps = { /* .async = */ false, @@ -12022,6 +12387,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; } + if ( + src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32 || + src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32 + ) { + return true; + } + // We can handle copying from a type to the same type if it's // contiguous (memcpy). We use f16 or f32 shaders to do the copy, // so the type/block size must be a multiple of 4. @@ -12076,10 +12448,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_ACC: case GGML_OP_CONCAT: case GGML_OP_SCALE: - return true; case GGML_OP_PAD: - return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && - (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); case GGML_OP_ROLL: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: @@ -12092,6 +12461,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_ARGMAX: case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_2D_DW: case GGML_OP_POOL_2D: @@ -12520,7 +12890,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * const float * params = (const float *)tensor->op_params; tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); } else if (tensor->op == GGML_OP_PAD) { - tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]); + tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3], + tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]); } else if (tensor->op == GGML_OP_REPEAT) { tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor); } else if (tensor->op == GGML_OP_REPEAT_BACK) { @@ -12666,6 +13037,19 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * const bool is_2D = tensor->op_params[6] == 1; tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type); + } else if (tensor->op == GGML_OP_IM2COL_3D) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t s1 = tensor->op_params[2]; + const int32_t p0 = tensor->op_params[3]; + const int32_t p1 = tensor->op_params[4]; + const int32_t p1 = tensor->op_params[5]; + const int32_t d0 = tensor->op_params[6]; + const int32_t d1 = tensor->op_params[7]; + const int32_t d1 = tensor->op_params[8]; + const int32_t IC = tensor->op_params[9]; + + tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type); } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { const int32_t dim = tensor->op_params[0]; const int32_t max_period = tensor->op_params[1]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp new file mode 100644 index 0000000000..3b010bdeb5 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp @@ -0,0 +1,112 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "rte.comp" + +layout (push_constant) uniform parameter +{ + uint32_t nb10; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t s0; + uint32_t s1; + uint32_t s2; + uint32_t p0; + uint32_t p1; + uint32_t p2; + uint32_t d0; + uint32_t d1; + uint32_t d2; + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t IC; + uint32_t KW; + uint32_t OH; + uint32_t KD_KH_KW; + uint32_t KH_KW; + uint32_t IC_KD_KH_KW; + uint32_t N_OD_OH; + uint32_t OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW; + uint32_t misalign_offsets; +} p; + +#include "types.comp" + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint32_t i = gl_GlobalInvocationID.x; + + uint32_t nb10 = p.nb10; + uint32_t nb11 = p.nb11; + uint32_t nb12 = p.nb12; + uint32_t nb13 = p.nb13; + uint32_t s0 = p.s0; + uint32_t s1 = p.s1; + uint32_t s2 = p.s2; + uint32_t p0 = p.p0; + uint32_t p1 = p.p1; + uint32_t p2 = p.p2; + uint32_t d0 = p.d0; + uint32_t d1 = p.d1; + uint32_t d2 = p.d2; + uint32_t IW = p.IW; + uint32_t IH = p.IH; + uint32_t ID = p.ID; + uint32_t IC = p.IC; + uint32_t KW = p.KW; + uint32_t OH = p.OH; + uint32_t KD_KH_KW = p.KD_KH_KW; + uint32_t KH_KW = p.KH_KW; + uint32_t IC_KD_KH_KW = p.IC_KD_KH_KW; + uint32_t N_OD_OH = p.N_OD_OH; + uint32_t OD_OH = p.OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW = p.OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW = p.OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW = p.OW_IC_KD_KH_KW; + + if (i >= IC_KD_KH_KW) { + return; + } + + const uint32_t iic = i / KD_KH_KW; + const uint32_t ikd = (i - iic * KD_KH_KW) / KH_KW; + const uint32_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; + const uint32_t ikw = i % KW; + + const uint32_t iow = gl_GlobalInvocationID.y; + for (uint32_t iz = gl_GlobalInvocationID.z; iz < N_OD_OH; iz += gl_NumWorkGroups.z) { + const uint32_t in_ = iz / OD_OH; + const uint32_t iod = (iz - in_*OD_OH) / OH; + const uint32_t ioh = iz % OH; + + const uint32_t iiw = iow * s0 + ikw * d0 - p0; + const uint32_t iih = ioh * s1 + ikh * d1 - p1; + const uint32_t iid = iod * s2 + ikd * d2 - p2; + + const uint32_t offset_dst = in_*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + + if (iih >= IH || iiw >= IW || iid >= ID) { + data_d[offset_dst + get_doffset()] = D_TYPE(0.0f); + } else { + const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10; + data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 7e10e99e9e..f6a7761ffa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -315,21 +315,23 @@ void main() { #if LOAD_VEC_A == 8 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); - buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); - buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); - buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w); - buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x); - buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y); - buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z); - buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); + A_TYPE32 aa = A_TYPE32(data_a[idx]); + buf_a[buf_idx ] = FLOAT_TYPE(aa[0].x); + buf_a[buf_idx + 1] = FLOAT_TYPE(aa[0].y); + buf_a[buf_idx + 2] = FLOAT_TYPE(aa[0].z); + buf_a[buf_idx + 3] = FLOAT_TYPE(aa[0].w); + buf_a[buf_idx + 4] = FLOAT_TYPE(aa[1].x); + buf_a[buf_idx + 5] = FLOAT_TYPE(aa[1].y); + buf_a[buf_idx + 6] = FLOAT_TYPE(aa[1].z); + buf_a[buf_idx + 7] = FLOAT_TYPE(aa[1].w); #elif LOAD_VEC_A == 4 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); - buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); - buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); - buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); + A_TYPE32 aa = A_TYPE32(data_a[idx]); + buf_a[buf_idx ] = FLOAT_TYPE(aa.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(aa.y); + buf_a[buf_idx + 2] = FLOAT_TYPE(aa.z); + buf_a[buf_idx + 3] = FLOAT_TYPE(aa.w); #else if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); @@ -808,14 +810,19 @@ void main() { const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; #endif const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; - buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); - buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); - buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); - buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w); - buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x); - buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y); - buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z); - buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w); +#if defined(DATA_B_BF16) + B_TYPE32 bb = TO_FLOAT_TYPE(data_b[idx]); +#else + B_TYPE32 bb = B_TYPE32(data_b[idx]); +#endif + buf_b[buf_idx + 0] = FLOAT_TYPE(bb[0].x); + buf_b[buf_idx + 1] = FLOAT_TYPE(bb[0].y); + buf_b[buf_idx + 2] = FLOAT_TYPE(bb[0].z); + buf_b[buf_idx + 3] = FLOAT_TYPE(bb[0].w); + buf_b[buf_idx + 4] = FLOAT_TYPE(bb[1].x); + buf_b[buf_idx + 5] = FLOAT_TYPE(bb[1].y); + buf_b[buf_idx + 6] = FLOAT_TYPE(bb[1].z); + buf_b[buf_idx + 7] = FLOAT_TYPE(bb[1].w); #elif LOAD_VEC_B == 4 #ifdef MUL_MAT_ID const u16vec2 row_idx = row_ids[loadc_b + l]; @@ -824,10 +831,15 @@ void main() { const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; #endif const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; - buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x); - buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y); - buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z); - buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w); +#if defined(DATA_B_BF16) + B_TYPE32 bb = TO_FLOAT_TYPE(data_b[idx]); +#else + B_TYPE32 bb = B_TYPE32(data_b[idx]); +#endif + buf_b[buf_idx + 0] = FLOAT_TYPE(bb.x); + buf_b[buf_idx + 1] = FLOAT_TYPE(bb.y); + buf_b[buf_idx + 2] = FLOAT_TYPE(bb.z); + buf_b[buf_idx + 3] = FLOAT_TYPE(bb.w); #elif !MUL_MAT_ID if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp index 450b67fc55..0d81220c71 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp @@ -1,7 +1,25 @@ #version 450 #include "types.comp" -#include "generic_unary_head.comp" + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint misalign_offsets; + + uint lp0; uint rp0; + uint lp1; uint rp1; + uint lp2; uint rp2; + uint lp3; uint rp3; +} p; + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; @@ -19,10 +37,13 @@ void main() { const uint i1 = (idx - i3_offset - i2_offset) / p.ne10; const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; - const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; + const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00; const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10; - const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; + const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 && + i1 >= p.lp1 && i1 < p.ne11 - p.rp1 && + i2 >= p.lp2 && i2 < p.ne12 - p.rp2 && + i3 >= p.lp3 && i3 < p.ne13 - p.rp3; data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp index 29bd77d7e1..144ea58e6f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp @@ -20,6 +20,10 @@ void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; + if (row >= p.KY) { + return; + } + FLOAT_TYPE scale = p.param1; // partial sums for thread in warp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index 408722c878..c2acc803f6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -13,10 +13,13 @@ #if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 #define A_TYPE float +#define A_TYPE32 float #elif LOAD_VEC_A == 4 #define A_TYPE vec4 +#define A_TYPE32 vec4 #elif LOAD_VEC_A == 8 #define A_TYPE mat2x4 +#define A_TYPE32 mat2x4 #endif #endif @@ -26,10 +29,13 @@ #if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 #define A_TYPE float16_t +#define A_TYPE32 float #elif LOAD_VEC_A == 4 #define A_TYPE f16vec4 +#define A_TYPE32 vec4 #elif LOAD_VEC_A == 8 #define A_TYPE f16mat2x4 +#define A_TYPE32 mat2x4 #endif #endif @@ -1424,6 +1430,11 @@ float bf16_to_fp32(uint32_t u) return uintBitsToFloat(u << 16); } +vec4 bf16_to_fp32(uvec4 u) +{ + return vec4(bf16_to_fp32(u.x), bf16_to_fp32(u.y), bf16_to_fp32(u.z), bf16_to_fp32(u.w)); +} + float e8m0_to_fp32(uint8_t x) { uint32_t bits; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 613498d0d5..b6570e0203 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -364,11 +364,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c }; // Shaders with f16 B_TYPE - string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); // bf16 { @@ -384,8 +384,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c if (!(coopmat || coopmat2)) #endif { - string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE32", "vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); } } @@ -408,13 +408,13 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) @@ -560,10 +560,14 @@ void process_shaders() { string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); + string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); + string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); @@ -713,6 +717,10 @@ void process_shaders() { string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}})); + string_to_spv("im2col_3d_f32", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("im2col_3d_f32_f16", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); + string_to_spv("im2col_3d_f32_f16_rte", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}})); + string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f1d4d584c3..17701cc437 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -815,6 +815,7 @@ static ggml_backend_i ggml_backend_webgpu_i = { /* .graph_compute = */ ggml_backend_webgpu_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; /* End GGML Backend Interface */ @@ -1394,17 +1395,15 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t webgpu_context ctx = reg_ctx->webgpu_ctx; wgpu::RequestAdapterOptions options = {}; - auto callback = [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message, - void * userdata) { - if (status != wgpu::RequestAdapterStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); - return; - } - *static_cast(userdata) = std::move(adapter); - }; - void * userdata = &ctx->adapter; ctx->instance.WaitAny( - ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, callback, userdata), UINT64_MAX); + ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, + [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + ctx->adapter = std::move(adapter); + }), UINT64_MAX); GGML_ASSERT(ctx->adapter != nullptr); ctx->adapter.GetLimits(&ctx->limits); diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index 7507a52aea..a4c51ab4e7 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -586,6 +586,7 @@ static ggml_backend_i ggml_backend_zdnn_i = { /* .graph_compute = */ ggml_backend_zdnn_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_zdnn_guid(void) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f35c337952..50dc1aa24f 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3623,6 +3623,7 @@ struct ggml_tensor * ggml_get_rows( struct ggml_tensor * a, struct ggml_tensor * b) { GGML_ASSERT(a->ne[2] == b->ne[1]); + GGML_ASSERT(a->ne[3] == b->ne[2]); GGML_ASSERT(b->ne[3] == 1); GGML_ASSERT(b->type == GGML_TYPE_I32); diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index af57a88c60..8cc4ef1cf4 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -1166,50 +1166,51 @@ void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const vo ctx->info[tensor_id].t.data = (void *)(uintptr_t)data; // double cast suppresses warning about casting away const } -struct gguf_writer { - std::vector & buf; +struct gguf_writer_base { + size_t written_bytes {0u}; - gguf_writer(std::vector & buf) : buf(buf) {} + ~gguf_writer_base(void) {} + + // we bet on devirtualization + virtual void write(int8_t val) = 0; + virtual void write(const std::vector & val) = 0; + virtual void write_tensor_data(const struct gguf_tensor_info & info, size_t offset_data, size_t alignment) = 0; template - void write(const T & val) const { + void write(const T & val) { for (size_t i = 0; i < sizeof(val); ++i) { - buf.push_back(reinterpret_cast(&val)[i]); + write(reinterpret_cast(&val)[i]); } } - void write(const std::vector & val) const { - buf.insert(buf.end(), val.begin(), val.end()); - } - - void write(const bool & val) const { + void write(const bool & val) { const int8_t val8 = val ? 1 : 0; write(val8); } - void write(const std::string & val) const { + void write(const std::string & val) { { const uint64_t n = val.length(); write(n); } for (size_t i = 0; i < val.length(); ++i) { - buf.push_back(reinterpret_cast(val.data())[i]); + write((val.data())[i]); } } - void write(const char * val) const { + void write(const char * val) { write(std::string(val)); } - void write(const enum ggml_type & val) const { + void write(const enum ggml_type & val) { write(int32_t(val)); } - void write(const enum gguf_type & val) const { + void write(const enum gguf_type & val) { write(int32_t(val)); } - void write(const struct gguf_kv & kv) const { + void write(const struct gguf_kv & kv) { const uint64_t ne = kv.get_ne(); write(kv.get_key()); @@ -1250,7 +1251,7 @@ struct gguf_writer { } } - void write_tensor_meta(const struct gguf_tensor_info & info) const { + void write_tensor_meta(const struct gguf_tensor_info & info) { write(info.t.name); const uint32_t n_dims = ggml_n_dims(&info.t); @@ -1263,14 +1264,33 @@ struct gguf_writer { write(info.offset); } - void pad(const size_t alignment) const { - while (buf.size() % alignment != 0) { + void pad(const size_t alignment) { + while (written_bytes % alignment != 0) { const int8_t zero = 0; write(zero); } } +}; - void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) const { +// vector buffer based writer +struct gguf_writer_buf final : public gguf_writer_base { + std::vector & buf; + + gguf_writer_buf(std::vector & buf) : buf(buf) {} + + using gguf_writer_base::write; + + void write(const int8_t val) override { + buf.push_back(val); + written_bytes++; + } + + void write(const std::vector & val) override { + buf.insert(buf.end(), val.begin(), val.end()); + written_bytes += val.size(); + } + + void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) override { GGML_ASSERT(buf.size() - offset_data == info.offset); GGML_ASSERT(ggml_is_contiguous(&info.t)); @@ -1284,14 +1304,58 @@ struct gguf_writer { GGML_ASSERT(info.t.data); memcpy(buf.data() + offset, info.t.data, nbytes); } + written_bytes += nbytes; pad(alignment); } }; -void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta) { - const struct gguf_writer gw(buf); +// file based writer +struct gguf_writer_file final : public gguf_writer_base { + FILE * file; + gguf_writer_file(FILE* file) : file(file) {} + + using gguf_writer_base::write; + + void write(const int8_t val) override { + const auto real_val = static_cast(val); + const auto ret = fputc(real_val, file); + written_bytes++; + if (ret != real_val) { + throw std::runtime_error("unexpected fputc result '" + std::to_string(ret) + "' instead of '" + std::to_string((int)real_val) + "'"); + } + } + + void write(const std::vector & val) override { + const auto ret = fwrite(val.data(), 1, val.size(), file); + written_bytes += val.size(); + if (ret != val.size()) { + throw std::runtime_error("unexpected fwrite number of bytes written, '" + std::to_string(ret) + "' instead of '" + std::to_string(val.size()) + "'"); + } + } + + void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) override { + GGML_ASSERT(written_bytes - offset_data == info.offset); + + GGML_ASSERT(ggml_is_contiguous(&info.t)); + const size_t nbytes = ggml_nbytes(&info.t); + + std::vector buf(nbytes); + if (info.t.buffer) { + ggml_backend_tensor_get(&info.t, buf.data(), 0, nbytes); + } else { + GGML_ASSERT(info.t.data); + memcpy(buf.data(), info.t.data, nbytes); + } + write(buf); + + pad(alignment); + } +}; + +template +static void gguf_write_out(const struct gguf_context * ctx, writer_t & gw, bool only_meta) { const int64_t n_kv = gguf_get_n_kv(ctx); const int64_t n_tensors = gguf_get_n_tensors(ctx); @@ -1321,7 +1385,7 @@ void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & bu return; } - const size_t offset_data = gw.buf.size(); + const size_t offset_data = gw.written_bytes; // write tensor data for (int64_t i = 0; i < n_tensors; ++i) { @@ -1329,6 +1393,11 @@ void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & bu } } +void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta) { + gguf_writer_buf gw(buf); + gguf_write_out(ctx, gw, only_meta); +} + bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { FILE * file = ggml_fopen(fname, "wb"); @@ -1337,11 +1406,17 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo return false; } - std::vector buf; - gguf_write_to_buf(ctx, buf, only_meta); - const bool ok = fwrite(buf.data(), 1, buf.size(), file) == buf.size(); + try { + gguf_writer_file gw(file); + gguf_write_out(ctx, gw, only_meta); + } catch (const std::runtime_error& ex) { + GGML_LOG_ERROR("%s: failed to write GGUF data into '%s': %s\n", __func__, fname, ex.what()); + fclose(file); + return false; + } + fclose(file); - return ok; + return true; } size_t gguf_get_meta_size(const struct gguf_context * ctx) { diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c922bacf6e..1e88b6505b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -109,6 +109,7 @@ class Keys: POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" + DECODER_BLOCK_COUNT = "{arch}.decoder_block_count" ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" SWIN_NORM = "{arch}.swin_norm" @@ -231,10 +232,11 @@ class Keys: MIDDLE_ID = "tokenizer.ggml.middle_token_id" class Adapter: - TYPE = "adapter.type" - LORA_ALPHA = "adapter.lora.alpha" - LORA_TASK_NAME = "adapter.lora.task_name" - LORA_PROMPT_PREFIX = "adapter.lora.prompt_prefix" + TYPE = "adapter.type" + LORA_ALPHA = "adapter.lora.alpha" + LORA_TASK_NAME = "adapter.lora.task_name" + LORA_PROMPT_PREFIX = "adapter.lora.prompt_prefix" + ALORA_INVOCATION_TOKENS = "adapter.alora.invocation_tokens" class IMatrix: CHUNK_COUNT = "imatrix.chunk_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a6cc8a931e..7ff12f7f57 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -676,6 +676,9 @@ class GGUFWriter: def add_decoder_start_token_id(self, id: int) -> None: self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id) + def add_decoder_block_count(self, value: int) -> None: + self.add_uint32(Keys.LLM.DECODER_BLOCK_COUNT.format(arch=self.arch), value) + def add_embedding_length_per_layer_input(self, value: int) -> None: self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value) diff --git a/include/llama.h b/include/llama.h index 11f8a363a5..453190e852 100644 --- a/include/llama.h +++ b/include/llama.h @@ -583,6 +583,10 @@ extern "C" { // Note: loaded adapters will be free when the associated model is deleted LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); + // Get the invocation tokens if the current lora is an alora + LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter); + LLAMA_API const llama_token * llama_adapter_get_alora_invocation_tokens (const struct llama_adapter_lora * adapter); + // The following functions operate on a llama_context, hence the naming: llama_verb_... // Add a loaded LoRA adapter to given context diff --git a/media/llama1-icon-transparent.png b/media/llama1-icon-transparent.png new file mode 100644 index 0000000000..432d6c2223 Binary files /dev/null and b/media/llama1-icon-transparent.png differ diff --git a/media/llama1-icon-transparent.svg b/media/llama1-icon-transparent.svg new file mode 100644 index 0000000000..e28203f4e8 --- /dev/null +++ b/media/llama1-icon-transparent.svg @@ -0,0 +1,77 @@ + + + + + + + + + + + + + + + + + diff --git a/media/llama1-icon.png b/media/llama1-icon.png new file mode 100644 index 0000000000..0e44672e54 Binary files /dev/null and b/media/llama1-icon.png differ diff --git a/media/llama1-icon.svg b/media/llama1-icon.svg new file mode 100644 index 0000000000..dcbe9cce9b --- /dev/null +++ b/media/llama1-icon.svg @@ -0,0 +1,87 @@ + + + + + + + + + + + + + + + + + + diff --git a/models/templates/NVIDIA-Nemotron-Nano-v2.jinja b/models/templates/NVIDIA-Nemotron-Nano-v2.jinja new file mode 100644 index 0000000000..c8ab584830 --- /dev/null +++ b/models/templates/NVIDIA-Nemotron-Nano-v2.jinja @@ -0,0 +1,162 @@ +{%- set ns = namespace(enable_thinking=true) -%} +{%- for message in messages -%} + {%- set content = message['content'] -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {%- if '/think' in content -%} + {%- set ns.enable_thinking = true -%} + {%- elif '/no_think' in content -%} + {%- set ns.enable_thinking = false -%} + {%- endif -%} + {%- endif -%} +{%- endfor -%} + +{%- if messages[0]['role'] != 'system' -%} + {%- set ns.non_tool_system_content = '' -%} + {{- 'System +' -}} +{%- else -%} + {%- set ns.non_tool_system_content = (messages[0]['content'] | default('', true)).replace('/think', '').replace('/no_think', '').strip() -%} + {{- 'System +' + ns.non_tool_system_content }} +{%- endif -%} + +{%- if tools -%} + {%- if ns.non_tool_system_content is defined and ns.non_tool_system_content != '' -%} + {{- ' + +' -}} + {%- endif -%} + {{- 'You can use the following tools to assist the user if required:' -}} + {{- ' +[' -}} + {%- for tool in tools -%} + {{- (tool.function if tool.function is defined else tool) | tojson -}} + {{- ', ' if not loop.last else '' -}} + {%- endfor -%} + {{- '] + +' -}} + {{- 'If you decide to call any tool(s), use the following format: +' -}} + {{- '[{{"name": "tool_name1", "arguments": "tool_args1"}}, ' -}} + {{- '{{"name": "tool_name2", "arguments": "tool_args2"}}] + +' -}} + {{- 'The user will execute tool-calls and return responses from tool(s) in this format: +' -}} + {{- '[{{"tool_response1"}}, {{"tool_response2"}}] + +' -}} + {{- 'Based on the tool responses, you can call additional tools if needed, correct tool calls if any errors are found, or just respond to the user.' -}} +{%- endif -%} +{{- ' + +' -}} +{%- set messages = messages[1:] if messages[0]['role'] == 'system' else messages -%} +{%- if messages[-1]['role'] == 'assistant' -%} + {%- set ns.last_turn_assistant_content = (messages[-1]['content'] | default('', true)).strip() -%} + {%- set ns.last_turn_assistant_tool_calls = messages[-1]['tool_calls'] if 'tool_calls' in messages[-1] else [] -%} + {%- set messages = messages[:-1] -%} +{%- endif -%} + +{%- for message in messages %} + {%- set content = message['content'] %} + {%- if message['role'] == 'user' -%} + {{- 'User +' + (content | default('', true)).replace('/think', '').replace('/no_think', '').strip() + ' +' }} + {%- elif message['role'] == 'tool' -%} + {%- if loop.first or (messages[loop.index0 - 1].role != 'tool') -%} + {{- 'User +' + '[' }} + {%- endif -%} + {{- message['content'] -}} + {{- ', ' if not loop.last and (messages[loop.index0 + 1].role == 'tool') else '' -}} + {%- if loop.last or (messages[loop.index0 + 1].role != 'tool') -%} + {{- ']' -}} + {%- endif -%} + {%- elif message['role'] == 'assistant' -%} + {%- if content and '' in content -%} + {%- set content = (content.split('')[1] | default('', true)).strip() %} + {%- endif -%} + {{- 'Assistant +' + ((content | default('', true)).strip() if content is not none else '') }} + {%- if message.tool_calls -%} + {%- if (content | default('', true)).strip() != '' -%} + {{- ' +' -}} + {%- endif -%} + {{- '[' -}} + {%- for call in message.tool_calls -%} + {%- set fn = call.function if call.function is defined else call -%} + {{- '{"name": "' + fn.name + '", "arguments": ' -}} + {%- if fn.arguments is string -%} + {{- fn.arguments -}} + {%- else -%} + {{- fn.arguments | tojson -}} + {%- endif -%} + {{- '}' + (', ' if not loop.last else '') -}} + {%- endfor -%} + {{- ']' -}} + {%- endif -%} + {{- ' + +' -}} + {%- endif -%} +{%- endfor -%} + +{%- if add_generation_prompt -%} + {{- 'Assistant +' -}} + {%- if ns.enable_thinking is defined and ns.enable_thinking is false -%} + {{- '' -}} + {%- else -%} + {{- ' +' -}} + {%- endif -%} + {%- if ns.last_turn_assistant_content is defined and ns.last_turn_assistant_content != '' -%} + {{- ns.last_turn_assistant_content -}} + {%- endif -%} +{%- else -%} + {%- if ns.last_turn_assistant_content is defined and ns.last_turn_assistant_content != '' -%} + {{- 'Assistant +' -}} + {%- if ns.enable_thinking is defined and ns.enable_thinking is false -%} + {{- '' -}} + {%- else -%} + {{- ' +' -}} + {%- endif -%} + {{- ns.last_turn_assistant_content -}} + {%- if continue_final_message is defined -%} + {%- if continue_final_message is false -%} + {{- ' + +' -}} + {%- endif -%} + {%- else -%} + {{- ' + +' -}} + {%- endif -%} + {%- endif -%} + {%- if ns.last_turn_assistant_tool_calls is defined and ns.last_turn_assistant_tool_calls | length > 0 -%} + {{- 'Assistant +' -}} + {{- '[' -}} + {%- for call in ns.last_turn_assistant_tool_calls -%} + {%- set fn = call.function if call.function is defined else call -%} + {{- '{"name": "' + fn.name + '", "arguments": ' -}} + {%- if fn.arguments is string -%} + {{- fn.arguments -}} + {%- else -%} + {{- fn.arguments | tojson -}} + {%- endif -%} + {{- '}' + (', ' if not loop.last else '') -}} + {%- endfor -%} + {{- ']' -}} + {{- ' + +' -}} + {%- endif -%} +{%- endif -%} \ No newline at end of file diff --git a/models/templates/README.md b/models/templates/README.md index 2e8eaa5953..3a649b8f4d 100644 --- a/models/templates/README.md +++ b/models/templates/README.md @@ -22,4 +22,5 @@ These templates can be updated with the following commands: ./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja ./scripts/get_chat_template.py Qwen/Qwen3-0.6B > models/templates/Qwen-Qwen3-0.6B.jinja ./scripts/get_chat_template.py zai-org/GLM-4.5 > models/templates/zai-org-GLM-4.5.jinja +./scripts/get_chat_template.py deepseek-ai/DeepSeek-V3.1 > models/templates/deepseek-ai-DeepSeek-V3.1.jinja ``` diff --git a/models/templates/deepseek-ai-DeepSeek-V3.1.jinja b/models/templates/deepseek-ai-DeepSeek-V3.1.jinja new file mode 100644 index 0000000000..e5656196a3 --- /dev/null +++ b/models/templates/deepseek-ai-DeepSeek-V3.1.jinja @@ -0,0 +1,3 @@ +{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% if not thinking is defined %}{% set thinking = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + ' + +' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{%- set ns.is_first = false -%}{%- set ns.is_last_user = true -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}{%- if ns.is_last_user %}{{'<|Assistant|>'}}{%- endif %}{%- set ns.is_last_user = false -%}{%- set ns.is_first = false %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}{%- else %}{{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %}{%- if ns.is_last_user %}{{'<|Assistant|>'}}{%- if message['prefix'] is defined and message['prefix'] and thinking %}{{''}} {%- else %}{{''}}{%- endif %}{%- endif %}{%- set ns.is_last_user = false -%}{%- if ns.is_tool %}{{message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{%- set content = message['content'] -%}{%- if '' in content %}{%- set content = content.split('', 1)[1] -%}{%- endif %}{{content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_last_user = false -%}{%- set ns.is_tool = true -%}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endfor -%}{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool %}{{'<|Assistant|>'}}{%- if not thinking %}{{''}}{%- else %}{{''}}{%- endif %}{% endif %} \ No newline at end of file diff --git a/requirements/requirements-convert_hf_to_gguf.txt b/requirements/requirements-convert_hf_to_gguf.txt index 766745f42f..90c98c3ffe 100644 --- a/requirements/requirements-convert_hf_to_gguf.txt +++ b/requirements/requirements-convert_hf_to_gguf.txt @@ -2,7 +2,9 @@ mistral-common>=1.8.3 -r ./requirements-convert_legacy_llama.txt --extra-index-url https://download.pytorch.org/whl/cpu -torch~=2.4.0; platform_machine != "s390x" + +## Embedding Gemma requires PyTorch 2.6.0 or later +torch~=2.6.0; platform_machine != "s390x" # torch s390x packages can only be found from nightly builds --extra-index-url https://download.pytorch.org/whl/nightly diff --git a/requirements/requirements-convert_legacy_llama.txt b/requirements/requirements-convert_legacy_llama.txt index 859204b27e..f6076142ce 100644 --- a/requirements/requirements-convert_legacy_llama.txt +++ b/requirements/requirements-convert_legacy_llama.txt @@ -1,5 +1,14 @@ numpy~=1.26.4 sentencepiece~=0.2.0 -transformers>=4.45.1,<5.0.0 + +# Embedding Gemma is currently a preview release: +# https://github.com/huggingface/transformers/releases/tag/v4.56.0-Embedding-Gemma-preview + +# The version is needed to be able to convert Embedding Gemma models to GGUF format: +git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview + +# Once Embedding Gemma is officially released, we can switch to: +#transformers>=4.57.1,<5.0.0 + gguf>=0.1.0 protobuf>=4.21.0,<5.0.0 diff --git a/requirements/requirements-tool_bench.txt b/requirements/requirements-tool_bench.txt index b94521fc7f..f7912aff72 100644 --- a/requirements/requirements-tool_bench.txt +++ b/requirements/requirements-tool_bench.txt @@ -1,6 +1,6 @@ aiohttp~=3.9.3 pytest~=8.3.3 -huggingface_hub~=0.23.2 +huggingface_hub>=0.34.0,<1.0 matplotlib~=3.10.0 numpy~=1.26.4 openai~=1.55.3 diff --git a/scripts/jinja/jinja-tester.py b/scripts/jinja/jinja-tester.py new file mode 100755 index 0000000000..a489305ee7 --- /dev/null +++ b/scripts/jinja/jinja-tester.py @@ -0,0 +1,504 @@ +#!/usr/bin/env python3 +import sys +import json +import argparse +import jinja2.ext as jinja2_ext +from PySide6.QtWidgets import ( + QApplication, + QMainWindow, + QWidget, + QVBoxLayout, + QHBoxLayout, + QLabel, + QPlainTextEdit, + QTextEdit, + QPushButton, + QFileDialog, +) +from PySide6.QtGui import QColor, QColorConstants, QTextCursor, QTextFormat +from PySide6.QtCore import Qt, QRect, QSize +from jinja2 import TemplateSyntaxError +from jinja2.sandbox import ImmutableSandboxedEnvironment +from datetime import datetime + + +def format_template_content(template_content): + """Format the Jinja template content using Jinja2's lexer.""" + if not template_content.strip(): + return template_content + + env = ImmutableSandboxedEnvironment() + tc_rstrip = template_content.rstrip() + tokens = list(env.lex(tc_rstrip)) + result = "" + indent_level = 0 + i = 0 + + while i < len(tokens): + token = tokens[i] + _, token_type, token_value = token + + if token_type == "block_begin": + block_start = i + # Collect all tokens for this block construct + construct_content = token_value + end_token_type = token_type.replace("_begin", "_end") + j = i + 1 + while j < len(tokens) and tokens[j][1] != end_token_type: + construct_content += tokens[j][2] + j += 1 + + if j < len(tokens): # Found the end token + construct_content += tokens[j][2] + i = j # Skip to the end token + + # Check for control structure keywords for indentation + stripped_content = construct_content.strip() + instr = block_start + 1 + while tokens[instr][1] == "whitespace": + instr = instr + 1 + + instruction_token = tokens[instr][2] + start_control_tokens = ["if", "for", "macro", "call", "block"] + end_control_tokens = ["end" + t for t in start_control_tokens] + is_control_start = any( + instruction_token.startswith(kw) for kw in start_control_tokens + ) + is_control_end = any( + instruction_token.startswith(kw) for kw in end_control_tokens + ) + + # Adjust indentation for control structures + # For control end blocks, decrease indent BEFORE adding the content + if is_control_end: + indent_level = max(0, indent_level - 1) + + # Remove all previous whitespace before this block + result = result.rstrip() + + # Add proper indent, but only if this is not the first token + added_newline = False + if result: # Only add newline and indent if there's already content + result += ( + "\n" + " " * indent_level + ) # Use 2 spaces per indent level + added_newline = True + else: # For the first token, don't add any indent + result += "" + + # Add the block content + result += stripped_content + + # Add '-' after '%' if it wasn't there and we added a newline or indent + if ( + added_newline + and stripped_content.startswith("{%") + and not stripped_content.startswith("{%-") + ): + # Add '-' at the beginning + result = ( + result[: result.rfind("{%")] + + "{%-" + + result[result.rfind("{%") + 2 :] + ) + if stripped_content.endswith("%}") and not stripped_content.endswith( + "-%}" + ): + # Only add '-' if this is not the last token or if there's content after + if i + 1 < len(tokens) and tokens[i + 1][1] != "eof": + result = result[:-2] + "-%}" + + # For control start blocks, increase indent AFTER adding the content + if is_control_start: + indent_level += 1 + else: + # Malformed template, just add the token + result += token_value + elif token_type == "variable_begin": + # Collect all tokens for this variable construct + construct_content = token_value + end_token_type = token_type.replace("_begin", "_end") + j = i + 1 + while j < len(tokens) and tokens[j][1] != end_token_type: + construct_content += tokens[j][2] + j += 1 + + if j < len(tokens): # Found the end token + construct_content += tokens[j][2] + i = j # Skip to the end token + + # For variable constructs, leave them alone + # Do not add indent or whitespace before or after them + result += construct_content + else: + # Malformed template, just add the token + result += token_value + elif token_type == "data": + # Handle data (text between Jinja constructs) + # For data content, preserve it as is + result += token_value + else: + # Handle any other tokens + result += token_value + + i += 1 + + # Clean up trailing newlines and spaces + result = result.rstrip() + + # Copy the newline / space count from the original + if (trailing_length := len(template_content) - len(tc_rstrip)): + result += template_content[-trailing_length:] + + return result + + +# ------------------------ +# Line Number Widget +# ------------------------ +class LineNumberArea(QWidget): + def __init__(self, editor): + super().__init__(editor) + self.code_editor = editor + + def sizeHint(self): + return QSize(self.code_editor.line_number_area_width(), 0) + + def paintEvent(self, event): + self.code_editor.line_number_area_paint_event(event) + + +class CodeEditor(QPlainTextEdit): + def __init__(self): + super().__init__() + self.line_number_area = LineNumberArea(self) + + self.blockCountChanged.connect(self.update_line_number_area_width) + self.updateRequest.connect(self.update_line_number_area) + self.cursorPositionChanged.connect(self.highlight_current_line) + + self.update_line_number_area_width(0) + self.highlight_current_line() + + def line_number_area_width(self): + digits = len(str(self.blockCount())) + space = 3 + self.fontMetrics().horizontalAdvance("9") * digits + return space + + def update_line_number_area_width(self, _): + self.setViewportMargins(self.line_number_area_width(), 0, 0, 0) + + def update_line_number_area(self, rect, dy): + if dy: + self.line_number_area.scroll(0, dy) + else: + self.line_number_area.update( + 0, rect.y(), self.line_number_area.width(), rect.height() + ) + + if rect.contains(self.viewport().rect()): + self.update_line_number_area_width(0) + + def resizeEvent(self, event): + super().resizeEvent(event) + cr = self.contentsRect() + self.line_number_area.setGeometry( + QRect(cr.left(), cr.top(), self.line_number_area_width(), cr.height()) + ) + + def line_number_area_paint_event(self, event): + from PySide6.QtGui import QPainter + + painter = QPainter(self.line_number_area) + painter.fillRect(event.rect(), QColorConstants.LightGray) + + block = self.firstVisibleBlock() + block_number = block.blockNumber() + top = int( + self.blockBoundingGeometry(block).translated(self.contentOffset()).top() + ) + bottom = top + int(self.blockBoundingRect(block).height()) + + while block.isValid() and top <= event.rect().bottom(): + if block.isVisible() and bottom >= event.rect().top(): + number = str(block_number + 1) + painter.setPen(QColorConstants.Black) + painter.drawText( + 0, + top, + self.line_number_area.width() - 2, + self.fontMetrics().height(), + Qt.AlignmentFlag.AlignRight, + number, + ) + block = block.next() + top = bottom + bottom = top + int(self.blockBoundingRect(block).height()) + block_number += 1 + + def highlight_current_line(self): + extra_selections = [] + if not self.isReadOnly(): + selection = QTextEdit.ExtraSelection() + line_color = QColorConstants.Yellow.lighter(160) + selection.format.setBackground(line_color) # pyright: ignore[reportAttributeAccessIssue] + selection.format.setProperty(QTextFormat.Property.FullWidthSelection, True) # pyright: ignore[reportAttributeAccessIssue] + selection.cursor = self.textCursor() # pyright: ignore[reportAttributeAccessIssue] + selection.cursor.clearSelection() # pyright: ignore[reportAttributeAccessIssue] + extra_selections.append(selection) + self.setExtraSelections(extra_selections) + + def highlight_position(self, lineno: int, col: int, color: QColor): + block = self.document().findBlockByLineNumber(lineno - 1) + if block.isValid(): + cursor = QTextCursor(block) + text = block.text() + start = block.position() + max(0, col - 1) + cursor.setPosition(start) + if col <= len(text): + cursor.movePosition( + QTextCursor.MoveOperation.NextCharacter, + QTextCursor.MoveMode.KeepAnchor, + ) + + extra = QTextEdit.ExtraSelection() + extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue] + extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue] + + self.setExtraSelections(self.extraSelections() + [extra]) + + def highlight_line(self, lineno: int, color: QColor): + block = self.document().findBlockByLineNumber(lineno - 1) + if block.isValid(): + cursor = QTextCursor(block) + cursor.select(QTextCursor.SelectionType.LineUnderCursor) + + extra = QTextEdit.ExtraSelection() + extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue] + extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue] + + self.setExtraSelections(self.extraSelections() + [extra]) + + def clear_highlighting(self): + self.highlight_current_line() + + +# ------------------------ +# Main App +# ------------------------ +class JinjaTester(QMainWindow): + def __init__(self): + super().__init__() + self.setWindowTitle("Jinja Template Tester") + self.resize(1200, 800) + + central = QWidget() + main_layout = QVBoxLayout(central) + + # -------- Top input area -------- + input_layout = QHBoxLayout() + + # Template editor with label + template_layout = QVBoxLayout() + template_label = QLabel("Jinja2 Template") + template_layout.addWidget(template_label) + self.template_edit = CodeEditor() + template_layout.addWidget(self.template_edit) + input_layout.addLayout(template_layout) + + # JSON editor with label + json_layout = QVBoxLayout() + json_label = QLabel("Context (JSON)") + json_layout.addWidget(json_label) + self.json_edit = CodeEditor() + self.json_edit.setPlainText(""" +{ + "add_generation_prompt": true, + "bos_token": "", + "eos_token": "", + "messages": [ + { + "role": "user", + "content": "What is the capital of Poland?" + } + ] +} + """.strip()) + json_layout.addWidget(self.json_edit) + input_layout.addLayout(json_layout) + + main_layout.addLayout(input_layout) + + # -------- Rendered output area -------- + output_label = QLabel("Rendered Output") + main_layout.addWidget(output_label) + self.output_edit = QPlainTextEdit() + self.output_edit.setReadOnly(True) + main_layout.addWidget(self.output_edit) + + # -------- Render button and status -------- + btn_layout = QHBoxLayout() + + # Load template button + self.load_btn = QPushButton("Load Template") + self.load_btn.clicked.connect(self.load_template) + btn_layout.addWidget(self.load_btn) + + # Format template button + self.format_btn = QPushButton("Format") + self.format_btn.clicked.connect(self.format_template) + btn_layout.addWidget(self.format_btn) + + self.render_btn = QPushButton("Render") + self.render_btn.clicked.connect(self.render_template) + btn_layout.addWidget(self.render_btn) + main_layout.addLayout(btn_layout) + + # Status label below buttons + self.status_label = QLabel("Ready") + main_layout.addWidget(self.status_label) + + self.setCentralWidget(central) + + def render_template(self): + self.template_edit.clear_highlighting() + self.output_edit.clear() + + template_str = self.template_edit.toPlainText() + json_str = self.json_edit.toPlainText() + + # Parse JSON context + try: + context = json.loads(json_str) if json_str.strip() else {} + except Exception as e: + self.status_label.setText(f"❌ JSON Error: {e}") + return + + def raise_exception(text: str) -> str: + raise RuntimeError(text) + + env = ImmutableSandboxedEnvironment( + trim_blocks=True, + lstrip_blocks=True, + extensions=[jinja2_ext.loopcontrols], + ) + env.filters["tojson"] = ( + lambda x, + indent=None, + separators=None, + sort_keys=False, + ensure_ascii=False: json.dumps( + x, + indent=indent, + separators=separators, + sort_keys=sort_keys, + ensure_ascii=ensure_ascii, + ) + ) + env.globals["strftime_now"] = lambda format: datetime.now().strftime(format) + env.globals["raise_exception"] = raise_exception + try: + template = env.from_string(template_str) + output = template.render(context) + self.output_edit.setPlainText(output) + self.status_label.setText("✅ Render successful") + except TemplateSyntaxError as e: + self.status_label.setText(f"❌ Syntax Error (line {e.lineno}): {e.message}") + if e.lineno: + self.template_edit.highlight_line(e.lineno, QColor("red")) + except Exception as e: + # Catch all runtime errors + # Try to extract template line number + lineno = None + tb = e.__traceback__ + while tb: + frame = tb.tb_frame + if frame.f_code.co_filename == "