diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c6d51fb0c2..3d4f837e24 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -159,31 +159,15 @@ jobs: - name: Dawn Dependency id: dawn-depends run: | - ARTIFACTS_JSON=$(curl -s -L \ - -H "Accept: application/vnd.github+json" \ - -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ - -H "X-GitHub-Api-Version: 2022-11-28" \ - "https://api.github.com/repos/google/dawn/actions/artifacts") - echo "Finding latest macos-latest-Release artifact..." - DOWNLOAD_URL=$(echo "$ARTIFACTS_JSON" | jq -r '.artifacts - | sort_by(.created_at) - | reverse - | map(select(.name | test("macos-latest-Release$"))) - | .[0].archive_download_url') - if [ "$DOWNLOAD_URL" = "null" ] || [ -z "$DOWNLOAD_URL" ]; then - echo "No suitable Dawn artifact found!" - exit 1 - fi - echo "Downloading from: $DOWNLOAD_URL" - curl -L \ - -H "Accept: application/vnd.github+json" \ - -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ - -o artifact.zip "$DOWNLOAD_URL" - unzip artifact.zip + DAWN_VERSION="v1.0.0" + DAWN_OWNER="reeselevine" + DAWN_REPO="dawn" + DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-macos-latest-Release.tar.gz" + echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" + curl -L -o artifact.tar.gz \ + "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" mkdir dawn - tar_file=$(find . -name '*.tar.gz' | head -n 1) - echo "Extracting: $tar_file" - tar -xvf "$tar_file" -C dawn --strip-components=1 + tar -xvf artifact.tar.gz -C dawn --strip-components=1 - name: Build id: cmake_build @@ -433,31 +417,15 @@ jobs: id: dawn-depends run: | sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev - ARTIFACTS_JSON=$(curl -s -L \ - -H "Accept: application/vnd.github+json" \ - -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ - -H "X-GitHub-Api-Version: 2022-11-28" \ - "https://api.github.com/repos/google/dawn/actions/artifacts") - echo "Finding latest ubuntu-latest-Release artifact..." - DOWNLOAD_URL=$(echo "$ARTIFACTS_JSON" | jq -r '.artifacts - | sort_by(.created_at) - | reverse - | map(select(.name | test("ubuntu-latest-Release$"))) - | .[0].archive_download_url') - if [ "$DOWNLOAD_URL" = "null" ] || [ -z "$DOWNLOAD_URL" ]; then - echo "No suitable Dawn artifact found!" - exit 1 - fi - echo "Downloading from: $DOWNLOAD_URL" - curl -L \ - -H "Accept: application/vnd.github+json" \ - -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ - -o artifact.zip "$DOWNLOAD_URL" - unzip artifact.zip + DAWN_VERSION="v1.0.0" + DAWN_OWNER="reeselevine" + DAWN_REPO="dawn" + DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-ubuntu-latest-Release.tar.gz" + echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" + curl -L -o artifact.tar.gz \ + "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" mkdir dawn - tar_file=$(find . -name '*.tar.gz' | head -n 1) - echo "Extracting: $tar_file" - tar -xvf "$tar_file" -C dawn --strip-components=1 + tar -xvf artifact.tar.gz -C dawn --strip-components=1 - name: Build id: cmake_build diff --git a/.github/workflows/pre-tokenizer-hashes.yml b/.github/workflows/pre-tokenizer-hashes.yml new file mode 100644 index 0000000000..dff998e239 --- /dev/null +++ b/.github/workflows/pre-tokenizer-hashes.yml @@ -0,0 +1,45 @@ +name: Check Pre-Tokenizer Hashes + +on: + push: + paths: + - 'convert_hf_to_gguf.py' + - 'convert_hf_to_gguf_update.py' + pull_request: + paths: + - 'convert_hf_to_gguf.py' + - 'convert_hf_to_gguf_update.py' + +jobs: + pre-tokenizer-hashes: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install Python dependencies + run: | + python3 -m venv .venv + .venv/bin/pip install -r requirements/requirements-convert_hf_to_gguf_update.txt + + - name: Update pre-tokenizer hashes + run: | + cp convert_hf_to_gguf.py /tmp + .venv/bin/python convert_hf_to_gguf_update.py --check-missing + + - name: Check if committed pre-tokenizer hashes matches generated version + run: | + if ! diff -q convert_hf_to_gguf.py /tmp/convert_hf_to_gguf.py; then + echo "Model pre-tokenizer hashes (in convert_hf_to_gguf.py) do not match generated hashes (from convert_hf_to_gguf_update.py)." + echo "To fix: run ./convert_hf_to_gguf_update.py and commit the updated convert_hf_to_gguf.py along with your changes" + echo "Differences found:" + diff convert_hf_to_gguf.py /tmp/convert_hf_to_gguf.py || true + exit 1 + fi + echo "Model pre-tokenizer hashes are up to date." diff --git a/README.md b/README.md index 9b2e0f851c..954fff83da 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ LLM inference in C/C++ ## Hot topics +- Support for the `gpt-oss` model with native MXFP4 format has been added | [PR](https://github.com/ggml-org/llama.cpp/pull/15091) | [Collaboration with NVIDIA](https://blogs.nvidia.com/blog/rtx-ai-garage-openai-oss) | [Comment](https://github.com/ggml-org/llama.cpp/discussions/15095) - Hot PRs: [All](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+) | [Open](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+is%3Aopen) - Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md) - VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode diff --git a/common/arg.cpp b/common/arg.cpp index cd85311913..0f01bb3145 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -2375,20 +2376,35 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } throw std::invalid_argument("unknown buffer type"); } - // FIXME: this leaks memory - params.tensor_buft_overrides.push_back({strdup(tensor_name.c_str()), buft_list.at(buffer_type)}); + // keep strings alive and avoid leaking memory by storing them in a static vector + static std::list buft_overrides; + buft_overrides.push_back(tensor_name); + params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), buft_list.at(buffer_type)}); } } )); add_opt(common_arg( - {"--cpu-moe"}, - "use CPU for Mixture of Experts (MoE) weights", + {"--cpu-moe", "-cmoe"}, + "keep all Mixture of Experts (MoE) weights in the CPU", [](common_params & params) { - params.tensor_buft_overrides.push_back({"\\.ffn_up_exps\\.weight$", ggml_backend_cpu_buffer_type()}); - params.tensor_buft_overrides.push_back({"\\.ffn_down_exps\\.weight$", ggml_backend_cpu_buffer_type()}); - params.tensor_buft_overrides.push_back({"\\.ffn_gate_exps\\.weight$", ggml_backend_cpu_buffer_type()}); + params.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()}); } ).set_env("LLAMA_ARG_CPU_MOE")); + add_opt(common_arg( + {"--n-cpu-moe", "-ncmoe"}, "N", + "keep the Mixture of Experts (MoE) weights of the first N layers in the CPU", + [](common_params & params, int value) { + if (value < 0) { + throw std::invalid_argument("invalid value"); + } + for (int i = 0; i < value; ++i) { + // keep strings alive and avoid leaking memory by storing them in a static vector + static std::list buft_overrides; + buft_overrides.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i)); + params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), ggml_backend_cpu_buffer_type()}); + } + } + ).set_env("LLAMA_ARG_N_CPU_MOE")); add_opt(common_arg( {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N", "number of layers to store in VRAM", @@ -2647,6 +2663,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.n_out_freq = value; } ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"--output-format"}, "{gguf,dat}", + string_format("output format for imatrix file (default: %s)", params.imat_dat > 0 ? "dat" : "gguf"), + [](common_params & params, const std::string & value) { + /**/ if (value == "gguf") { params.imat_dat = -1; } + else if (value == "dat") { params.imat_dat = 1; } + else { throw std::invalid_argument("invalid output format"); } + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); add_opt(common_arg( {"--save-frequency"}, "N", string_format("save an imatrix copy every N iterations (default: %d)", params.n_save_freq), @@ -2922,11 +2947,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n" "- none: leaves thoughts unparsed in `message.content`\n" "- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n" - "(default: deepseek)", + "(default: auto)", [](common_params & params, const std::string & value) { /**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; } else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; } else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; } + else if (value == "auto") { params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; } else { throw std::invalid_argument("invalid value"); } } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK")); diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index 18a30e49aa..96ba8f533e 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -55,7 +55,15 @@ bool common_chat_msg_parser::add_tool_call(const std::string & name, const std:: bool common_chat_msg_parser::add_tool_call(const json & tool_call) { std::string name = tool_call.contains("name") ? tool_call.at("name") : ""; std::string id = tool_call.contains("id") ? tool_call.at("id") : ""; - std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : ""; + std::string arguments = ""; + if (tool_call.contains("arguments")) { + if (tool_call.at("arguments").is_object()) { + arguments = tool_call.at("arguments").dump(); + } else { + arguments = tool_call.at("arguments"); + } + } + return add_tool_call(name, id, arguments); } diff --git a/common/chat.cpp b/common/chat.cpp index 0c777d7a78..316bd24170 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -126,6 +126,8 @@ std::vector common_chat_msg_diff::compute_diffs(const comm typedef minja::chat_template common_chat_template; struct common_chat_templates { + bool add_bos; + bool add_eos; bool has_explicit_template; // Model had builtin template or template overridde was specified. std::unique_ptr template_default; // always set (defaults to chatml) std::unique_ptr template_tool_use; @@ -143,6 +145,8 @@ struct templates_params { bool enable_thinking = true; std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); json extra_context; + bool add_bos; + bool add_eos; }; common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { @@ -445,6 +449,8 @@ std::string common_chat_format_single( common_chat_templates_inputs inputs; inputs.use_jinja = use_jinja; + inputs.add_bos = tmpls->add_bos; + inputs.add_eos = tmpls->add_eos; std::string fmt_past_msg; if (!past_msg.empty()) { @@ -469,6 +475,8 @@ std::string common_chat_format_single( std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) { common_chat_templates_inputs inputs; inputs.use_jinja = use_jinja; + inputs.add_bos = tmpls->add_bos; + inputs.add_eos = tmpls->add_eos; auto add_simple_msg = [&](auto role, auto content) { common_chat_msg msg; msg.role = role; @@ -546,6 +554,8 @@ common_chat_templates_ptr common_chat_templates_init( } std::string token_bos = bos_token_override; std::string token_eos = eos_token_override; + bool add_bos = false; + bool add_eos = false; if (model) { const auto * vocab = llama_model_get_vocab(model); const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { @@ -560,9 +570,13 @@ common_chat_templates_ptr common_chat_templates_init( }; token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); + add_bos = llama_vocab_get_add_bos(vocab); + add_eos = llama_vocab_get_add_eos(vocab); } common_chat_templates_ptr tmpls(new common_chat_templates()); tmpls->has_explicit_template = has_explicit_template; + tmpls->add_bos = add_bos; + tmpls->add_eos = add_eos; try { tmpls->template_default = std::make_unique(default_template_src, token_bos, token_eos); } catch (const std::exception & e) { @@ -592,6 +606,8 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; + case COMMON_CHAT_FORMAT_GRANITE: return "Granite"; + case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS"; default: throw std::runtime_error("Unknown chat format"); } @@ -600,8 +616,10 @@ const char * common_chat_format_name(common_chat_format format) { const char * common_reasoning_format_name(common_reasoning_format format) { switch (format) { case COMMON_REASONING_FORMAT_NONE: return "none"; + case COMMON_REASONING_FORMAT_AUTO: return "auto"; case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek"; case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy"; + case COMMON_REASONING_FORMAT_GRANITE: return "granite"; default: throw std::runtime_error("Unknown reasoning format"); } @@ -748,10 +766,10 @@ static std::string apply( // instead of using `chat_template_options.use_bos_token = false`, since these tokens // may be needed inside the template / between messages too. auto result = tmpl.apply(tmpl_inputs, tmpl_opts); - if (string_starts_with(result, tmpl.bos_token())) { + if (inputs.add_bos && string_starts_with(result, tmpl.bos_token())) { result = result.substr(tmpl.bos_token().size()); } - if (string_ends_with(result, tmpl.eos_token())) { + if (inputs.add_eos && string_ends_with(result, tmpl.eos_token())) { result = result.substr(0, result.size() - tmpl.eos_token().size()); } return result; @@ -1289,6 +1307,26 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { tool_calls_end); } +static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + auto prompt = apply(tmpl, inputs); + + data.prompt = prompt; + data.format = COMMON_CHAT_FORMAT_GPT_OSS; + + // TODO: support tool calls in GPT-OSS? + + return data; +} +static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { + // TODO @ngxson : this won't work with --special enabled, we should fix that + builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>"); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } +} + static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { LOG_DBG("%s\n", __func__); common_chat_params data; @@ -1646,7 +1684,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { "|" // match 5 (function name again) ); - if (auto res = builder.try_find_regex(open_regex)) { + while (auto res = builder.try_find_regex(open_regex)) { const auto & block_start = res->groups[1]; std::string block_end = block_start.empty() ? "" : "```"; @@ -1668,7 +1706,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { builder.consume_literal(block_end); builder.consume_spaces(); } - builder.add_content(builder.consume_rest()); } else { throw common_chat_msg_partial_exception("failed to parse tool call"); } @@ -1693,7 +1730,124 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { builder.consume_spaces(); } } - builder.add_content(builder.consume_rest()); + } + } + + builder.add_content(builder.consume_rest()); +} + +static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + + // Pass thinking context for Granite template + json additional_context = { + {"thinking", inputs.enable_thinking}, + }; + + data.prompt = apply(tmpl, inputs, /* messages_override= */ std::nullopt, /* tools_override= */ std::nullopt, additional_context); + data.format = COMMON_CHAT_FORMAT_GRANITE; + + if (string_ends_with(data.prompt, "\n") || string_ends_with(data.prompt, "")) { + if (!inputs.enable_thinking) { + data.prompt += ""; + } else { + data.thinking_forced_open = true; + } + } + + if (!inputs.tools.is_null()) { + // Granite uses <|tool_call|> followed by JSON list + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + tool_rules.push_back(builder.add_rule(name + "-call", builder.add_schema(name + +"-args", { + {"type", "object"}, + {"properties", { + {"name", {{"const", name}}}, + {"arguments", parameters}, + }}, + {"required", json::array({"name", "arguments"})}, + }))); + }); + + auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")); + auto tool_list = builder.add_rule("tool_list", "\"[\" space " + tool_call + " (\",\" space " + tool_call + ")* space \"]\""); + + if (data.thinking_forced_open) { + builder.add_rule("root", "\"\" space \"\" space [^<]* \"\" space \"<|tool_call|>\" space " + tool_list); + } else { + builder.add_rule("root", "\"<|tool_call|>\" space " + tool_list); + } + + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_WORD, + "<|tool_call|>" + }); + + data.preserved_tokens = { + "", + "", + "", + "", + "<|tool_call|>", + }; + }); + } else { + // Handle thinking tags for non-tool responses + if (data.thinking_forced_open && inputs.enable_thinking) { + data.grammar_lazy = false; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + builder.add_rule("root", "\"\" space \"\" space .* \"\" space"); + }); + data.preserved_tokens = { + "", + "", + "", + "", + }; + } + } + + return data; +} + +static void common_chat_parse_granite(common_chat_msg_parser & builder) { + // Parse thinking tags + builder.try_parse_reasoning("", ""); + + // Parse response tags using regex + static const common_regex response_regex("([\\s\\S]*?)"); + if (auto res = builder.try_find_regex(response_regex)) { + // Extract the content between the tags (capture group 1) + auto content = builder.str(res->groups[1]); + builder.add_content(content); + builder.move_to(res->groups[0].end); + } + + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // Look for tool calls + static const common_regex tool_call_regex(regex_escape("<|tool_call|>")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + // Expect JSON array of tool calls + auto tool_calls_data = builder.consume_json(); + if (tool_calls_data.json.is_array()) { + if (!builder.add_tool_calls(tool_calls_data.json)) { + builder.add_content("<|tool_call|>" + tool_calls_data.json.dump()); + } + } else { + builder.add_content("<|tool_call|>" + tool_calls_data.json.dump()); } } else { builder.add_content(builder.consume_rest()); @@ -1733,6 +1887,8 @@ static common_chat_params common_chat_templates_apply_jinja( params.enable_thinking = inputs.enable_thinking; params.grammar = inputs.grammar; params.now = inputs.now; + params.add_bos = inputs.add_bos; + params.add_eos = inputs.add_eos; params.extra_context = json::object(); for (auto el : inputs.chat_template_kwargs) { @@ -1769,11 +1925,21 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_command_r7b(tmpl, params); } + // Granite (IBM) - detects thinking / tools support + if (src.find("elif thinking") != std::string::npos && src.find("<|tool_call|>") != std::string::npos) { + return common_chat_params_init_granite(tmpl, params); + } + // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) if (src.find("") != std::string::npos && params.json_schema.is_null()) { return common_chat_params_init_hermes_2_pro(tmpl, params); } + // GPT-OSS + if (src.find("<|channel|>") != std::string::npos && params.json_schema.is_null()) { + return common_chat_params_init_gpt_oss(tmpl, params); + } + // Use generic handler when mixing tools + JSON schema. // TODO: support that mix in handlers below. if ((params.tools.is_array() && params.json_schema.is_object())) { @@ -1824,6 +1990,7 @@ static common_chat_params common_chat_templates_apply_legacy( int alloc_size = 0; std::vector chat; std::vector contents; + for (const auto & msg : inputs.messages) { auto content = msg.content; for (const auto & part : msg.content_parts) { @@ -1925,6 +2092,12 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_COMMAND_R7B: common_chat_parse_command_r7b(builder); break; + case COMMON_CHAT_FORMAT_GRANITE: + common_chat_parse_granite(builder); + break; + case COMMON_CHAT_FORMAT_GPT_OSS: + common_chat_parse_gpt_oss(builder); + break; default: throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } diff --git a/common/chat.h b/common/chat.h index ca807c145e..eb628d8bc2 100644 --- a/common/chat.h +++ b/common/chat.h @@ -109,6 +109,8 @@ enum common_chat_format { COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, COMMON_CHAT_FORMAT_HERMES_2_PRO, COMMON_CHAT_FORMAT_COMMAND_R7B, + COMMON_CHAT_FORMAT_GRANITE, + COMMON_CHAT_FORMAT_GPT_OSS, COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; @@ -127,6 +129,8 @@ struct common_chat_templates_inputs { bool enable_thinking = true; std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); std::map chat_template_kwargs; + bool add_bos = false; + bool add_eos = false; }; struct common_chat_params { diff --git a/common/common.h b/common/common.h index b8b01a7e99..5eab199af5 100644 --- a/common/common.h +++ b/common/common.h @@ -236,8 +236,10 @@ struct common_params_diffusion { enum common_reasoning_format { COMMON_REASONING_FORMAT_NONE, + COMMON_REASONING_FORMAT_AUTO, COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in tags in stream mode COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas. + COMMON_REASONING_FORMAT_GRANITE, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas. }; struct common_params { @@ -394,7 +396,7 @@ struct common_params { std::string chat_template = ""; // NOLINT bool use_jinja = false; // NOLINT bool enable_chat_template = true; - common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO; int reasoning_budget = -1; bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response @@ -439,6 +441,7 @@ struct common_params { int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations int32_t i_chunk = 0; // start processing from this chunk + int8_t imat_dat = 0; // whether the legacy imatrix.dat format should be output (gguf <= 0 < dat) bool process_output = false; // collect data for the output tensor bool compute_ppl = true; // whether to compute perplexity diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index feef03d1ce..b8c7d97a78 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -678,6 +678,9 @@ class TextModel(ModelBase): if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2": # ref: https://huggingface.co/THUDM/glm-4-9b-hf res = "glm4" + if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902": + # ref: https://huggingface.co/zai-org/GLM-4.5-Air + res = "glm4" if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 res = "minerva-7b" @@ -702,6 +705,9 @@ class TextModel(ModelBase): if chkhsh == "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890": # ref: https://huggingface.co/moonshotai/Kimi-K2-Base res = "kimi-k2" + if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c": + # ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B + res = "qwen2" if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B res = "llama-bpe" @@ -849,6 +855,9 @@ class TextModel(ModelBase): if chkhsh == "2085e1638f6c377a0aa4ead21b27bb4cb941bf800df86ed391011769c1758dfb": # ref: https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B res = "exaone4" + if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756": + # ref: https://huggingface.co/JetBrains/Mellum-4b-base + res = "mellum" if res is None: logger.warning("\n") @@ -3319,7 +3328,13 @@ class Qwen25OmniModel(Qwen2VLVisionModel): @ModelBase.register("InternVisionModel") class InternVisionModel(MmprojModel): def set_gguf_parameters(self): + assert self.hparams_vision is not None + if isinstance(self.hparams_vision['image_size'], list): + self.hparams_vision['image_size'] = self.hparams_vision['image_size'][0] + if isinstance(self.hparams_vision['patch_size'], list): + self.hparams_vision['patch_size'] = self.hparams_vision['patch_size'][0] super().set_gguf_parameters() + hparams = self.hparams self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.INTERNVL) self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"]) @@ -3343,14 +3358,30 @@ class InternVisionModel(MmprojModel): return gguf.GGMLQuantizationType.F32 return False + def _mapping_interns1_name(self, name): + names_map = { + "model.multi_modal_projector.layer_norm.bias": "mlp1.0.bias", + "model.multi_modal_projector.layer_norm.weight": "mlp1.0.weight", + "model.multi_modal_projector.linear_1.bias": "mlp1.1.bias", + "model.multi_modal_projector.linear_1.weight": "mlp1.1.weight", + "model.multi_modal_projector.linear_2.bias": "mlp1.3.bias", + "model.multi_modal_projector.linear_2.weight": "mlp1.3.weight", + } + if name in names_map: + name = names_map[name] + return name + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused - if name.startswith("vision_model") or name.startswith("mlp"): + vision_prefix = ['vision_model', 'mlp', 'model.vision_tower', 'model.multi_modal_projector'] + # deal with intern-s1 special case + name = self._mapping_interns1_name(name) + if any([name.startswith(prefix) for prefix in vision_prefix]): # process visual tensors # correct name if name.startswith("vision_model"): name = "vision_tower." + name - if (".ls" in name or "position_embedding" in name) and not name.endswith(".weight"): + if (".ls" in name or ".lambda_" in name or "position_embedding" in name) and not name.endswith(".weight"): name += ".weight" # split QKV tensors if needed if ".qkv." in name: @@ -3436,6 +3467,10 @@ class Qwen2MoeModel(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # process the experts separately + name = name.replace("language_model.", "") # InternVL + if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"): + # skip visual tensors + return [] if name.find("experts") != -1: n_experts = self.hparams["num_experts"] assert bid is not None @@ -3489,6 +3524,85 @@ class Qwen3Model(Qwen2Model): class Qwen3MoeModel(Qwen2MoeModel): model_arch = gguf.MODEL_ARCH.QWEN3MOE + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + hparams = ModelBase.load_hparams(self.dir_model) + self.origin_hf_arch = hparams.get('architectures', [None])[0] + + def set_vocab(self): + # deal with intern-s1 + if self.origin_hf_arch == 'InternS1ForConditionalGeneration': + self._set_vocab_interns1() + return + + try: + self._set_vocab_sentencepiece() + except FileNotFoundError: + self._set_vocab_gpt2() + + def _set_vocab_interns1(self): + tokens: list[str] = [] + toktypes: list[int] = [] + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab()) + vocab_size = self.hparams.get("vocab_size", len(vocab)) + assert max(vocab.values()) < vocab_size + + tokpre = self.get_vocab_base_pre(tokenizer) + + reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()} + added_vocab = tokenizer.get_added_vocab() + + added_tokens_decoder = tokenizer.added_tokens_decoder + + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.UNUSED) + else: + token: str = reverse_vocab[i] + if token in added_vocab: + # The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized. + # To avoid unexpected issues - we make sure to normalize non-normalized tokens + if not added_tokens_decoder[i].normalized: + previous_token = token + token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) + if previous_token != token: + logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer") + + if added_tokens_decoder[i].special or self.does_token_look_special(token): + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.USER_DEFINED) + else: + toktypes.append(gguf.TokenType.NORMAL) + tokens.append(token) + + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_tokens_map_file = self.dir_model / 'special_tokens_map.json' + additional_special_tokens = [] + if special_tokens_map_file.is_file(): + with open(special_tokens_map_file, encoding = 'utf-8') as f: + additional_special_tokens = json.load(f).get('additional_special_tokens', []) + tokenizer_cfg_file = self.dir_model / 'special_tokens_map.json' + if tokenizer_cfg_file.is_file(): + with open(tokenizer_cfg_file, encoding = 'utf-8') as f: + added_tokens_decoder = json.load(f).get('added_tokens_decoder', {}) + token2ids_map = {data['content'] : int(token) for token, data in added_tokens_decoder.items() if data['special']} + for token in additional_special_tokens: + if token in token2ids_map: + special_vocab._set_special_token(token, token2ids_map[token]) + special_vocab._set_special_token('eos', 151645) + special_vocab._set_special_token("bos", 151643) + special_vocab.add_to_gguf(self.gguf_writer) + @ModelBase.register("GPT2LMHeadModel") class GPT2Model(TextModel): @@ -6056,6 +6170,7 @@ class DeepseekModel(TextModel): @ModelBase.register("DeepseekV2ForCausalLM") @ModelBase.register("DeepseekV3ForCausalLM") +@ModelBase.register("KimiVLForConditionalGeneration") class DeepseekV2Model(TextModel): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 @@ -6158,6 +6273,13 @@ class DeepseekV2Model(TextModel): _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # skip vision tensors and remove "language_model." for Kimi-VL + if "vision_tower" in name or "multi_modal_projector" in name: + return [] + + if name.startswith("language_model."): + name = name.replace("language_model.", "") + # rename e_score_correction_bias tensors if name.endswith("e_score_correction_bias"): name = name.replace("e_score_correction_bias", "e_score_correction.bias") @@ -6682,6 +6804,139 @@ class Glm4Model(TextModel): return super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Glm4MoeForCausalLM") +class Glm4MoeModel(TextModel): + model_arch = gguf.MODEL_ARCH.GLM4_MOE + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # GLM4_MOE has num_hidden_layers + 1 actual layers (including NextN layer) + self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_vocab(self): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + # Special tokens + # Note: Using <|endoftext|> (151329) for eot causes endless generation + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329 + special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 + + # Patch broken chat template + if isinstance(special_vocab.chat_template, str) and "visible_text(m.content).endswith" in special_vocab.chat_template: + special_vocab.chat_template = special_vocab.chat_template.replace( + """{{ visible_text(m.content) }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}}""", + """{% set content = visible_text(m.content) %}{{ content }}\n{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not content.endswith("/nothink")) else '' -}}""") + + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + if (rope_dim := self.hparams.get("head_dim")) is None: + rope_dim = ( + self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + ) + self.gguf_writer.add_rope_dimension_count( + int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)) + ) + + # MoE parameters - Use only routed expert count (shared experts handled separately) + if (n_routed_experts := self.hparams.get("n_routed_experts")) is not None: + self.gguf_writer.add_expert_count(n_routed_experts) + if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + if (n_shared_experts := self.hparams.get("n_shared_experts")) is not None: + self.gguf_writer.add_expert_shared_count(n_shared_experts) + if (first_k_dense_replace := self.hparams.get("first_k_dense_replace")) is not None: + self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace) + + # Expert gating function (sigmoid for GLM4_MOE) + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + + # Routed scaling factor + if (routed_scaling_factor := self.hparams.get("routed_scaling_factor")) is not None: + self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) + + # Normalise topk probabilities + if (norm_topk_prob := self.hparams.get("norm_topk_prob")) is not None: + self.gguf_writer.add_expert_weights_norm(norm_topk_prob) + + # NextN/MTP prediction layers + if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None: + self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: + if name.startswith("model.visual."): # ignore visual part + return [] + elif name.startswith("model.language_model."): + name = name.replace("language_model.", "") # for multimodal variants + + # Handle main token embedding (but not layer-specific NextN embeddings) + if name == "model.embed_tokens.weight" and ".layers." not in name: + return [(self.map_tensor_name("token_embd.weight"), data_torch)] + + # Handle routed experts + if name.find("mlp.experts") != -1: + n_experts = self.hparams["n_routed_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + + new_name = self.map_tensor_name(name) + + return [(new_name, data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + @ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") class ChatGLMModel(TextModel): model_arch = gguf.MODEL_ARCH.CHATGLM @@ -7800,6 +8055,130 @@ class SmolLM3Model(LlamaModel): self.gguf_writer.add_chat_template(chat_template) +@ModelBase.register("GptOssForCausalLM") +class GptOssModel(TextModel): + model_arch = gguf.MODEL_ARCH.GPT_OSS + + def transform_nibble_layout(self, tensor): + assert tensor.dtype == torch.uint8 + assert tensor.shape[-1] == 16 + # swap nibbles + t_lo = tensor & 0x0F + t_hi = tensor & 0xF0 + t_swapped = (t_lo << 4) | (t_hi >> 4) + tensor = t_swapped + # transform aaaa...bbbb... to abababab... + blk_a, blk_b = tensor.chunk(2, dim=-1) + # get a_ + blk_a0 = (blk_a & 0xF0).view(-1, 1) + blk_a1 = (blk_a << 4).view(-1, 1) + blk_a = torch.stack((blk_a0, blk_a1), dim=2).view(tensor.shape) + # get _b + blk_b0 = (blk_b >> 4).view(-1, 1) + blk_b1 = (blk_b & 0x0F).view(-1, 1) + blk_b = torch.stack((blk_b0, blk_b1), dim=2).view(tensor.shape) + # swap once more + out = blk_a | blk_b + out_h = out & 0xF0 + out_l = out & 0x0F + out = (out_h >> 4) | (out_l << 4) + return out + + def repack_mxfp4(self, new_name: str, blocks: Tensor, scales: Tensor): + assert blocks.dtype == torch.uint8 + assert scales.dtype == torch.uint8 + scales = scales.unsqueeze(-1) + assert len(blocks.shape) == 4 + assert len(scales.shape) == 4 + blocks = self.transform_nibble_layout(blocks) + new_data = torch.concat((scales, blocks), dim=-1) + new_shape = [new_data.shape[0], new_data.shape[1], new_data.shape[2] * 32] + logger.info(f"Repacked {new_name} with shape {new_shape} and quantization MXFP4") + # flatten last dim + new_data = new_data.view(new_data.shape[0], new_data.shape[1], new_data.shape[2] * new_data.shape[3]) + new_data = new_data.numpy() + self.gguf_writer.add_tensor(new_name, new_data, raw_dtype=gguf.GGMLQuantizationType.MXFP4) + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + blocks0: Tensor = torch.zeros(1) + blocks1: Tensor = torch.zeros(1) + # we assume that tensors are loaded in the correct order + for name, data_torch in self.get_tensors(): + if "mlp.experts.down_proj_blocks" in name: + blocks0 = data_torch + elif "mlp.experts.down_proj_scales" in name: + new_name = self.map_tensor_name(name.replace("_scales", ".weight")) + self.repack_mxfp4(new_name, blocks0, data_torch) + elif "mlp.experts.gate_up_proj_blocks" in name: + blocks0, blocks1 = data_torch[:, ::2, :, :], data_torch[:, 1::2, :, :] + elif "mlp.experts.gate_up_proj_scales" in name: + scales0, scales1 = data_torch[:, ::2, :], data_torch[:, 1::2, :] + new_name_gate = self.map_tensor_name(name.replace("gate_up_proj_scales", "gate_proj.weight")) + new_name_up = self.map_tensor_name(name.replace("gate_up_proj_scales", "up_proj.weight")) + self.repack_mxfp4(new_name_gate, blocks0, scales0) + self.repack_mxfp4(new_name_up, blocks1, scales1) + return [] + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if "sinks" in name: + name += ".weight" + + # correct naming for down_proj + if "down_proj" in name: + if name.endswith("_bias"): + name = name.replace("down_proj_bias", "down_proj.bias") + elif "_blocks" not in name and "_scales" not in name: + logger.warning(f"{name} is not in MXFP4, performance may be degraded") + name = name.replace("down_proj", "down_proj.weight") + data_torch = data_torch.transpose(-1, -2) + else: + # otherwise, it should already be repacked to ggml MXFP4 format + return [] + + # split the gate_up into gate and up + if "gate_up_proj" in name: + if name.endswith("_bias"): + name_up = name.replace("gate_up_proj_bias", "up_proj.bias") + name_gate = name.replace("gate_up_proj_bias", "gate_proj.bias") + gate_proj_bias, up_proj_bias = data_torch[..., ::2], data_torch[..., 1::2] + return [ + (self.map_tensor_name(name_gate), gate_proj_bias), + (self.map_tensor_name(name_up), up_proj_bias) + ] + elif "_blocks" not in name and "_scales" not in name: + logger.warning(f"{name} is not in MXFP4, performance may be degraded") + name_up = name.replace("gate_up_proj", "up_proj.weight") + name_gate = name.replace("gate_up_proj", "gate_proj.weight") + data_torch = data_torch.transpose(-1, -2) + gate_proj_weight, up_proj_weight = data_torch[:, ::2, :], data_torch[:, 1::2, :] + return [ + (self.map_tensor_name(name_gate), gate_proj_weight), + (self.map_tensor_name(name_up), up_proj_weight) + ] + else: + # otherwise, it should already be repacked to ggml MXFP4 format + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def set_vocab(self): + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) + self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size"]) + + rope_scaling = self.hparams.get("rope_scaling") or {} + rope_type = rope_scaling.get("rope_type", rope_scaling.get("type")) + assert rope_type == "yarn", f"GPT-OSS only supports yarn rope scaling, got {rope_type}" + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling.get("original_max_position_embeddings", 4096)) + + @ModelBase.register("Lfm2ForCausalLM") @ModelBase.register("LFM2ForCausalLM") class LFM2Model(TextModel): @@ -7939,6 +8318,7 @@ class LazyTorchTensor(gguf.LazyBase): _dtype_map: dict[torch.dtype, type] = { torch.float16: np.float16, torch.float32: np.float32, + torch.uint8: np.uint8, } # used for safetensors slices diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index c4904b5393..575e05e193 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -59,6 +59,10 @@ parser.add_argument( "--full", action="store_true", help="download full list of models - make sure you have access to all of them", ) +parser.add_argument( + "--check-missing", action="store_true", + help="only check for missing pre-tokenizer hashes", +) parser.add_argument( "hf_token", help="optional HF token", @@ -70,6 +74,10 @@ hf_token = args.hf_token if args.hf_token is not None else hf_token if hf_token is None: logger.warning("HF token not found. You can provide it as an argument or set it in ~/.cache/huggingface/token") +if args.check_missing and args.full: + logger.warning("Downloading full list of models requested, ignoring --check-missing!") + args.check_missing = False + # TODO: this string has to exercise as much pre-tokenizer functionality as possible # will be updated with time - contributions welcome CHK_TXT = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````\"\"\"\"......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL' @@ -130,6 +138,7 @@ models = [ {"name": "midm-2.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct", }, {"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"}, {"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", }, + {"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", }, ] # some models are known to be broken upstream, so we will skip them as exceptions @@ -138,6 +147,7 @@ pre_computed_hashes = [ {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b"}, {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"}, {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, + {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"}, {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, {"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"}, @@ -147,6 +157,7 @@ pre_computed_hashes = [ {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", "chkhsh": "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896"}, {"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"}, {"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"}, + {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"}, ] @@ -221,12 +232,13 @@ if not args.full: all_models = models.copy() models = [model for model in all_models if model["name"] not in existing_models] -logging.info(f"Downloading {len(models)} models...") -for model in models: - try: - download_model(model) - except Exception as e: - logger.error(f"Failed to download model {model['name']}. Error: {e}") +if not args.check_missing: + logging.info(f"Downloading {len(models)} models...") + for model in models: + try: + download_model(model) + except Exception as e: + logger.error(f"Failed to download model {model['name']}. Error: {e}") # generate the source code for the convert_hf_to_gguf.py:get_vocab_base_pre() function: diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 89e5588931..81ce0f0057 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -39,8 +39,9 @@ if (WIN32) set(CMAKE_SHARED_MODULE_PREFIX "") endif() -option(BUILD_SHARED_LIBS "ggml: build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT}) -option(GGML_BACKEND_DL "ggml: build backends as dynamic libraries (requires BUILD_SHARED_LIBS)" OFF) +option(BUILD_SHARED_LIBS "ggml: build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT}) +option(GGML_BACKEND_DL "ggml: build backends as dynamic libraries (requires BUILD_SHARED_LIBS)" OFF) +set(GGML_BACKEND_DIR "" CACHE PATH "ggml: directory to load dynamic backends from (requires GGML_BACKEND_DL") # # option list @@ -175,6 +176,7 @@ option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF) option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON) +option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF) option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF) option(GGML_MUSA_MUDNN_COPY "ggml: enable muDNN for accelerated copy" OFF) option(GGML_VULKAN "ggml: use Vulkan" OFF) diff --git a/ggml/cmake/ggml-config.cmake.in b/ggml/cmake/ggml-config.cmake.in index 2322c6cd9d..91c9d5cd34 100644 --- a/ggml/cmake/ggml-config.cmake.in +++ b/ggml/cmake/ggml-config.cmake.in @@ -125,54 +125,56 @@ if(NOT TARGET ggml::ggml) IMPORTED_LOCATION "${GGML_BASE_LIBRARY}") set(_ggml_all_targets "") - foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS}) - string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}") - string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx) + if (NOT GGML_BACKEND_DL) + foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS}) + string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}") + string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx) - find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend} - REQUIRED - HINTS ${GGML_LIB_DIR} - NO_CMAKE_FIND_ROOT_PATH) + find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend} + REQUIRED + HINTS ${GGML_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH) - message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}") + message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}") - add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED) - set_target_properties(ggml::${_ggml_backend} - PROPERTIES - INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}" - IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" - IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}" - INTERFACE_COMPILE_FEATURES c_std_90 - POSITION_INDEPENDENT_CODE ON) - - string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}") - if(is_cpu_variant) - list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base") - set_target_properties(ggml::${_ggml_backend} - PROPERTIES - INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}") - - if(GGML_CPU_INTERFACE_LINK_OPTIONS) - set_target_properties(ggml::${_ggml_backend} - PROPERTIES - INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}") - endif() - - else() - list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base") + add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED) set_target_properties(ggml::${_ggml_backend} PROPERTIES - INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}") + INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}" + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}" + INTERFACE_COMPILE_FEATURES c_std_90 + POSITION_INDEPENDENT_CODE ON) - if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS) + string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}") + if(is_cpu_variant) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base") + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}") + + if(GGML_CPU_INTERFACE_LINK_OPTIONS) + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}") + endif() + + else() + list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base") set_target_properties(ggml::${_ggml_backend} PROPERTIES - INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}") - endif() - endif() + INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}") - list(APPEND _ggml_all_targets ggml::${_ggml_backend}) - endforeach() + if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS) + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}") + endif() + endif() + + list(APPEND _ggml_all_targets ggml::${_ggml_backend}) + endforeach() + endif() list(APPEND GGML_INTERFACE_LINK_LIBRARIES ggml::ggml-base "${_ggml_all_targets}") set_target_properties(ggml::ggml diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 8a8775be36..2f06e1e39b 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -304,6 +304,16 @@ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ GGML_TENSOR_LOCALS(size_t, nb, dst, nb) +#define GGML_TENSOR_TERNARY_OP_LOCALS \ + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \ + GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + #define GGML_TENSOR_BINARY_OP_LOCALS01 \ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ @@ -395,7 +405,8 @@ extern "C" { // GGML_TYPE_IQ4_NL_4_4 = 36, // GGML_TYPE_IQ4_NL_4_8 = 37, // GGML_TYPE_IQ4_NL_8_8 = 38, - GGML_TYPE_COUNT = 39, + GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) + GGML_TYPE_COUNT = 40, }; // precision @@ -430,6 +441,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors + GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors }; // available tensor operations: @@ -438,6 +450,7 @@ extern "C" { GGML_OP_DUP, GGML_OP_ADD, + GGML_OP_ADD_ID, GGML_OP_ADD1, GGML_OP_ACC, GGML_OP_SUB, @@ -557,6 +570,7 @@ extern "C" { GGML_GLU_OP_REGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU, + GGML_GLU_OP_SWIGLU_OAI, GGML_GLU_OP_GEGLU_ERF, GGML_GLU_OP_GEGLU_QUICK, @@ -831,6 +845,13 @@ extern "C" { struct ggml_tensor * b, enum ggml_type type); + // dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]] + GGML_API struct ggml_tensor * ggml_add_id( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * ids); + GGML_API struct ggml_tensor * ggml_add1( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1198,6 +1219,13 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_swiglu_oai( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float alpha, + float limit); + // normalize along rows GGML_API struct ggml_tensor * ggml_norm( struct ggml_context * ctx, @@ -1570,6 +1598,10 @@ extern "C" { float scale, float max_bias); + GGML_API void ggml_soft_max_add_sinks( + struct ggml_tensor * a, + struct ggml_tensor * sinks); + GGML_API struct ggml_tensor * ggml_soft_max_ext_back( struct ggml_context * ctx, struct ggml_tensor * a, @@ -2052,6 +2084,10 @@ extern "C" { GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec( const struct ggml_tensor * a); + GGML_API void ggml_flash_attn_ext_add_sinks( + struct ggml_tensor * a, + struct ggml_tensor * sinks); + // TODO: needs to be adapted to ggml_flash_attn_ext GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index b50e8e2955..937aa2c620 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -214,6 +214,13 @@ add_library(ggml ggml-backend-reg.cpp) add_library(ggml::ggml ALIAS ggml) +if (GGML_BACKEND_DIR) + if (NOT GGML_BACKEND_DL) + message(FATAL_ERROR "GGML_BACKEND_DIR requires GGML_BACKEND_DL") + endif() + target_compile_definitions(ggml PUBLIC GGML_BACKEND_DIR="${GGML_BACKEND_DIR}") +endif() + target_link_libraries(ggml PUBLIC ggml-base) if (CMAKE_SYSTEM_NAME MATCHES "Linux") @@ -227,7 +234,11 @@ function(ggml_add_backend_library backend) set_target_properties(${backend} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) target_compile_definitions(${backend} PRIVATE GGML_BACKEND_DL) add_dependencies(ggml ${backend}) - install(TARGETS ${backend} LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR}) + if (GGML_BACKEND_DIR) + install(TARGETS ${backend} LIBRARY DESTINATION ${GGML_BACKEND_DIR}) + else() + install(TARGETS ${backend} LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() else() add_library(${backend} ${ARGN}) target_link_libraries(ggml PUBLIC ${backend}) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index fcc552da51..8b6e602836 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -29,6 +29,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) { case GGML_OP_DIAG_MASK_ZERO: case GGML_OP_DIAG_MASK_INF: case GGML_OP_ADD: + case GGML_OP_ADD_ID: case GGML_OP_ADD1: case GGML_OP_SUB: case GGML_OP_MUL: diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index a1cddb806c..1256bb1424 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -505,6 +505,9 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, std::vector search_paths; if (user_search_path == nullptr) { +#ifdef GGML_BACKEND_DIR + search_paths.push_back(fs::u8path(GGML_BACKEND_DIR)); +#endif // default search paths: executable directory, current directory search_paths.push_back(get_executable_path()); search_paths.push_back(fs::current_path()); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index eaf41e5a6c..1b9d29e911 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1071,6 +1071,11 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg } } } + // if the node is still unassigned, assign it to the first backend that supports it + for (int b = 0; b < sched->n_backends && *cur_backend_id == -1; b++) { + ggml_backend_sched_set_if_supported(sched, node, b, cur_backend_id); + } + GGML_ASSERT(*cur_backend_id != -1); } // pass 5: split graph, find tensors that need to be copied @@ -1098,7 +1103,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg const int node_backend_id = tensor_backend_id(node); - assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback + GGML_ASSERT(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback // check if we should start a new split based on the sources of the current node bool need_new_split = false; @@ -1156,7 +1161,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg size_t src_id = hash_id(src); const int src_backend_id = sched->hv_tensor_backend_ids[src_id]; - assert(src_backend_id != -1); // all inputs should be assigned by now + GGML_ASSERT(src_backend_id != -1); // all inputs should be assigned by now if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) { if (tensor_id_copy(src_id, src_backend_id, 0) == NULL) { diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index ec158dfac6..aeac2e5744 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -281,10 +281,10 @@ ggml_backend_t ggml_backend_blas_init(void) { ggml_backend_blas_context * ctx = new ggml_backend_blas_context; ggml_backend_t backend = new ggml_backend { - /* .guid = */ ggml_backend_blas_guid(), - /* .interface = */ blas_backend_i, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0), - /* .context = */ ctx, + /* .guid = */ ggml_backend_blas_guid(), + /* .iface = */ blas_backend_i, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0), + /* .context = */ ctx, }; #if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP) diff --git a/ggml/src/ggml-cann/CMakeLists.txt b/ggml/src/ggml-cann/CMakeLists.txt index 7742b39153..aee5e7b06e 100755 --- a/ggml/src/ggml-cann/CMakeLists.txt +++ b/ggml/src/ggml-cann/CMakeLists.txt @@ -31,6 +31,13 @@ string(REGEX MATCH "[0-9]+[a-zA-Z]" SOC_TYPE_MAJOR_SN "${SOC_VERSION}") set(SOC_TYPE_COMPILE_OPTION "ASCEND_${SOC_TYPE_MAJOR_SN}") string(TOUPPER ${SOC_TYPE_COMPILE_OPTION} SOC_TYPE_COMPILE_OPTION) message(STATUS "CANN: SOC_VERSION = ${SOC_VERSION}") +option(USE_ACL_GRAPH "Enable CANN graph execution (ACL graph mode)" OFF) + +if(USE_ACL_GRAPH AND (SOC_TYPE_MAJOR_SN STREQUAL "310P" OR SOC_TYPE_COMPILE_OPTION STREQUAL "ASCEND_310P")) + message(FATAL_ERROR + "CANN Graph (ACL graph mode) is not supported on 310P devices. " + "Please build with -DUSE_ACL_GRAPH=OFF or use a supported SOC.") +endif() if (CANN_INSTALL_DIR) # Only Support Linux. @@ -68,6 +75,13 @@ if (CANN_INSTALL_DIR) target_compile_definitions(ggml-cann PRIVATE "-D${SOC_TYPE_COMPILE_OPTION}") + if (USE_ACL_GRAPH) + target_compile_definitions(ggml-cann PRIVATE USE_ACL_GRAPH) + message(STATUS "CANN: USE_ACL_GRAPH is enabled.") + else() + message(STATUS "CANN: USE_ACL_GRAPH is disabled.") + endif() + message(STATUS "CANN: CANN_INCLUDE_DIRS = ${CANN_INCLUDE_DIRS}") message(STATUS "CANN: CANN_LIBRARIES = ${CANN_LIBRARIES}") else() diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 8dfe3b061c..9d294f72b6 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -337,6 +337,29 @@ private: int32_t device_; }; +#ifdef USE_ACL_GRAPH +struct ggml_graph_node_properties { + void * node_address; + ggml_op node_op; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + void * src_address[GGML_MAX_SRC]; + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; +}; + +struct ggml_cann_graph { + ~ggml_cann_graph() { + if (graph != nullptr) { + aclmdlRIDestroy(graph); + } + } + + aclmdlRI graph = nullptr; + + std::vector ggml_graph_properties; +}; +#endif // USE_ACL_GRAPH + /** * @brief Context for managing CANN backend operations. */ @@ -345,8 +368,13 @@ struct ggml_backend_cann_context { std::string name; /**< Name of the device. */ std::string description; /**< Description of the device. */ aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */ +#ifdef USE_ACL_GRAPH + /// Cached CANN ACL graph used for executing the current ggml computation graph. + std::unique_ptr cann_graph; +#endif cann_task_queue task_queue; bool async_mode; + bool support_set_rows; aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */ @@ -362,6 +390,14 @@ struct ggml_backend_cann_context { async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or("")); GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, device, async_mode ? "ON" : "OFF"); + + support_set_rows = parse_bool(get_env("LLAMA_SET_ROWS").value_or("")); + GGML_LOG_INFO("%s: LLAMA_SET_ROWS is %s\n", __func__, support_set_rows ? "ON" : "OFF"); + + if (!support_set_rows) { + GGML_LOG_INFO("%s: CANN Graph currently only supports execution when LLAMA_SET_ROWS is ON. " + "Falling back to eager mode.\n", __func__); + } } /** diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 8eb8b1470b..cf575b3675 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2075,6 +2075,160 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) { ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream())); } +#ifdef USE_ACL_GRAPH +/** + * @brief Populate the internal CANN graph node properties from the ggml computation graph. + * + * This function copies all node attributes (operation type, dimensions, strides, input sources, + * and operation parameters) into the cached CANN graph structure for later reuse or comparison. + * + * @param cann_ctx The CANN backend context. + * @param cgraph The ggml computational graph. + */ +static void set_ggml_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) { + for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) { + ggml_tensor * node = cgraph->nodes[node_idx]; + cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_address = node->data; + cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_op = node->op; + + for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { + cann_ctx->cann_graph->ggml_graph_properties[node_idx].ne[dim] = node->ne[dim]; + cann_ctx->cann_graph->ggml_graph_properties[node_idx].nb[dim] = node->nb[dim]; + } + for (int src = 0; src < GGML_MAX_SRC; src++) { + cann_ctx->cann_graph->ggml_graph_properties[node_idx].src_address[src] = + node->src[src] ? node->src[src]->data : nullptr; + } + memcpy(cann_ctx->cann_graph->ggml_graph_properties[node_idx].op_params, node->op_params, GGML_MAX_OP_PARAMS); + } +} + +/** + * @brief Check if a ggml tensor node matches a previously captured CANN graph node. + * + * This function compares all relevant fields (address, op type, shape, source inputs, op params) + * to determine whether the current node matches a previously recorded version. + * + * @param node The current ggml tensor node. + * @param graph_node_properties The stored properties of a CANN graph node. + * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise. + */ +static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { + if (node->data != graph_node_properties->node_address && + node->op != GGML_OP_VIEW) { + return false; + } + if (node->op != graph_node_properties->node_op) { + return false; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (node->ne[i] != graph_node_properties->ne[i]) { + return false; + } + if (node->nb[i] != graph_node_properties->nb[i]) { + return false; + } + } + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node->src[i] && + node->src[i]->data != graph_node_properties->src_address[i] && + node->op != GGML_OP_VIEW + ) { + return false; + } + } + if (node->op == GGML_OP_SCALE && + memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { + return false; + } + return true; +} + +/** + * @brief Determine if the CANN graph needs to be rebuilt due to graph changes. + * + * This checks whether the number or properties of ggml graph nodes have changed + * compared to the last captured CANN graph. If so, the CANN graph must be re-captured. + * + * @param cann_ctx The CANN backend context. + * @param cgraph The current ggml computation graph. + * @return true if an update is required; false otherwise. + */ +static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) { + // The number of nodes is different, so the graph needs to be reconstructed. + if (cann_ctx->cann_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { + cann_ctx->cann_graph->ggml_graph_properties.resize(cgraph->n_nodes); + return true; + } + + // The number of nodes is the same; iterate over each node to check whether they match. + for (int i = 0; i < cgraph->n_nodes; i++) { + bool has_matching_properties = ggml_graph_node_has_matching_properties( + cgraph->nodes[i], &cann_ctx->cann_graph->ggml_graph_properties[i]); + if(!has_matching_properties) { + return true; + } + } + return false; +} +#endif // USE_ACL_GRAPH + +/** + * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API. + * + * If CANN graph execution is enabled and graph capture is required, this function begins + * graph capture, runs the graph, ends capture, and stores the captured graph. + * + * Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher. + * + * @param cann_ctx The CANN backend context. + * @param cgraph The ggml computation graph. + * @param use_cann_graph Whether to use CANN graph execution. + * @param cann_graph_update_required Whether graph capture is needed due to graph changes. + */ +static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph, + bool & use_cann_graph, bool & cann_graph_update_required) { +#ifdef USE_ACL_GRAPH + if (use_cann_graph && cann_graph_update_required) { + if (cann_ctx->cann_graph->graph != nullptr) { + ACL_CHECK(aclmdlRIDestroy(cann_ctx->cann_graph->graph)); + cann_ctx->cann_graph->graph = nullptr; + } + ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL)); + } +#endif // USE_ACL_GRAPH + + // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph. + // With the use of CANN graphs, the execution will be performed by the graph launch. + if (!use_cann_graph || cann_graph_update_required) { + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + + bool ok = ggml_cann_compute_forward(*cann_ctx, node); + if (!ok) { + GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); + } + } + +#ifdef USE_ACL_GRAPH + if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture + ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &cann_ctx->cann_graph->graph)); + } + + if (use_cann_graph) { + // Execute graph + ACL_CHECK(aclmdlRIExecuteAsync(cann_ctx->cann_graph->graph, cann_ctx->stream())); + } +#endif // USE_ACL_GRAPH +} + + /** * @brief Computes a computational graph using a CANN backend. * @@ -2091,27 +2245,38 @@ static enum ggml_status ggml_backend_cann_graph_compute( ggml_backend_t backend, ggml_cgraph* cgraph) { ggml_backend_cann_context* cann_ctx = (ggml_backend_cann_context*)backend->context; - ggml_cann_set_device(cann_ctx->device); - //release temp buffer create by set tensor. release_nz_workspace(); +#ifdef USE_ACL_GRAPH + bool use_cann_graph = true; + bool cann_graph_update_required = false; - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor* node = cgraph->nodes[i]; - - if (ggml_is_empty(node) || node->op == GGML_OP_NONE) { - continue; - } - - bool ok = ggml_cann_compute_forward(*cann_ctx, node); - - if (!ok) { - GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, - node->name, ggml_op_name(node->op)); - } - GGML_ASSERT(ok); + // check environment LLAMA_SET_ROWS + if (!cann_ctx->support_set_rows) { + use_cann_graph = false; } + if (use_cann_graph) { + if (cann_ctx->cann_graph == nullptr) { + cann_ctx->cann_graph.reset(new ggml_cann_graph()); + cann_graph_update_required = true; + } + + cann_graph_update_required = is_cann_graph_update_required(cann_ctx, cgraph); + set_ggml_graph_node_properties(cann_ctx, cgraph); + } +#else + bool use_cann_graph = false; + bool cann_graph_update_required = false; +#endif // USE_ACL_GRAPH + + evaluate_and_capture_cann_graph( + cann_ctx, + cgraph, + use_cann_graph, + cann_graph_update_required + ); + return GGML_STATUS_SUCCESS; } @@ -2226,12 +2391,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, // only support F32 and F16. return false; } - - if (!ggml_are_same_shape(op, src) && !ggml_is_contiguous(op)) { - // unsupport dst is not contiguous. - return false; - } - return true; } break; case GGML_OP_CONT: { @@ -2340,6 +2499,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, memcpy(&bias, (float*)op->op_params + 1, sizeof(float)); return bias == 0.0f; // TODO: support bias != 0.0f case GGML_OP_SOFT_MAX: + // TODO: support attention sinks [TAG_ATTN_SINKS] + if (op->src[2]) { + return false; + } // TODO: support broadcast // ref: https://github.com/ggml-org/llama.cpp/pull/14435 return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); @@ -2354,6 +2517,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){ return false; } + // TODO: support attention sinks [TAG_ATTN_SINKS] + if (op->src[4]) { + return false; + } if (op->src[1]->ne[0] != op->src[2]->ne[0]) { // different head sizes of K and V are not supported yet return false; diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index fbb04426ab..93ab7ea446 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -99,6 +99,9 @@ typedef sycl::half2 ggml_half2; #define QI4_1 (QK4_1 / (4 * QR4_1)) #define QR4_1 2 +#define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4)) +#define QR_MXFP4 2 + #define QI5_0 (QK5_0 / (4 * QR5_0)) #define QR5_0 2 @@ -184,6 +187,13 @@ typedef struct { } block_q4_1; static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding"); +#define QK_MXFP4 32 +typedef struct { + uint8_t e; // E8M0 + uint8_t qs[QK_MXFP4/2]; +} block_mxfp4; +static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding"); + #define QK5_0 32 typedef struct { ggml_half d; // delta @@ -1074,10 +1084,17 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512) 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, GGML_TABLE_END() +// TODO: fix name to kvalues_iq4_nl GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16) -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113, GGML_TABLE_END() +// e2m1 values (doubled) +// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf +GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16) + 0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12, +GGML_TABLE_END() + #define NGRID_IQ1S 2048 #define IQ1S_DELTA 0.125f #define IQ1M_DELTA 0.125f diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index f02cfe8fa5..b62e3158ee 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -13,6 +13,7 @@ #define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0 #define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1 #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0 +#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -68,6 +69,7 @@ #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K +#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 @@ -90,6 +92,7 @@ #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K +#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 @@ -120,6 +123,7 @@ #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0 #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K +#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 @@ -149,6 +153,7 @@ #define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K #define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K +#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 @@ -179,6 +184,7 @@ #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0 #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K +#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index c6d1d852e9..aadbb487ec 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -589,6 +589,67 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi *s = sumf; } +void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_MXFP4 == 0); + static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same"); + + const block_mxfp4 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK_MXFP4; + + int ib = 0; + float sumf = 0; + +#if defined __ARM_NEON + const int8x16_t values = vld1q_s8(kvalues_mxfp4); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + uint8x16x2_t q4bits; + int8x16x4_t q4b; + int8x16x4_t q8b; + int32x4_t prod_1; + int32x4_t prod_2; + + for (; ib + 1 < nb; ib += 2) { + q4bits.val[0] = vld1q_u8(x[ib + 0].qs); + q4bits.val[1] = vld1q_u8(x[ib + 1].qs); + q8b.val[0] = vld1q_s8(y[ib + 0].qs); + q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16); + q8b.val[2] = vld1q_s8(y[ib + 1].qs); + q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16); + + q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); + q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); + q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); + q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); + + prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); + prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); + + sumf += + GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) + + GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2); + } + +#endif + for (; ib < nb; ++ib) { + const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e); + int sumi1 = 0; + int sumi2 = 0; + for (int j = 0; j < QK_MXFP4/2; ++j) { + sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf]; + sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4]; + } + sumf += d * (sumi1 + sumi2); + } + *s = sumf; +} + void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 30fd59f702..cb49320a67 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -66,6 +66,12 @@ static inline int hsum_i32_4(const __m128i a) { } #if defined(__AVX2__) || defined(__AVX512F__) +static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { + const __m256i ax = _mm256_sign_epi8(x, x); + const __m256i sy = _mm256_sign_epi8(y, x); + return _mm256_maddubs_epi16(ax, sy); +} + // spread 32 bits to 32 bytes { 0x00, 0xFF } static inline __m256i bytes_from_bits_32(const uint8_t * x) { uint32_t x32; @@ -261,6 +267,11 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)), _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0))); } + +static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) { + return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)), + _mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0))); +} #endif #elif defined(__SSSE3__) // horizontally add 4x4 floats @@ -746,6 +757,91 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } +void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_MXFP4 == 0); + static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same"); + + const block_mxfp4 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK_MXFP4; + + int ib = 0; + float sumf = 0; + +#if defined __AVX2__ + + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4); + const __m128i m4b = _mm_set1_epi8(0x0f); + const __m256i mone = _mm256_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs); + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs); + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs); + const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); + const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); + const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); + accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)), + _mm256_cvtepi32_ps(p_1), accum1); + accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)), + _mm256_cvtepi32_ps(p_2), accum2); + } + + sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); + +#elif defined __AVX__ + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4); + const __m128i m4b = _mm_set1_epi8(0x0f); + + __m256 accum = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs); + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1); + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); + + const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)); + const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)); + const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)); + const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)); + + const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1); + const __m256 deltas = quad_mx_delta_float(x[ib].e, y[ib].d, x[ib + 1].e, y[ib + 1].d); + accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum); + } + + sumf = hsum_float_8(accum); + +#endif + for (; ib < nb; ++ib) { + const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e); + int sumi1 = 0; + int sumi2 = 0; + for (int j = 0; j < QK_MXFP4/2; ++j) { + sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf]; + sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4]; + } + sumf += d * (sumi1 + sumi2); + } + *s = sumf; +} + void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -3206,14 +3302,6 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined(__AVX2__) -static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { - const __m256i ax = _mm256_sign_epi8(x, x); - const __m256i sy = _mm256_sign_epi8(y, x); - return _mm256_maddubs_epi16(ax, sy); -} -#endif - void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index c5271b7757..d89cd8f4ef 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -253,6 +253,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_1, .nrows = 1, }, + [GGML_TYPE_MXFP4] = { + .from_float = quantize_row_mxfp4, + .vec_dot = ggml_vec_dot_mxfp4_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, [GGML_TYPE_Q2_K] = { .from_float = quantize_row_q2_K, .vec_dot = ggml_vec_dot_q2_K_q8_K, @@ -1670,6 +1676,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_add(params, tensor); } break; + case GGML_OP_ADD_ID: + { + ggml_compute_forward_add_id(params, tensor); + } break; case GGML_OP_ADD1: { ggml_compute_forward_add1(params, tensor); @@ -1924,7 +1934,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_FLASH_ATTN_EXT: { - ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + ggml_compute_forward_flash_attn_ext(params, tensor); } break; case GGML_OP_FLASH_ATTN_BACK: { @@ -2111,6 +2121,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_DUP: case GGML_OP_CONT: case GGML_OP_ADD: + case GGML_OP_ADD_ID: case GGML_OP_ADD1: case GGML_OP_ACC: { @@ -2172,6 +2183,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_GLU_OP_REGLU: case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: { @@ -2673,6 +2685,7 @@ struct ggml_cplan ggml_graph_plan( } } break; case GGML_OP_ADD: + case GGML_OP_ADD_ID: case GGML_OP_ADD1: { if (ggml_is_quantized(node->src[0]->type)) { diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index c9daa4c39e..8dacd36714 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -35,7 +35,7 @@ // ggml-backend interface -std::vector& ggml_backend_cpu_get_extra_buffers_type() { +std::vector & ggml_backend_cpu_get_extra_buffer_types() { static std::vector bufts = []() { std::vector bufts; @@ -57,8 +57,6 @@ std::vector& ggml_backend_cpu_get_extra_buffers_type } #endif - bufts.push_back(NULL); - return bufts; }(); @@ -66,14 +64,20 @@ std::vector& ggml_backend_cpu_get_extra_buffers_type } static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device) { - return ggml_backend_cpu_get_extra_buffers_type().data(); + static std::vector extra_bufts = [] { + std::vector bufts = ggml_backend_cpu_get_extra_buffer_types(); + bufts.push_back(nullptr); + return bufts; + }(); + + return extra_bufts.data(); GGML_UNUSED(device); } static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) { - for (auto * extra : ggml_backend_cpu_get_extra_buffers_type()) { - if (extra && extra == buft) { + for (auto * extra : ggml_backend_cpu_get_extra_buffer_types()) { + if (extra == buft) { return true; } } @@ -210,10 +214,10 @@ ggml_backend_t ggml_backend_cpu_init(void) { ctx->abort_callback_data = NULL; ggml_backend_t cpu_backend = new ggml_backend { - /* .guid = */ ggml_backend_cpu_guid(), - /* .interface = */ ggml_backend_cpu_i, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), - /* .context = */ ctx, + /* .guid = */ ggml_backend_cpu_guid(), + /* .iface = */ ggml_backend_cpu_i, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ ctx, }; if (cpu_backend == NULL) { @@ -397,20 +401,13 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st return true; } - // extra_buffer_op? - for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) { - if (extra) { - auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context; - if (buf_extra && buf_extra->supports_op(dev, op)) { - return true; - } - } - } - - // the other case need host buffer. - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (op->src[i] && op->src[i]->buffer && !ggml_backend_buft_is_host(op->src[i]->buffer->buft)) { - return false; + // check extra buffer types + // note: only the first sources are checked for extra buffer types to reduce overhead, increase if necessary + for (int i = 0; i < 4; i++) { + if (op->src[i] && op->src[i]->buffer && + ggml_backend_cpu_is_extra_buffer_type(op->src[i]->buffer->buft)) { + auto * buf_extra = (ggml::cpu::extra_buffer_type *) op->src[i]->buffer->buft->context; + return buf_extra->supports_op(dev, op); } } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6581d27add..854f1c2b49 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8,6 +8,7 @@ #include "vec.h" #include +#include // ggml_compute_forward_dup @@ -1283,6 +1284,7 @@ void ggml_compute_forward_add( case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -1309,6 +1311,77 @@ void ggml_compute_forward_add( } } +// ggml_compute_forward_add_id + +static void ggml_compute_forward_add_id_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_I32); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_TERNARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + // src1 indices + const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21); + + GGML_ASSERT(i11 >= 0 && i11 < ne11); + + ggml_vec_add_f32(ne0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), + (float *) ((char *) src1->data + i11*nb11)); + } +} + +void ggml_compute_forward_add_id( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_add_id_f32(params, dst); + } break; + default: + { + GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type)); + } + } +} + // ggml_compute_forward_add1 static void ggml_compute_forward_add1_f32( @@ -1660,6 +1733,7 @@ void ggml_compute_forward_add1( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -1787,6 +1861,7 @@ void ggml_compute_forward_acc( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -3614,6 +3689,93 @@ static void ggml_compute_forward_swiglu( } } +// ggml_compute_forward_swiglu_oai + +static void ggml_compute_forward_swiglu_oai_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + const float alpha = ggml_get_op_params_f32(dst, 2); + const float limit = ggml_get_op_params_f32(dst, 3); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1])); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + for (int k = 0; k < nc; k++) { + const float x = std::min(src0_p[k], limit); + const float y = std::clamp(src1_p[k], -limit, limit); + const float out_glu = x / (1.f + expf(alpha * (-x))); + dst_p[k] = out_glu * (y + 1.f); + } + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = dst_p[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_swiglu_oai( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_swiglu_oai_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_geglu_erf static void ggml_compute_forward_geglu_erf_f32( @@ -4599,6 +4761,7 @@ void ggml_compute_forward_out_prod( case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4873,6 +5036,7 @@ void ggml_compute_forward_set( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -5134,6 +5298,7 @@ void ggml_compute_forward_get_rows( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -5523,6 +5688,7 @@ static void ggml_compute_forward_soft_max_f32( const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; assert(ggml_is_contiguous(dst)); assert(ggml_are_same_shape(src0, dst)); @@ -5557,6 +5723,9 @@ static void ggml_compute_forward_soft_max_f32( const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + // sinks + const float * sk = src2 ? (float *)((char *) src2->data) : nullptr; + for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ith; i01 < ne01; i01 += nth) { @@ -5599,9 +5768,18 @@ static void ggml_compute_forward_soft_max_f32( float max = -INFINITY; ggml_vec_max_f32(ne00, &max, wp); + // if we have sinks, make a correction as if they were included in the softmax + if (sk) { + max = MAX(max, sk[i02]); + } + ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max); assert(sum > 0.0); + if (sk) { + sum += (ggml_float) expf(sk[i02] - max); + } + sum = 1.0/sum; ggml_vec_scale_f32(ne00, dp, sum); @@ -5836,6 +6014,7 @@ void ggml_compute_forward_clamp( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: + case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -7989,12 +8168,14 @@ void ggml_compute_forward_argsort( static void ggml_compute_forward_flash_attn_ext_f16( const ggml_compute_params * params, - const ggml_tensor * q, - const ggml_tensor * k, - const ggml_tensor * v, - const ggml_tensor * mask, ggml_tensor * dst) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) GGML_TENSOR_LOCALS(size_t, nbq, q, nb) GGML_TENSOR_LOCALS(int64_t, nek, k, ne) @@ -8189,6 +8370,23 @@ static void ggml_compute_forward_flash_attn_ext_f16( } } + // sinks + if (sinks) { + const float s = ((float *)((char *) sinks->data))[h]; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + ms = expf(M - s); + ggml_vec_scale_f32(DV, VKQ32, ms); + } else { + vs = expf(s - M); + } + + S = S*ms + vs; + } + // V /= S const float S_inv = 1.0f/S; ggml_vec_scale_f32(DV, VKQ32, S_inv); @@ -8208,17 +8406,13 @@ static void ggml_compute_forward_flash_attn_ext_f16( void ggml_compute_forward_flash_attn_ext( const ggml_compute_params * params, - const ggml_tensor * q, - const ggml_tensor * k, - const ggml_tensor * v, - const ggml_tensor * mask, ggml_tensor * dst) { switch (dst->op_params[3]) { case GGML_PREC_DEFAULT: case GGML_PREC_F32: { // uses F32 accumulators - ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); + ggml_compute_forward_flash_attn_ext_f16(params, dst); } break; default: { @@ -9080,6 +9274,10 @@ void ggml_compute_forward_glu( { ggml_compute_forward_swiglu(params, dst); } break; + case GGML_GLU_OP_SWIGLU_OAI: + { + ggml_compute_forward_swiglu_oai(params, dst); + } break; case GGML_GLU_OP_GEGLU_ERF: { ggml_compute_forward_geglu_erf(params, dst); diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 3a32ec20db..f154afb462 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -29,6 +29,7 @@ extern "C" { void ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_add_id(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst); @@ -82,13 +83,7 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); -void ggml_compute_forward_flash_attn_ext( - const struct ggml_compute_params * params, - const struct ggml_tensor * q, - const struct ggml_tensor * k, - const struct ggml_tensor * v, - const struct ggml_tensor * mask, - struct ggml_tensor * dst); +void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_flash_attn_back( const struct ggml_compute_params * params, const bool masked, diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index ee35ab42fd..365cb36d2d 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -46,6 +46,10 @@ void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI quantize_row_q8_1_ref(x, y, k); } +void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_mxfp4_ref(x, y, k); +} + // // 2-6 bit quantization in super-blocks // @@ -181,6 +185,37 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c *s = sumf; } +void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_MXFP4 == 0); + static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same"); + + const block_mxfp4 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK_MXFP4; + + int ib = 0; + float sumf = 0; + + for (; ib < nb; ++ib) { + const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e); + + int sumi1 = 0; + int sumi2 = 0; + for (int j = 0; j < QK_MXFP4/2; ++j) { + sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf]; + sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4]; + } + sumf += d * (sumi1 + sumi2); + } + *s = sumf; +} + void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index dc4342c87f..d83eb1b144 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -19,6 +19,8 @@ void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -39,6 +41,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -67,8 +71,12 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); diff --git a/ggml/src/ggml-cpu/traits.cpp b/ggml/src/ggml-cpu/traits.cpp index 139fa59641..4f32f10255 100644 --- a/ggml/src/ggml-cpu/traits.cpp +++ b/ggml/src/ggml-cpu/traits.cpp @@ -10,7 +10,7 @@ extra_buffer_type::~extra_buffer_type() {} } // namespace ggml::cpu bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) { - for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) { + for (auto extra : ggml_backend_cpu_get_extra_buffer_types()) { if (extra && extra->context) { auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context; auto tensor_traits = buf_extra->get_tensor_traits(op); @@ -23,7 +23,7 @@ bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct } bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size) { - for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) { + for (auto extra : ggml_backend_cpu_get_extra_buffer_types()) { if (extra && extra->context) { auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context; auto tensor_traits = buf_extra->get_tensor_traits(op); diff --git a/ggml/src/ggml-cpu/traits.h b/ggml/src/ggml-cpu/traits.h index 99a6186b1d..f4e0990ddf 100644 --- a/ggml/src/ggml-cpu/traits.h +++ b/ggml/src/ggml-cpu/traits.h @@ -33,6 +33,6 @@ class extra_buffer_type { } // namespace ggml::cpu // implemented in ggml-cpu.cpp. -std::vector & ggml_backend_cpu_get_extra_buffers_type(); +std::vector & ggml_backend_cpu_get_extra_buffer_types(); #endif diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index d18783a00a..2250d93cb0 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -55,7 +55,22 @@ inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } + +inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { + int i = 0; +#if defined(__AVX2__) + for (; i + 7 < n; i += 8) { + __m256 vx = _mm256_loadu_ps(x + i); + __m256 vy = _mm256_loadu_ps(y + i); + __m256 vz = _mm256_add_ps(vx, vy); + _mm256_storeu_ps(z + i, vz); + } +#endif + for (; i < n; ++i) { + z[i] = x[i] + y[i]; + } +} + inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) { for (int i = 0; i < n; ++i) { z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i])); @@ -992,9 +1007,9 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { for (int i = 0; i < n; ++i) { - float v = GGML_CPU_FP16_TO_FP32(x[i]); - float w = GGML_CPU_FP16_TO_FP32(g[i]); - y[i] = GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w); + float xi = GGML_CPU_FP16_TO_FP32(x[i]); + float gi = GGML_CPU_FP16_TO_FP32(g[i]); + y[i] = GGML_CPU_FP32_TO_FP16((xi/(1.0f + expf(-xi))) * gi); } } diff --git a/ggml/src/ggml-cuda/add-id.cu b/ggml/src/ggml-cuda/add-id.cu new file mode 100644 index 0000000000..8bed62ac9d --- /dev/null +++ b/ggml/src/ggml-cuda/add-id.cu @@ -0,0 +1,58 @@ +#include "add-id.cuh" + +static __global__ void add_id_kernel( + const float * src0, const float * src1, const int32_t * src2, float * dst, + int64_t ne0, int64_t ne1, + size_t nb01, size_t nb02, + size_t nb11, + size_t nb21 + ) { + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.y; + + const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21); + + const size_t nb1 = ne0 * sizeof(float); + const size_t nb2 = ne1 * nb1; + + float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2); + const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02); + const float * src1_row = (const float *)((char *)src1 + i11*nb11); + + for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { + dst_row[i0] = src0_row[i0] + src1_row[i0]; + } +} + +void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + GGML_TENSOR_TERNARY_OP_LOCALS + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_I32); + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb20 == sizeof(int32_t)); + + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + const int32_t * src2_d = (const int32_t *)src2->data; + float * dst_d = (float *)dst->data; + + int threads = std::min((int)ne00, 768); // cols + dim3 blocks(ne01, ne02); // n_experts_used, n_tokens + add_id_kernel<<>>( + src0_d, src1_d, src2_d, dst_d, + ne0, ne1, + nb01, nb02, + nb11, + nb21 + ); +} diff --git a/ggml/src/ggml-cuda/add-id.cuh b/ggml/src/ggml-cuda/add-id.cuh new file mode 100644 index 0000000000..30b1721ac3 --- /dev/null +++ b/ggml/src/ggml-cuda/add-id.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 7fb04d51b7..2e5d48797f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1,6 +1,7 @@ #pragma once #include "ggml.h" +#include "ggml-impl.h" #include "ggml-cuda.h" #include @@ -232,9 +233,13 @@ typedef float2 dfloat2; #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING -#define NEW_MMA_AVAILABLE +#define TURING_MMA_AVAILABLE #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#define AMPERE_MMA_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #define CP_ASYNC_AVAILABLE #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE @@ -302,10 +307,14 @@ static bool amd_mfma_available(const int cc) { } // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. -static bool new_mma_available(const int cc) { +static bool turing_mma_available(const int cc) { return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; } +static bool ampere_mma_available(const int cc) { + return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE; +} + static bool cp_async_available(const int cc) { return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE; } @@ -549,6 +558,24 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i #endif // defined(GGML_USE_HIP) } +static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { +#if CUDART_VERSION >= 12080 + const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x); + return (float) e; +#else + uint32_t bits; + if (x == 0) { + bits = 0x00400000; + } else { + bits = (uint32_t) x << 23; + } + + float result; + memcpy(&result, &bits, sizeof(float)); + return result; +#endif // CUDART_VERSION >= 12050 +} + typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); static __device__ __forceinline__ float get_alibi_slope( @@ -607,6 +634,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI8_0; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_MXFP4; + static constexpr int qr = QR_MXFP4; + static constexpr int qi = QI_MXFP4; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 15c927861f..e3beddbc1b 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -465,6 +465,24 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst } } +template +static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4); + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[ib].qs + 4*il; + const float d = ggml_cuda_e8m0_to_fp32(x[ib].e); + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f; + y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]*0.5f; + } +} + template static void dequantize_block_cuda(const void * vx, dst_t * y, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, @@ -588,6 +606,12 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_iq4_xs<<>>(vx, y); } +template +static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_mxfp4<<>>(vx, y); +} + template static __global__ void convert_unary( const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, @@ -677,6 +701,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq4_xs_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; + case GGML_TYPE_MXFP4: + return dequantize_row_mxfp4_cuda; case GGML_TYPE_F32: return convert_unary_cont_cuda; case GGML_TYPE_BF16: @@ -726,6 +752,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq4_xs_cuda; case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; + case GGML_TYPE_MXFP4: + return dequantize_row_mxfp4_cuda; case GGML_TYPE_F16: return convert_unary_cont_cuda; case GGML_TYPE_BF16: diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index b6db446c6f..e46f0e2081 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -15,6 +15,7 @@ typedef void (* fattn_kernel_t)( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, @@ -736,7 +737,8 @@ void launch_fattn( GGML_ASSERT(V || is_mla); - const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; ggml_tensor * KQV = dst; @@ -940,6 +942,7 @@ void launch_fattn( K_data, V_data, mask ? ((const char *) mask->data) : nullptr, + sinks ? ((const char *) sinks->data) : nullptr, KV_max.ptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index a86b95428f..39731baaeb 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -418,7 +418,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( float * const __restrict__ KQ_max, float * const __restrict__ KQ_rowsum, const int kb0) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE typedef fattn_mma_f16_config c; #ifdef CP_ASYNC_AVAILABLE @@ -776,7 +776,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); GGML_UNUSED(kb0); GGML_UNUSED(tile_Q); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } template @@ -785,6 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const half2 * const __restrict__ K_h2, const half2 * const __restrict__ V_h2, const half2 * const __restrict__ mask_h2, + const float * const __restrict__ sinks_f, float2 * const __restrict__ dstk, float2 * const __restrict__ dstk_fixup, const float scale, @@ -800,7 +801,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int jt, const int kb0_start, const int kb0_stop) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. typedef fattn_mma_f16_config c; @@ -957,6 +958,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } + // If attention sinks are used, potentially re-scale if KQ_max is small. + // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum + // so it's being done unconditionally for every thread. + if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) { + float KQ_max_scale[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented"); + const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col); + const float sink = sinks_f[jc % ncols2]; + + const float KQ_max_new = fmaxf(KQ_max[col], sink); + const float KQ_max_diff = KQ_max[col] - KQ_max_new; + KQ_max_scale[col] = expf(KQ_max_diff); + KQ_max[col] = KQ_max_new; + + *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD; + + const float KQ_max_add = expf(sink - KQ_max_new); + KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add; + } + + if (ntiles == 1) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); +#pragma unroll + for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { + VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; + } + } + } + } + } + // Combine VKQ accumulator values if np > 1. // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // So also write VKQ accumulators to shared memory in column-major format if np == 1. @@ -1196,7 +1243,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } template @@ -1206,6 +1253,7 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, @@ -1222,7 +1270,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) +#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { @@ -1267,20 +1315,24 @@ static __global__ void flash_attn_ext_f16( // kb0 == k start index when in the output tile. int kb0_start = kbc % iter_k; int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); + while (kbc < kbc_stop && kb0_stop == iter_k) { const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); - const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); - const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. + const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2 + const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2)); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio)); + const int head0 = zt * ncols2; + + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2); + float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; int kb0_stop_kernel = kb0_stop * kb_niter; @@ -1293,12 +1345,12 @@ static __global__ void flash_attn_ext_f16( if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } @@ -1314,18 +1366,21 @@ static __global__ void flash_attn_ext_f16( } const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); - const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); - const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. + const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2 + const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2)); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio)); + const int head0 = zt * ncols2; + + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1); - float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2); + float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; int kb0_stop_kernel = kb0_stop * kb_niter; @@ -1337,10 +1392,10 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; flash_attn_ext_f16_process_tile - (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, + (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); @@ -1352,7 +1407,7 @@ static __global__ void flash_attn_ext_f16( GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) +#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) } template diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 9d0b24ae7e..0fcfaa32ea 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, @@ -272,7 +273,7 @@ static __global__ void flash_attn_tile_ext_f16( } } #else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index be72f76fb6..23550cbbd9 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f32( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, @@ -37,7 +38,7 @@ static __global__ void flash_attn_tile_ext_f32( return; #endif // FP16_MMA_AVAILABLE if (use_logit_softcap && !(D == 128 || D == 256)) { - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index a2df2f66be..b05f682cd3 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, @@ -61,7 +62,8 @@ static __global__ void flash_attn_vec_ext_f16( K += nb13*sequence + nb12*(head / gqa_ratio); V += nb23*sequence + nb22*(head / gqa_ratio); - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + const float * sinksf = (const float *) (sinks); const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); @@ -75,11 +77,12 @@ static __global__ void flash_attn_vec_ext_f16( half2 * KQ2 = (half2 *) KQ; half kqmax[ncols]; + half kqsum[ncols]; #pragma unroll for (int j = 0; j < ncols; ++j) { kqmax[j] = -HALF_MAX_HALF; + kqsum[j] = 0.0f; } - half kqsum[ncols] = {0.0f}; __shared__ half kqmax_shared[ncols][WARP_SIZE]; __shared__ half kqsum_shared[ncols][WARP_SIZE]; @@ -283,6 +286,39 @@ static __global__ void flash_attn_vec_ext_f16( __syncthreads(); } + if (sinksf && blockIdx.y == 0) { + const half sink = __float2half(sinksf[head]); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.x == 0) { + kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + half kqmax_new_j = kqmax_shared[j][threadIdx.x]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; + + const half val = hexp(sink - kqmax[j]); + kqsum[j] = kqsum[j]*KQ_max_scale; + + if (tid == 0) { + kqsum[j] += val; + } + + VKQ[j] *= __half2half2(KQ_max_scale); + } + + __syncthreads(); + } + #pragma unroll for (int j = 0; j < ncols; ++j) { kqsum[j] = warp_reduce_sum((float)kqsum[j]); @@ -313,7 +349,7 @@ static __global__ void flash_attn_vec_ext_f16( dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); } #else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 9ab0fc133b..d6d0bfb744 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, @@ -72,7 +73,8 @@ static __global__ void flash_attn_vec_ext_f32( K += nb13*sequence + nb12*(head / gqa_ratio); V += nb23*sequence + nb22*(head / gqa_ratio); - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + const float * sinksf = (const float *) (sinks); const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); @@ -88,11 +90,12 @@ static __global__ void flash_attn_vec_ext_f32( } float kqmax[ncols]; + float kqsum[ncols]; #pragma unroll for (int j = 0; j < ncols; ++j) { kqmax[j] = -FLT_MAX/2.0f; + kqsum[j] = 0.0f; } - float kqsum[ncols] = {0.0f}; __shared__ float kqmax_shared[ncols][WARP_SIZE]; __shared__ float kqsum_shared[ncols][WARP_SIZE]; @@ -279,6 +282,39 @@ static __global__ void flash_attn_vec_ext_f32( __syncthreads(); } + if (sinksf && blockIdx.y == 0) { + const float sink = sinksf[head]; + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.x == 0) { + kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + float kqmax_new_j = kqmax_shared[j][threadIdx.x]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; + + const float val = expf(sink - kqmax[j]); + kqsum[j] = kqsum[j]*KQ_max_scale; + + if (tid == 0) { + kqsum[j] += val; + } + + VKQ[j] *= KQ_max_scale; + } + + __syncthreads(); + } + #pragma unroll for (int j = 0; j < ncols; ++j) { kqsum[j] = warp_reduce_sum(kqsum[j]); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 40554204b6..93d4d81021 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -29,6 +29,7 @@ static __global__ void flash_attn_ext_f16( const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, + const char * __restrict__ sinks, const int * __restrict__ KV_max, float * __restrict__ dst, float2 * __restrict__ dst_meta, @@ -423,7 +424,7 @@ static __global__ void flash_attn_ext_f16( dst_meta[j_dst_unrolled] = dst_meta_val; } #else - GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); + GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index a51136f6b8..6c1185deac 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -269,17 +269,28 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; - const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); + // TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS] + if (sinks && !fp16_mma_available(cc)) { + if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { + ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); + } + return; + } + #if defined(GGML_HIP_ROCWMMA_FATTN) if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) { ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); @@ -315,8 +326,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; - const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && - (Q->ne[3] > 1 || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion; + const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192); + const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion && + (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000); const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { if (prec == GGML_PREC_DEFAULT) { @@ -328,7 +340,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } // The MMA implementation needs Turing or newer, use the old WMMA code for Volta: - if (fp16_mma_available(cc) && !new_mma_available(cc)) { + if (fp16_mma_available(cc) && !turing_mma_available(cc)) { ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); return; } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 5179279467..d9110491ec 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4,6 +4,7 @@ #include "ggml-cuda/common.cuh" #include "ggml-cuda/acc.cuh" +#include "ggml-cuda/add-id.cuh" #include "ggml-cuda/arange.cuh" #include "ggml-cuda/argmax.cuh" #include "ggml-cuda/argsort.cuh" @@ -21,8 +22,9 @@ #include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" +#include "ggml-cuda/mmf.cuh" #include "ggml-cuda/mmq.cuh" -#include "ggml-cuda/mmv.cuh" +#include "ggml-cuda/mmvf.cuh" #include "ggml-cuda/mmvq.cuh" #include "ggml-cuda/norm.cuh" #include "ggml-cuda/opt-step-adamw.cuh" @@ -1852,6 +1854,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct ggml_cuda_pool_alloc src0_alloc(ctx.pool()); ggml_cuda_pool_alloc src1_alloc(ctx.pool()); + bool is_src0_cont_2 = ggml_is_contiguous_2(src0); + bool is_src1_cont_2 = ggml_is_contiguous_2(src1); + // Handle src0 src0_ptr = (const cuda_t *) src0->data; @@ -1870,6 +1875,8 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct s11 = ne10; s12 = ne11*s11; s13 = ne12*s12; + + is_src1_cont_2 = true; } // Setup destination buffer @@ -1918,15 +1925,19 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct const int64_t r2 = ne12/ne02; const int64_t r3 = ne13/ne03; - if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { + if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) { + // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3: + const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00; + const int64_t smb = ne12 == 1 ? s13 : s12; + // there is no broadcast and src0, src1 are contiguous across dims 2, 3 // use cublasGemmStridedBatchedEx CUBLAS_CHECK( cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA - src1_ptr, cu_data_type_b, s11, s12, // strideB - beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC + alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA + src1_ptr, cu_data_type_b, s11, smb, // strideB + beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC ne12*ne13, cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); @@ -1998,7 +2009,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src; - bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) + bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + bool use_mul_mat_f = !ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 @@ -2018,14 +2031,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor } const int cc = ggml_cuda_info().devices[id].cc; + const int warp_size = ggml_cuda_info().devices[id].warp_size; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]); + use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]); + use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } } else { const int cc = ggml_cuda_info().devices[ctx.device].cc; + const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]); + use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]); + use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } @@ -2038,15 +2055,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); //TODO update for generic tensor parallelism - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16); bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc); bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32; - if (!split && use_mul_mat_vec) { + if (!split && use_mul_mat_vec_f) { // the custom F16 vector kernel can be used over batched cuBLAS GEMM // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) - ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst); + ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst); + } else if (!split && use_mul_mat_f) { + ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst); } else if (!split && use_mul_mat_vec_q) { ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst); } else if (!split && use_mul_mat_q) { @@ -2055,8 +2074,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // general KQ + KQV multi-batch without FlashAttention ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); - } else if (use_mul_mat_vec) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr); + } else if (use_mul_mat_vec_f) { + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr); } else if (use_mul_mat_vec_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { @@ -2084,7 +2103,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * if (ggml_is_quantized(src0->type)) { ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); } else { - ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst); + ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst); } return; } @@ -2250,6 +2269,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ADD1: // TODO: more efficient implementation ggml_cuda_op_add(ctx, dst); break; + case GGML_OP_ADD_ID: + ggml_cuda_op_add_id(ctx, dst); + break; case GGML_OP_SUB: ggml_cuda_op_sub(ctx, dst); break; @@ -2324,6 +2346,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_GLU_OP_SWIGLU: ggml_cuda_op_swiglu(ctx, dst); break; + case GGML_GLU_OP_SWIGLU_OAI: + ggml_cuda_op_swiglu_oai(ctx, dst); + break; case GGML_GLU_OP_GEGLU_ERF: ggml_cuda_op_geglu_erf(ctx, dst); break; @@ -2598,6 +2623,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj"; + const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased"; + const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased"; + const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased"; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2620,7 +2648,13 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud #endif } - if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true)) { + if (node->op == GGML_OP_ADD && + node->src[1] && node->src[1]->ne[1] > 1 && + (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && + (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) && + strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 && + strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && + strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) { // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation // by means of matching node names. See // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and @@ -3218,6 +3252,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_GLU_OP_REGLU: case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: return ggml_is_contiguous_1(op->src[0]); @@ -3268,6 +3303,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -3414,6 +3450,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: case GGML_OP_ADD: + case GGML_OP_ADD_ID: case GGML_OP_ADD1: case GGML_OP_SUB: case GGML_OP_MUL: @@ -3488,12 +3525,17 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g #endif // FLASH_ATTN_AVAILABLE if (op->src[1]->ne[0] != op->src[2]->ne[0]) { const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; - if (!new_mma_available(cc)) { + if (!turing_mma_available(cc)) { return false; } const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2]; return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0; } + // TODO: more general-purpose attention sink support [TAG_ATTN_SINKS] + if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) + && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { + return false; + } if (op->src[0]->ne[0] == 192) { return false; } @@ -3757,10 +3799,10 @@ ggml_backend_t ggml_backend_cuda_init(int device) { } ggml_backend_t cuda_backend = new ggml_backend { - /* .guid = */ ggml_backend_cuda_guid(), - /* .interface = */ ggml_backend_cuda_interface, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device), - /* .context = */ ctx, + /* .guid = */ ggml_backend_cuda_guid(), + /* .iface = */ ggml_backend_cuda_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device), + /* .context = */ ctx, }; return cuda_backend; diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu index 5bb85b4807..16bb9bec97 100644 --- a/ggml/src/ggml-cuda/im2col.cu +++ b/ggml/src/ggml-cuda/im2col.cu @@ -1,65 +1,76 @@ #include "im2col.cuh" +#define MAX_GRIDDIM_Z 65535 + template static __global__ void im2col_kernel( - const float * x, T * dst, int64_t batch_offset, - int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW, + const float * x, T * dst, + int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, + int64_t IC_IH_IW, int64_t IH_IW, int64_t N_OH, int64_t KH_KW, int64_t IC_KH_KW, int s0, int s1, int p0, int p1, int d0, int d1) { const int64_t i = threadIdx.x + blockIdx.x * blockDim.x; - if (i >= pelements) { + if (i >= IC_KH_KW) { return; } - const int64_t ksize = OW * KH; - const int64_t kx = i / ksize; - const int64_t kd = kx * ksize; - const int64_t ky = (i - kd) / OW; - const int64_t ix = i % OW; + const int64_t iic = i / (KH_KW); + const int64_t rem = i - iic * KH_KW; + const int64_t ikh = rem / KW; + const int64_t ikw = rem - ikh * KW; - const int64_t oh = blockIdx.y; - const int64_t batch = blockIdx.z / IC; - const int64_t ic = blockIdx.z % IC; + const int64_t iow = blockIdx.y; + for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) { + const int64_t in = iz / OH; + const int64_t ioh = iz - in * OH; - const int64_t iiw = ix * s0 + kx * d0 - p0; - const int64_t iih = oh * s1 + ky * d1 - p1; + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; - const int64_t offset_dst = - ((batch * OH + oh) * OW + ix) * CHW + - (ic * (KW * KH) + ky * KW + kx); + const int64_t offset_dst = + ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = ic * offset_delta + batch * batch_offset; - dst[offset_dst] = x[offset_src + iih * IW + iiw]; + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; + dst[offset_dst] = x[offset_src + iih * IW + iiw]; + } } + + GGML_UNUSED(IC); + GGML_UNUSED(KH); } +// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] template static void im2col_cuda(const float * x, T* dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC, - int64_t batch, int64_t batch_offset, int64_t offset_delta, + int64_t N, int64_t IC_IH_IW, int64_t IH_IW, int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { - const int parallel_elements = OW * KW * KH; - const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; - dim3 block_nums(num_blocks, OH, batch * IC); - im2col_kernel<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); + const int64_t IC_KH_KW = IC * KH * KW; + const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; + const int64_t N_OH = N * OH; + const int64_t KH_KW = KW*KH; + dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z)); + im2col_kernel<<>>(x, dst, IC, IW, IH, OH, OW, KW, KH, + IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW, + s0, s1, p0, p1, d0, d1); } static void im2col_cuda_f16(const float * x, half * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC, - int64_t batch, int64_t batch_offset, int64_t offset_delta, + int64_t N, int64_t IC_IH_IW, int64_t IH_IW, int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { - im2col_cuda(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream); + im2col_cuda(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); } static void im2col_cuda_f32(const float * x, float * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC, - int64_t batch, int64_t batch_offset, int64_t offset_delta, + int64_t N, int64_t IC_IH_IW, int64_t IH_IW, int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { - im2col_cuda(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream); + im2col_cuda(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); } void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -91,13 +102,13 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t OH = is_2D ? dst->ne[2] : 1; const int64_t OW = dst->ne[1]; - const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 - const int64_t batch = src1->ne[is_2D ? 3 : 2]; - const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 + const int64_t IC_IH_IW = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const int64_t N = src1->ne[is_2D ? 3 : 2]; + const int64_t IH_IW = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 if(dst->type == GGML_TYPE_F16) { - im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); + im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); } else { - im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); + im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); } } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index a86365c6a0..83ee16b27d 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -23,13 +23,13 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { int ret = 0; -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" : "=r"(ret) : "r"(x)); #else GGML_UNUSED(x); NO_DEVICE_CODE; -#endif // defined(NEW_MMA_AVAILABLE) +#endif // defined(TURING_MMA_AVAILABLE) return ret; } @@ -167,6 +167,38 @@ namespace ggml_cuda_mma { } }; + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr int ne = I * J / WARP_SIZE; + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 8) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 4) { + return l * 8 + threadIdx.x / 4; + } else if constexpr (I == 16 && J == 8) { + return (l % 2) * 8 + threadIdx.x / 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 8) { + return l * 4 + threadIdx.x % 4; + } else if constexpr (I == 16 && J == 4) { + return threadIdx.x % 4; + } else if constexpr (I == 16 && J == 8) { + return (l / 2) * 4 + threadIdx.x % 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + }; + template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { tile ret; @@ -209,7 +241,7 @@ namespace ggml_cuda_mma { template static __device__ __forceinline__ void load_ldmatrix( tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J; asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" @@ -217,13 +249,13 @@ namespace ggml_cuda_mma { : "l"(xs)); #else load_generic(t, xs0, stride); -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } template static __device__ __forceinline__ void load_ldmatrix( tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" @@ -232,13 +264,13 @@ namespace ggml_cuda_mma { #else load_generic(xs0, stride); GGML_UNUSED(t); -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } template static __device__ __forceinline__ void load_ldmatrix( tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { -#if defined(NEW_MMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) int * xi = (int * ) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" @@ -246,13 +278,13 @@ namespace ggml_cuda_mma { : "l"(xs)); #else load_generic(t, xs0, stride); -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } template static __device__ __forceinline__ void load_ldmatrix_trans( tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE int * xi = (int * ) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" @@ -263,12 +295,12 @@ namespace ggml_cuda_mma { GGML_UNUSED(xs0); GGML_UNUSED(stride); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } static __device__ __forceinline__ void mma( tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) @@ -287,12 +319,12 @@ namespace ggml_cuda_mma { GGML_UNUSED(A); GGML_UNUSED(B); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } static __device__ __forceinline__ void mma( tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) @@ -317,12 +349,12 @@ namespace ggml_cuda_mma { GGML_UNUSED(A); GGML_UNUSED(B); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } static __device__ __forceinline__ void mma( tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -344,12 +376,12 @@ namespace ggml_cuda_mma { GGML_UNUSED(A); GGML_UNUSED(B); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } static __device__ __forceinline__ void mma( tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -380,12 +412,29 @@ namespace ggml_cuda_mma { GGML_UNUSED(A); GGML_UNUSED(B); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) { +#ifdef AMPERE_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; + asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // AMPERE_MMA_AVAILABLE } static __device__ __forceinline__ void mma( tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -407,12 +456,29 @@ namespace ggml_cuda_mma { GGML_UNUSED(A); GGML_UNUSED(B); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) { +#ifdef AMPERE_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; + asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // AMPERE_MMA_AVAILABLE } static __device__ __forceinline__ void mma( tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -443,7 +509,7 @@ namespace ggml_cuda_mma { GGML_UNUSED(A); GGML_UNUSED(B); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } static __device__ __forceinline__ void mma( diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu new file mode 100644 index 0000000000..1437367e87 --- /dev/null +++ b/ggml/src/ggml-cuda/mmf.cu @@ -0,0 +1,431 @@ +#include "ggml.h" +#include "common.cuh" +#include "mma.cuh" +#include "mmf.cuh" + +using namespace ggml_cuda_mma; + +#define MMF_ROWS_PER_BLOCK 32 + +template +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f( + const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, + const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int tile_k_padded = warp_size + 4; + constexpr int ntA = rows_per_block / tile_A::I; + constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; + + const int row0 = blockIdx.x * rows_per_block; + const int channel_dst = blockIdx.y; + const int channel_x = channel_dst / channel_ratio; + const int channel_y = channel_dst; + const int sample_dst = blockIdx.z; + const int sample_x = sample_dst / sample_ratio; + const int sample_y = sample_dst; + + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ; + y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; + dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; + + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; + + tile_C C[ntA][ntB]; + + T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded); + + for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { + tile_A A[ntA][warp_size / tile_A::J]; +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int i = 0; i < tile_A::I; ++i) { + tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { + load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); + } + } + +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; + } + } else if constexpr (std::is_same_v || std::is_same_v) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); + tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + } + } + + float * buf_iw = (float *) data_mmv; + constexpr int kiw = nwarps*rows_per_block + 4; + + if (nwarps > 1) { + __syncthreads(); + } +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); + const int j = itB*tile_C::J + tile_C::get_j(l); + buf_iw[j*kiw + i] = C[itA][itB].x[l]; + } + } + } + + if (nwarps > 1) { + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > cols_per_block && j >= cols_per_block) { + return; + } + + float sum = 0.0f; + static_assert(rows_per_block == warp_size, "need loop/check"); +#pragma unroll + for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { + const int i = i0 + threadIdx.x; + + sum += buf_iw[j*kiw + i]; + } + dst[j*stride_col_dst + row0 + threadIdx.x] = sum; + } +#else + NO_DEVICE_CODE; + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(ids); GGML_UNUSED(dst); + GGML_UNUSED(ncols); GGML_UNUSED(nchannels_y); GGML_UNUSED(stride_row); GGML_UNUSED(stride_col_y); GGML_UNUSED(stride_col_dst); + GGML_UNUSED(channel_ratio); GGML_UNUSED(stride_channel_x); GGML_UNUSED(stride_channel_y); GGML_UNUSED(stride_channel_dst); + GGML_UNUSED(sample_ratio); GGML_UNUSED(stride_sample_x); GGML_UNUSED(stride_sample_y); GGML_UNUSED(stride_sample_dst); +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +} + +template +static void mul_mat_f_cuda( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; + + GGML_ASSERT(!ids && "mul_mat_id not implemented"); + + GGML_ASSERT(ncols_x % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(stride_col_y % 2 == 0); + GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); + GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const int64_t channel_ratio = nchannels_dst / nchannels_x; + const int64_t sample_ratio = nsamples_dst / nsamples_x; + + const int device = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[device].warp_size; + + int64_t nwarps_best = 1; + int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); + int64_t max_block_size = 256; + for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { + const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); + if (niter < niter_best) { + niter_best = niter; + nwarps_best = nwarps; + } + } + + constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; + const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4; + const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; + const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); + const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst); + const dim3 block_dims(warp_size, nwarps_best, 1); + switch (nwarps_best) { + case 1: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 2: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 3: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 4: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 5: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 6: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 7: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 8: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +template +static void mul_mat_f_switch_cols_per_block( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + switch (ncols_dst) { + case 1: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 2: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 3: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 4: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 5: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 6: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 7: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 8: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 9: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 10: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 11: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 12: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 13: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 14: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 15: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 16: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { + GGML_ASSERT( src1->type == GGML_TYPE_F32); + GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const size_t ts_src0 = ggml_type_size(src0->type); + const size_t ts_src1 = ggml_type_size(src1->type); + const size_t ts_dst = ggml_type_size(dst->type); + + GGML_ASSERT(ne13 == ne3); + + GGML_ASSERT( nb00 == ts_src0); + GGML_ASSERT( nb10 == ts_src1); + GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); + GGML_ASSERT( nb0 == ts_dst); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + const float * src1_d = (const float *) src1->data; + const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; + float * dst_d = (float *) dst->data; + + const int64_t s01 = src0->nb[1] / ts_src0; + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s1 = dst->nb[1] / ts_dst; + const int64_t s02 = src0->nb[2] / ts_src0; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s2 = dst->nb[2] / ts_dst; + const int64_t s03 = src0->nb[3] / ts_src0; + const int64_t s13 = src1->nb[3] / ts_src1; + const int64_t s3 = dst->nb[3] / ts_dst; + + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: + const int64_t ncols_dst = ids ? ne2 : ne1; + const int64_t nchannels_y = ids ? ne11 : ne12; + const int64_t nchannels_dst = ids ? ne1 : ne2; + const int64_t stride_channel_dst = ids ? s1 : s2; + const int64_t stride_channel_y = ids ? s11 : s12; + + GGML_ASSERT(!ids || ncols_dst == 1); + + switch (src0->type) { + case GGML_TYPE_F32: { + const float * src0_d = (const float *) src0->data; + constexpr int vals_per_T = 1; + mul_mat_f_switch_cols_per_block( + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, + ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + } break; + case GGML_TYPE_F16: { + const half2 * src0_d = (const half2 *) src0->data; + constexpr int vals_per_T = 2; + mul_mat_f_switch_cols_per_block( + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, + ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + } break; + case GGML_TYPE_BF16: { + const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; + constexpr int vals_per_T = 2; + mul_mat_f_switch_cols_per_block( + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, + ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + } break; + default: + GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); + } +} + +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) { + if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) { + return false; + } + if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { + return false; + } + if (ne11 > 16) { + return false; + } + switch (type) { + case GGML_TYPE_F32: + return ampere_mma_available(cc); + case GGML_TYPE_F16: + return turing_mma_available(cc); + case GGML_TYPE_BF16: + return ampere_mma_available(cc); + default: + return false; + } +} diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh new file mode 100644 index 0000000000..785f9f211c --- /dev/null +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); + +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11); diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index d4954fbe69..384ee7615f 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -20,6 +20,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con case GGML_TYPE_Q8_0: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_MXFP4: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_Q2_K: mul_mat_q_case(ctx, args, stream); break; @@ -282,6 +285,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -306,7 +310,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return false; } - if (new_mma_available(cc)) { + if (turing_mma_available(cc)) { return true; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 04a8d80e12..96129bd831 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -58,6 +58,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_Q8_0: return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_MXFP4: + return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q2_K: return MMQ_Q8_1_DS_LAYOUT_D2S6; case GGML_TYPE_Q3_K: @@ -90,7 +92,7 @@ struct tile_x_sizes { }; static int get_mmq_x_max_host(const int cc) { - return (amd_mfma_available(cc) || new_mma_available(cc)) ? 128 : + return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 : GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; @@ -100,9 +102,9 @@ static int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) return 128; -#else // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) #if defined(GGML_USE_HIP) return 64; @@ -119,7 +121,7 @@ static constexpr __device__ int get_mmq_x_max_device() { #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } static int get_mmq_y_host(const int cc) { @@ -170,6 +172,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1; case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1; case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K; case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K; case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K; @@ -206,6 +209,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; @@ -229,7 +233,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { static int mmq_get_granularity_host(const int mmq_x, const int cc) { if (amd_mfma_available(cc)) { return mmq_x >= 128 ? 32 : 16; - } else if (new_mma_available(cc) && mmq_x >= 48) { + } else if (turing_mma_available(cc) && mmq_x >= 48) { return 16; } else { return 8; @@ -240,7 +244,7 @@ static int mmq_get_granularity_host(const int mmq_x, const int cc) { static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { return mmq_x >= 128 ? 32 : 16; } -#elif defined(NEW_MMA_AVAILABLE) +#elif defined(TURING_MMA_AVAILABLE) static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { return mmq_x >= 48 ? 16 : 8; } @@ -275,14 +279,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0); constexpr int nrows = warp_size / threads_per_row; @@ -301,12 +305,12 @@ template static __device__ __forceinline__ void loa const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b2(bxi->qs, kqsx); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); #else x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0; @@ -323,11 +327,11 @@ template static __device__ __forceinline__ void loa const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -378,14 +382,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1); constexpr int nrows = warp_size / threads_per_row; @@ -404,12 +408,12 @@ template static __device__ __forceinline__ void loa const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b4(bxi->qs, kqsx); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1; @@ -426,11 +430,11 @@ template static __device__ __forceinline__ void loa const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -481,14 +485,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0); constexpr int nrows = warp_size / threads_per_row; @@ -523,13 +527,13 @@ template static __device__ __forceinline__ void loa qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0; @@ -546,11 +550,11 @@ template static __device__ __forceinline__ void loa const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -559,14 +563,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1); constexpr int nrows = warp_size / threads_per_row; @@ -599,13 +603,13 @@ template static __device__ __forceinline__ void loa qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1; @@ -622,11 +626,11 @@ template static __device__ __forceinline__ void loa const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -635,14 +639,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp constexpr int threads_per_row = 32; @@ -661,13 +665,13 @@ template static __device__ __forceinline__ void loa const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0; @@ -684,11 +688,76 @@ template static __device__ __forceinline__ void loa const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } +} + +template static __device__ __forceinline__ void load_tiles_mxfp4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI_MXFP4; + const int kqsx = txi % QI_MXFP4; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx; + + const int aux_q4 = get_int_b1(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); + const int k0 = kbx * (2 * QI_MXFP4) + kqsx; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; +#else + x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -1109,7 +1178,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( } } } -#elif defined(NEW_MMA_AVAILABLE) +#elif defined(TURING_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile<16, 8, int> tile_A_8; @@ -1195,14 +1264,14 @@ template static __device__ __forceinline__ void loa const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K); constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row; @@ -1226,11 +1295,11 @@ template static __device__ __forceinline__ void loa const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int sc_m = bxi->scales[kqsx]; @@ -1241,11 +1310,11 @@ template static __device__ __forceinline__ void loa const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); #endif // FAST_FP16_AVAILABLE -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; #else x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -1383,7 +1452,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } } } -#elif defined(NEW_MMA_AVAILABLE) +#elif defined(TURING_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile<16, 8, int> tile_A_8; @@ -1513,7 +1582,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else @@ -1521,7 +1590,7 @@ template static __device__ __forceinline__ void loa int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K); constexpr int nrows = warp_size / threads_per_row; @@ -1549,11 +1618,11 @@ template static __device__ __forceinline__ void loa const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -1580,7 +1649,7 @@ template static __device__ __forceinline__ void loa const int sc = __vsubss4(sc_low | sc_high, 0x20202020); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) const int8_t * sc8 = (const int8_t *) ≻ const float d = bxi->d; @@ -1590,10 +1659,10 @@ template static __device__ __forceinline__ void loa } #else x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } -#if !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)) +#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; @@ -1606,7 +1675,7 @@ template static __device__ __forceinline__ void loa x_df[i] = bxi->d; } -#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)) +#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) } template @@ -1659,7 +1728,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else @@ -1667,7 +1736,7 @@ template static __device__ __forceinline__ void loa int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K); constexpr int nrows = warp_size / threads_per_row; @@ -1684,15 +1753,15 @@ template static __device__ __forceinline__ void loa const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; const int qs0 = get_int_b4(bxi->qs, txi); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int rows_per_warp = warp_size / 2; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { @@ -1760,7 +1829,7 @@ template static __device__ __forceinline__ void loa x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } template @@ -1803,7 +1872,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else @@ -1811,7 +1880,7 @@ template static __device__ __forceinline__ void loa int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K); constexpr int nrows = warp_size / threads_per_row; @@ -1839,16 +1908,16 @@ template static __device__ __forceinline__ void loa const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0; const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int rows_per_warp = warp_size / 2; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { @@ -1917,7 +1986,7 @@ template static __device__ __forceinline__ void loa x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } template @@ -1960,7 +2029,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K); @@ -1969,7 +2038,7 @@ template static __device__ __forceinline__ void loa int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K); constexpr int nrows = warp_size / threads_per_row; @@ -1996,13 +2065,13 @@ template static __device__ __forceinline__ void loa const int kq0 = 2*txi - txi % (QI6_K/2) + 0; const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } #pragma unroll @@ -2015,11 +2084,11 @@ template static __device__ __forceinline__ void loa const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int rows_per_warp = warp_size / 4; @@ -2033,11 +2102,11 @@ template static __device__ __forceinline__ void loa const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8)); #else x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8)); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2130,7 +2199,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } } } -#elif defined(NEW_MMA_AVAILABLE) +#elif defined(TURING_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile< 8, 4, int> tile_B; @@ -2242,14 +2311,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL); constexpr int nrows = warp_size / threads_per_row; @@ -2268,16 +2337,16 @@ template static __device__ __forceinline__ void loa const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx; const int aux_q4 = get_int_b2(bxi->qs, kqsx); - const int2 v = get_int_from_table_16(aux_q4); + const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); const int k0 = kbx * (2 * QI4_NL) + kqsx; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL; @@ -2294,11 +2363,11 @@ template static __device__ __forceinline__ void loa const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); #else x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2307,14 +2376,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2345,22 +2414,22 @@ template static __device__ __forceinline__ void loa const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2369,14 +2438,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2403,24 +2472,24 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2429,14 +2498,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2470,24 +2539,24 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2496,14 +2565,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2532,22 +2601,22 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2556,14 +2625,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2599,22 +2668,22 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2623,14 +2692,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S); constexpr int nrows = warp_size / threads_per_row; @@ -2658,23 +2727,23 @@ template static __device__ __forceinline__ void loa const int grid0 = (grid >> 0) & 0x0F0F0F0F; const int grid1 = (grid >> 4) & 0x0F0F0F0F; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); #else x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2683,14 +2752,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS); constexpr int nrows = warp_size / threads_per_row; @@ -2707,16 +2776,16 @@ template static __device__ __forceinline__ void loa const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; const int aux_q4 = get_int_b4(bxi->qs, kqsx); - const int2 v = get_int_from_table_16(aux_q4); + const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); const int k0 = 8 * (kqsx / 4) + kqsx % 4; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int rows_per_warp = warp_size / 8; @@ -2735,11 +2804,11 @@ template static __device__ __forceinline__ void loa const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2790,9 +2859,9 @@ static __device__ __forceinline__ void mmq_write_back_mma( constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); -#if defined(NEW_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { @@ -2863,6 +2932,14 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; @@ -2984,13 +3061,13 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( int * tile_y = data_mul_mat_q + mmq_x; int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; constexpr mmq_write_back_t write_back = mmq_write_back_mma; #else constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int blocks_per_iter = MMQ_ITER_K / qk; @@ -3457,7 +3534,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); const size_t nbs_ids = mmq_x*sizeof(int); - const size_t nbs_x = (new_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); } @@ -3642,6 +3719,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q4_1); extern DECL_MMQ_CASE(GGML_TYPE_Q5_0); extern DECL_MMQ_CASE(GGML_TYPE_Q5_1); extern DECL_MMQ_CASE(GGML_TYPE_Q8_0); +extern DECL_MMQ_CASE(GGML_TYPE_MXFP4); extern DECL_MMQ_CASE(GGML_TYPE_Q2_K); extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmvf.cu similarity index 86% rename from ggml/src/ggml-cuda/mmv.cu rename to ggml/src/ggml-cuda/mmvf.cu index e14c93516b..1ad4bc75ba 100644 --- a/ggml/src/ggml-cuda/mmv.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -1,9 +1,9 @@ #include "ggml.h" #include "common.cuh" -#include "mmv.cuh" +#include "mmvf.cuh" template -static __global__ void mul_mat_vec( +static __global__ void mul_mat_vec_f( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, @@ -37,7 +37,7 @@ static __global__ void mul_mat_vec( float sumf[ncols_dst] = {0.0f}; - if constexpr (std::is_same::value) { + if constexpr (std::is_same_v) { const float2 * x2 = (const float2 *) x; for (int col2 = tid; col2 < ncols2; col2 += block_size) { @@ -50,10 +50,10 @@ static __global__ void mul_mat_vec( sumf[j] += tmpx.y*tmpy.y; } } - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same_v) { const half2 * x2 = (const half2 *) x; - if (std::is_same::value) { + if (std::is_same_v) { for (int col2 = tid; col2 < ncols2; col2 += block_size) { const float2 tmpx = __half22float2(x2[col2]); @@ -86,7 +86,7 @@ static __global__ void mul_mat_vec( NO_DEVICE_CODE; #endif // FP16_AVAILABLE } - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same_v) { const int * x2 = (const int *) x; for (int col2 = tid; col2 < ncols2; col2 += block_size) { const int tmpx = x2[col2]; @@ -98,7 +98,7 @@ static __global__ void mul_mat_vec( } } } else { - static_assert(std::is_same::value, "unsupported type"); + static_assert(std::is_same_v, "unsupported type"); } #pragma unroll @@ -126,7 +126,7 @@ static __global__ void mul_mat_vec( } template -static void launch_mul_mat_vec_cuda( +static void launch_mul_mat_vec_f_cuda( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, @@ -141,11 +141,9 @@ static void launch_mul_mat_vec_cuda( GGML_ASSERT( nsamples_dst % nsamples_x == 0); const int64_t channel_ratio = nchannels_dst / nchannels_x; const int64_t sample_ratio = nsamples_dst / nsamples_x; - int device; - int warp_size; - CUDA_CHECK(cudaGetDevice(&device)); - warp_size = ggml_cuda_info().devices[device].warp_size; + const int device = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[device].warp_size; int64_t block_size_best = warp_size; int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size); @@ -161,54 +159,54 @@ static void launch_mul_mat_vec_cuda( } } - const int smem = warp_size*sizeof(float); + const int nbytes_shared = warp_size*sizeof(float); const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { case 32: { - mul_mat_vec<<>> + mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 64: { - mul_mat_vec<<>> + mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 96: { - mul_mat_vec<<>> + mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 128: { - mul_mat_vec<<>> + mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 160: { - mul_mat_vec<<>> + mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 192: { - mul_mat_vec<<>> + mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 224: { - mul_mat_vec<<>> + mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 256: { - mul_mat_vec<<>> + mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); @@ -220,7 +218,7 @@ static void launch_mul_mat_vec_cuda( } template -static void mul_mat_vec_cuda_switch_ncols_dst( +static void mul_mat_vec_f_cuda_switch_ncols_dst( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, @@ -230,49 +228,49 @@ static void mul_mat_vec_cuda_switch_ncols_dst( cudaStream_t stream) { switch (ncols_dst) { case 1: - launch_mul_mat_vec_cuda + launch_mul_mat_vec_f_cuda (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 2: - launch_mul_mat_vec_cuda + launch_mul_mat_vec_f_cuda (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 3: - launch_mul_mat_vec_cuda + launch_mul_mat_vec_f_cuda (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 4: - launch_mul_mat_vec_cuda + launch_mul_mat_vec_f_cuda (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 5: - launch_mul_mat_vec_cuda + launch_mul_mat_vec_f_cuda (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 6: - launch_mul_mat_vec_cuda + launch_mul_mat_vec_f_cuda (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 7: - launch_mul_mat_vec_cuda + launch_mul_mat_vec_f_cuda (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 8: - launch_mul_mat_vec_cuda + launch_mul_mat_vec_f_cuda (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); @@ -284,7 +282,7 @@ static void mul_mat_vec_cuda_switch_ncols_dst( } template -static void mul_mat_vec_cuda( +static void mul_mat_vec_f_cuda( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst, @@ -292,22 +290,22 @@ static void mul_mat_vec_cuda( const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, enum ggml_prec prec, cudaStream_t stream) { - if constexpr(std::is_same::value) { + if constexpr(std::is_same_v) { if (prec == GGML_PREC_DEFAULT) { - mul_mat_vec_cuda_switch_ncols_dst + mul_mat_vec_f_cuda_switch_ncols_dst (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); return; } } - mul_mat_vec_cuda_switch_ncols_dst + mul_mat_vec_f_cuda_switch_ncols_dst (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } -void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { +void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -355,19 +353,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; @@ -376,7 +374,7 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * } } -void ggml_cuda_op_mul_mat_vec( +void ggml_cuda_op_mul_mat_vec_f( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, @@ -414,19 +412,19 @@ void ggml_cuda_op_mul_mat_vec( switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; @@ -442,15 +440,15 @@ void ggml_cuda_op_mul_mat_vec( GGML_UNUSED(src1_padded_row_size); } -bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) { +bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) { if (src0_ne[0] % 2 != 0) { return false; } switch (type) { case GGML_TYPE_F32: if (GGML_CUDA_CC_IS_NVIDIA(cc)) { - if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { - return ne11 <= 8; + if (ampere_mma_available(cc)) { + return ne11 <= 3; } if (cc >= GGML_CUDA_CC_TURING) { return ne11 <= 4; @@ -466,6 +464,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ case GGML_TYPE_F16: if (GGML_CUDA_CC_IS_NVIDIA(cc)) { const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1); + if (ampere_mma_available(cc)) { + return src0_small && ne11 == 1; + } if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { return src0_small && ne11 <= 4; } @@ -486,6 +487,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ case GGML_TYPE_BF16: if (GGML_CUDA_CC_IS_NVIDIA(cc)) { const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1); + if (ampere_mma_available(cc)) { + return src0_small && ne11 == 1; + } if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { return src0_small && ne11 <= 4; } diff --git a/ggml/src/ggml-cuda/mmv.cuh b/ggml/src/ggml-cuda/mmvf.cuh similarity index 55% rename from ggml/src/ggml-cuda/mmv.cuh rename to ggml/src/ggml-cuda/mmvf.cuh index 1330bcb6a8..1da460992e 100644 --- a/ggml/src/ggml-cuda/mmv.cuh +++ b/ggml/src/ggml-cuda/mmvf.cuh @@ -1,11 +1,11 @@ #include "common.cuh" -void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); +void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); -void ggml_cuda_op_mul_mat_vec( +void ggml_cuda_op_mul_mat_vec_f( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); -bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11); +bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index dc7adf509f..5c8e5c4a7e 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -13,6 +13,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1; case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1; case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1; + case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1; case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1; case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1; case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1; @@ -38,6 +39,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ; case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ; case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ; + case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ; case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ; case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ; case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ; @@ -384,6 +386,13 @@ static void mul_mat_vec_q_switch_type( nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; + case GGML_TYPE_MXFP4: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; case GGML_TYPE_Q2_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index 14543e978c..eeacde0bdb 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -45,7 +45,7 @@ struct soft_max_params { #endif // __clang__ template static __global__ void soft_max_f32( - const float * x, const T * mask, float * dst, const soft_max_params p) { + const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) { const int ncols = ncols_template == 0 ? p.ncols : ncols_template; const int tid = threadIdx.x; @@ -77,7 +77,7 @@ static __global__ void soft_max_f32( // shared memory buffer to cache values between iterations: float * vals = use_shared ? buf_iw + WARP_SIZE : dst; - float max_val = -INFINITY; + float max_val = sinks ? sinks[i02] : -INFINITY; #pragma unroll for (int col0 = 0; col0 < ncols; col0 += block_size) { @@ -143,6 +143,10 @@ static __global__ void soft_max_f32( tmp = warp_reduce_sum(tmp); } + if (sinks) { + tmp += expf(sinks[i02] - max_val); + } + const float inv_sum = 1.0f / tmp; #pragma unroll @@ -183,7 +187,7 @@ static __global__ void soft_max_back_f32( } template -static void launch_soft_max_kernels(const float * x, const T * mask, float * dst, +static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared) { const int id = ggml_cuda_get_device(); @@ -196,7 +200,7 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst if (p.ncols == ncols) { CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32), smpbo); soft_max_f32<<>> - (x, mask, dst, p); + (x, mask, sinks, dst, p); return true; } return false; @@ -209,12 +213,12 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst //default case CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32), smpbo); - soft_max_f32<<>>(x, mask, dst, p); + soft_max_f32<<>>(x, mask, sinks, dst, p); } template -static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) { int nth = WARP_SIZE; const int64_t ncols_x = params.ncols; @@ -230,10 +234,10 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons if (nbytes_shared <= smpbo) { - launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared); + launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared); } else { const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, dst, params); + soft_max_f32<<>>(x, mask, sinks, dst, params); } } @@ -249,9 +253,11 @@ static void soft_max_back_f32_cuda( void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; const float * src0_d = (const float *) src0->data; const void * src1_d = src1 ? (const void *) src1->data : nullptr; + const void * src2_d = src2 ? (const void *) src2->data : nullptr; float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); @@ -309,9 +315,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { params.m1 = m1; if (use_f16) { - soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream); + soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream); } else { - soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream); + soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream); } } diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu new file mode 100644 index 0000000000..c14624c52c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_MXFP4); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 91c830c4da..5aff8a876a 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -300,6 +300,81 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_op_unary_gated(ctx, dst); } +// swiglu_oai + +template +static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) { + const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + // perform base op and multiply with gate (either offset in same tensor or a separate one) + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + + float xi = x[j0]; + float gi = g[j1]; + xi = fminf(xi, limit); + gi = fmaxf(fminf(gi, limit), -limit); + + float out_glu = xi / (1.0f + expf(-xi * alpha)); + out_glu = out_glu * (1.0f + gi); + + dst[i] = out_glu; +} + +template +static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) { + const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; + swiglu_oai_kernel<<>>(x, g, dst, k, n, o0, o1, alpha, limit); +} + +void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + void * src0_d = src0->data; + void * src1_d = src1 ? src1->data : src0->data; + const int64_t src0_o = src0->nb[1]; + const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + void * dst_d = dst->data; + const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(src0->nb[0] == ggml_element_size(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == dst->type); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src1->nb[0] == ggml_element_size(src1)); + GGML_ASSERT(src1->ne[0] == nc); + GGML_ASSERT(src0->type == src1->type); + } + + //const int32_t swapped = ((const int32_t *) dst->op_params)[1]; + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + const float alpha = ggml_get_op_params_f32(dst, 2); + const float limit = ggml_get_op_params_f32(dst, 3); + + float * src0_p = (float *) src0_d; + float * src1_p = (float *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream); +} + /* silu_back */ static __device__ __forceinline__ float op_silu_back(float grad, float x) { diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index cb14d16f8f..da3caf1d89 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -67,6 +67,8 @@ void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index ba195e1d10..d8f9aa5ba6 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -1,8 +1,20 @@ #pragma once #include "common.cuh" + #include +static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) { + const uint8_t * x8 = (const uint8_t *) x; + + int x32 = x8[4*i32 + 0] << 0; + x32 |= x8[4*i32 + 1] << 8; + x32 |= x8[4*i32 + 2] << 16; + x32 |= x8[4*i32 + 3] << 24; + + return x32; +} + static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) { const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment @@ -16,6 +28,20 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32 return ((const int *) x)[i32]; // assume at least 4 byte alignment } +static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) { + const int q0_32 = (q4 >> 0) & 0x0F0F0F0F; + const int8_t * q0_8 = (const int8_t *) &q0_32; + const char4 val0_8 = make_char4( + table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]); + + const int q1_32 = (q4 >> 4) & 0x0F0F0F0F; + const int8_t * q1_8 = (const int8_t *) &q1_32; + const char4 val1_8 = make_char4( + table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]); + + return make_int2(*((const int *) &val0_8), *((const int *) &val1_8)); +} + // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q @@ -211,6 +237,30 @@ template static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_ return d8_1*sumf; } +#define VDR_MXFP4_Q8_1_MMVQ 2 +#define VDR_MXFP4_Q8_1_MMQ 4 + +static __device__ __forceinline__ float vec_dot_mxfp4_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx; + + const int * q8 = (const int *) bq8_1->qs + iqs; + + int sumi = 0; +#pragma unroll + for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) { + const int aux_q4 = get_int_b1(bq4->qs, iqs + l); + const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); + + sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi); + sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi); + } + + const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds); + return d * sumi; +} + #define VDR_Q2_K_Q8_1_MMVQ 1 #define VDR_Q2_K_Q8_1_MMQ 4 @@ -1068,20 +1118,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1); } -static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) { - const int q0_32 = (q4 >> 0) & 0x0F0F0F0F; - const int8_t * q0_8 = (const int8_t *) &q0_32; - const char4 val0_8 = make_char4( - kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]); - - const int q1_32 = (q4 >> 4) & 0x0F0F0F0F; - const int8_t * q1_8 = (const int8_t *) &q1_32; - const char4 val1_8 = make_char4( - kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]); - - return make_int2(*((const int *) &val0_8), *((const int *) &val1_8)); -} - #define VDR_IQ4_NL_Q8_1_MMVQ 2 #define VDR_IQ4_NL_Q8_1_MMQ 4 @@ -1096,7 +1132,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1( #pragma unroll for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) { const int aux_q4 = get_int_b2(bq4->qs, iqs + l); - const int2 v = get_int_from_table_16(aux_q4); + const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi); sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi); @@ -1118,7 +1154,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1( #pragma unroll for (int j = 0; j < 4; ++j) { const int aux_q4 = get_int_b4(bq4->qs, iqs + j); - const int2 v = get_int_from_table_16(aux_q4); + const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0); const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4); diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index 1746b07320..3b3086778e 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -6,6 +6,10 @@ #include #include +#if CUDART_VERSION >= 12050 +#include +#endif // CUDART_VERSION >= 12050 + #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 8b172e60f4..c31f319232 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -200,6 +200,7 @@ #endif typedef hip_bfloat16 nv_bfloat16; +typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 1989632024..8c55a2e4e5 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -137,4 +137,5 @@ #define cudaStreamEndCapture musaStreamEndCapture #define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor -typedef mt_bfloat16 nv_bfloat16; +typedef __mt_bfloat16 nv_bfloat16; +typedef __mt_bfloat162 nv_bfloat162; diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index e92ec7faa3..852de97346 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -121,6 +121,10 @@ if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7 add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12) endif() +if (GGML_HIP_EXPORT_METRICS) + set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps") +endif() + if (NOT GGML_CUDA_FA) add_compile_definitions(GGML_CUDA_NO_FA) endif() diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index a2e30994c4..19a7adb2d1 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -410,6 +410,67 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x) #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) +static inline float ggml_e8m0_to_fp32(uint8_t x) { + uint32_t bits; // Stores the raw bit representation of the float + + // Handle special case for minimum exponent (denormalized float) + if (x == 0) { + // Bit pattern for 2^(-127): + // - Sign bit: 0 (positive) + // - Exponent: 0 (denormalized number) + // - Mantissa: 0x400000 (0.5 in fractional form) + // Value = 0.5 * 2^(-126) = 2^(-127) + bits = 0x00400000; + } + // note: disabled as we don't need to handle NaNs + //// Handle special case for NaN (all bits set) + //else if (x == 0xFF) { + // // Standard quiet NaN pattern: + // // - Sign bit: 0 + // // - Exponent: all 1s (0xFF) + // // - Mantissa: 0x400000 (quiet NaN flag) + // bits = 0x7FC00000; + //} + // Normalized values (most common case) + else { + // Construct normalized float by shifting exponent into position: + // - Exponent field: 8 bits (positions 30-23) + // - Mantissa: 0 (implicit leading 1) + // Value = 2^(x - 127) + bits = (uint32_t) x << 23; + } + + float result; // Final float value + // Safely reinterpret bit pattern as float without type-punning issues + memcpy(&result, &bits, sizeof(float)); + return result; +} + +// Equal to ggml_e8m0_to_fp32/2 +// Useful with MXFP4 quantization since the E0M2 values are doubled +static inline float ggml_e8m0_to_fp32_half(uint8_t x) { + uint32_t bits; + + // For x < 2: use precomputed denormal patterns + if (x < 2) { + // 0x00200000 = 2^(-128), 0x00400000 = 2^(-127) + bits = 0x00200000 << x; + } + // For x >= 2: normalized exponent adjustment + else { + // 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1) + bits = (uint32_t)(x - 1) << 23; + } + // Note: NaNs are not handled here + + float result; + memcpy(&result, &bits, sizeof(float)); + return result; +} + +#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x) +#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x) + /** * Converts brain16 to float32. * diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 8424464d8c..fc6526d6d5 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -23,6 +23,9 @@ #define N_R0_Q8_0 4 #define N_SG_Q8_0 2 +#define N_R0_MXFP4 2 +#define N_SG_MXFP4 2 + #define N_R0_Q2_K 4 #define N_SG_Q2_K 2 @@ -129,6 +132,15 @@ typedef struct { uint64_t o1[8]; } ggml_metal_kargs_bin; +typedef struct { + int64_t ne0; + int64_t ne1; + size_t nb01; + size_t nb02; + size_t nb11; + size_t nb21; +} ggml_metal_kargs_add_id; + typedef struct { int32_t ne00; int32_t ne01; @@ -444,6 +456,8 @@ typedef struct{ uint64_t nb1; int32_t i00; int32_t i10; + float alpha; + float limit; } ggml_metal_kargs_glu; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 337f7985ba..cb8eff4a77 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -195,6 +195,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, GGML_METAL_KERNEL_TYPE_DIV, GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, + GGML_METAL_KERNEL_TYPE_ADD_ID, GGML_METAL_KERNEL_TYPE_REPEAT_F32, GGML_METAL_KERNEL_TYPE_REPEAT_F16, GGML_METAL_KERNEL_TYPE_REPEAT_I32, @@ -234,6 +235,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, + GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, @@ -286,6 +288,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, @@ -310,6 +313,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, @@ -351,6 +358,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, @@ -373,6 +381,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, @@ -397,6 +406,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, @@ -579,6 +589,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_REGLU, GGML_METAL_KERNEL_TYPE_GEGLU, GGML_METAL_KERNEL_TYPE_SWIGLU, + GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, GGML_METAL_KERNEL_TYPE_GEGLU_ERF, GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, GGML_METAL_KERNEL_TYPE_SUM_ROWS, @@ -1199,6 +1210,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true); @@ -1238,6 +1250,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); @@ -1290,6 +1303,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction); @@ -1314,6 +1328,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2, mul_mv_ext_mxfp4_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3, mul_mv_ext_mxfp4_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4, mul_mv_ext_mxfp4_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5, mul_mv_ext_mxfp4_f32_r1_5, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction); @@ -1355,6 +1373,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction); @@ -1377,6 +1396,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm); @@ -1401,6 +1422,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm); @@ -1583,6 +1605,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, swiglu_oai, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); @@ -1774,6 +1797,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_GLU_OP_REGLU: case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; @@ -1791,6 +1815,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_ADD_ID: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: case GGML_OP_REPEAT: @@ -2042,6 +2067,7 @@ static int ggml_metal_encode_node( const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; size_t offs_src0 = 0; @@ -2291,6 +2317,38 @@ static int ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } } break; + case GGML_OP_ADD_ID: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + GGML_ASSERT(src2t == GGML_TYPE_I32); + GGML_ASSERT(dstt == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous_rows(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ID].pipeline; + + ggml_metal_kargs_add_id args = { + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb11 =*/ nb11, + /*.nb21 =*/ nb21, + + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_REPEAT: { id pipeline; @@ -2710,6 +2768,9 @@ static int ggml_metal_encode_node( case GGML_GLU_OP_SWIGLU: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; break; + case GGML_GLU_OP_SWIGLU_OAI: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_OAI].pipeline; + break; case GGML_GLU_OP_GEGLU_ERF: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline; break; @@ -2720,7 +2781,9 @@ static int ggml_metal_encode_node( GGML_ABORT("fatal error"); } - const int32_t swp = ((const int32_t *) dst->op_params)[1]; + const int32_t swp = ggml_get_op_params_i32(dst, 1); + const float alpha = ggml_get_op_params_f32(dst, 2); + const float limit = ggml_get_op_params_f32(dst, 3); const int32_t i00 = swp ? ne0 : 0; const int32_t i10 = swp ? 0 : ne0; @@ -2734,6 +2797,8 @@ static int ggml_metal_encode_node( /*.nb1 =*/ nb1, /*.i00 =*/ src1 ? 0 : i00, /*.i10 =*/ src1 ? 0 : i10, + /*.alpha=*/ alpha, + /*.limit=*/ limit }; [encoder setComputePipelineState:pipeline]; @@ -2992,8 +3057,13 @@ static int ggml_metal_encode_node( } else { [encoder setBuffer:h_src0 offset:offs_src0 atIndex:1]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&args length:sizeof(args) atIndex:3]; + if (id_src2) { + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + } else { + [encoder setBuffer:h_src0 offset:offs_src0 atIndex:2]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBytes:&args length:sizeof(args) atIndex:4]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; @@ -3291,6 +3361,7 @@ static int ggml_metal_encode_node( src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || + src0t == GGML_TYPE_MXFP4 || src0t == GGML_TYPE_IQ4_NL || false) && (ne11 >= 2 && ne11 <= 8) ) || @@ -3383,6 +3454,14 @@ static int ggml_metal_encode_node( case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break; default: GGML_ABORT("not implemented"); } break; + case GGML_TYPE_MXFP4: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; case GGML_TYPE_Q4_K: switch (r1ptg) { case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break; @@ -3481,6 +3560,7 @@ static int ggml_metal_encode_node( case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; + case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break; case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; @@ -3623,6 +3703,13 @@ static int ggml_metal_encode_node( nr0 = N_R0_Q8_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; } break; + case GGML_TYPE_MXFP4: + { + nsg = N_SG_MXFP4; + nr0 = N_R0_MXFP4; + smem = 32*sizeof(float); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline; + } break; case GGML_TYPE_Q2_K: { nsg = N_SG_Q2_K; @@ -3756,8 +3843,6 @@ static int ggml_metal_encode_node( case GGML_OP_MUL_MAT_ID: { // src2 = ids - const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); - GGML_ASSERT(src2t == GGML_TYPE_I32); GGML_ASSERT(!ggml_is_transposed(src0)); @@ -3883,6 +3968,7 @@ static int ggml_metal_encode_node( case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break; case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break; + case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break; case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break; case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break; case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break; @@ -4018,6 +4104,13 @@ static int ggml_metal_encode_node( nr0 = N_R0_Q8_0; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; } break; + case GGML_TYPE_MXFP4: + { + nsg = N_SG_MXFP4; + nr0 = N_R0_MXFP4; + smem = 32*sizeof(float); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline; + } break; case GGML_TYPE_Q2_K: { nsg = N_SG_Q2_K; @@ -4170,6 +4263,7 @@ static int ggml_metal_encode_node( case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; + case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4 ].pipeline; break; case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; @@ -4980,11 +5074,14 @@ static int ggml_metal_encode_node( GGML_ASSERT(ne11 == ne21); GGML_ASSERT(ne12 == ne22); - struct ggml_tensor * src3 = node->src[3]; + struct ggml_tensor * src3 = node->src[3]; // mask + struct ggml_tensor * src4 = node->src[4]; // sinks size_t offs_src3 = 0; + size_t offs_src4 = 0; id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && @@ -5000,8 +5097,6 @@ static int ggml_metal_encode_node( const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); - const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); - float scale; float max_bias; float logit_softcap; @@ -5389,7 +5484,12 @@ static int ggml_metal_encode_node( } else { [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:5]; + if (id_src4) { + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; if (!use_vec_kernel) { // half8x8 kernel diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 99a453090f..b35a3bbdc3 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -35,6 +35,10 @@ constexpr constant static float kvalues_iq4nl_f[16] = { -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f }; +constexpr constant static float kvalues_mxfp4_f[16] = { + 0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f +}; + static inline int best_index_int8(int n, constant float * val, float x) { if (x <= val[0]) return 0; if (x >= val[n-1]) return n-1; @@ -46,6 +50,18 @@ static inline int best_index_int8(int n, constant float * val, float x) { return x - val[mu-1] < val[mu] - x ? mu-1 : mu; } +static inline float e8m0_to_fp32(uint8_t x) { + uint32_t bits; + + if (x == 0) { + bits = 0x00400000; + } else { + bits = (uint32_t) x << 23; + } + + return as_type(bits); +} + // NOTE: this is not dequantizing - we are simply fitting the template template void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { @@ -242,6 +258,27 @@ void quantize_q5_1(device const float * src, device block_q5_1 & dst) { } } +void quantize_q8_0(device const float * src, device block_q8_0 & dst) { +#pragma METAL fp math_mode(safe) + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = src[j]; + amax = MAX(amax, fabs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst.d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = src[j]*id; + + dst.qs[j] = round(x0); + } +} + void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) { #pragma METAL fp math_mode(safe) float amax = 0.0f; // absolute max @@ -462,25 +499,34 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re } } -void quantize_q8_0(device const float * src, device block_q8_0 & dst) { -#pragma METAL fp math_mode(safe) - float amax = 0.0f; // absolute max +template +void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) { + device const uint8_t * q2 = (device const uint8_t *)xb->qs; - for (int j = 0; j < QK8_0; j++) { - const float v = src[j]; - amax = MAX(amax, fabs(v)); + const float d = e8m0_to_fp32(xb->e); + const uint8_t shr = il >= 1 ? 4 : 0; + + for (int i = 0; i < 4; ++i) { + reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F]; + reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F]; + reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F]; + reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F]; } +} - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; +template +void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) { + device const uint8_t * q2 = (device const uint8_t *)xb->qs; - dst.d = d; + const float d = e8m0_to_fp32(xb->e); + const short il4 = il%4; - for (int j = 0; j < QK8_0; ++j) { - const float x0 = src[j]*id; + const uint8_t shr = il >= 4 ? 4 : 0; - dst.qs[j] = round(x0); - } + reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F]; + reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F]; + reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F]; + reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F]; } template @@ -960,6 +1006,32 @@ kernel void kernel_div( } } +kernel void kernel_add_id( + constant ggml_metal_kargs_add_id & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i1 = tgpig.x; + const int i2 = tgpig.y; + + const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21)); + + const size_t nb1 = args.ne0 * sizeof(float); + const size_t nb2 = args.ne1 * nb1; + + device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2); + device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02); + device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11); + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + dst_row[i0] = src0_row[i0] + src1_row[i0]; + } +} + template kernel void kernel_repeat( constant ggml_metal_kargs_repeat & args, @@ -1431,6 +1503,32 @@ kernel void kernel_swiglu( } } +kernel void kernel_swiglu_oai( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { + float x0 = src0_row[i0]; + float x1 = src1_row[i0]; + + x0 = min(x0, args.limit); + x1 = max(min(x1, args.limit), -args.limit); + + float out_glu = x0 / (1.0f + exp(-x0 * args.alpha)); + out_glu = out_glu * (1.0f + x1); + + dst_row[i0] = out_glu; + } +} + kernel void kernel_geglu_erf( device const char * src0, device const char * src1, @@ -1534,6 +1632,7 @@ template kernel void kernel_soft_max( device const char * src0, device const char * src1, + device const char * src2, device char * dst, constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], @@ -1552,6 +1651,7 @@ kernel void kernel_soft_max( device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr; + device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr; device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3); float slope = 1.0f; @@ -1567,7 +1667,7 @@ kernel void kernel_soft_max( } // parallel max - float lmax = -INFINITY; + float lmax = psrc2 ? psrc2[i02] : -INFINITY; for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) { lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)); @@ -1623,6 +1723,10 @@ kernel void kernel_soft_max( sum = simd_sum(sum); } + if (psrc2) { + sum += exp(psrc2[i02] - max_val); + } + const float inv_sum = 1.0f/sum; for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) { @@ -1634,6 +1738,7 @@ template kernel void kernel_soft_max_4( device const char * src0, device const char * src1, + device const char * src2, device char * dst, constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], @@ -1652,6 +1757,7 @@ kernel void kernel_soft_max_4( device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr; + device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr; device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3); float slope = 1.0f; @@ -1666,7 +1772,7 @@ kernel void kernel_soft_max_4( } // parallel max - float4 lmax4 = -INFINITY; + float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY; for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) { lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); @@ -1725,6 +1831,10 @@ kernel void kernel_soft_max_4( sum = simd_sum(sum); } + if (psrc2) { + sum += exp(psrc2[i02] - max_val); + } + const float inv_sum = 1.0f/sum; for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) { @@ -3106,6 +3216,11 @@ template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4 template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>; template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_mxfp4, 32, dequantize_mxfp4_t4>; +template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_mxfp4, 32, dequantize_mxfp4_t4>; +template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4, 32, dequantize_mxfp4_t4>; +template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4, 32, dequantize_mxfp4_t4>; + template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>; template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>; template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>; @@ -4092,6 +4207,7 @@ kernel void kernel_flash_attn_ext( device const char * k, device const char * v, device const char * mask, + device const char * sinks, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -4407,6 +4523,35 @@ kernel void kernel_flash_attn_ext( } } + if (sinks != q && sgitg == 0) { + for (ushort j = 0; j < Q; ++j) { + const float m = M[j]; + const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; + + M[j] = simd_max(max(M[j], s)); + + const float ms = exp(m - M[j]); + const float vs = exp(s - M[j]); + + S[j] = S[j]*ms + simd_sum(vs); + + if (tiisg == j) { + ss[j*TS + 2*C + j] = ms; + } + } + + // O = diag(ms)*O + { + s8x8_t ms; + simdgroup_load(ms, ss + 2*C, TS, 0, false); + + #pragma unroll(DV8) + for (short i = 0; i < DV8; ++i) { + simdgroup_multiply(lo[i], ms, lo[i]); + } + } + } + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (short j = tiisg; j < Q; j += NW) { ss[j*TS + 0] = S[j]; @@ -4618,6 +4763,7 @@ kernel void kernel_flash_attn_ext_vec( device const char * k, device const char * v, device const char * mask, + device const char * sinks, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -4835,6 +4981,23 @@ kernel void kernel_flash_attn_ext_vec( } } + if (sinks != q && sgitg == 0) { + const float m = M; + const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; + + M = simd_max(max(M, s)); + + const float ms = exp(m - M); + const float vs = exp(s - M); + + S = S*ms + simd_sum(vs); + +#pragma unroll(DV4/NL) + for (short ii = 0; ii < DV4; ii += NL) { + lo[ii/NL] *= ms; + } + } + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) if (tiisg == 0) { ss[0] = (s_t) S; @@ -6940,6 +7103,95 @@ kernel void kernel_mul_mv_iq4_xs_f32( kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } +template +void kernel_mul_mv_mxfp4_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + threadgroup float * shmem_f32 = (threadgroup float *) shmem; + const int nb = args.ne00/QK_MXFP4; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + const short ix = tiisg/2; // 0...15 + const short it = tiisg%2; // 0 or 1 + + shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[nr0]={0.f}; + + device const float * yb = y + ix * QK_MXFP4 + it * 8; + + for (int ib = ix; ib < nb; ib += 16) { + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; + yl[1] = y4[4]; + yl[2] = y4[1]; + yl[3] = y4[5]; + +#pragma unroll(nr0) + for (short row = 0; row < nr0; row++) { + device const block_mxfp4 & xb = x[row*nb + ib]; + device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it); + + float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]); + float4 acc2 = yl[1]*float4(shmem_f32[q2[0] >> 4 ], shmem_f32[q2[1] >> 4 ], shmem_f32[q2[2] >> 4 ], shmem_f32[q2[3] >> 4 ]); + float4 acc3 = yl[2]*float4(shmem_f32[q2[4] & 0x0F], shmem_f32[q2[5] & 0x0F], shmem_f32[q2[6] & 0x0F], shmem_f32[q2[7] & 0x0F]); + float4 acc4 = yl[3]*float4(shmem_f32[q2[4] >> 4 ], shmem_f32[q2[5] >> 4 ], shmem_f32[q2[6] >> 4 ], shmem_f32[q2[7] >> 4 ]); + + acc1 = (acc1 + acc3) + (acc2 + acc4); + + sumf[row] += e8m0_to_fp32(xb.e) * ((acc1[0] + acc1[1]) + (acc1[2] + acc1[3])); + } + + yb += 16 * QK_MXFP4; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all; + } + } +} + +[[host_name("kernel_mul_mv_mxfp4_f32")]] +kernel void kernel_mul_mv_mxfp4_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + template kernel void kernel_get_rows_q( constant ggml_metal_kargs_get_rows & args, @@ -7475,6 +7727,7 @@ template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q; @@ -7527,6 +7780,7 @@ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm; @@ -7558,6 +7812,7 @@ template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; @@ -7703,6 +7958,8 @@ template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 3adea83615..d8290faa46 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -55,6 +55,7 @@ endfunction() set(GGML_OPENCL_KERNELS add + add_id argsort clamp cpy diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 150842f366..8ba1e00df7 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -345,6 +345,7 @@ struct ggml_backend_opencl_context { cl_command_queue queue; cl_program program_add; + cl_program program_add_id; cl_program program_clamp; cl_program program_cpy; cl_program program_cvt; @@ -404,6 +405,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16; cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16; cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16; + cl_kernel kernel_add_id; cl_kernel kernel_scale; cl_kernel kernel_silu, kernel_silu_4; cl_kernel kernel_gelu, kernel_gelu_4; @@ -412,7 +414,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_relu; cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16; cl_kernel kernel_clamp; - cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick, + cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick, kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16; cl_kernel kernel_norm; cl_kernel kernel_rms_norm, kernel_rms_norm_mul; @@ -600,6 +602,7 @@ struct ggml_backend_opencl_context { if (ref_count == 0) { #ifdef GGML_OPENCL_PROFILING write_profiling_info(); + profiling_info.clear(); #endif } } @@ -681,6 +684,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // add_id + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "add_id.cl.h" + }; +#else + const std::string kernel_src = read_file("add_id.cl"); +#endif + backend_ctx->program_add_id = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_add_id = clCreateKernel(backend_ctx->program_add_id, "kernel_add_id", &err), err)); + GGML_LOG_CONT("."); + } + // clamp { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -787,6 +806,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err)); CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err)); CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err)); + CL_CHECK((backend_ctx->kernel_swiglu_oai = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_oai", &err), err)); CL_CHECK((backend_ctx->kernel_geglu_erf = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err)); CL_CHECK((backend_ctx->kernel_geglu_quick = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err)); CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err)); @@ -2046,8 +2066,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { backend_ctx->adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version); backend_ctx->has_vector_subgroup_broadcast = - backend_ctx->adreno_cl_compiler_version.major >= 47 || - backend_ctx->adreno_cl_compiler_version.major == 17; + (backend_ctx->adreno_cl_compiler_version.type == E031 && backend_ctx->adreno_cl_compiler_version.major >= 47) || + (backend_ctx->adreno_cl_compiler_version.type == DX && backend_ctx->adreno_cl_compiler_version.major >= 17); GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n", backend_ctx->has_vector_subgroup_broadcast ? "true" : "false"); @@ -2467,6 +2487,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return (op->src[0]->type == op->src[1]->type) && (op->src[0]->type == op->type) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); + case GGML_OP_ADD_ID: + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_GELU: @@ -2488,6 +2510,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_REGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); @@ -2601,10 +2624,10 @@ ggml_backend_t ggml_backend_opencl_init(void) { ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); ggml_backend_t backend = new ggml_backend { - /* .guid = */ ggml_backend_opencl_guid(), - /* .interface = */ ggml_backend_opencl_i, - /* .device = */ dev, - /* .context = */ backend_ctx + /* .guid = */ ggml_backend_opencl_guid(), + /* .iface = */ ggml_backend_opencl_i, + /* .device = */ dev, + /* .context = */ backend_ctx }; return backend; @@ -3822,6 +3845,75 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const } } +static void ggml_cl_add_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const ggml_tensor * src2 = dst->src[2]; + GGML_ASSERT(src2); + GGML_ASSERT(src2->extra); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_I32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous_rows(src0)); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + + const cl_ulong nb11 = src1->nb[1]; + + const cl_ulong nb21 = src2->nb[1]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offset2 = extra2->offset + src2->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel = backend_ctx->kernel_add_id; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb21)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1)); + + int nth = MIN(ne00, (int) backend_ctx->get_kernel_workgroup_size(kernel)); + size_t global_work_size[] = { (size_t)ne01*nth, (size_t)ne02, 1 }; + size_t local_work_size[] = { (size_t)nth, 1, 1 }; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -6500,17 +6592,24 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c GGML_ASSERT(src1->extra); } + const ggml_tensor * src2 = dst->src[2]; + if (src2) { + GGML_ASSERT(src2->extra); + } + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr; + ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr; cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0; + cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0; const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; @@ -6578,25 +6677,27 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias)); - CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0)); - CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), extra2 ? &extra2->data_device : &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &max_bias)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float), &m0)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &m1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &n_head_log2)); size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)nth, 1, 1}; @@ -7003,6 +7104,9 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const kernel = backend_ctx->kernel_swiglu_f16; } break; + case GGML_GLU_OP_SWIGLU_OAI: + kernel = backend_ctx->kernel_swiglu_oai; + break; case GGML_GLU_OP_GEGLU_ERF: if (dst->type == GGML_TYPE_F32) { kernel = backend_ctx->kernel_geglu_erf; @@ -7038,7 +7142,10 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const const cl_ulong nb1 = dst->nb[1]; - const int swp = ((const int32_t *) dst->op_params)[1]; + const int swp = ggml_get_op_params_i32(dst, 1); + const float alpha = ggml_get_op_params_f32(dst, 2); + const float limit = ggml_get_op_params_f32(dst, 3); + const int ne00_off = src1 ? 0 : (swp ? ne0 : 0); const int ne10_off = src1 ? 0 : (swp ? 0 : ne0); @@ -7055,6 +7162,11 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne00_off)); CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10_off)); + if (ggml_get_glu_op(dst) == GGML_GLU_OP_SWIGLU_OAI) { + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &limit)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &alpha)); + } + const size_t nrows = ggml_nrows(src0); size_t nth = 512; size_t global_work_size[] = {nrows*nth, 1, 1}; @@ -7111,6 +7223,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_add; break; + case GGML_OP_ADD_ID: + if (!any_on_device) { + return false; + } + func = ggml_cl_add_id; + break; case GGML_OP_MUL: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/add_id.cl b/ggml/src/ggml-opencl/kernels/add_id.cl new file mode 100644 index 0000000000..e9c6d55e6a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/add_id.cl @@ -0,0 +1,42 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// add_id +//------------------------------------------------------------------------------ +kernel void kernel_add_id( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * src2, + ulong offset2, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb02, + ulong nb11, + ulong nb21, + int ne0, + int ne1 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + src2 = (global char*)((global char*)src2 + offset2); + dst = (global char*)((global char*)dst + offsetd); + + int i1 = get_group_id(0); + int i2 = get_group_id(1); + + const int i11 = *((global const int *) (src2 + i1*sizeof(int) + i2*nb21)); + + const size_t nb1 = ne0 * sizeof(float); + const size_t nb2 = ne1 * nb1; + + global float * dst_row = (global float *)((global char *)dst + i1*nb1 + i2*nb2); + global float * src0_row = (global float *)((global char *)src0 + i1*nb01 + i2*nb02); + global float * src1_row = (global float *)((global char *)src1 + i11*nb11); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + dst_row[i0] = src0_row[i0] + src1_row[i0]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/glu.cl b/ggml/src/ggml-opencl/kernels/glu.cl index 7cca16e6a9..059a4bbf1b 100644 --- a/ggml/src/ggml-opencl/kernels/glu.cl +++ b/ggml/src/ggml-opencl/kernels/glu.cl @@ -202,6 +202,47 @@ kernel void kernel_swiglu_f16( } } +//------------------------------------------------------------------------------ +// swiglu_oai +//------------------------------------------------------------------------------ +kernel void kernel_swiglu_oai( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off, + float limit, + float alpha +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + float x0 = src0_row[i0]; + float x1 = src1_row[i0]; + + x0 = min(x0, limit); + x1 = max(min(x1, limit), -limit); + + float out_glu = x0 / (1.0f + exp(-x0 * alpha)); + out_glu = out_glu * (1.0f + x1); + + dst_row[i0] = out_glu; + } +} + //------------------------------------------------------------------------------ // geglu_erf //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl index a6d8ede670..571d16507c 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl @@ -26,6 +26,8 @@ kernel void kernel_soft_max_4_f16( ulong offset0, global char * src1, ulong offset1, + global char * src2, + ulong offset2, global char * dst, ulong offsetd, int ne00, @@ -48,6 +50,7 @@ kernel void kernel_soft_max_4_f16( ) { src0 = src0 + offset0; src1 = src1 + offset1; + src2 = src2 + offset2; dst = dst + offsetd; int i03 = get_group_id(2); @@ -60,6 +63,7 @@ kernel void kernel_soft_max_4_f16( global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0; global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; @@ -75,7 +79,7 @@ kernel void kernel_soft_max_4_f16( } // parallel max - float4 lmax4 = -INFINITY; + float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY; for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)); } @@ -92,7 +96,11 @@ kernel void kernel_soft_max_4_f16( } float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; - const float sum = sub_group_reduce_add(lsum); + float sum = sub_group_reduce_add(lsum); + + if (psrc2) { + sum += exp(psrc2[i02] - max); + } for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { pdst4[i00] /= sum; diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl index 35b5573b46..1f944b2201 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl @@ -26,6 +26,8 @@ kernel void kernel_soft_max_4( ulong offset0, global char * src1, ulong offset1, + global char * src2, + ulong offset2, global char * dst, ulong offsetd, int ne00, @@ -48,6 +50,7 @@ kernel void kernel_soft_max_4( ) { src0 = src0 + offset0; src1 = src1 + offset1; + src2 = src2 + offset2; dst = dst + offsetd; int i03 = get_group_id(2); @@ -60,6 +63,7 @@ kernel void kernel_soft_max_4( global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0; global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; @@ -75,7 +79,7 @@ kernel void kernel_soft_max_4( } // parallel max - float4 lmax4 = -INFINITY; + float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY; for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); } @@ -92,7 +96,11 @@ kernel void kernel_soft_max_4( } float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; - const float sum = sub_group_reduce_add(lsum); + float sum = sub_group_reduce_add(lsum); + + if (psrc2) { + sum += exp(psrc2[i02] - max); + } for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { pdst4[i00] /= sum; diff --git a/ggml/src/ggml-opencl/kernels/softmax_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_f16.cl index 9d292b5746..4baa6c28e4 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_f16.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_f16.cl @@ -26,6 +26,8 @@ kernel void kernel_soft_max_f16( ulong offset0, global char * src1, ulong offset1, + global char * src2, + ulong offset2, global char * dst, ulong offsetd, int ne00, @@ -48,6 +50,7 @@ kernel void kernel_soft_max_f16( ) { src0 = src0 + offset0; src1 = src1 + offset1; + src2 = src2 + offset2; dst = dst + offsetd; int i03 = get_group_id(2); @@ -60,6 +63,7 @@ kernel void kernel_soft_max_f16( global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0; global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; @@ -75,7 +79,7 @@ kernel void kernel_soft_max_f16( } // parallel max - float lmax = -INFINITY; + float lmax = psrc2 ? psrc2[i02] : -INFINITY; for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); } @@ -91,7 +95,11 @@ kernel void kernel_soft_max_f16( pdst[i00] = exp_psrc0; } - const float sum = sub_group_reduce_add(lsum); + float sum = sub_group_reduce_add(lsum); + + if (psrc2) { + sum += exp(psrc2[i02] - max); + } for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { pdst[i00] /= sum; diff --git a/ggml/src/ggml-opencl/kernels/softmax_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_f32.cl index 7c53dfbe5a..d503190b47 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_f32.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_f32.cl @@ -26,6 +26,8 @@ kernel void kernel_soft_max( ulong offset0, global char * src1, ulong offset1, + global char * src2, + ulong offset2, global char * dst, ulong offsetd, int ne00, @@ -48,6 +50,7 @@ kernel void kernel_soft_max( ) { src0 = src0 + offset0; src1 = src1 + offset1; + src2 = src2 + offset2; dst = dst + offsetd; int i03 = get_group_id(2); @@ -60,6 +63,7 @@ kernel void kernel_soft_max( global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0; global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; @@ -75,7 +79,7 @@ kernel void kernel_soft_max( } // parallel max - float lmax = -INFINITY; + float lmax = psrc2 ? psrc2[i02] : -INFINITY; for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); } @@ -91,7 +95,11 @@ kernel void kernel_soft_max( pdst[i00] = exp_psrc0; } - const float sum = sub_group_reduce_add(lsum); + float sum = sub_group_reduce_add(lsum); + + if (psrc2) { + sum += exp(psrc2[i02] - max); + } for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { pdst[i00] /= sum; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 9a7d1b22d7..a57d2a16d6 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -21,6 +21,17 @@ #define UNUSED GGML_UNUSED +static inline int best_index_int8(int n, const int8_t * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} + // reference implementation for deterministic creation of model files void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) { static const int qk = QK4_0; @@ -246,6 +257,53 @@ void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_REST } } +static inline int best_index_mxfp4(float x, float e) { + int best_index = 0; + float best_err = fabsf(kvalues_mxfp4[0]*e - x); + for (int i = 1; i < 16; i++) { + float err = fabsf(kvalues_mxfp4[i]*e - x); + if (err < best_err) { + best_index = i; + best_err = err; + } + } + return best_index; +} + +void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) { + static const int qk = QK_MXFP4; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + + if (amax < fabsf(v)) { + amax = fabsf(v); + } + } + + const uint8_t e = (uint8_t) (floorf(log2f(amax)) - 2 + 127); + + const float d = GGML_E8M0_TO_FP32_HALF(e); + + y[i].e = e; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t x0 = best_index_mxfp4(x[i*qk + 0 + j], d); + const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d); + + y[i].qs[j] = x0; + y[i].qs[j] |= x1 << 4; + } + } +} + void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { static const int qk = QK4_0; @@ -356,6 +414,26 @@ void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRI } } +void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + static const int qk = QK_MXFP4; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_E8M0_TO_FP32_HALF(x[i].e); + + for (int j = 0; j < qk/2; ++j) { + const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F]; + const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >> 4]; + + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; + } + } +} + // // 2-6 bit quantization in super-blocks // @@ -2014,6 +2092,12 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, return nrow * row_size; } +size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_UNUSED(quant_weights); + quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row); +} + // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) { @@ -4551,17 +4635,6 @@ size_t quantize_iq1_m(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, // ============================ 4-bit non-linear quants -static inline int best_index_int8(int n, const int8_t * val, float x) { - if (x <= val[0]) return 0; - if (x >= val[n-1]) return n-1; - int ml = 0, mu = n-1; - while (mu-ml > 1) { - int mav = (ml+mu)/2; - if (x < val[mav]) mu = mav; else ml = mav; - } - return x - val[mu-1] < val[mu] - x ? mu-1 : mu; -} - static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x, ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l, float * scales, float * weight, uint8_t * L, @@ -4961,6 +5034,15 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) { return true; } +static bool validate_e_e8m0(uint8_t e, size_t i) { + if (e == 0xff) { + fprintf(stderr, "ggml_validate_row_data: found invalid e value %d at block %zu\n", e, i); + return false; + } + + return true; +} + #define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \ const type * q = (const type *) (data); \ for (size_t i = 0; i < (nb); ++i) { \ @@ -4977,6 +5059,14 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) { } \ } +#define VALIDATE_ROW_DATA_E_E8M0_IMPL(type, data, nb) \ + const type * q = (const type *) (data); \ + for (size_t i = 0; i < (nb); ++i) { \ + if (!validate_e_e8m0(q[i].e, i)) { \ + return false; \ + } \ + } + #define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \ const type * q = (const type *) (data); \ for (size_t i = 0; i < (nb); ++i) { \ @@ -5130,6 +5220,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb); } break; + case GGML_TYPE_MXFP4: + { + VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb); + } break; case GGML_TYPE_Q2_K: { VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin); diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index d09173e111..3b688f31c2 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -21,6 +21,8 @@ GGML_API void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k); + GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); @@ -45,6 +47,8 @@ GGML_API void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GG GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); //GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -90,6 +94,8 @@ GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + GGML_API void iq2xs_init_impl(enum ggml_type type); GGML_API void iq2xs_free_impl(enum ggml_type type); GGML_API void iq3xs_init_impl(int grid_size); diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 29bc421d58..df6ba54076 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -823,10 +823,10 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { }; ggml_backend_t backend = new ggml_backend { - /* .guid = */ ggml_backend_rpc_guid(), - /* .interface = */ ggml_backend_rpc_interface, - /* .device = */ ggml_backend_rpc_add_device(endpoint), - /* .context = */ ctx + /* .guid = */ ggml_backend_rpc_guid(), + /* .iface = */ ggml_backend_rpc_interface, + /* .device = */ ggml_backend_rpc_add_device(endpoint), + /* .context = */ ctx }; return backend; } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 2acdef98a6..3992dad01d 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2609,6 +2609,8 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src1->ne[1] == 1); + GGML_ASSERT(src1->ne[3] == 1); const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; @@ -2688,6 +2690,9 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons const size_t type_size_src0 = ggml_type_size(src0->type); const size_t type_size_src1 = ggml_type_size(src1->type); + bool is_src0_cont_2 = ggml_is_contiguous_2(src0); + bool is_src1_cont_2 = ggml_is_contiguous_2(src1); + // SRC1 strides int64_t s11 = nb11 / type_size_src1; int64_t s12 = nb12 / type_size_src1; @@ -2737,6 +2742,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons s11 = ne10; s12 = ne11 * s11; s13 = ne12 * s12; + + is_src1_cont_2 = true; } ggml_sycl_pool_alloc dst_f16(ctx.pool()); @@ -2852,12 +2859,16 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons else #endif { - if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { + if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) { + // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3: + const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00; + const int64_t smb = ne12 == 1 ? s13 : s12; + // there is no broadcast and src0, src1 are contiguous across dims 2, 3 SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, - src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00, - src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf, + src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma, + src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf, mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type))); } else { const int ne23 = ne12 * ne13; @@ -3187,7 +3198,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor // The kernel from the if path is faster for that specific case, but does not support all mul mats. ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } - } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1 && src1->ne[3] == 1) { // KQV single-batch ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst); } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2] * src1->ne[3] > 1) { @@ -4182,15 +4193,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { - struct ggml_tensor * a; - struct ggml_tensor * b; - if (op->op == GGML_OP_MUL_MAT) { - a = op->src[0]; - b = op->src[1]; - } else { - a = op->src[2]; - b = op->src[1]; - } + struct ggml_tensor * a = op->src[0]; + struct ggml_tensor * b = op->src[1]; + if (a->ne[3] != b->ne[3]) { return false; } @@ -4205,7 +4210,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g } } ggml_type src0_type = op->src[0]->type; - if (src0_type == GGML_TYPE_BF16) { + if (src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_MXFP4) { + // TODO: support MXFP4 + // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added return false; } return true; @@ -4350,6 +4357,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g if (op->src[0]->ne[3] != 1) { return false; } + // TODO: support attention sinks [TAG_ATTN_SINKS] + if (op->src[2]) { + return false; + } // TODO: support broadcast // ref: https://github.com/ggml-org/llama.cpp/pull/14435 return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); @@ -4575,10 +4586,10 @@ ggml_backend_t ggml_backend_sycl_init(int device) { }; ggml_backend_t sycl_backend = new ggml_backend { - /* .guid = */ ggml_backend_sycl_guid(), - /* .interface = */ ggml_backend_sycl_interface, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device), - /* .context = */ ctx + /* .guid = */ ggml_backend_sycl_guid(), + /* .iface = */ ggml_backend_sycl_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device), + /* .context = */ ctx }; return sycl_backend; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 75b58c26fc..4070e248ba 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -222,6 +222,7 @@ enum vk_device_architecture { AMD_RDNA2, AMD_RDNA3, INTEL_XE2, + NVIDIA_PRE_TURING, }; // HSK x HSV @@ -315,10 +316,33 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& // https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html return vk_device_architecture::INTEL_XE2; } + } else if (props.vendorID == VK_VENDOR_ID_NVIDIA) { + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool cooperative_matrix = false; + + // Detect "pre-turing" based on lack of coopmat support. + for (const auto& properties : ext_props) { + if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) { + cooperative_matrix = true; + break; + } + } + + if (!cooperative_matrix) { + return vk_device_architecture::NVIDIA_PRE_TURING; + } } return vk_device_architecture::OTHER; } +enum vk_conv_shapes { + CONV_SHAPE_128x128, + CONV_SHAPE_64x32, + CONV_SHAPE_32x256, + CONV_SHAPE_COUNT, +}; + struct vk_device_struct { std::recursive_mutex mutex; @@ -425,6 +449,8 @@ struct vk_device_struct { vk_pipeline pipeline_div[2][2][2]; vk_pipeline pipeline_div_norepeat[2][2][2]; + vk_pipeline pipeline_add_id_f32; + vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32; vk_pipeline pipeline_scale_f32; @@ -459,6 +485,7 @@ struct vk_device_struct { vk_pipeline pipeline_geglu[2]; vk_pipeline pipeline_reglu[2]; vk_pipeline pipeline_swiglu[2]; + vk_pipeline pipeline_swiglu_oai[2]; vk_pipeline pipeline_geglu_erf[2]; vk_pipeline pipeline_geglu_quick[2]; @@ -483,8 +510,8 @@ struct vk_device_struct { vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; vk_pipeline pipeline_opt_step_adamw_f32; - vk_pipeline pipeline_conv2d_f32; - vk_pipeline pipeline_conv2d_f16_f32; + vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT]; + vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT]; vk_pipeline pipeline_conv2d_dw_whcn_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32; @@ -507,6 +534,7 @@ struct vk_device_struct { ggml_backend_buffer_type buffer_type; bool disable_fusion; + bool disable_host_visible_vidmem; #ifdef GGML_VULKAN_MEMORY_DEBUG std::unique_ptr memory_logger; @@ -681,6 +709,8 @@ struct vk_op_glu_push_constants { uint32_t ne00; uint32_t ne20; uint32_t mode; // 0: default, 1: swapped, 2: split + float alpha; // for swiglu_oai + float limit; }; struct vk_op_unary_push_constants { @@ -770,6 +800,15 @@ struct vk_op_binary_push_constants { float param1; float param2; int32_t param3; }; +struct vk_op_add_id_push_constants { + uint32_t ne0; + uint32_t ne1; + uint32_t s01; + uint32_t s02; + uint32_t s11; + uint32_t s21; +}; + struct vk_op_diag_mask_push_constants { uint32_t ncols; uint32_t rows_per_channel; @@ -811,6 +850,7 @@ struct vk_op_soft_max_push_constants { float m1; uint32_t n_head_log2; uint32_t nrows_x; + uint32_t has_sinks; }; struct vk_op_argsort_push_constants { @@ -908,8 +948,22 @@ struct vk_op_conv2d_push_constants { uint32_t nb1; uint32_t nb2; uint32_t nb3; + + // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH + uint32_t KWmp; uint32_t KWL; + uint32_t KWKHmp; uint32_t KWKHL; + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; }; +template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) { + // Compute magic values to divide by KW, KW*KH, OW, OW*OH + init_fastdiv_values(p.KW, p.KWmp, p.KWL); + init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL); + init_fastdiv_values(p.OW, p.OWmp, p.OWL); + init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); +} + struct vk_op_conv2d_dw_push_constants { uint32_t ne; uint32_t batches; @@ -1751,6 +1805,8 @@ static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { } else if (device->uma) { // Fall back to host memory type buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + } else if (device->disable_host_visible_vidmem) { + buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eDeviceLocal); } else { // use rebar if available, otherwise fallback to device only visible memory buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); @@ -1939,6 +1995,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec break; case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_MXFP4: lut_size = 4*16; break; default: @@ -2068,12 +2125,12 @@ static void ggml_vk_load_shaders(vk_device& device) { s_mmq_wg_denoms = { 32, 64, 1 }; // spec constants and tile sizes for quant matmul (Qi_K) - l_warptile_mmq_k = { 256, 64, 128, 64, 1 }; - m_warptile_mmq_k = { 256, 32, 64, 64, 0 }; - s_warptile_mmq_k = { 256, 32, 32, 128, 0 }; - l_mmq_wg_denoms_k = { 64, 128, 1 }; - m_mmq_wg_denoms_k = { 32, 64, 1 }; - s_mmq_wg_denoms_k = { 32, 32, 1 }; + l_warptile_mmq_k = { 256, 128, 256, 64, 1 }; + m_warptile_mmq_k = { 256, 128, 128, 64, 1 }; + s_warptile_mmq_k = { 256, 32, 64, 128, 0 }; + l_mmq_wg_denoms_k = { 128, 256, 1 }; + m_mmq_wg_denoms_k = { 128, 128, 1 }; + s_mmq_wg_denoms_k = { 32, 64, 1 }; // spec constants and tile sizes for quant matmul_id l_warptile_mmqid = { 256, 128, 128, 16, 0 }; @@ -2229,14 +2286,14 @@ static void ggml_vk_load_shaders(vk_device& device) { }; #define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \ @@ -2315,6 +2372,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) @@ -2341,6 +2399,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) #undef CREATE_MM #undef CREATE_MM2 } else @@ -2402,6 +2461,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -2423,6 +2483,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); @@ -2455,6 +2516,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } else { CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); @@ -2476,6 +2538,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } #undef CREATE_MM2 #undef CREATE_MM @@ -2543,6 +2606,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -2580,6 +2644,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); #undef CREATE_MM2 #undef CREATE_MMQ #undef CREATE_MM @@ -2634,6 +2699,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -2671,6 +2737,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); } // reusing CREATE_MM from the fp32 path if ((device->coopmat2 || device->coopmat_support) @@ -2729,6 +2796,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f32_f32_len, mul_mat_vec_mxfp4_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); @@ -2752,6 +2820,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f16_f32_len, mul_mat_vec_mxfp4_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); @@ -2776,6 +2845,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); // dequant shaders ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); @@ -2798,6 +2868,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); // get_rows ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -2817,6 +2888,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -2835,9 +2907,10 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { @@ -2847,7 +2920,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true); } } - ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 12 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -2938,6 +3011,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_BINARY(div, _norepeat, {1}) #undef CREATE_BINARY + ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); @@ -2988,6 +3063,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_GLU(geglu) CREATE_GLU(reglu) CREATE_GLU(swiglu) + CREATE_GLU(swiglu_oai) CREATE_GLU(geglu_erf) CREATE_GLU(geglu_quick) #undef CREATE_GLU @@ -2997,10 +3073,10 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); - ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); @@ -3048,48 +3124,105 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); // conv2d - uint32_t conv2d_WG_SIZE = 256; - uint32_t conv2d_BS_K = 128; - uint32_t conv2d_BS_CRS = 16; - uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices. - if (device->subgroup_shuffle && - device->vendor_id != VK_VENDOR_ID_INTEL) { // Do not enable collectives on Intel, see PR 14316 - use_collectives = 1; - conv2d_BS_CRS = std::min( - device->subgroup_size, - conv2d_BS_CRS); // CRS block size should be capped at sugroup size for correctness when shuffle is used. - } - uint32_t conv2d_BS_NPQ = 128; - uint32_t conv2d_TS_K = 8; - uint32_t conv2d_shmem_req = - (conv2d_BS_K * (conv2d_BS_CRS + 1) + conv2d_BS_CRS * (conv2d_BS_NPQ + 1)) * sizeof(float); - if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) { - conv2d_BS_CRS = 8; - if (use_collectives) { - conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS); - } - } + for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) { + uint32_t conv2d_WG_SIZE = 256; + uint32_t conv2d_BS_K = 128; + uint32_t conv2d_BS_CRS = 16; + uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices. + uint32_t conv2d_BS_NPQ = 128; + uint32_t conv2d_TS_K = 8; + uint32_t conv2d_SHMEM_PAD = 4; + bool conv2d_UNROLL = true; - if (use_collectives) { - ggml_vk_create_pipeline( - device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, - sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 }, - { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true); - ggml_vk_create_pipeline( - device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3, - sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 }, - { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true); - } else { - ggml_vk_create_pipeline( - device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, - sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 }, - { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, - false); - ggml_vk_create_pipeline( - device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3, - sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 }, - { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, - false); +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + conv2d_SHMEM_PAD = 8; // 8 float16_t + } +#endif + + if (device->vendor_id == VK_VENDOR_ID_INTEL) { + conv2d_SHMEM_PAD = 0; + conv2d_UNROLL = false; + } else if (device->vendor_id == VK_VENDOR_ID_AMD) { + conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4; + } + + switch (s) { + default: + case CONV_SHAPE_128x128: + conv2d_BS_K = 128; + conv2d_BS_NPQ = 128; + conv2d_BS_CRS = 16; + if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) { + conv2d_UNROLL = false; + } + break; + case CONV_SHAPE_64x32: + conv2d_BS_K = 64; + conv2d_BS_NPQ = 32; + conv2d_BS_CRS = 32; + conv2d_TS_K = 4; + break; + case CONV_SHAPE_32x256: + conv2d_BS_K = 32; + conv2d_BS_NPQ = 256; + conv2d_BS_CRS = 16; + break; + } + + // Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math. + bool allow_collectives_nv = device->vendor_id != VK_VENDOR_ID_NVIDIA || + device->architecture == vk_device_architecture::NVIDIA_PRE_TURING; + bool allow_collectives_amd = device->vendor_id != VK_VENDOR_ID_AMD || + device->architecture == vk_device_architecture::AMD_GCN; + + if (device->subgroup_shuffle && + device->vendor_id != VK_VENDOR_ID_INTEL && // Do not enable collectives on Intel, see PR 14316. + allow_collectives_nv && + allow_collectives_amd) { + use_collectives = 1; + conv2d_BS_CRS = std::min( + device->subgroup_size, + conv2d_BS_CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used. + } + + uint32_t conv2d_shmem_req = + (conv2d_BS_K * (conv2d_BS_CRS + conv2d_SHMEM_PAD) + conv2d_BS_CRS * (conv2d_BS_NPQ + conv2d_SHMEM_PAD)) * sizeof(float); + if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) { + conv2d_BS_CRS = 8; + if (use_collectives) { + conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS); + } + } + + std::array wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 }; + std::vector spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }; + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_cm2_len, conv2d_f32_cm2_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_cm2_len, conv2d_f16_f32_cm2_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + } else +#endif + if (conv2d_UNROLL) { + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_unroll_len, conv2d_f16_f32_unroll_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + } else { + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); + } } ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); @@ -3135,6 +3268,9 @@ static vk_device ggml_vk_get_device(size_t idx) { const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY"); device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr; + const char* GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM = getenv("GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM"); + device->disable_host_visible_vidmem = GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM != nullptr; + bool fp16_storage = false; bool fp16_compute = false; bool maintenance4_support = false; @@ -4149,6 +4285,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: break; default: return nullptr; @@ -4219,6 +4356,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: break; default: return nullptr; @@ -4262,6 +4400,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: break; default: return nullptr; @@ -4316,6 +4455,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: break; default: return nullptr; @@ -4351,6 +4491,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: break; default: return nullptr; @@ -4536,6 +4677,7 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))"); GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size()); GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT); + GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size()); vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++]; vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; @@ -4943,26 +5085,37 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz ggml_vk_queue_command_pools_cleanup(dst->device); } -static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) { +static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, const vk_pipeline& pipeline) { VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")"); uint32_t split_k = 1; - if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) { + if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) { // If k is 'large' and the SMs will fill less than halfway, use split_k. uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]); uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]); - if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) { - split_k = ctx->device->shader_core_count / (m_tiles * n_tiles); - // Clamp to 2 or 4 - split_k = std::min(split_k, 4u); - if (split_k == 3) { - split_k = 2; + + if (k >= 2048) { + if (m_tiles * n_tiles <= ctx->device->shader_core_count / 2) { + split_k = ctx->device->shader_core_count / (m_tiles * n_tiles); + } else if (m_tiles * n_tiles <= ctx->device->shader_core_count * 2 / 3) { + split_k = 3; } - if (ctx->device->coopmat2) { - // coopmat2 shader expects splits to be aligned to 256 - while (split_k > 1 && ((k / split_k) % 256) != 0) { - split_k /= 2; + // Cap the split at 8x. Unless k is huge this is a lot of overhead. + split_k = std::min(split_k, 8u); + + // ggml_vk_matmul will align the splits to be a multiple of 256. + // If this rounded up size would cause the last split to be empty, + // then reduce the split count. + while (true) { + if (split_k == 1) { + break; } + uint32_t k_split = CEIL_DIV(k, split_k); + k_split = ROUNDUP_POW2(k_split, 256); + if (k_split * (split_k - 1) < k) { + break; + } + split_k--; } } } @@ -4974,9 +5127,22 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); if (ctx->device->coopmat2) { + const uint32_t shader_core_count = ctx->device->shader_core_count; + const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]); + const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]); + // Use large shader when the N dimension is greater than the medium shader's tile size uint32_t crossover_large = mmp->m->wg_denoms[1]; - if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { + + // Prefer large over medium if either: + // - medium or large tiles would overfill the GPU + // - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not + // (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead) + bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count || + // split_k==3 with large tiles likely better than medium tiles with no split_k. + (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2); + + if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { return aligned ? mmp->a_l : mmp->l; } // Use medium shader when the N dimension is greater than the small shader's tile size @@ -5020,7 +5186,11 @@ static void ggml_vk_matmul( GGML_ASSERT(batch_stride_d == m * n); - const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n }; + // Round the split size up to a multiple of 256 (k-quant alignment) + uint32_t k_split = CEIL_DIV(k, split_k); + k_split = ROUNDUP_POW2(k_split, 256); + + const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n }; // Make sure enough workgroups get assigned for split k to work ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); ggml_vk_sync_buffers(subctx); @@ -5742,7 +5912,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint64_t ne00 = src0->ne[0]; const uint64_t ne01 = src0->ne[1]; const uint64_t ne02 = src0->ne[2]; - // const uint64_t ne03 = src0->ne[3]; + const uint64_t ne03 = src0->ne[3]; const uint64_t nb01 = src0->nb[1]; const uint64_t nb02 = src0->nb[2]; @@ -5754,7 +5924,12 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint64_t ne12 = src1->ne[2]; // const uint64_t ne13 = src1->ne[3]; + const uint32_t nb03 = (uint32_t)(src0->nb[3] / sizeof(ggml_fp16_t)); + const uint32_t nb13 = (uint32_t)(src1->nb[3] / sizeof(float)); + const uint32_t nb23 = (uint32_t)(dst->nb[3] / sizeof(float)); + GGML_ASSERT(ne11 == 1); + GGML_ASSERT(src0->ne[3] == src1->ne[3]); // checked in supports_op ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; @@ -5770,7 +5945,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con src1_uma = d_Qy != nullptr; } - const uint64_t d_ne = ne01 * ne11 * ne12; + const uint64_t d_ne = ne01 * ne11 * ne12 * ne03; const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t); const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t); @@ -5805,10 +5980,10 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; // compute - const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; + const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), nb03, nb13, nb23 }; ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, - { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); + { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 }); } static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -6332,11 +6507,14 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co return supported; } -static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) { +static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3]; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + if (sinks) { + std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3]; + } std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); GGML_TENSOR_LOCALS(int64_t, neq, q, ne) @@ -6535,10 +6713,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr; - size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0; + vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr, d_S = nullptr; + size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0, s_buf_offset = 0; - bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false; + bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false, S_uma = false; if (ctx->device->uma) { ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset); @@ -6553,6 +6731,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset); M_uma = d_M != nullptr; } + if (sinks) { + ggml_vk_host_get(ctx->device, sinks->data, d_S, s_buf_offset); + S_uma = d_S != nullptr; + } } @@ -6588,7 +6770,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } } - uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2; + if (!S_uma) { + d_S = d_Q; + s_buf_offset = q_buf_offset; + if (sinks) { + ggml_backend_vk_buffer_context * s_buf_ctx = (ggml_backend_vk_buffer_context*)sinks->buffer->context; + d_S = s_buf_ctx->dev_buffer; + s_buf_offset = vk_tensor_offset(sinks) + sinks->view_offs; + } + } + + uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2; const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, @@ -6612,6 +6804,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, }, // We only use split_k when group query attention is enabled, which means @@ -6621,10 +6814,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); ggml_vk_sync_buffers(subctx); - const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k }; + const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, { vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, + vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, }, pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); @@ -6635,12 +6829,41 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, }, pc, { workgroups_x, workgroups_y, workgroups_z }); } } +static std::array ggml_vk_get_conv_elements(const ggml_tensor *dst) { + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; + + // src0 - kernel: [KW, KH, Cin, Cout] + // src1 - input: [W, H, Cin, N] + // dst - result: [OW, OH, Cout, N] + + // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) + auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; + }; + // parallelize in {OW/BS_K, OH/BS_NPQ, 1} + int64_t W = src1->ne[0]; + int64_t H = src1->ne[1]; + int64_t KW = src0->ne[0]; + int64_t KH = src0->ne[1]; + int64_t Cout = src0->ne[3]; + int64_t N = src1->ne[3]; + int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]); + int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]); + int64_t NPQ = N * OW * OH; + + // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups + std::array elements = { static_cast(Cout), static_cast(NPQ), 1 }; + return elements; +} + static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { switch (op) { case GGML_OP_GET_ROWS: @@ -6691,6 +6914,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const break; } return nullptr; + case GGML_OP_ADD_ID: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_add_id_f32; + } + return nullptr; case GGML_OP_CONCAT: if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_concat_f32; @@ -6836,6 +7064,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16]; case GGML_GLU_OP_SWIGLU: return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_SWIGLU_OAI: + return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16]; case GGML_GLU_OP_GEGLU_ERF: return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16]; case GGML_GLU_OP_GEGLU_QUICK: @@ -6851,6 +7081,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_SOFT_MAX: GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; @@ -6970,10 +7201,30 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const case GGML_OP_CONV_2D: if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + auto elements = ggml_vk_get_conv_elements(dst); + vk_conv_shapes shape; + + uint32_t tiles[CONV_SHAPE_COUNT]; + for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) { + tiles[i] = CEIL_DIV(elements[0], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[0]) * CEIL_DIV(elements[1], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[1]); + } + + // We can't query number of shader cores on Intel, use 32 as a placeholder + // so small convolutions will still choose a smaller tile. + const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32; + + if (elements[0] > 64 && tiles[CONV_SHAPE_128x128] >= shader_core_count * 2) { + shape = CONV_SHAPE_128x128; + } else if (elements[0] <= 32 && tiles[CONV_SHAPE_32x256] >= shader_core_count * 2) { + shape = CONV_SHAPE_32x256; + } else { + shape = CONV_SHAPE_64x32; + } + if (src0->type == GGML_TYPE_F32) { - return ctx->device->pipeline_conv2d_f32; + return ctx->device->pipeline_conv2d_f32[shape]; } else if (src0->type == GGML_TYPE_F16) { - return ctx->device->pipeline_conv2d_f16_f32; + return ctx->device->pipeline_conv2d_f16_f32[shape]; } } return nullptr; @@ -7001,6 +7252,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_ADD_ID: case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_SQR: @@ -7301,29 +7553,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } break; case GGML_OP_CONV_2D: { - // src0 - kernel: [KW, KH, Cin, Cout] - // src1 - input: [W, H, Cin, N] - // dst - result: [OW, OH, Cout, N] - - // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) - auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { - return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; - }; - // parallelize in {OW/BS_K, OH/BS_NPQ, 1} - int64_t W = src1->ne[0]; - int64_t H = src1->ne[1]; - int64_t KW = src0->ne[0]; - int64_t KH = src0->ne[1]; - int64_t Cout = src0->ne[3]; - int64_t N = src1->ne[3]; - int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]); - int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]); - int64_t NPQ = N * OW * OH; - - // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups - elements = { static_cast(Cout), static_cast(NPQ), 1 }; - } - break; + elements = ggml_vk_get_conv_elements(dst); + } break; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_DIV: @@ -7368,6 +7599,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { ne, 1, 1 }; } } break; + case GGML_OP_ADD_ID: + { + elements = { (uint32_t)ne01, (uint32_t)ne02, 1 }; + } break; case GGML_OP_SET_ROWS: { uint32_t ne = ggml_nelements(src0); @@ -7407,8 +7642,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } - if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) { - // Empty src1 is possible in soft_max, but the shader needs a buffer + if (op == GGML_OP_GLU) { + // Empty src1 is possible in glu, but the shader needs a buffer vk_subbuffer subbuf_y; if (use_src1) { subbuf_y = { d_Y, y_buf_offset, y_sz }; @@ -7418,6 +7653,24 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_SOFT_MAX) { + // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer + vk_subbuffer subbuf_y; + if (use_src1) { + subbuf_y = { d_Y, y_buf_offset, y_sz }; + } else { + subbuf_y = { d_X, 0, x_sz }; + } + + vk_subbuffer subbuf_z; + if (use_src2) { + subbuf_z = { d_Z, z_buf_offset, z_sz }; + } else { + subbuf_z = { d_X, 0, x_sz }; + } + + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { // Empty src2 is possible in rope, but the shader needs a buffer vk_subbuffer subbuf_z; @@ -7546,6 +7799,21 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const }, dryrun); } +static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t src2_type_size = ggml_type_size(src2->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, { + (uint32_t)dst->ne[0], + (uint32_t)dst->ne[1], + (uint32_t)src0->nb[1] / src0_type_size, + (uint32_t)src0->nb[2] / src0_type_size, + (uint32_t)src1->nb[1] / src1_type_size, + (uint32_t)src2->nb[1] / src2_type_size, + }, dryrun); +} + static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) { GGML_ASSERT(version == 6 || version == 7); int num_srcs = version == 6 ? 6 : 7; @@ -7964,8 +8232,12 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con } static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const float * op_params_f = (const float *)dst->op_params; + const bool swapped = (bool)dst->op_params[1]; const bool split = src1 != nullptr; + const float alpha = op_params_f[2]; + const float limit = op_params_f[3]; GGML_ASSERT(ggml_is_contiguous(src0)); @@ -7979,7 +8251,15 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const const uint32_t mode = split ? 2 : (swapped ? 1 : 0); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, + { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], + (uint32_t)dst->ne[0], + mode, + alpha, + limit + }, dryrun); } static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { @@ -7987,7 +8267,7 @@ static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& sub ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); } -static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { +static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; float scale = op_params[0]; @@ -8009,7 +8289,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, { ncols, src1 != nullptr ? nrows_y : (uint32_t)0, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], @@ -8019,6 +8299,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, m0, m1, n_head_log2, nrows_x, + src2 != nullptr }, dryrun); } @@ -9258,6 +9539,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_REGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: break; @@ -9269,6 +9551,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_REPEAT_BACK: case GGML_OP_GET_ROWS: case GGML_OP_ADD: + case GGML_OP_ADD_ID: case GGML_OP_ACC: case GGML_OP_SUB: case GGML_OP_MUL: @@ -9423,6 +9706,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_DIV: ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_ADD_ID: + ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node, dryrun); + break; case GGML_OP_CONCAT: ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun); @@ -9520,6 +9807,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_REGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun); @@ -9533,7 +9821,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_SOFT_MAX: - ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun); + ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun); break; case GGML_OP_SOFT_MAX_BACK: @@ -9606,7 +9894,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_FLASH_ATTN_EXT: - ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun); + ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node->src[4], node, dryrun); break; @@ -9679,6 +9967,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_ADD_ID: case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_SCALE: @@ -9748,6 +10037,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_REGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: buf = tensor->buffer; @@ -10477,10 +10767,10 @@ ggml_backend_t ggml_backend_vk_init(size_t dev_num) { ggml_vk_init(ctx, dev_num); ggml_backend_t vk_backend = new ggml_backend { - /* .guid = */ ggml_backend_vk_guid(), - /* .interface = */ ggml_backend_vk_interface, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num), - /* .context = */ ctx, + /* .guid = */ ggml_backend_vk_guid(), + /* .iface = */ ggml_backend_vk_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num), + /* .context = */ ctx, }; return vk_backend; @@ -10597,6 +10887,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_GLU_OP_GEGLU: case GGML_GLU_OP_REGLU: case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: return ggml_is_contiguous(op->src[0]) && @@ -10642,6 +10933,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: break; default: return false; @@ -10679,6 +10971,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) { return false; } + if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) { + return false; + } if (op->src[0]->type != GGML_TYPE_F32) { return false; } @@ -10751,6 +11046,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: return true; default: return false; @@ -10849,6 +11145,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); + case GGML_OP_ADD_ID: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 && + op->type == GGML_TYPE_F32; case GGML_OP_SILU_BACK: case GGML_OP_RMS_NORM_BACK: case GGML_OP_SQR: @@ -11267,6 +11566,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]); + if (src_clone[4]) { + ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]); + } } else if (tensor->op == GGML_OP_MUL_MAT) { tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_MUL_MAT_ID) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp new file mode 100644 index 0000000000..3ae8f0116c --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp @@ -0,0 +1,42 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#include "types.comp" + +layout (push_constant) uniform parameter +{ + uint ne0; + uint ne1; + uint s01; + uint s02; + uint s11; + uint s21; +} p; + +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) readonly buffer Z {int32_t data_c[];}; +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i1 = gl_WorkGroupID.x; + const uint i2 = gl_WorkGroupID.y; + + const uint i11 = data_c[i1 + i2 * p.s21]; + + const uint s1 = p.ne0; + const uint s2 = p.ne0 * p.ne1; + + const uint d0 = i1 * s1 + i2 * s2; + const uint a0 = i1 * p.s01 + i2 * p.s02; + const uint b0 = i11 * p.s11; + + for (uint i0 = gl_LocalInvocationID.x; i0 < p.ne0; i0 += BLOCK_SIZE) { + data_d[d0 + i0] = data_a[a0 + i0] + data_b[b0 + i0]; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index 481940a52b..86bafba4a4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -1,14 +1,18 @@ #version 450 +#extension GL_EXT_control_flow_attributes : enable +#ifdef COOPMAT2 +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_KHR_memory_scope_semantics : enable +#endif + #ifdef USE_COLLECTIVES # extension GL_KHR_shader_subgroup_shuffle : enable #endif #include "types.comp" -// Make spec constant -#define SHMEM_PAD 0 - // shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j layout(binding = 0) readonly buffer A { A_TYPE knl_data[]; @@ -56,6 +60,12 @@ layout(push_constant) uniform parameter { uint32_t nb1; uint32_t nb2; uint32_t nb3; + + // fastdiv helper values + uint32_t KWmp; uint32_t KWL; + uint32_t KWKHmp; uint32_t KWKHL; + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; } p; @@ -68,6 +78,7 @@ layout(constant_id = 3) const uint BS_NPQ = 128; // Thread-tile sizes layout(constant_id = 4) const uint TS_K = 8; layout(constant_id = 5) const uint use_collectives = 1; +layout(constant_id = 6) const uint SHMEM_PAD = 4; uint32_t tid = gl_LocalInvocationID.x; const uint32_t WG_SIZE = gl_WorkGroupSize.x; @@ -85,6 +96,12 @@ uint32_t n_elems_out = K * NPQ; // Number of blocktiles per input uint32_t NB_CRS = splitWork(CRS, BS_CRS); +#ifdef COOPMAT2 +#define SHMEM_TYPE float16_t +#else +#define SHMEM_TYPE float +#endif + const uint32_t Ash_stride = BS_CRS + SHMEM_PAD; const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD; @@ -94,8 +111,8 @@ const uint32_t Bsh_numel = BS_CRS * BS_NPQ; const uint32_t Ash_len = BS_K * Ash_stride; const uint32_t Bsh_len = BS_CRS * Bsh_stride; -shared float Ash[Ash_len]; // K x CRS -shared float Bsh[Bsh_len]; // CRS x NPQ +shared SHMEM_TYPE Ash[Ash_len]; // K x CRS +shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ // Threadtile sizes const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; @@ -104,10 +121,6 @@ const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; const uint32_t NT_K = BS_K / TS_K; const uint32_t NT_NPQ = BS_NPQ / TS_NPQ; -float regA[TS_K]; -float regB[TS_NPQ]; -float regC[TS_K][TS_NPQ]; - /* Compute KxCRS @ CRSxNPQ = K x NPQ @@ -131,12 +144,44 @@ uint32_t Br = tid / BS_NPQ; uint32_t Bc = tid % BS_NPQ; const uint32_t BrpWg = WG_SIZE / BS_NPQ; +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + +#ifdef COOPMAT2 +#define ACC_TYPE float16_t + +ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem) +{ + uint32_t K_idx = B_idx_K * BS_K + r; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c; + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; + uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; + uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; + if (K_idx < K && NPQ_idx < NPQ) { + dst_data[dst_idx] = D_TYPE(elem); + } + return elem; +} +#endif + void main() { +#ifdef COOPMAT2 + coopmat matC; + matC = coopmat(0.0); +#else + float regC[TS_K][TS_NPQ]; for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { regC[T_ly][T_lx] = 0.0; } } +#endif /* Advance block in CRS dim */ for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) { uint32_t CRS_idx_a; @@ -151,9 +196,9 @@ void main() { uint32_t cached_KW_idx; if (use_collectives == 1) { cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID; - cached_Cin_idx = cached_CRS_idx / (p.KW * p.KH); + cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH); - cached_KH_idx = cached_CRS_remainder / p.KW; + cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW; CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac); @@ -162,16 +207,16 @@ void main() { KW_idx_a = subgroupShuffle(cached_KW_idx, Ac); } else { CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) - Cin_idx_a = CRS_idx_a / (p.KW * p.KH); + Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; - KH_idx_a = CRS_remainder / p.KW; + KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; KW_idx_a = CRS_remainder - KH_idx_a * p.KW; } #else CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) - Cin_idx_a = CRS_idx_a / (p.KW * p.KH); + Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH); CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; - KH_idx_a = CRS_remainder / p.KW; + KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; KW_idx_a = CRS_remainder - KH_idx_a * p.KW; #endif @@ -185,16 +230,16 @@ void main() { if (K_idx >= K || CRS_idx_a >= CRS) { val = 0.0; } - Ash[B_ly * Ash_stride + B_lx] = val; + Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); } /* Load input to B_block: (BS_CRS x BS_NPQ) */ - for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) { + UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) { uint32_t B_ly = r_offset + Br; /* Row index of B block */ uint32_t B_lx = Bc; uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */ - uint32_t N_idx = NPQ_idx / (p.OH * p.OW); + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW; - uint32_t OH_idx = NPQ_remainder / p.OW; + uint32_t OH_idx = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW; uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW; uint32_t CRS_idx_b; @@ -209,16 +254,16 @@ void main() { KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br); } else { CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ - Cin_idx_b = CRS_idx_b / (p.KW * p.KH); + Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; - KH_idx_b = CRS_remainder / p.KW; + KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; KW_idx_b = CRS_remainder - KH_idx_b * p.KW; } #else CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ - Cin_idx_b = CRS_idx_b / (p.KW * p.KH); + Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; - KH_idx_b = CRS_remainder / p.KW; + KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; KW_idx_b = CRS_remainder - KH_idx_b * p.KW; #endif @@ -230,36 +275,55 @@ void main() { if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) { val = 0.0; } - Bsh[B_ly * Bsh_stride + B_lx] = val; + Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); } barrier(); - for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { - for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { - regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx]; - } - for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { - regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx]; - } - for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { +#ifdef COOPMAT2 + coopmat matA; + coopmat matB; + + coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); + matC = coopMatMulAdd(matA, matB, matC); +#else + if (T_y * TS_K < K) { + UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { + float regA[TS_K]; + float regB[TS_NPQ]; + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx]; + } for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { - regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]); + regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx]; + } + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]); + } } } } +#endif barrier(); } /* Save C* */ - for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { - for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { - uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly; - uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx; - uint32_t N_idx = NPQ_idx / (p.OH * p.OW); - uint32_t OH_idx = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW; - uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; - uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; - if (K_idx < K && NPQ_idx < NPQ) { - dst_data[dst_idx] = regC[T_ly][T_lx]; +#ifdef COOPMAT2 + coopMatPerElementNV(matC, matC, perElemOpStore); +#else + if (T_y * TS_K < K) { + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx; + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; + uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; + uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; + if (K_idx < K && NPQ_idx < NPQ) { + dst_data[dst_idx] = regC[T_ly][T_lx]; + } } } } +#endif } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp index dbc7daa332..978d430030 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -4,8 +4,8 @@ #include "generic_unary_head.comp" #include "dequant_funcs.comp" -#if defined(DATA_A_IQ4_NL) -// 16 invocations needed for init_iq4nl_shmem +#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) +// 16 invocations needed for init_iq_shmem layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; #else layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index 0d9739d406..d3127fbd98 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -434,6 +434,18 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_MXFP4) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + vec2 v0 = dequantize(ib, iqs, a_offset); + vec2 v1 = dequantize(ib, iqs + 1, a_offset); + return vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + #if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) vec2 get_dm(uint ib, uint a_offset) { return vec2(0, 0); @@ -455,6 +467,12 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif +#if defined(DATA_A_MXFP4) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0); +} +#endif + #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) vec2 get_dm(uint ib, uint a_offset) { return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m)); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 9cb7da2daa..706540fd85 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -654,6 +654,25 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor } #endif +#if defined(DATA_A_MXFP4) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 { + block_mxfp4 block; +}; + +float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float d = e8m0_to_fp32(bl.block.e); + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + float16_t ret = float16_t(kvalues_mxfp4[qs] * d); + return ret; +} +#endif + #if defined(DATA_A_Q4_0) #define dequantFuncA dequantFuncQ4_0 #elif defined(DATA_A_Q4_1) @@ -696,4 +715,6 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor #define dequantFuncA dequantFuncIQ4_XS #elif defined(DATA_A_IQ4_NL) #define dequantFuncA dequantFuncIQ4_NL +#elif defined(DATA_A_MXFP4) +#define dequantFuncA dequantFuncMXFP4 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp new file mode 100644 index 0000000000..ee496e9d56 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_mxfp4 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + init_iq_shmem(gl_WorkGroupSize); + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint q_idx = 8*il; + const uint b_idx = 1024*i + 32*ir + q_idx; + + const float d = e8m0_to_fp32(data_a[ib].e); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]); + data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 45c6e7736a..d40848e15f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -305,6 +305,27 @@ void main() { return; } + if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); + + float ms = 1.0f; + float vs = 1.0f; + + if (sink > Mf[r]) { + ms = exp(Mf[r] - sink); + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + Of[r][d] *= ms; + } + } else { + vs = exp(sink - Mf[r]); + } + + Lf[r] = Lf[r]*ms + vs; + } + } + float Lfrcp[Br]; [[unroll]] for (uint32_t r = 0; r < Br; ++r) { Lfrcp[r] = 1.0 / Lf[r]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp index 7defe72b40..b57c9dcfc4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -50,10 +50,13 @@ layout (push_constant) uniform parameter { uint32_t k_num; } p; +#define SINK_ENABLE_BIT (1<<24) #define MASK_ENABLE_BIT (1<<16) #define N_LOG2_MASK 0xFFFF -layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; +layout (binding = 4) readonly buffer S {float data_s[];}; + +layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; #if defined(A_TYPE_PACKED16) #define BINDING_IDX_K 0 @@ -111,6 +114,14 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i return ACC_TYPE(pow(base, ACC_TYPE(exph))); } +// Load the sink value, indexed by Q's dimension 2. +ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + return ACC_TYPE(data_s[h]); +} + uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, q_stride, k_stride, v_stride, m_stride; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 486735fe8b..230e815f22 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -329,6 +329,27 @@ void main() { return; } + if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); + + float ms = 1.0f; + float vs = 1.0f; + + if (sink > Mf[r]) { + ms = exp(Mf[r] - sink); + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + Of[r][d] *= ACC_TYPE(ms); + } + } else { + vs = exp(sink - Mf[r]); + } + + Lf[r] = Lf[r]*ms + vs; + } + } + float Lfrcp[rows_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Lfrcp[r] = 1.0 / Lf[r]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 274f48fcab..b0564ca0bf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -248,6 +248,34 @@ void main() { // resize L by using smear/reduce coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); + if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { + coopmat S; + coopMatPerElementNV(S, S, perElemOpGetSink, iq2); + + coopmat Mr; + + // resize M by using smear/reduce + coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce); + + // O, Ldiag, Mr all have the same type so all element locations match + [[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) { + ACC_TYPE sink = S[i]; + + ACC_TYPE ms = ACC_TYPE(1.0f); + ACC_TYPE vs = ACC_TYPE(1.0f); + + if (sink > Mr[i]) { + ms = exp(Mr[i] - sink); + + O[i] *= ms; + } else { + vs = exp(sink - Mr[i]); + } + + Ldiag[i] = Ldiag[i]*ms + vs; + } + } + [[unroll]] for (int k = 0; k < Ldiag.length(); ++k) { Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index 0a17a9df23..76ef4b6dfb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -7,13 +7,15 @@ layout(constant_id = 0) const uint BLOCK_SIZE = 32; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {float data_a[];}; -layout (binding = 1) writeonly buffer D {float data_d[];}; +layout (binding = 1) readonly buffer B {float data_s[];}; +layout (binding = 2) writeonly buffer D {float data_d[];}; layout (push_constant) uniform parameter { uint D; uint N; uint ne3; uint k_num; + uint sinks; } p; shared float tmpsh[BLOCK_SIZE]; @@ -73,6 +75,22 @@ void main() { } L = tmpsh[0]; + float sink; + if (p.sinks != 0) { + sink = data_s[n]; + + float ms = 1.0f; + float vs = 1.0f; + + if (sink > m_max) { + ms = exp(m_max - sink); + } else { + vs = exp(sink - m_max); + } + + L = L*ms + vs; + } + L = 1.0 / L; // D dimension is split across workgroups in the y dimension @@ -85,6 +103,13 @@ void main() { float m = data_a[m_offset + k * lm_stride]; O += exp(m - m_max) * data_a[o_offset]; } + if (p.sinks != 0) { + if (sink > m_max) { + float ms = 1.0f; + ms = exp(m_max - sink); + O *= ms; + } + } O *= L; data_d[iq3 * D * N + D * n + d] = O; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp index 004a61fc16..51d70869d9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp @@ -14,4 +14,6 @@ layout (push_constant) uniform parameter uint ne00; uint ne20; uint mode; + float alpha; + float limit; } p; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp index bc633369f9..638878d94c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp @@ -26,6 +26,9 @@ layout (push_constant) uniform parameter uint ne12; uint b_offset; uint d_offset; + uint nb03; + uint nb13; + uint nb23; } p; shared FLOAT_TYPE tmp[BLOCK_SIZE]; @@ -34,6 +37,7 @@ void main() { const uint tid = gl_LocalInvocationID.x; const uint row_x = gl_GlobalInvocationID.y; const uint channel = gl_GlobalInvocationID.z; + const uint i3 = gl_WorkGroupID.x; const uint channel_x = channel / p.channel_x_divisor; const uint channel_y = channel % p.ne12; @@ -41,7 +45,7 @@ void main() { const uint nrows_dst = p.nrows_x; const uint row_dst = row_x; - const uint idst = channel*nrows_dst + row_dst; + const uint idst = i3*p.nb23 + channel*nrows_dst + row_dst; FLOAT_TYPE temp = 0.0f; @@ -58,8 +62,8 @@ void main() { const uint row_y = col_x; - const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; - const uint iy = channel_y*p.channel_stride_y + row_y; + const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y; const vec4 av4 = vec4(data_a_v4[ix / 4]); const vec4 bv4 = vec4(data_b_v4[iy / 4]); @@ -74,8 +78,8 @@ void main() { const uint row_y = col_x; - const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; - const uint iy = channel_y*p.channel_stride_y + row_y; + const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y; const vec4 av4 = vec4(data_a_v4[ix / 4]); const vec4 bv4 = vec4(data_b_v4[iy / 4]); @@ -91,8 +95,8 @@ void main() { const uint row_y = col_x; - const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; - const uint iy = channel_y*p.channel_stride_y + row_y; + const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y; const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index f481549911..8c5114a79d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -747,6 +747,21 @@ void main() { buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d; buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d; buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d; +#elif defined(DATA_A_MXFP4) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; + + const uint ib = idx / 8; + const uint iqs = (idx & 0x07) * 2; + + const float d = e8m0_to_fp32(data_a[ib].e); + const uint vui = uint(data_a[ib].qs[iqs]); + const uint vui2 = uint(data_a[ib].qs[iqs+1]); + + buf_a[buf_idx ] = FLOAT_TYPE(kvalues_mxfp4[vui & 0xF] * d); + buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_mxfp4[vui >> 4] * d); + buf_a[buf_idx + 1] = FLOAT_TYPE(kvalues_mxfp4[vui2 & 0xF] * d); + buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_mxfp4[vui2 >> 4] * d); #endif } [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp index 63b15471bd..34e8db9770 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp @@ -92,6 +92,12 @@ FLOAT_TYPE get_d(uint ib) { } #endif +#if defined(DATA_A_MXFP4) +FLOAT_TYPE get_d(uint ib) { + return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e)); +} +#endif + #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) FLOAT_TYPE_VEC2 get_dm(uint ib) { return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp index 5bcd3b1e3d..5f20a1ee7d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp @@ -20,6 +20,7 @@ layout (push_constant) uniform parameter float m1; uint n_head_log2; uint nrows_x; + uint has_sinks; } p; #include "types.comp" @@ -29,7 +30,8 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; -layout (binding = 2) buffer D {D_TYPE data_d[];}; +layout (binding = 2) readonly buffer Z {float data_c[];}; +layout (binding = 3) buffer D {D_TYPE data_d[];}; shared FLOAT_TYPE vals[BLOCK_SIZE]; @@ -60,13 +62,13 @@ void soft_max(uint num_iters) { const uint h = (rowx / p.ne01) % p.ne02; // head index const float base = h < p.n_head_log2 ? p.m0 : p.m1; - const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; + const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; slope = pow(base, exp); } // Find max - FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000); + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; // Cache values while we compute the max, so we don't need to read them // again when we're ready to compute exp(x-max). @@ -148,6 +150,10 @@ void soft_max(uint num_iters) { } sum = vals[0]; + if (p.has_sinks != 0) { + sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val)); + } + FLOAT_TYPE rcpdivisor = 1.0/sum; [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp new file mode 100644 index 0000000000..970750eec0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp @@ -0,0 +1,14 @@ +#version 450 + +#include "glu_head.comp" + +float op(float a, float b) { + float xi = min(a, p.limit); + float gi = max(min(b, p.limit), -p.limit); + + float out_glu = xi / (1.0f + exp(-xi * p.alpha)); + out_glu = out_glu * (1.0f + gi); + return out_glu; +} + +#include "glu_main.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index 3bde717832..a36c33e267 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -1337,6 +1337,29 @@ struct block_iq4_nl_packed16 #define A_TYPE_PACKED16 block_iq4_nl_packed16 #endif +#define QUANT_K_MXFP4 32 +#define QUANT_R_MXFP4 2 + +struct block_mxfp4 +{ + uint8_t e; + uint8_t qs[QUANT_K_MXFP4/2]; +}; + +//struct block_mxfp4_packed16 +//{ +// uint8_t e; +// uint16_t qs[QUANT_K_MXFP4/2/2]; +//}; + +#if defined(DATA_A_MXFP4) +#define QUANT_K QUANT_K_MXFP4 +#define QUANT_R QUANT_R_MXFP4 +#define QUANT_AUXF 1 +#define A_TYPE block_mxfp4 +//#define A_TYPE_PACKED16 block_mxfp4_packed16 +#endif + #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) const int8_t kvalues_iq4nl_const[16] = { int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), @@ -1356,6 +1379,25 @@ void init_iq_shmem(uvec3 wgsize) } #endif +#if defined(DATA_A_MXFP4) +const FLOAT_TYPE kvalues_mxfp4_const[16] = { + FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f), + FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f) +}; + +shared FLOAT_TYPE kvalues_mxfp4[16]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) { + kvalues_mxfp4[i] = kvalues_mxfp4_const[i]; + } + barrier(); +} +#endif + // returns the bfloat value in the low 16b. // See ggml_compute_fp32_to_bf16 uint32_t fp32_to_bf16(float f) @@ -1370,4 +1412,17 @@ float bf16_to_fp32(uint32_t u) return uintBitsToFloat(u << 16); } +float e8m0_to_fp32(uint8_t x) { + uint32_t bits; + + if (x == 0) { + bits = 0x00400000; + } else { + bits = x; + bits = bits << 23; + } + + return uintBitsToFloat(bits); +} + #endif // !defined(GGML_TYPES_COMP) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index f9f0c95b8b..4cd94c51e3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -64,6 +64,7 @@ const std::vector type_names = { "iq3_s", "iq4_xs", "iq4_nl", + "mxfp4", "bf16", }; @@ -118,7 +119,7 @@ void execute_command(const std::string& command, std::string& stdout_str, std::s CloseHandle(pi.hProcess); CloseHandle(pi.hThread); #else -int stdout_pipe[2]; + int stdout_pipe[2]; int stderr_pipe[2]; if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) { @@ -362,7 +363,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool std::string load_vec_quant = "2"; if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl")) + else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) load_vec_quant = "4"; if (tname == "bf16") { @@ -602,6 +603,8 @@ void process_shaders() { string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); @@ -655,14 +658,24 @@ void process_shaders() { string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); - string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}}); - string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}}); + string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}}); + string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}}); + + string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}}); + string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}}); + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true); + string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true); +#endif string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + for (auto &c : compiles) { c.wait(); } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index c5abc69343..ba1addc8d9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1,34 +1,44 @@ +/* + WebGPU backend implementation. + Note: Use ClangFormat to format this file. +*/ + #include "ggml-webgpu.h" +#include "ggml-backend-impl.h" +#include "ggml-impl.h" +#include "ggml-wgsl-shaders.hpp" + #include -#include "ggml-impl.h" -#include "ggml-backend-impl.h" - -#include "ggml-wgsl-shaders.hpp" - +#include #include #include #include +#include #include #ifdef GGML_WEBGPU_DEBUG -#define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl +# define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl +# define WEBGPU_DEBUG_BUF_ELEMS 32 #else -#define WEBGPU_LOG_DEBUG(msg) ((void) 0) -#endif // GGML_WEBGPU_DEBUG +# define WEBGPU_LOG_DEBUG(msg) ((void) 0) +#endif // GGML_WEBGPU_DEBUG /* Constants */ -#define WEBGPU_MUL_MAT_WG_SIZE 64 -#define WEBGPU_MUL_MAT_PARAMS_SIZE (13 * sizeof(uint32_t)) // M, N, K, batch sizes, broadcasts -#define WEBGPU_CPY_PARAMS_SIZE (15 * sizeof(uint32_t)) // strides and offsets -#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 +#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16 +#define WEBGPU_MUL_MAT_WG_SIZE 64 +#define WEBGPU_NUM_PARAM_BUFS 100 +#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters +#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32 +#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 +#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 /* End Constants */ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. -static void * const webgpu_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT +static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT // Always returns the base offset of a tensor, regardless of views. static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) { @@ -40,100 +50,177 @@ static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) { /* Struct definitions */ +// Forward reference +static void ggml_webgpu_create_buffer(wgpu::Device & device, + wgpu::Buffer & buffer, + size_t size, + wgpu::BufferUsage usage, + const char * label); + +struct webgpu_pool_bufs { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; +}; + +// Holds a pool of parameter buffers for WebGPU operations +struct webgpu_buf_pool { + std::vector free; + + std::mutex mutex; + + std::condition_variable cv; + + void init(wgpu::Device device, + int num_bufs, + size_t buf_size, + wgpu::BufferUsage dev_buf_usage, + wgpu::BufferUsage host_buf_usage) { + for (int i = 0; i < num_bufs; i++) { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; + ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); + ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); + free.push_back({ host_buf, dev_buf }); + } + } + + webgpu_pool_bufs alloc_bufs() { + std::unique_lock lock(mutex); + cv.wait(lock, [this] { return !free.empty(); }); + webgpu_pool_bufs bufs = free.back(); + free.pop_back(); + return bufs; + } + + void free_bufs(std::vector bufs) { + std::lock_guard lock(mutex); + free.insert(free.end(), bufs.begin(), bufs.end()); + cv.notify_all(); + } + + void cleanup() { + std::lock_guard lock(mutex); + for (auto & bufs : free) { + bufs.host_buf.Destroy(); + bufs.dev_buf.Destroy(); + } + free.clear(); + } +}; + // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { wgpu::Instance instance; - wgpu::Adapter adapter; - wgpu::Device device; - wgpu::Queue queue; - wgpu::Limits limits; - wgpu::SupportedFeatures features; + wgpu::Adapter adapter; + wgpu::Device device; + wgpu::Queue queue; + wgpu::Limits limits; - std::mutex mutex; - bool device_initialized = false; + std::recursive_mutex mutex; + + bool device_init = false; + + webgpu_buf_pool param_buf_pool; + webgpu_buf_pool set_rows_error_buf_pool; - // pipelines and parameter buffers - // TODO: reuse params buffers for different pipelines when possible wgpu::ComputePipeline memset_pipeline; - wgpu::Buffer memset_params_dev_buf; - wgpu::Buffer memset_params_host_buf; wgpu::ComputePipeline mul_mat_pipeline; - wgpu::Buffer mul_mat_params_dev_buf; - wgpu::Buffer mul_mat_params_host_buf; + wgpu::ComputePipeline set_rows_pipeline; wgpu::ComputePipeline cpy_pipeline; - wgpu::Buffer cpy_params_dev_buf; - wgpu::Buffer cpy_params_host_buf; size_t memset_bytes_per_thread; // Staging buffer for reading data from the GPU wgpu::Buffer get_tensor_staging_buf; + + // Command buffers which need to be submitted + std::vector staged_command_bufs; + + // Parameter buffers associated with the staged command buffers + std::vector staged_param_bufs; + // Buffers associated with set_rows operations, used to store potential errors + std::vector staged_set_row_error_bufs; + + std::vector callback_futures; + +#ifdef GGML_WEBGPU_DEBUG + wgpu::Buffer debug_host_buf; + wgpu::Buffer debug_dev_buf; +#endif }; typedef std::shared_ptr webgpu_context; struct ggml_backend_webgpu_reg_context { webgpu_context webgpu_ctx; - - size_t device_count; - const char * name; + size_t device_count; + const char * name; }; struct ggml_backend_webgpu_device_context { webgpu_context webgpu_ctx; - - std::string device_name; - std::string device_desc; + std::string device_name; + std::string device_desc; }; struct ggml_backend_webgpu_context { webgpu_context webgpu_ctx; - - std::string name; + std::string name; }; struct ggml_backend_webgpu_buffer_context { webgpu_context webgpu_ctx; - - wgpu::Buffer buffer; + wgpu::Buffer buffer; ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) : - webgpu_ctx(ctx), buffer(buf) { - } + webgpu_ctx(std::move(ctx)), + buffer(std::move(buf)) {} }; /* End struct definitions */ /* WebGPU object initializations */ -static void ggml_webgpu_create_pipeline(wgpu::Device &device, wgpu::ComputePipeline &pipeline, const char * shader_code, const char * label, const std::vector &constants = {}) { +static void ggml_webgpu_create_pipeline(wgpu::Device & device, + wgpu::ComputePipeline & pipeline, + const char * shader_code, + const char * label, + const std::vector & constants = {}) { WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()"); + wgpu::ShaderSourceWGSL shader_source; shader_source.code = shader_code; + wgpu::ShaderModuleDescriptor shader_desc; shader_desc.nextInChain = &shader_source; + wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc); wgpu::ComputePipelineDescriptor pipeline_desc; - pipeline_desc.label = label; - pipeline_desc.compute.module = shader_module; - pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code - pipeline_desc.layout = nullptr; // nullptr means auto layout + pipeline_desc.label = label; + pipeline_desc.compute.module = shader_module; + pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code + pipeline_desc.layout = nullptr; // nullptr means auto layout if (constants.size() > 0) { - pipeline_desc.compute.constants = constants.data(); + pipeline_desc.compute.constants = constants.data(); pipeline_desc.compute.constantCount = constants.size(); } pipeline = device.CreateComputePipeline(&pipeline_desc); } -static void ggml_webgpu_create_buffer(wgpu::Device &device, wgpu::Buffer &buffer, size_t size, wgpu::BufferUsage usage, const char* label) { +static void ggml_webgpu_create_buffer(wgpu::Device & device, + wgpu::Buffer & buffer, + size_t size, + wgpu::BufferUsage usage, + const char * label) { WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()"); wgpu::BufferDescriptor buffer_desc; - buffer_desc.size = size; - buffer_desc.usage = usage; - buffer_desc.label = label; + buffer_desc.size = size; + buffer_desc.usage = usage; + buffer_desc.label = label; buffer_desc.mappedAtCreation = false; + // TODO: error handling buffer = device.CreateBuffer(&buffer_desc); } @@ -142,75 +229,205 @@ static void ggml_webgpu_create_buffer(wgpu::Device &device, wgpu::Buffer &buffer /** WebGPU Actions */ -static void ggml_backend_webgpu_map_buffer(webgpu_context ctx, wgpu::Buffer buffer, wgpu::MapMode mode, size_t offset, size_t size) { - ctx->instance.WaitAny(buffer.MapAsync( - mode, offset, size, wgpu::CallbackMode::WaitAnyOnly, - [](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", message.data); - } - }), - UINT64_MAX - ); +// Wait for the queue to finish processing all submitted work +static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) { + std::lock_guard lock(ctx->mutex); + if (ctx->callback_futures.empty()) { + // no existing callbacks, wait on queue submission + ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, + [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + if (status != wgpu::QueueWorkDoneStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data); + } + }), + UINT64_MAX); + } else { + // existing callbacks, wait on them + ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX); + ctx->callback_futures.clear(); + } } -static void ggml_backend_webgpu_buffer_memset(webgpu_context ctx, wgpu::Buffer buf, uint32_t value, size_t offset, size_t size) { - std::lock_guard lock(ctx->mutex); - wgpu::Device device = ctx->device; +static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) { + std::lock_guard lock(ctx->mutex); + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_submit_queue()"); + if (ctx->staged_command_bufs.empty()) { + // Nothing to submit + return; + } + ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data()); - // map the host parameters buffer - ggml_backend_webgpu_map_buffer(ctx, ctx->memset_params_host_buf, wgpu::MapMode::Write, 0, ctx->memset_params_host_buf.GetSize()); - uint32_t * params = (uint32_t *) ctx->memset_params_host_buf.GetMappedRange(); + // If there are SET_ROWS operations in this submission, copy their error buffers to the host. + if (ctx->staged_set_row_error_bufs.size() > 0) { + wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); + for (auto & error_bufs : ctx->staged_set_row_error_bufs) { + // Copy the error buffer to the host buffer + encoder.CopyBufferToBuffer(error_bufs.dev_buf, 0, error_bufs.host_buf, 0, error_bufs.host_buf.GetSize()); + } + wgpu::CommandBuffer commands = encoder.Finish(); + ctx->queue.Submit(1, &commands); + } - params[0] = (uint32_t)offset; - params[1] = (uint32_t)size; - params[2] = value; - ctx->memset_params_host_buf.Unmap(); + ctx->staged_command_bufs.clear(); + std::vector staged_param_bufs = std::move(ctx->staged_param_bufs); + std::vector staged_set_row_error_bufs = std::move(ctx->staged_set_row_error_bufs); - wgpu::BindGroupEntry entries[2]; - entries[0].binding = 0; // binding for the buffer to memset - entries[0].buffer = buf; - entries[0].offset = 0; - entries[0].size = buf.GetSize(); - entries[1].binding = 1; // binding for the parameters - entries[1].buffer = ctx->memset_params_dev_buf; - entries[1].offset = 0; - entries[1].size = ctx->memset_params_dev_buf.GetSize(); + // Free the staged parameter buffers once the submission completes + wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, + [ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + if (status != wgpu::QueueWorkDoneStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data); + } + // Free the staged buffers + ctx->param_buf_pool.free_bufs(staged_param_bufs); + }); + ctx->callback_futures.push_back({ p_f }); + + // Check for errrors in SET_ROWS operations + for (auto & error_bufs : staged_set_row_error_bufs) { + wgpu::Future f = error_bufs.host_buf.MapAsync( + wgpu::MapMode::Read, + 0, + error_bufs.host_buf.GetSize(), + wgpu::CallbackMode::AllowSpontaneous, + [ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { + if (status != wgpu::MapAsyncStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", message.data); + } else { + const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf.GetConstMappedRange(); + if (*error_data) { + GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); + } + // We can't unmap in here due to WebGPU reentrancy limitations. + ctx->set_rows_error_buf_pool.free_bufs({ error_bufs }); + } + }); + ctx->callback_futures.push_back({ f }); + } +} + +static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx, + wgpu::Buffer & buffer, + wgpu::MapMode mode, + size_t offset, + size_t size) { + ctx->instance.WaitAny(buffer.MapAsync(mode, + offset, + size, + wgpu::CallbackMode::AllowSpontaneous, + [](wgpu::MapAsyncStatus status, wgpu::StringView message) { + if (status != wgpu::MapAsyncStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", + message.data); + } + }), + UINT64_MAX); +} + +#ifdef GGML_WEBGPU_DEBUG +// This function adds debugging information to shaders, as WebGPU does not support printing directly. +// To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and +// debug statements in the shader, and then call this function after encoding the commands and submitting them. +static void ggml_backend_webgpu_debug(webgpu_context & ctx) { + wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); + wgpu::CommandBuffer commands = encoder.Finish(); + ctx->queue.Submit(1, &commands); + + ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize()); + const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange(); + std::cout << "debug data:"; + for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) { + std::cout << " " << i << ": " << debug_data[i]; + } + std::cout << "\n"; + ctx->debug_host_buf.Unmap(); +} +#endif + +static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & ctx, + wgpu::ComputePipeline & pipeline, + std::vector params, + std::vector bind_group_entries, + uint32_t wg_x, + bool submit_and_wait = false) { + webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); + + ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize()); + uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange(); + for (size_t i = 0; i < params.size(); i++) { + _params[i] = params[i]; + }; + + params_bufs.host_buf.Unmap(); + + uint32_t params_bufs_binding_num = bind_group_entries.size(); + bind_group_entries.push_back({ .binding = params_bufs_binding_num, + .buffer = params_bufs.dev_buf, + .offset = 0, + .size = params_bufs.dev_buf.GetSize() }); wgpu::BindGroupDescriptor bind_group_desc; - bind_group_desc.layout = ctx->memset_pipeline.GetBindGroupLayout(0); - bind_group_desc.entryCount = 2; - bind_group_desc.label = "ggml_memset"; - bind_group_desc.entries = entries; - wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc); + bind_group_desc.layout = pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = bind_group_entries.size(); + bind_group_desc.entries = bind_group_entries.data(); + wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); - wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer( - ctx->memset_params_host_buf, 0, - ctx->memset_params_dev_buf, 0, - ctx->memset_params_dev_buf.GetSize() - ); + wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); - pass.SetPipeline(ctx->memset_pipeline); + pass.SetPipeline(pipeline); pass.SetBindGroup(0, bind_group); - size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread; - pass.DispatchWorkgroups(((size + 3) + bytes_per_wg - 1) / bytes_per_wg, 1, 1); + pass.DispatchWorkgroups(wg_x, 1, 1); pass.End(); wgpu::CommandBuffer commands = encoder.Finish(); - - ctx->queue.Submit(1, &commands); + if (submit_and_wait) { + // Submit and wait immediately + ctx->queue.Submit(1, &commands); + ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, + [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + if (status != wgpu::QueueWorkDoneStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data); + } + ctx->param_buf_pool.free_bufs({ params_bufs }); + }), + UINT64_MAX); + } else { + // Lock the context mutex when pushing to the staging vectors. + std::lock_guard lock(ctx->mutex); + // Enqueue commands and only submit if we have enough staged commands + ctx->staged_command_bufs.push_back(commands); + ctx->staged_param_bufs.push_back(params_bufs); + if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { + ggml_backend_webgpu_submit_queue(ctx); + } + } } -static void ggml_backend_webgpu_wait_on_submission(webgpu_context ctx) { - // Wait for the queue to finish processing all commands - ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::WaitAnyOnly, - [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to wait on queue: %s\n", message.data); - } - }), - UINT64_MAX - ); +static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx, + wgpu::Buffer & buf, + uint32_t value, + size_t offset, + size_t size) { + std::vector params = { (uint32_t) offset, (uint32_t) size, value }; + std::vector entries = { + { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() } + }; + size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread; + uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg; + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true); +} + +static size_t ggml_backend_webgpu_tensor_offset(const ggml_tensor * tensor) { + return webgpu_tensor_offset(tensor) + tensor->view_offs; +} + +static wgpu::Buffer ggml_backend_webgpu_tensor_buf(const ggml_tensor * tensor) { + ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; + return ctx->buffer; } /** End WebGPU Actions */ @@ -218,218 +435,221 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context ctx) { /** GGML Backend Interface */ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) { - ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context; + ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context; return ctx->name.c_str(); } static void ggml_backend_webgpu_free(ggml_backend_t backend) { - ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *)backend->context; + ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")"); // TODO: cleanup GGML_UNUSED(ctx); } +static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + size_t src_offset = ggml_backend_webgpu_tensor_offset(src); + // assumes power of 2 offset alignment + size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); + // align to minimum offset alignment + src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); + size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst); + size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); + dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); + uint32_t ne = (uint32_t) ggml_nelements(dst); + std::vector params = { ne, + (uint32_t) (src_misalignment / ggml_type_size(src->type)), + (uint32_t) (dst_misalignment / ggml_type_size(dst->type)), + // Convert byte-strides to element-strides + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + // Logical shape — same for both tensors even if permuted + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + (uint32_t) src->ne[3] }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_backend_webgpu_tensor_buf(src), + .offset = src_offset, + .size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & + ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) }, + { .binding = 1, + .buffer = ggml_backend_webgpu_tensor_buf(dst), + .offset = dst_offset, + .size = (ggml_nbytes(dst) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & + ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) } + }; + + size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; + uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x); +} + +static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { + // For set rows specifically, we need to check if src and idx are empty tensors. + if (ggml_is_empty(src) || ggml_is_empty(idx)) { + return; + } + + webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs(); + if (error_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { + error_bufs.host_buf.Unmap(); + } + + size_t src_offset = ggml_backend_webgpu_tensor_offset(src); + // assumes power of 2 offset alignment + size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); + // align to minimum offset alignment + src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); + size_t idx_offset = ggml_backend_webgpu_tensor_offset(idx); + size_t idx_misalignment = idx_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); + idx_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); + size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst); + size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); + dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); + + std::vector params = { (uint32_t) (src_misalignment / ggml_type_size(src->type)), + (uint32_t) (idx_misalignment / ggml_type_size(idx->type)), + (uint32_t) (dst_misalignment / ggml_type_size(dst->type)), + // Convert byte-strides to element-strides + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), + (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), + (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + // Shape of src + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + (uint32_t) src->ne[3], + // Shape of idx + (uint32_t) (idx->ne[1]), + (uint32_t) (idx->ne[2]) }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_backend_webgpu_tensor_buf(src), + .offset = ggml_backend_webgpu_tensor_offset(src), + .size = ggml_nbytes(src) }, + { .binding = 1, + .buffer = ggml_backend_webgpu_tensor_buf(idx), + .offset = ggml_backend_webgpu_tensor_offset(idx), + .size = ggml_nbytes(idx) }, + { .binding = 2, + .buffer = ggml_backend_webgpu_tensor_buf(dst), + .offset = ggml_backend_webgpu_tensor_offset(dst), + .size = ggml_nbytes(dst) }, + { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() } + }; + + size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; + uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; + + std::lock_guard lock(ctx->mutex); + ctx->staged_set_row_error_bufs.push_back(error_bufs); + + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x); +} + +static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { + std::vector params = { + (uint32_t) dst->ne[1], // number of rows in result (M) + (uint32_t) dst->ne[0], // number of columns in result (N) + (uint32_t) src0->ne[0], // number of columns in src0/src1 (K) + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 1 + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 1 + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 2 + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 2 + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 3 + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 3 + (uint32_t) src0->ne[2], // batch size in dimension 2 + (uint32_t) src0->ne[3], // batch size in dimension 3 + (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2 + (uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3 + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_backend_webgpu_tensor_buf(src0), + .offset = ggml_backend_webgpu_tensor_offset(src0), + .size = ggml_nbytes(src0) }, + { .binding = 1, + .buffer = ggml_backend_webgpu_tensor_buf(src1), + .offset = ggml_backend_webgpu_tensor_offset(src1), + .size = ggml_nbytes(src1) }, + { .binding = 2, + .buffer = ggml_backend_webgpu_tensor_buf(dst), + .offset = ggml_backend_webgpu_tensor_offset(dst), + .size = ggml_nbytes(dst) } + }; + + uint32_t wg_x = + (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline, params, entries, wg_x); +} + // Returns true if node has enqueued work into the queue, false otherwise -static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node){ +static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { if (ggml_is_empty(node)) { return false; } - WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")"); + ggml_tensor * src0 = node->src[0]; + ggml_tensor * src1 = node->src[1]; switch (node->op) { - // no-ops + // no-ops case GGML_OP_NONE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: return false; - - case GGML_OP_CPY: { - std::lock_guard lock(ctx->mutex); - const ggml_tensor * src = node->src[0]; - ggml_backend_webgpu_buffer_context * src_ctx = (ggml_backend_webgpu_buffer_context *) src->buffer->context; - size_t src_offset = webgpu_tensor_offset(src) + src->view_offs; - // assumes power of 2 offset alignment - size_t src_misalignment = src_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); - // align to minimum offset alignment - src_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); - ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context; - size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs; - size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); - dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1); - - wgpu::Device device = ctx->device; - ggml_backend_webgpu_map_buffer(ctx, ctx->cpy_params_host_buf, - wgpu::MapMode::Write, 0, ctx->cpy_params_host_buf.GetSize()); - uint32_t * params = (uint32_t *) ctx->cpy_params_host_buf.GetMappedRange(); - uint32_t ne = (uint32_t)ggml_nelements(node); - params[0] = ne; - params[1] = src_misalignment/ggml_type_size(src->type); - params[2] = dst_misalignment/ggml_type_size(node->type); - - // Convert byte-strides to element-strides - params[3] = (uint32_t)src->nb[0]/ggml_type_size(src->type); - params[4] = (uint32_t)src->nb[1]/ggml_type_size(src->type); - params[5] = (uint32_t)src->nb[2]/ggml_type_size(src->type); - params[6] = (uint32_t)src->nb[3]/ggml_type_size(src->type); - params[7] = (uint32_t)node->nb[0]/ggml_type_size(node->type); - params[8] = (uint32_t)node->nb[1]/ggml_type_size(node->type); - params[9] = (uint32_t)node->nb[2]/ggml_type_size(node->type); - params[10] = (uint32_t)node->nb[3]/ggml_type_size(node->type); - // Logical shape — same for both tensors even if permuted - params[11] = (uint32_t)(src->ne[0]); - params[12] = (uint32_t)(src->ne[1]); - params[13] = (uint32_t)(src->ne[2]); - params[14] = (uint32_t)(src->ne[3]); - - ctx->cpy_params_host_buf.Unmap(); - - wgpu::BindGroupEntry entries[3]; - entries[0].binding = 0; - entries[0].buffer = src_ctx->buffer; - entries[0].offset = src_offset; - entries[0].size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1); - - entries[1].binding = 1; - entries[1].buffer = dst_ctx->buffer; - entries[1].offset = dst_offset; - entries[1].size = (ggml_nbytes(node) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1); - - entries[2].binding = 2; - entries[2].buffer = ctx->cpy_params_dev_buf; - entries[2].offset = 0; - entries[2].size = ctx->cpy_params_dev_buf.GetSize(); - - wgpu::BindGroupDescriptor bind_group_desc; - bind_group_desc.layout = ctx->cpy_pipeline.GetBindGroupLayout(0); - bind_group_desc.label = "ggml_op_cpy"; - bind_group_desc.entryCount = 3; - bind_group_desc.entries = entries; - wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc); - - wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer( - ctx->cpy_params_host_buf, 0, - ctx->cpy_params_dev_buf, 0, - ctx->cpy_params_dev_buf.GetSize() - ); - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); - pass.SetPipeline(ctx->cpy_pipeline); - pass.SetBindGroup(0, bind_group); - size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; - pass.DispatchWorkgroups((ne + max_wg_size - 1) / max_wg_size); - pass.End(); - wgpu::CommandBuffer commands = encoder.Finish(); - - // TODO, don't submit here, batch submissions - ctx->queue.Submit(1, &commands); - // TODO, don't wait on submission here - ggml_backend_webgpu_wait_on_submission(ctx); - return true; - } - + case GGML_OP_CPY: + { + ggml_webgpu_cpy(ctx, src0, node); + break; + } + case GGML_OP_SET_ROWS: + { + ggml_webgpu_set_rows(ctx, src0, src1, node); + break; + } case GGML_OP_MUL_MAT: - { - const ggml_tensor * src0 = node->src[0]; - ggml_backend_webgpu_buffer_context * src0_ctx = (ggml_backend_webgpu_buffer_context *) src0->buffer->context; - size_t src0_offset = webgpu_tensor_offset(src0) + src0->view_offs; - const ggml_tensor * src1 = node->src[1]; - ggml_backend_webgpu_buffer_context * src1_ctx = (ggml_backend_webgpu_buffer_context *) src1->buffer->context; - size_t src1_offset = webgpu_tensor_offset(src1) + src1->view_offs; - ggml_backend_webgpu_buffer_context * dst_ctx = (ggml_backend_webgpu_buffer_context *) node->buffer->context; - - size_t dst_offset = webgpu_tensor_offset(node) + node->view_offs; - - wgpu::Device device = ctx->device; - - // map the host parameters buffer - ggml_backend_webgpu_map_buffer(ctx, ctx->mul_mat_params_host_buf, - wgpu::MapMode::Write, 0, ctx->mul_mat_params_host_buf.GetSize()); - uint32_t * params = (uint32_t *) ctx->mul_mat_params_host_buf.GetMappedRange(); - - params[0] = (uint32_t)node->ne[1]; // number of rows in result (M) - params[1] = (uint32_t)node->ne[0]; // number of columns in result (N) - params[2] = (uint32_t)src0->ne[0]; // number of columns in src0/src1 (K) - - params[3] = (uint32_t)src0->nb[1]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 1 - params[4] = (uint32_t)src1->nb[1]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 1 - params[5] = (uint32_t)src0->nb[2]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 2 - params[6] = (uint32_t)src1->nb[2]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 2 - params[7] = (uint32_t)src0->nb[3]/ggml_type_size(src0->type); // stride (elements) of src0 in dimension 3 - params[8] = (uint32_t)src1->nb[3]/ggml_type_size(src1->type); // stride (elements) of src1 in dimension 3 - - params[9] = (uint32_t)src0->ne[2]; // batch size in dimension 2 - params[10] = (uint32_t)src0->ne[3]; // batch size in dimension 3 - params[11] = (uint32_t)(src1->ne[2]/src0->ne[2]); // broadcast in dimension 2 - params[12] = (uint32_t)(src1->ne[3]/src0->ne[3]); // broadcast in dimension 3 - - ctx->mul_mat_params_host_buf.Unmap(); - - wgpu::BindGroupEntry entries[4]; - entries[0].binding = 0; - entries[0].buffer = src0_ctx->buffer; - entries[0].offset = src0_offset; - entries[0].size = ggml_nbytes(src0); - - entries[1].binding = 1; - entries[1].buffer = src1_ctx->buffer; - entries[1].offset = src1_offset; - entries[1].size = ggml_nbytes(src1); - - entries[2].binding = 2; - entries[2].buffer = dst_ctx->buffer; - entries[2].offset = dst_offset; - entries[2].size = ggml_nbytes(node); - - entries[3].binding = 3; - entries[3].buffer = ctx->mul_mat_params_dev_buf; - entries[3].offset = 0; - entries[3].size = ctx->mul_mat_params_dev_buf.GetSize(); - - wgpu::BindGroupDescriptor bind_group_desc; - bind_group_desc.layout = ctx->mul_mat_pipeline.GetBindGroupLayout(0); - bind_group_desc.entryCount = 4; - bind_group_desc.label = "ggml_op_mul_mat"; - bind_group_desc.entries = entries; - wgpu::BindGroup bind_group = device.CreateBindGroup(&bind_group_desc); - - wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer( - ctx->mul_mat_params_host_buf, 0, - ctx->mul_mat_params_dev_buf, 0, - ctx->mul_mat_params_dev_buf.GetSize() - ); - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); - pass.SetPipeline(ctx->mul_mat_pipeline); - pass.SetBindGroup(0, bind_group); - pass.DispatchWorkgroups((node->ne[0] * node->ne[1] * node->ne[2] * node->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE); - pass.End(); - wgpu::CommandBuffer commands = encoder.Finish(); - - // TODO, don't submit here, batch submissions - ctx->queue.Submit(1, &commands); - // TODO, don't wait on submission here - ggml_backend_webgpu_wait_on_submission(ctx); - return true; - } - + { + ggml_webgpu_mul_mat(ctx, src0, src1, node); + break; + } default: return false; } + return true; } static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_webgpu_context * backend_ctx = static_cast(backend->context); - webgpu_context ctx = backend_ctx->webgpu_ctx; + webgpu_context ctx = backend_ctx->webgpu_ctx; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_webgpu_encode_node(ctx, cgraph->nodes[i]); } + ggml_backend_webgpu_submit_queue(ctx); + ggml_backend_webgpu_wait_on_submission(ctx); + return GGML_STATUS_SUCCESS; } @@ -465,49 +685,72 @@ static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) return webgpu_ptr_base; } -static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { +static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + uint8_t value, + size_t offset, + size_t size) { if (size == 0) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do."); return; } - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " + << offset << ", " << size << ")"); ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; + size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + // This is a trick to set all bytes of a u32 to the same 1 byte value. - uint32_t val32 = (uint32_t)value * 0x01010101; + uint32_t val32 = (uint32_t) value * 0x01010101; ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size); } -static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); - ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; - webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; +static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " + << offset << ", " << size << ")"); + ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; + webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; - webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size/4)*4); + webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); if (size % 4 != 0) { // If size is not a multiple of 4, we need to memset the remaining bytes size_t remaining_size = size % 4; + // pack the remaining bytes into a uint32_t uint32_t val32 = 0; + for (size_t i = 0; i < remaining_size; i++) { - ((uint8_t *)&val32)[i] = ((const uint8_t *)data)[size - remaining_size + i]; + ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i]; } // memset the remaining bytes - ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size); + ggml_backend_webgpu_buffer_memset( + webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size); + } else { + // wait for WriteBuffer to complete + ggml_backend_webgpu_wait_on_submission(webgpu_ctx); } } -static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); +static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " + << offset << ", " << size << ")"); - ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; - webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; - wgpu::Device device = webgpu_ctx->device; + ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; + webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; + wgpu::Device device = webgpu_ctx->device; size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; @@ -517,22 +760,25 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, final_size = size + (4 - (size % 4)); } - std::lock_guard lock(webgpu_ctx->mutex); + std::lock_guard lock(webgpu_ctx->mutex); - if (webgpu_ctx->get_tensor_staging_buf == nullptr || - webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) { + if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) { // Create a new staging buffer if it doesn't exist or is too small if (webgpu_ctx->get_tensor_staging_buf) { webgpu_ctx->get_tensor_staging_buf.Destroy(); } - ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf"); + ggml_webgpu_create_buffer(device, + webgpu_ctx->get_tensor_staging_buf, + final_size, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, + "get_tensor_staging_buf"); } // Copy the data from the buffer to the staging buffer wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size); wgpu::CommandBuffer commands = encoder.Finish(); + // Submit the command buffer to the queue webgpu_ctx->queue.Submit(1, &commands); @@ -548,7 +794,6 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")"); - ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size); } @@ -556,13 +801,13 @@ static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8 static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { /* .free_buffer = */ ggml_backend_webgpu_buffer_free_buffer, /* .get_base = */ ggml_backend_webgpu_buffer_get_base, - /* .init_tensor = */ NULL, // TODO: optional, needed? + /* .init_tensor = */ NULL, // TODO: optional, needed? /* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor, - /* .cpy_tensor = */ NULL, // TODO: optional, implement this + /* .cpy_tensor = */ NULL, // TODO: optional, implement this /* .clear = */ ggml_backend_webgpu_buffer_clear, - /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor + /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor }; /* End GGML Backend Buffer Interface */ @@ -574,13 +819,17 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer return ctx->device_name.c_str(); } -static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { +static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, + size_t size) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")"); ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); wgpu::Buffer buf; - ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, size, - wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, "allocated_buffer"); + ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, + buf, + size, + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, + "allocated_buffer"); ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf); @@ -615,8 +864,8 @@ static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_ static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_webgpu_device_context * ctx = static_cast(dev->context); // TODO: what do we actually want to return here? maxBufferSize might not be the full available memory. - *free = ctx->webgpu_ctx->limits.maxBufferSize; - *total = ctx->webgpu_ctx->limits.maxBufferSize; + *free = ctx->webgpu_ctx->limits.maxBufferSize; + *total = ctx->webgpu_ctx->limits.maxBufferSize; } static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) { @@ -639,98 +888,126 @@ static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct static ggml_guid_t ggml_backend_webgpu_guid(void) { static const char * guid_str = "__ggml_webgpu :)"; - return reinterpret_cast((void *)guid_str); + return reinterpret_cast((void *) guid_str); } -static void ggml_webgpu_init_memset_pipeline(webgpu_context webgpu_ctx) { +static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) { // we use the maximum workgroup size for the memset pipeline size_t max_wg_size = webgpu_ctx->limits.maxComputeWorkgroupSizeX; size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension; // Size the bytes_per_thread so that the largest buffer size can be handled - webgpu_ctx->memset_bytes_per_thread = (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads; + webgpu_ctx->memset_bytes_per_thread = + (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads; std::vector constants(2); - constants[0].key = "wg_size"; + constants[0].key = "wg_size"; constants[0].value = max_wg_size; - constants[1].key = "bytes_per_thread"; + constants[1].key = "bytes_per_thread"; constants[1].value = webgpu_ctx->memset_bytes_per_thread; ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, "memset", constants); - ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_dev_buf, - 3 * sizeof(uint32_t), // 3 parameters: buffer size, offset, value - wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "memset_params_dev_buf"); - ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->memset_params_host_buf, - 3 * sizeof(uint32_t), wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "memset_params_host_buf"); } -static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context webgpu_ctx) { +static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat"); - ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_dev_buf, WEBGPU_MUL_MAT_PARAMS_SIZE, - wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "mul_mat_params_dev_buf"); - ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->mul_mat_params_host_buf, WEBGPU_MUL_MAT_PARAMS_SIZE, - wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "mul_mat_params_host_buf"); } -static void ggml_webgpu_init_cpy_pipeline(webgpu_context webgpu_ctx) { +static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { std::vector constants(1); - constants[0].key = "wg_size"; + constants[0].key = "wg_size"; constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; - - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants); - ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_dev_buf, WEBGPU_CPY_PARAMS_SIZE, - wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst, "cpy_params_dev_buf"); - ggml_webgpu_create_buffer(webgpu_ctx->device, webgpu_ctx->cpy_params_host_buf, WEBGPU_CPY_PARAMS_SIZE, - wgpu::BufferUsage::MapWrite | wgpu::BufferUsage::CopySrc, "cpy_params_host_buf"); + ggml_webgpu_create_pipeline( + webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", constants); +} + +static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants(1); + constants[0].key = "wg_size"; + constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX; + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants); } -// TODO: Make thread safe if multiple devices are used static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { GGML_UNUSED(params); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()"); - ggml_backend_webgpu_device_context * dev_ctx = static_cast(dev->context); - webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx; + ggml_backend_webgpu_device_context * dev_ctx = static_cast(dev->context); + webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx; - std::lock_guard lock(webgpu_ctx->mutex); - - if (!webgpu_ctx->device_initialized) { + // Multiple threads may try to initialize the device + std::lock_guard lock(webgpu_ctx->mutex); + if (!webgpu_ctx->device_init) { // Initialize device - wgpu::DeviceDescriptor dev_desc; - dev_desc.requiredLimits = &webgpu_ctx->limits; - dev_desc.requiredFeatures = webgpu_ctx->features.features; - dev_desc.requiredFeatureCount = webgpu_ctx->features.featureCount; - dev_desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous, - [](const wgpu::Device& device, wgpu::DeviceLostReason reason, wgpu::StringView message) { + std::vector required_features = { wgpu::FeatureName::ShaderF16, + wgpu::FeatureName::ImplicitDeviceSynchronization }; + wgpu::DeviceDescriptor dev_desc; + dev_desc.requiredLimits = &webgpu_ctx->limits; + dev_desc.requiredFeatures = required_features.data(); + dev_desc.requiredFeatureCount = required_features.size(); + dev_desc.SetDeviceLostCallback( + wgpu::CallbackMode::AllowSpontaneous, + [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { GGML_UNUSED(device); - GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), message.data); - }); + GGML_LOG_ERROR( + "ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), message.data); + }); dev_desc.SetUncapturedErrorCallback( - [](const wgpu::Device& device, wgpu::ErrorType reason, wgpu::StringView message) { + [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { GGML_UNUSED(device); - GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), message.data); - }); - webgpu_ctx->instance.WaitAny(webgpu_ctx->adapter.RequestDevice(&dev_desc, wgpu::CallbackMode::WaitAnyOnly, - [webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { - if (status != wgpu::RequestDeviceStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data); - return; - } - webgpu_ctx->device = device; - }), - UINT64_MAX - ); + GGML_LOG_ERROR( + "ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), message.data); + }); + webgpu_ctx->instance.WaitAny( + webgpu_ctx->adapter.RequestDevice( + &dev_desc, + wgpu::CallbackMode::AllowSpontaneous, + [webgpu_ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { + if (status != wgpu::RequestDeviceStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", message.data); + return; + } + webgpu_ctx->device = std::move(device); + }), + UINT64_MAX); GGML_ASSERT(webgpu_ctx->device != nullptr); // Initialize (compute) queue webgpu_ctx->queue = webgpu_ctx->device.GetQueue(); + // Create buffer pool for shader parameters + webgpu_ctx->param_buf_pool.init(webgpu_ctx->device, + WEBGPU_NUM_PARAM_BUFS, + WEBGPU_PARAMS_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->device, + WEBGPU_NUM_SET_ROWS_ERROR_BUFS, + WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); + ggml_webgpu_init_memset_pipeline(webgpu_ctx); ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx); + ggml_webgpu_init_set_rows_pipeline(webgpu_ctx); ggml_webgpu_init_cpy_pipeline(webgpu_ctx); - webgpu_ctx->device_initialized = true; + +#ifdef GGML_WEBGPU_DEBUG + // Initialize debug buffers + ggml_webgpu_create_buffer(webgpu_ctx->device, + webgpu_ctx->debug_host_buf, + WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, + "debug_host_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->device, + webgpu_ctx->debug_dev_buf, + WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, + "debug_dev_buf"); +#endif + webgpu_ctx->device_init = true; } static ggml_backend_webgpu_context backend_ctx; - backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name; + backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name; backend_ctx.webgpu_ctx = webgpu_ctx; // See GGML Backend Interface section @@ -748,14 +1025,15 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm // See GGML Backend Buffer Type Interface section static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = { /* .iface = */ { - /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ NULL, // defaults to false + /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size, + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .is_host = */ NULL, // defaults to false }, - /* .device = */ dev, + /* .device = */ + dev, /* .context = */ NULL, }; @@ -764,7 +1042,7 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { GGML_UNUSED(dev); - return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name; + return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name; } static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { @@ -775,7 +1053,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_VIEW: case GGML_OP_PERMUTE: return true; - case GGML_OP_CPY: + case GGML_OP_CPY | GGML_OP_SET_ROWS: return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_MUL_MAT: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; @@ -827,30 +1105,38 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t webgpu_context ctx = reg_ctx->webgpu_ctx; wgpu::RequestAdapterOptions options = {}; - auto callback = [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char *message, void *userdata) { - if (status != wgpu::RequestAdapterStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); - return; - } - *static_cast(userdata) = adapter; - }; - void *userdata = &ctx->adapter; - ctx->instance.WaitAny(ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::WaitAnyOnly, callback, userdata), UINT64_MAX); + auto callback = + [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message, void * userdata) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + *static_cast(userdata) = std::move(adapter); + }; + void * userdata = &ctx->adapter; + ctx->instance.WaitAny( + ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, callback, userdata), UINT64_MAX); GGML_ASSERT(ctx->adapter != nullptr); ctx->adapter.GetLimits(&ctx->limits); - ctx->adapter.GetFeatures(&ctx->features); wgpu::AdapterInfo info{}; ctx->adapter.GetInfo(&info); static ggml_backend_webgpu_device_context device_ctx; - device_ctx.webgpu_ctx = ctx; + device_ctx.webgpu_ctx = ctx; device_ctx.device_name = GGML_WEBGPU_NAME; device_ctx.device_desc = std::string(info.description.data); - GGML_LOG_INFO("ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | device_desc: %s\n", - info.vendorID, info.vendor.data, info.architecture.data, info.deviceID, info.device.data, info.description.data); + GGML_LOG_INFO( + "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | " + "device_desc: %s\n", + info.vendorID, + info.vendor.data, + info.architecture.data, + info.deviceID, + info.device.data, + info.description.data); // See GGML Backend Device Interface section static ggml_backend_device device = { @@ -861,7 +1147,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t return &device; } - static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { /* .get_name = */ ggml_backend_webgpu_reg_get_name, /* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count, @@ -871,23 +1156,21 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { /* End GGML Backend Registration Interface */ -// TODO: Does this need to be thread safe? Is it only called once? ggml_backend_reg_t ggml_backend_webgpu_reg() { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); webgpu_context webgpu_ctx = std::make_shared(); - webgpu_ctx->device_initialized = false; static ggml_backend_webgpu_reg_context ctx; - ctx.webgpu_ctx = webgpu_ctx; - ctx.name = GGML_WEBGPU_NAME; + ctx.webgpu_ctx = webgpu_ctx; + ctx.name = GGML_WEBGPU_NAME; ctx.device_count = 1; - wgpu::InstanceDescriptor instance_descriptor{}; - std::vector instance_features = {wgpu::InstanceFeatureName::TimedWaitAny}; - instance_descriptor.requiredFeatures = instance_features.data(); - instance_descriptor.requiredFeatureCount = instance_features.size(); - webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor); + wgpu::InstanceDescriptor instance_descriptor{}; + std::vector instance_features = { wgpu::InstanceFeatureName::TimedWaitAny }; + instance_descriptor.requiredFeatures = instance_features.data(); + instance_descriptor.requiredFeatureCount = instance_features.size(); + webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor); GGML_ASSERT(webgpu_ctx->instance != nullptr); static ggml_backend_reg reg = { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl new file mode 100644 index 0000000000..4bd6f94a23 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl @@ -0,0 +1,82 @@ +enable f16; + +@group(0) @binding(0) +var src: array; + +@group(0) @binding(1) +var idx: array; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var error: atomic; + +struct Params { + offset_src: u32, // in elements + offset_idx: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_idx0: u32, + stride_idx1: u32, + stride_idx2: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Shape of src + ne0: u32, + n_rows: u32, + ne2: u32, + ne3: u32, + + // Shape of idx + idx1: u32, + idx2: u32, +}; + +@group(0) @binding(4) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.n_rows * params.ne2 * params.ne3) { + return; + } + var i = gid.x; + let i_src3 = i / (params.ne2 * params.n_rows); + let i_dst3 = i / (params.ne2 * 3); + + i = i % (params.ne2 * params.n_rows); + let i_src2 = i / params.n_rows; + let i_src1 = i % params.n_rows; + + let i_idx2 = i_src3 % params.idx2; + let i_idx1 = i_src2 % params.idx1; + let i_idx0 = i_src1; + + let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2; + + let idx_high_val = idx[idx_high]; + let idx_low_val = idx[idx_high + 1]; + + if (idx_low_val != 0) { + // Upper bits of index are not zero, output will be incorrect + atomicStore(&error, 1); + return; + } + + let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; + let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; + + for (var i: u32 = 0; i < params.ne0; i++) { + dst[i_dst_row + i] = f16(src[i_src_row + i]); + } +} diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 124cf3e8b6..55a76f8248 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -582,9 +582,6 @@ FILE * ggml_fopen(const char * fname, const char * mode) { #endif } -static void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc); -static void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc); -static void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc); static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { [GGML_TYPE_I8] = { @@ -690,6 +687,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref, }, + [GGML_TYPE_MXFP4] = { + .type_name = "mxfp4", + .blck_size = QK_MXFP4, + .type_size = sizeof(block_mxfp4), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_mxfp4, + .from_float_ref = (ggml_from_float_t)quantize_row_mxfp4_ref, + }, [GGML_TYPE_Q2_K] = { .type_name = "q2_K", .blck_size = QK_K, @@ -917,6 +922,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "DUP", "ADD", + "ADD_ID", "ADD1", "ACC", "SUB", @@ -1010,13 +1016,14 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86"); +static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", "x", "x+y", + "x[i]+y", "x+y", "view(x,nb,offset)+=y->x", "x-y", @@ -1110,7 +1117,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86"); +static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1140,11 +1147,12 @@ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", "GEGLU", "SWIGLU", + "SWIGLU_OAI", "GEGLU_ERF", "GEGLU_QUICK", }; -static_assert(GGML_GLU_OP_COUNT == 5, "GGML_GLU_OP_COUNT != 5"); +static_assert(GGML_GLU_OP_COUNT == 6, "GGML_GLU_OP_COUNT != 6"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -1312,6 +1320,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; + case GGML_FTYPE_MOSTLY_MXFP4: wtype = GGML_TYPE_MXFP4; break; case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; @@ -1962,6 +1971,27 @@ struct ggml_tensor * ggml_add_cast( return ggml_add_cast_impl(ctx, a, b, type); } +struct ggml_tensor * ggml_add_id( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * ids) { + + GGML_ASSERT(a->ne[0] == b->ne[0]); + GGML_ASSERT(a->ne[1] == ids->ne[0]); + GGML_ASSERT(a->ne[2] == ids->ne[1]); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_ADD_ID; + result->src[0] = a; + result->src[1] = b; + result->src[2] = ids; + + return result; +} + // ggml_add1 static struct ggml_tensor * ggml_add1_impl( @@ -2812,6 +2842,19 @@ struct ggml_tensor * ggml_geglu_quick_split( return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false); } +struct ggml_tensor * ggml_swiglu_oai( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float alpha, + float limit) { + struct ggml_tensor * result = ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU_OAI, false); + ggml_set_op_params_f32(result, 2, alpha); + ggml_set_op_params_f32(result, 3, limit); + + return result; +} + // ggml_norm static struct ggml_tensor * ggml_norm_impl( @@ -3779,6 +3822,22 @@ struct ggml_tensor * ggml_soft_max_ext( return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); } +void ggml_soft_max_add_sinks( + struct ggml_tensor * a, + struct ggml_tensor * sinks) { + if (!sinks) { + a->src[2] = NULL; + return; + } + + GGML_ASSERT(a->op == GGML_OP_SOFT_MAX); + GGML_ASSERT(a->src[2] == NULL); + GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]); + GGML_ASSERT(sinks->type == GGML_TYPE_F32); + + a->src[2] = sinks; +} + // ggml_soft_max_ext_back static struct ggml_tensor * ggml_soft_max_ext_back_impl( @@ -4812,6 +4871,22 @@ enum ggml_prec ggml_flash_attn_ext_get_prec( return (enum ggml_prec) prec_i32; } +void ggml_flash_attn_ext_add_sinks( + struct ggml_tensor * a, + struct ggml_tensor * sinks) { + if (!sinks) { + a->src[4] = NULL; + return; + } + + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + GGML_ASSERT(a->src[4] == NULL); + GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]); + GGML_ASSERT(sinks->type == GGML_TYPE_F32); + + a->src[4] = sinks; +} + // ggml_flash_attn_back struct ggml_tensor * ggml_flash_attn_back( @@ -6872,6 +6947,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 5707085cb6..911eea504a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -105,6 +105,7 @@ class Keys: EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" EXPERT_GATING_FUNC = "{arch}.expert_gating_func" MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers" + NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" @@ -357,6 +358,7 @@ class MODEL_ARCH(IntEnum): DEEPSEEK2 = auto() CHATGLM = auto() GLM4 = auto() + GLM4_MOE = auto() BITNET = auto() T5 = auto() T5ENCODER = auto() @@ -378,6 +380,7 @@ class MODEL_ARCH(IntEnum): HUNYUAN_MOE = auto() HUNYUAN_DENSE = auto() SMOLLM3 = auto() + GPT_OSS = auto() LFM2 = auto() DREAM = auto() SMALLTHINKER = auto() @@ -414,6 +417,7 @@ class MODEL_TENSOR(IntEnum): ATTN_OUT_NORM = auto() ATTN_POST_NORM = auto() ATTN_ROT_EMBD = auto() + ATTN_SINKS = auto() FFN_GATE_INP = auto() FFN_GATE_INP_SHEXP = auto() FFN_NORM = auto() @@ -614,6 +618,13 @@ class MODEL_TENSOR(IntEnum): A_MMPROJ_FC = auto() A_MM_NORM_PRE = auto() A_MM_NORM_MID = auto() + # nextn/mtp + NEXTN_EH_PROJ = auto() + NEXTN_EMBED_TOKENS = auto() + NEXTN_ENORM = auto() + NEXTN_HNORM = auto() + NEXTN_SHARED_HEAD_HEAD = auto() + NEXTN_SHARED_HEAD_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -678,6 +689,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.DEEPSEEK2: "deepseek2", MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.GLM4: "glm4", + MODEL_ARCH.GLM4_MOE: "glm4moe", MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", @@ -700,6 +712,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", MODEL_ARCH.HUNYUAN_DENSE: "hunyuan-dense", MODEL_ARCH.SMOLLM3: "smollm3", + MODEL_ARCH.GPT_OSS: "gpt-oss", MODEL_ARCH.LFM2: "lfm2", MODEL_ARCH.DREAM: "dream", MODEL_ARCH.SMALLTHINKER: "smallthinker", @@ -734,6 +747,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", + MODEL_TENSOR.ATTN_SINKS: "blk.{bid}.attn_sinks", MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", @@ -936,6 +950,13 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc", MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre", MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid", + # NextN/MTP + MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj", + MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens", + MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.nextn.enorm", + MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.nextn.hnorm", + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.nextn.shared_head_head", + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.nextn.shared_head_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -2124,6 +2145,37 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_POST_NORM, MODEL_TENSOR.FFN_POST_NORM, ], + MODEL_ARCH.GLM4_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, + ], MODEL_ARCH.BITNET: [ MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, @@ -2505,6 +2557,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.GPT_OSS: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_SINKS, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.LFM2: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD_NORM, @@ -2659,6 +2727,7 @@ class GGMLQuantizationType(IntEnum): BF16 = 30 TQ1_0 = 34 TQ2_0 = 35 + MXFP4 = 39 class ExpertGatingFuncType(IntEnum): @@ -2799,6 +2868,7 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { GGMLQuantizationType.BF16: (1, 2), GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13), GGMLQuantizationType.TQ2_0: (256, 2 + 64), + GGMLQuantizationType.MXFP4: (32, 1 + 16), } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index f4fd64ad82..a6cc8a931e 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -138,8 +138,9 @@ class GGUFWriter: size = prod(shape) if "_exps." in name: - expert_params += (size // shape[-3]) - expert_sum += shape[-3] + expert_count = shape[-2 if ".bias" in name else -3] + expert_params += (size // expert_count) + expert_sum += expert_count n_expert_tensors += 1 else: shared_params += size @@ -753,6 +754,9 @@ class GGUFWriter: def add_moe_every_n_layers(self, value: int) -> None: self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value) + def add_nextn_predict_layers(self, count: int) -> None: + self.add_uint32(Keys.LLM.NEXTN_PREDICT_LAYERS.format(arch=self.arch), count) + def add_swin_norm(self, value: bool) -> None: self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value) diff --git a/gguf-py/gguf/scripts/gguf_new_metadata.py b/gguf-py/gguf/scripts/gguf_new_metadata.py index 63f2300348..2fa5800cf7 100755 --- a/gguf-py/gguf/scripts/gguf_new_metadata.py +++ b/gguf-py/gguf/scripts/gguf_new_metadata.py @@ -111,6 +111,7 @@ def main() -> None: parser.add_argument("--general-description", type=str, help="The models general.description", metavar='"Description ..."') parser.add_argument("--chat-template", type=str, help="Chat template string (or JSON string containing templates)", metavar='"{% ... %} ..."') parser.add_argument("--chat-template-config", type=Path, help="Config file containing chat template(s)", metavar='tokenizer_config.json') + parser.add_argument("--chat-template-file", type=Path, help="Jinja file containing chat template", metavar='chat_template.jinja') parser.add_argument("--pre-tokenizer", type=str, help="The models tokenizer.ggml.pre", metavar='"pre tokenizer"') parser.add_argument("--remove-metadata", action="append", type=str, help="Remove metadata (by key name) from output model", metavar='general.url') parser.add_argument("--special-token", action="append", type=str, help="Special token by value", nargs=2, metavar=(' | '.join(token_names.keys()), '""')) @@ -134,12 +135,17 @@ def main() -> None: new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template) if args.chat_template_config: - with open(args.chat_template_config, 'r') as fp: + with open(args.chat_template_config, 'r', encoding='utf-8') as fp: config = json.load(fp) template = config.get('chat_template') if template: new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template) + if args.chat_template_file: + with open(args.chat_template_file, 'r', encoding='utf-8') as fp: + template = fp.read() + new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template) + if args.pre_tokenizer: new_metadata[gguf.Keys.Tokenizer.PRE] = MetadataDetails(gguf.GGUFValueType.STRING, args.pre_tokenizer) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index df490fc80e..dc7c03b464 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -33,6 +33,7 @@ class TensorNameMap: "language_model.model.embed_tokens", # llama4 "encoder", # neobert "model.transformer.wte", # llada + "embed_tokens", # qwen3-embedding ), # Token type embeddings @@ -143,6 +144,7 @@ class TensorNameMap: "transformer_encoder.{bid}.attention_norm", # neobert "model.layers.{bid}.operator_norm", # lfm2 "model.transformer.blocks.{bid}.attn_norm", # llada + "layers.{bid}.input_layernorm", # qwen3-embedding ), # Attention norm 2 @@ -188,6 +190,7 @@ class TensorNameMap: "transformer.h.{bid}.attn.attention.q_proj", # exaone "model.layers.{bid}.self_attn.q_proj", # llama4 "model.transformer.blocks.{bid}.q_proj", # llada + "layers.{bid}.self_attn.q_proj", # qwen3-embedding ), # Attention key @@ -205,6 +208,7 @@ class TensorNameMap: "transformer.h.{bid}.attn.attention.k_proj", # exaone "model.layers.{bid}.self_attn.k_proj", # llama4 "model.transformer.blocks.{bid}.k_proj", # llada + "layers.{bid}.self_attn.k_proj", # qwen3-embedding ), # Attention value @@ -221,6 +225,7 @@ class TensorNameMap: "transformer.h.{bid}.attn.attention.v_proj", # exaone "model.layers.{bid}.self_attn.v_proj", # llama4 "model.transformer.blocks.{bid}.v_proj", # llada + "layers.{bid}.self_attn.v_proj", # qwen3-embedding ), # Attention output @@ -254,6 +259,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.o_proj", # llama4 "transformer_encoder.{bid}.wo", # neobert "model.transformer.blocks.{bid}.attn_out", # llada + "layers.{bid}.self_attn.o_proj", # qwen3-embedding ), # Attention output norm @@ -279,6 +285,10 @@ class TensorNameMap: "transformer.h.{bid}.attn.rotary_emb.inv_freq", # codeshell ), + MODEL_TENSOR.ATTN_SINKS: ( + "model.layers.{bid}.self_attn.sinks", # openai-moe + ), + # Feed-forward norm MODEL_TENSOR.FFN_NORM: ( "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox @@ -300,6 +310,7 @@ class TensorNameMap: "transformer_encoder.{bid}.ffn_norm", # neobert "model.layers.layers.{bid}.pre_mlp_norm", # plamo2 "model.transformer.blocks.{bid}.ff_norm", # llada + "layers.{bid}.post_attention_layernorm", # qwen3-embedding ), # Post feed-forward norm @@ -325,6 +336,7 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe "model.layers.{bid}.feed_forward.router", # llama4 jamba "encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe + "model.layers.{bid}.mlp.router", # openai-moe "model.layers.{bid}.mlp.gate.wg", # hunyuan "model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker ), @@ -373,7 +385,8 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.up_proj", # llama4 jamba granite-hybrid "transformer_encoder.{bid}.ffn.w12", # neobert "model.layers.{bid}.block_sparse_moe.up", # smallthinker - "model.transformer.blocks.{bid}.up_proj", # llada + "model.transformer.blocks.{bid}.up_proj", # llada + "layers.{bid}.mlp.up_proj", # qwen3-embedding ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -416,6 +429,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.gate_proj", # llama4 jamba granite-hybrid "model.layers.{bid}.block_sparse_moe.gate", # smallthinker "model.transformer.blocks.{bid}.ff_proj", # llada + "layers.{bid}.mlp.gate_proj", # qwen3-embedding ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -465,7 +479,8 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.down_proj", # llama4 jamba granite-hybrid "transformer_encoder.{bid}.ffn.w3", # neobert "model.layers.{bid}.block_sparse_moe.down", # smallthinker - "model.transformer.blocks.{bid}.ff_out", # llada + "model.transformer.blocks.{bid}.ff_out", # llada + "layers.{bid}.mlp.down_proj", # qwen3-embedding ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -497,6 +512,7 @@ class TensorNameMap: "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 "transformer.layers.{bid}.attn.q_norm", # openelm "model.layers.layers.{bid}.mixer.q", # plamo2 + "layers.{bid}.self_attn.q_norm", # qwen3-embedding ), MODEL_TENSOR.ATTN_K_NORM: ( @@ -508,6 +524,7 @@ class TensorNameMap: "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 "transformer.layers.{bid}.attn.k_norm", # openelm "model.layers.layers.{bid}.mixer.k", # plamo2 + "layers.{bid}.self_attn.k_norm", # qwen3-embedding ), MODEL_TENSOR.ROPE_FREQS: ( @@ -1093,11 +1110,13 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_EMBD_CLS: ( "vision_tower.vision_model.embeddings.class_embedding", + "model.vision_tower.embeddings.cls_token", # Intern-S1 "vision_model.class_embedding", # llama 4 ), MODEL_TENSOR.V_ENC_EMBD_PATCH: ( "vision_tower.vision_model.embeddings.patch_embedding", + "model.vision_tower.embeddings.patch_embeddings.projection", # Intern-S1 "vpm.embeddings.patch_embedding", "model.vision_model.embeddings.patch_embedding", # SmolVLM "vision_tower.patch_conv", # pixtral @@ -1107,6 +1126,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_EMBD_POS: ( "vision_tower.vision_model.embeddings.position_embedding", + "model.vision_tower.embeddings.position_embeddings", # Intern-S1 "vpm.embeddings.position_embedding", "model.vision_model.embeddings.position_embedding", # SmolVLM "vision_model.positional_embedding_vlm", # llama 4 @@ -1114,6 +1134,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_ATTN_Q: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj", + "model.vision_tower.encoder.layer.{bid}.attention.q_proj", # Intern-S1 "vpm.encoder.layers.{bid}.self_attn.q_proj", "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM "vision_model.model.layers.{bid}.self_attn.q_proj", # llama4 @@ -1123,10 +1144,12 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.attn.q_norm", # InternVL + "model.vision_tower.encoder.layer.{bid}.attention.q_norm", # Intern-S1 ), MODEL_TENSOR.V_ENC_ATTN_K: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj", + "model.vision_tower.encoder.layer.{bid}.attention.k_proj", # Intern-S1 "vpm.encoder.layers.{bid}.self_attn.k_proj", "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM "vision_model.model.layers.{bid}.self_attn.k_proj", # llama4 @@ -1136,10 +1159,12 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.attn.k_norm", # InternVL + "model.vision_tower.encoder.layer.{bid}.attention.k_norm", # Intern-S1 ), MODEL_TENSOR.V_ENC_ATTN_V: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj", + "model.vision_tower.encoder.layer.{bid}.attention.v_proj", # Intern-S1 "vpm.encoder.layers.{bid}.self_attn.v_proj", "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM "vision_model.model.layers.{bid}.self_attn.v_proj", # llama4 @@ -1150,6 +1175,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_INPUT_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.layer_norm1", "vision_tower.vision_model.encoder.layers.{bid}.norm1", # InternVL + "model.vision_tower.encoder.layer.{bid}.layernorm_before", # Intern-S1 "vpm.encoder.layers.{bid}.layer_norm1", "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral @@ -1160,6 +1186,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_ATTN_O: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj", "vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL + "model.vision_tower.encoder.layer.{bid}.attention.projection_layer", # Intern-S1 "vpm.encoder.layers.{bid}.self_attn.out_proj", "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM "vision_model.model.layers.{bid}.self_attn.o_proj", # llama4 @@ -1170,6 +1197,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", "vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL + "model.vision_tower.encoder.layer.{bid}.layernorm_after", # Intern-S1 "vpm.encoder.layers.{bid}.layer_norm2", "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM "vision_model.model.layers.{bid}.post_attention_layernorm", # llama4 @@ -1179,6 +1207,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_FFN_UP: ( "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1", + "model.vision_tower.encoder.layer.{bid}.mlp.fc1", # Intern-S1 "vpm.encoder.layers.{bid}.mlp.fc1", "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral @@ -1194,6 +1223,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_FFN_DOWN: ( "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2", + "model.vision_tower.encoder.layer.{bid}.mlp.fc2", # Intern-S1 "vpm.encoder.layers.{bid}.mlp.fc2", "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral @@ -1204,10 +1234,12 @@ class TensorNameMap: MODEL_TENSOR.V_LAYER_SCALE_1: ( "vision_tower.vision_model.encoder.layers.{bid}.ls1", # InternVL + "model.vision_tower.encoder.layer.{bid}.lambda_1", # Intern-S1 ), MODEL_TENSOR.V_LAYER_SCALE_2: ( "vision_tower.vision_model.encoder.layers.{bid}.ls2", # InternVL + "model.vision_tower.encoder.layer.{bid}.lambda_2", # Intern-S1 ), MODEL_TENSOR.V_PRE_NORM: ( @@ -1357,6 +1389,31 @@ class TensorNameMap: MODEL_TENSOR.A_MM_NORM_MID: ( "audio.multi_modal_projector.ln_mid", # ultravox ), + + # NextN/MTP tensors for GLM4_MOE + MODEL_TENSOR.NEXTN_EH_PROJ: ( + "model.layers.{bid}.eh_proj", + ), + + MODEL_TENSOR.NEXTN_EMBED_TOKENS: ( + "model.layers.{bid}.embed_tokens", + ), + + MODEL_TENSOR.NEXTN_ENORM: ( + "model.layers.{bid}.enorm", + ), + + MODEL_TENSOR.NEXTN_HNORM: ( + "model.layers.{bid}.hnorm", + ), + + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: ( + "model.layers.{bid}.shared_head.head", + ), + + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: ( + "model.layers.{bid}.shared_head.norm", + ), } # architecture-specific block mappings diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index e1d5aaf47a..7111557bfd 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -312,7 +312,11 @@ class SpecialVocab: with open(config_file, encoding = 'utf-8') as f: config = json.load(f) for typ in self.special_token_types: - self._set_special_token(typ, config.get(f'{typ}_token_id')) + token_id = config.get(f'{typ}_token_id') + # If not found at root, check in text_config (for multimodal models like Kimi-VL) + if token_id is None and 'text_config' in config: + token_id = config['text_config'].get(f'{typ}_token_id') + self._set_special_token(typ, token_id) return True diff --git a/include/llama.h b/include/llama.h index 2cbe18d8cf..545e957e5f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -152,6 +152,7 @@ extern "C" { //LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // removed from gguf files, use Q4_0 and runtime repack LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors + LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/models/templates/README.md b/models/templates/README.md index 35b6386dd0..2e8eaa5953 100644 --- a/models/templates/README.md +++ b/models/templates/README.md @@ -21,4 +21,5 @@ These templates can be updated with the following commands: ./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja ./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja ./scripts/get_chat_template.py Qwen/Qwen3-0.6B > models/templates/Qwen-Qwen3-0.6B.jinja -``` \ No newline at end of file +./scripts/get_chat_template.py zai-org/GLM-4.5 > models/templates/zai-org-GLM-4.5.jinja +``` diff --git a/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja b/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja new file mode 100644 index 0000000000..f506536096 --- /dev/null +++ b/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja @@ -0,0 +1,59 @@ +{# Alias tools -> available_tools #} +{%- if tools and not available_tools -%} + {%- set available_tools = tools -%} +{%- endif -%} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content'] %} + {%- set loop_messages = messages[1:] %} + {%- else %} + {%- set system_message = "Knowledge Cutoff Date: April 2024. Today's Date: " + strftime_now('%B %d, %Y') + ". You are Granite, developed by IBM." %} + {%- if available_tools and documents %} + {%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request. Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %} + {%- elif available_tools %} + {%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %} + {%- elif documents %} + {%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %} + {%- elif thinking %} + {%- set system_message = system_message + " You are a helpful AI assistant. +Respond to every user query in a comprehensive and detailed way. You can write down your thoughts and reasoning process before responding. In the thought process, engage in a comprehensive cycle of analysis, summarization, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. In the response section, based on various attempts, explorations, and reflections from the thoughts section, systematically present the final solution that you deem correct. The response should summarize the thought process. Write your thoughts between and write your response between for each user query." %} + {%- else %} + {%- set system_message = system_message + " You are a helpful AI assistant." %} + {%- endif %} + {%- if 'citations' in controls and documents %} + {%- set system_message = system_message + ' +Use the symbols <|start_of_cite|> and <|end_of_cite|> to indicate when a fact comes from a document in the search result, e.g <|start_of_cite|> {document_id: 1}my fact <|end_of_cite|> for a fact from document 1. Afterwards, list all the citations with their corresponding documents in an ordered list.' %} + {%- endif %} + {%- if 'hallucinations' in controls and documents %} + {%- set system_message = system_message + ' +Finally, after the response is written, include a numbered list of sentences from the response with a corresponding risk value that are hallucinated and not based in the documents.' %} + {%- endif %} + {%- set loop_messages = messages %} + {%- endif %} + {{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|> +' }} + {%- if available_tools %} + {{- '<|start_of_role|>available_tools<|end_of_role|>' }} + {{- available_tools | tojson(indent=4) }} + {{- '<|end_of_text|> +' }} + {%- endif %} + {%- if documents %} + {%- for document in documents %} + {{- '<|start_of_role|>document {"document_id": "' + document['doc_id'] | string + '"}<|end_of_role|> +' }} + {{- document['text'] }} + {{- '<|end_of_text|> +' }} + {%- endfor %} + {%- endif %} + {%- for message in loop_messages %} + {{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|> +' }} + {%- if loop.last and add_generation_prompt %} + {{- '<|start_of_role|>assistant' }} + {%- if controls %} + {{- ' ' + controls | tojson()}} + {%- endif %} + {{- '<|end_of_role|>' }} + {%- endif %} + {%- endfor %} diff --git a/requirements/requirements-convert_hf_to_gguf.txt b/requirements/requirements-convert_hf_to_gguf.txt index fd21ec4795..766745f42f 100644 --- a/requirements/requirements-convert_hf_to_gguf.txt +++ b/requirements/requirements-convert_hf_to_gguf.txt @@ -2,7 +2,7 @@ mistral-common>=1.8.3 -r ./requirements-convert_legacy_llama.txt --extra-index-url https://download.pytorch.org/whl/cpu -torch~=2.2.1; platform_machine != "s390x" +torch~=2.4.0; platform_machine != "s390x" # torch s390x packages can only be found from nightly builds --extra-index-url https://download.pytorch.org/whl/nightly diff --git a/requirements/requirements-convert_hf_to_gguf_update.txt b/requirements/requirements-convert_hf_to_gguf_update.txt index 431c596c12..afe2747d44 100644 --- a/requirements/requirements-convert_hf_to_gguf_update.txt +++ b/requirements/requirements-convert_hf_to_gguf_update.txt @@ -1,7 +1 @@ -r ./requirements-convert_legacy_llama.txt ---extra-index-url https://download.pytorch.org/whl/cpu -torch~=2.2.1; platform_machine != "s390x" - -# torch s390x packages can only be found from nightly builds ---extra-index-url https://download.pytorch.org/whl/nightly -torch>=0.0.0.dev0; platform_machine == "s390x" diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index 406930fb0a..8366f89a08 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -315,28 +315,29 @@ class LlamaBenchData: class LlamaBenchDataSQLite3(LlamaBenchData): - connection: sqlite3.Connection + connection: Optional[sqlite3.Connection] = None cursor: sqlite3.Cursor table_name: str def __init__(self, tool: str = "llama-bench"): super().__init__(tool) - self.connection = sqlite3.connect(":memory:") - self.cursor = self.connection.cursor() + if self.connection is None: + self.connection = sqlite3.connect(":memory:") + self.cursor = self.connection.cursor() - # Set table name and schema based on tool - if self.tool == "llama-bench": - self.table_name = "test" - db_fields = LLAMA_BENCH_DB_FIELDS - db_types = LLAMA_BENCH_DB_TYPES - elif self.tool == "test-backend-ops": - self.table_name = "test_backend_ops" - db_fields = TEST_BACKEND_OPS_DB_FIELDS - db_types = TEST_BACKEND_OPS_DB_TYPES - else: - assert False + # Set table name and schema based on tool + if self.tool == "llama-bench": + self.table_name = "llama_bench" + db_fields = LLAMA_BENCH_DB_FIELDS + db_types = LLAMA_BENCH_DB_TYPES + elif self.tool == "test-backend-ops": + self.table_name = "test_backend_ops" + db_fields = TEST_BACKEND_OPS_DB_FIELDS + db_types = TEST_BACKEND_OPS_DB_TYPES + else: + assert False - self.cursor.execute(f"CREATE TABLE {self.table_name}({', '.join(' '.join(x) for x in zip(db_fields, db_types))});") + self.cursor.execute(f"CREATE TABLE {self.table_name}({', '.join(' '.join(x) for x in zip(db_fields, db_types))});") def _builds_init(self): if self.connection: @@ -397,9 +398,6 @@ class LlamaBenchDataSQLite3(LlamaBenchData): class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3): def __init__(self, data_file: str, tool: Any): - super().__init__(tool) - - self.connection.close() self.connection = sqlite3.connect(data_file) self.cursor = self.connection.cursor() @@ -409,29 +407,30 @@ class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3): # Tool selection logic if tool is None: - if "test" in table_names: - self.table_name = "test" - self.tool = "llama-bench" + if "llama_bench" in table_names: + self.table_name = "llama_bench" + tool = "llama-bench" elif "test_backend_ops" in table_names: self.table_name = "test_backend_ops" - self.tool = "test-backend-ops" + tool = "test-backend-ops" else: raise RuntimeError(f"No suitable table found in database. Available tables: {table_names}") elif tool == "llama-bench": - if "test" in table_names: - self.table_name = "test" - self.tool = "llama-bench" + if "llama_bench" in table_names: + self.table_name = "llama_bench" + tool = "llama-bench" else: raise RuntimeError(f"Table 'test' not found for tool 'llama-bench'. Available tables: {table_names}") elif tool == "test-backend-ops": if "test_backend_ops" in table_names: self.table_name = "test_backend_ops" - self.tool = "test-backend-ops" + tool = "test-backend-ops" else: raise RuntimeError(f"Table 'test_backend_ops' not found for tool 'test-backend-ops'. Available tables: {table_names}") else: raise RuntimeError(f"Unknown tool: {tool}") + super().__init__(tool) self._builds_init() @staticmethod @@ -653,6 +652,8 @@ if not bench_data: if not bench_data.builds: raise RuntimeError(f"{input_file} does not contain any builds.") +tool = bench_data.tool # May have chosen a default if tool was None. + hexsha8_baseline = name_baseline = None diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index ba7bf95986..18dcc6ddfe 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -62,6 +62,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, + { LLM_ARCH_GLM4_MOE, "glm4moe" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, @@ -87,6 +88,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, { LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" }, { LLM_ARCH_SMOLLM3, "smollm3" }, + { LLM_ARCH_OPENAI_MOE, "gpt-oss" }, { LLM_ARCH_LFM2, "lfm2" }, { LLM_ARCH_DREAM, "dream" }, { LLM_ARCH_SMALLTHINKER, "smallthinker" }, @@ -127,6 +129,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, { LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" }, + { LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, @@ -1391,6 +1394,40 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, + { + LLM_ARCH_GLM4_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + // NextN/MTP tensors - preserved but unused (in final layer, dynamic layer number) + { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" }, + { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" }, + { LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" }, + { LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" }, + }, + }, { LLM_ARCH_BITNET, { @@ -1935,6 +1972,25 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_OPENAI_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_SINKS, "blk.%d.attn_sinks" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_LFM2, { @@ -2050,6 +2106,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_SINKS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SCALE}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -2181,6 +2238,14 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + // NextN/MTP tensors are currently ignored (reserved for future MTP support) + // These tensors only exist in the last layer(s) and are treated as output tensors + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-arch.h b/src/llama-arch.h index 9b8bd65b23..7af587e795 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -66,6 +66,7 @@ enum llm_arch { LLM_ARCH_DEEPSEEK2, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, + LLM_ARCH_GLM4_MOE, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, @@ -91,6 +92,7 @@ enum llm_arch { LLM_ARCH_HUNYUAN_MOE, LLM_ARCH_HUNYUAN_DENSE, LLM_ARCH_SMOLLM3, + LLM_ARCH_OPENAI_MOE, LLM_ARCH_LFM2, LLM_ARCH_DREAM, LLM_ARCH_SMALLTHINKER, @@ -131,6 +133,7 @@ enum llm_kv { LLM_KV_EXPERT_WEIGHTS_NORM, LLM_KV_EXPERT_GATING_FUNC, LLM_KV_MOE_EVERY_N_LAYERS, + LLM_KV_NEXTN_PREDICT_LAYERS, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, @@ -263,6 +266,7 @@ enum llm_tensor { LLM_TENSOR_ATTN_OUT_NORM, LLM_TENSOR_ATTN_POST_NORM, LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_ATTN_SINKS, LLM_TENSOR_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_NORM, @@ -409,6 +413,12 @@ enum llm_tensor { LLM_TENSOR_SHORTCONV_CONV, LLM_TENSOR_SHORTCONV_INPROJ, LLM_TENSOR_SHORTCONV_OUTPROJ, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; enum llm_tensor_layer { diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index c4576e2427..0a96a9a579 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -66,6 +66,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, + { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE }, { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, }; @@ -192,9 +193,11 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_LLAMA4; } else if (tmpl_contains("<|endofuserprompt|>")) { return LLM_CHAT_TEMPLATE_DOTS1; - } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) { + } else if (tmpl_contains("<|extra_0|>") && tmpl_contains("<|extra_4|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; - } else if (tmpl_contains("<|hy_place▁holder▁no▁2|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) { + } else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) { + return LLM_CHAT_TEMPLATE_OPENAI_MOE; + } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE; } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { return LLM_CHAT_TEMPLATE_KIMI_K2; @@ -622,8 +625,6 @@ int32_t llm_chat_apply_template( } else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) { // Yandex template ("\n\n" is defined as EOT token) - ss << ""; - for (size_t i = 0; i < chat.size(); i++) { std::string role(chat[i]->role); if (role == "user") { @@ -706,6 +707,16 @@ int32_t llm_chat_apply_template( ss << "<|startoftext|>" << message->content << "<|extra_0|>"; } } + } else if (tmpl == LLM_CHAT_TEMPLATE_OPENAI_MOE) { + // OpenAI MoE (based on Harmony chat template) + for (auto message : chat) { + std::string role(message->role); + ss << "<|start|>" << role << "<|message|>" << message->content; + ss << (role == "assistant" ? "<|return|>" : "<|end|>"); + } + if (add_ass) { + ss << "<|start|>assistant"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_DENSE) { // tencent/Hunyuan-4B-Instruct for (size_t i = 0; i < chat.size(); i++) { diff --git a/src/llama-chat.h b/src/llama-chat.h index 4cf77fd286..35a943856f 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -46,6 +46,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_SMOLVLM, LLM_CHAT_TEMPLATE_DOTS1, LLM_CHAT_TEMPLATE_HUNYUAN_MOE, + LLM_CHAT_TEMPLATE_OPENAI_MOE, LLM_CHAT_TEMPLATE_HUNYUAN_DENSE, LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_UNKNOWN, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index bd637f3dff..26a5cf9c3f 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -105,7 +105,7 @@ llama_context::llama_context( { const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); - supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : false; + supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows; if (!supports_set_rows && !cparams.kv_unified) { LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__); @@ -786,7 +786,7 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; const int64_t n_embd = hparams.n_embd; - const int32_t n_vocab = model.vocab.n_tokens(); + const int64_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { @@ -959,7 +959,7 @@ int llama_context::decode(const llama_batch & batch_inp) { const auto & vocab = model.vocab; const auto & hparams = model.hparams; - const int32_t n_vocab = vocab.n_tokens(); + const int64_t n_vocab = vocab.n_tokens(); const int64_t n_embd = hparams.n_embd; // when computing embeddings, all tokens are output @@ -1328,21 +1328,21 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { } void llama_context::output_reorder() { - const uint32_t n_vocab = model.vocab.n_tokens(); + const uint64_t n_vocab = model.vocab.n_tokens(); const uint64_t n_embd = model.hparams.n_embd; - for (uint32_t s = 0; s < output_swaps.size(); ++s) { - const uint32_t i0 = output_swaps[s].i0; - const uint32_t i1 = output_swaps[s].i1; + for (size_t s = 0; s < output_swaps.size(); ++s) { + const uint64_t i0 = output_swaps[s].i0; + const uint64_t i1 = output_swaps[s].i1; if (logits_size > 0) { - for (uint32_t k = 0; k < n_vocab; k++) { + for (uint64_t k = 0; k < n_vocab; k++) { std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); } } if (embd_size > 0) { - for (uint32_t k = 0; k < n_embd; k++) { + for (uint64_t k = 0; k < n_embd; k++) { std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); } } diff --git a/src/llama-context.h b/src/llama-context.h index 7cfdc6a517..25c143d56d 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -289,7 +289,7 @@ private: // env: LLAMA_SET_ROWS (temporary) // ref: https://github.com/ggml-org/llama.cpp/pull/14285 - bool supports_set_rows = false; + bool supports_set_rows = true; // env: LLAMA_GRAPH_REUSE_DISABLE bool graph_reuse_disable = false; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 491a26b634..053c72d6dc 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -740,6 +740,8 @@ ggml_tensor * llm_graph_context::build_ffn( cur = ggml_reglu(ctx0, cur); cb(cur, "ffn_reglu", il); } break; + default: + GGML_ABORT("fatal error"); } if (gate && type_gate == LLM_FFN_PAR) { @@ -749,8 +751,8 @@ ggml_tensor * llm_graph_context::build_ffn( if (down) { cur = build_lora_mm(down, cur); - if (arch == LLM_ARCH_GLM4) { - // GLM4 seems to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } @@ -787,6 +789,45 @@ ggml_tensor * llm_graph_context::build_moe_ffn( llama_expert_gating_func_type gating_op, int il, ggml_tensor * probs_in) const { + return build_moe_ffn( + cur, + gate_inp, /* gate_inp_b */ nullptr, + up_exps, /* up_exps_b */ nullptr, + gate_exps, /* gate_exps_b */ nullptr, + down_exps, /* down_exps_b */ nullptr, + exp_probs_b, + n_expert, + n_expert_used, + type_op, + norm_w, + scale_w, + w_scale, + gating_op, + il, + probs_in + ); +} + +ggml_tensor * llm_graph_context::build_moe_ffn( + ggml_tensor * cur, + ggml_tensor * gate_inp, + ggml_tensor * gate_inp_b, + ggml_tensor * up_exps, + ggml_tensor * up_exps_b, + ggml_tensor * gate_exps, + ggml_tensor * gate_exps_b, + ggml_tensor * down_exps, + ggml_tensor * down_exps_b, + ggml_tensor * exp_probs_b, + int64_t n_expert, + int64_t n_expert_used, + llm_ffn_op_type type_op, + bool norm_w, + bool scale_w, + float w_scale, + llama_expert_gating_func_type gating_op, + int il, + ggml_tensor * probs_in) const { const int64_t n_embd = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN @@ -800,6 +841,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn( logits = probs_in; } + if (gate_inp_b) { + logits = ggml_add(ctx0, logits, gate_inp_b); + cb(logits, "ffn_moe_logits_biased", il); + } + ggml_tensor * probs = nullptr; switch (gating_op) { case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX: @@ -810,6 +856,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn( { probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens] } break; + case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT: + { + probs = logits; // [n_expert, n_tokens] + } break; default: GGML_ABORT("fatal error"); } @@ -838,6 +888,13 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] cb(weights, "ffn_moe_weights", il); + if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) { + weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); + weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens] + weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); + cb(weights, "ffn_moe_weights_softmax", il); + } + if (norm_w) { weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); @@ -866,6 +923,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(up, "ffn_moe_up", il); + if (up_exps_b) { + up = ggml_add_id(ctx0, up, up_exps_b, selected_experts); + cb(up, "ffn_moe_up_biased", il); + } + ggml_tensor * experts = nullptr; if (gate_exps) { cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] @@ -874,6 +936,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cur = up; } + if (gate_exps_b) { + cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts); + cb(cur, "ffn_moe_gate_biased", il); + } + switch (type_op) { case LLM_FFN_SILU: if (gate_exps) { @@ -891,6 +958,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cur = ggml_gelu(ctx0, cur); cb(cur, "ffn_moe_gelu", il); } break; + case LLM_FFN_SWIGLU_OAI_MOE: + { + // TODO: move to hparams? + constexpr float alpha = 1.702f; + constexpr float limit = 7.0f; + cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit); + cb(cur, "ffn_moe_swiglu_oai", il); + } break; case LLM_FFN_RELU: if (gate_exps) { cur = ggml_reglu_split(ctx0, cur, up); @@ -906,6 +981,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn( experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); + if (down_exps_b) { + experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts); + cb(experts, "ffn_moe_down_biased", il); + } + if (!weight_before_ffn) { experts = ggml_mul(ctx0, experts, weights); cb(cur, "ffn_moe_weighted", il); @@ -1144,6 +1224,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * kq_b, ggml_tensor * kq_mask, ggml_tensor * v_mla, + ggml_tensor * sinks, float kq_scale) const { const bool v_trans = v->nb[1] > v->nb[2]; @@ -1180,7 +1261,8 @@ ggml_tensor * llm_graph_context::build_attn_mha( cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); - ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + ggml_flash_attn_ext_add_sinks(cur, sinks); + ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32); if (v_mla) { #if 0 @@ -1228,6 +1310,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( } kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + ggml_soft_max_add_sinks(kq, sinks); if (!v_trans) { // note: avoid this branch @@ -1298,7 +1381,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = k_cur; ggml_tensor * v = v_cur; - ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1386,13 +1469,13 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = mctx_cur->get_k(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il); - ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale); cb(cur, "kqv_out", il); if (wo) { cur = build_lora_mm(wo, cur); - if (arch == LLM_ARCH_GLM4) { - // GLM4 seems to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } @@ -1415,6 +1498,32 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v_mla, float kq_scale, int il) const { + return build_attn_with_sinks( + inp, + wo, + wo_b, + q_cur, + k_cur, + v_cur, + kq_b, + v_mla, + nullptr, + kq_scale, + il); +} + +ggml_tensor * llm_graph_context::build_attn_with_sinks( + llm_graph_input_attn_kv_unified_iswa * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * v_mla, + ggml_tensor * sinks, + float kq_scale, + int il) const { // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced ggml_build_forward_expand(gf, q_cur); @@ -1452,7 +1561,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = mctx_cur->get_k(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il); - ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, sinks, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1506,7 +1615,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = k_cur; ggml_tensor * v = v_cur; - ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale); cb(cur, "kqv_out", il); if (wo) { diff --git a/src/llama-graph.h b/src/llama-graph.h index 8614d49674..6ff49de3a1 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -39,6 +39,7 @@ enum llm_ffn_op_type { LLM_FFN_SWIGLU, LLM_FFN_GEGLU, LLM_FFN_REGLU, + LLM_FFN_SWIGLU_OAI_MOE, }; enum llm_ffn_gate_type { @@ -619,6 +620,7 @@ struct llm_graph_context { llm_ffn_gate_type type_gate, int il) const; + // build MoE FFN without bias tensors ggml_tensor * build_moe_ffn( ggml_tensor * cur, ggml_tensor * gate_inp, @@ -636,6 +638,27 @@ struct llm_graph_context { int il, ggml_tensor * probs_in = nullptr) const; + ggml_tensor * build_moe_ffn( + ggml_tensor * cur, + ggml_tensor * gate_inp, + ggml_tensor * gate_inp_b, + ggml_tensor * up_exps, + ggml_tensor * up_exps_b, + ggml_tensor * gate_exps, + ggml_tensor * gate_exps_b, + ggml_tensor * down_exps, + ggml_tensor * down_exps_b, + ggml_tensor * exp_probs_b, + int64_t n_expert, + int64_t n_expert_used, + llm_ffn_op_type type_op, + bool norm_w, + bool scale_w, + float w_scale, + llama_expert_gating_func_type gating_op, + int il, + ggml_tensor * probs_in = nullptr) const; + // // inputs // @@ -662,6 +685,7 @@ struct llm_graph_context { ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false) ggml_tensor * kq_b, ggml_tensor * kq_mask, + ggml_tensor * sinks, ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale) const; @@ -708,6 +732,20 @@ struct llm_graph_context { float kq_scale, int il) const; + // TODO: temporary to keep the diff small. after the code is public will refactor to simplify this + ggml_tensor * build_attn_with_sinks( + llm_graph_input_attn_kv_unified_iswa * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional + ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + ggml_tensor * sinks, // [n_head_q] + float kq_scale, + int il) const; + llm_graph_input_attn_cross * build_attn_inp_cross() const; ggml_tensor * build_attn( diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 8b7e2a1130..bd23122443 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -9,9 +9,10 @@ #define LLAMA_MAX_EXPERTS 384 // Kimi-K2 enum llama_expert_gating_func_type { - LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, - LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1, - LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2, + LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT = 3, // applied to the router weights instead of the logits }; enum llama_swa_type { @@ -73,6 +74,7 @@ struct llama_hparams { bool expert_weights_norm = false; uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; uint32_t moe_every_n_layers = 0; + uint32_t nextn_predict_layers = 0; float f_norm_eps; float f_norm_rms_eps; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 321dc79fc3..e539142e6b 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -39,6 +39,10 @@ llama_kv_cache_unified::llama_kv_cache_unified( if (model.arch == LLM_ARCH_GEMMA3N) { n_layer_cache = 20; } + if (model.arch == LLM_ARCH_GLM4_MOE) { + // GLM-4.5: Only process up to last layer, skip final NextN layer + n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers; + } // create a context for each buffer type std::map ctx_map; @@ -183,7 +187,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( const size_t memory_size_k = size_k_bytes(); const size_t memory_size_v = size_v_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream, ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); @@ -193,7 +197,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); - supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : 0; + supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : supports_set_rows; if (!supports_set_rows) { // ref: https://github.com/ggml-org/llama.cpp/pull/14363 diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 3e28e346c3..342a675962 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -230,7 +230,7 @@ private: // env: LLAMA_SET_ROWS (temporary) // ref: https://github.com/ggml-org/llama.cpp/pull/14285 - bool supports_set_rows = false; + bool supports_set_rows = true; const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index d8e2086c87..e98b4e3546 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -25,6 +25,7 @@ llama_memory_hybrid::llama_memory_hybrid( /* common */ uint32_t n_seq_max, bool offload, + bool unified, /* layer filters */ layer_filter_cb && filter_attn, layer_filter_cb && filter_recr) : @@ -38,7 +39,7 @@ llama_memory_hybrid::llama_memory_hybrid( type_v, v_trans, offload, - 1, + unified, kv_size, n_seq_max, n_pad, diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h index 4ac3181757..c2d56cd541 100644 --- a/src/llama-memory-hybrid.h +++ b/src/llama-memory-hybrid.h @@ -39,6 +39,7 @@ public: /* common */ uint32_t n_seq_max, bool offload, + bool unified, /* layer filters */ layer_filter_cb && filter_attn = nullptr, layer_filter_cb && filter_recr = nullptr); diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index bd9e6da883..f71c40f8e3 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -35,6 +35,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; + case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return "MXFP4 MoE"; case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h index 0f52b011b6..c9189f6cb4 100644 --- a/src/llama-model-loader.h +++ b/src/llama-model-loader.h @@ -58,8 +58,9 @@ struct llama_model_loader { } }; - static const int TENSOR_NOT_REQUIRED = 1; - static const int TENSOR_DUPLICATED = 2; + static const int TENSOR_NOT_REQUIRED = 1 << 0; + static const int TENSOR_DUPLICATED = 1 << 1; + static const int TENSOR_SKIP = 1 << 2; int n_kv = 0; int n_tensors = 0; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e3f12edd9b..58ca7df707 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -109,8 +109,10 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_A13B: return "A13B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; + case LLM_TYPE_106B_A12B: return "106B.A12B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; + case LLM_TYPE_355B_A32B: return "355B.A32B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; default: return "?B"; @@ -190,6 +192,13 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); op_tensor = ggml_add(ctx, a, w); } break; + case GGML_OP_ADD_ID: + { + int n_expert_used = hparams.n_expert_used; + ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); + ggml_tensor * c = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); + op_tensor = ggml_add_id(ctx, a, w, c); + } break; case GGML_OP_MUL: { ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); @@ -258,6 +267,10 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1); op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); } break; + case GGML_OP_SCALE: + { + op_tensor = ggml_scale(ctx, w, 1.0f); + } break; default: GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); } @@ -899,6 +912,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_QWEN3: { + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break; @@ -1433,6 +1447,34 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM4_MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // Expert gating function (GLM-4.5 uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + switch (hparams.n_layer) { + case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) + case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1782,6 +1824,17 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_OPENAI_MOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(2); + + // TODO: switch (hparams.n_layer) + } break; case LLM_ARCH_LFM2: { ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); @@ -1948,6 +2001,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; + const auto TENSOR_SKIP = llama_model_loader::TENSOR_SKIP; // create tensors for the weights { @@ -2003,7 +2057,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // skip unused tensors - if (info.op == GGML_OP_NONE) { + if (info.op == GGML_OP_NONE || flags & TENSOR_SKIP) { const size_t nbytes = ggml_nbytes(t_meta); LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", tn.str().c_str(), nbytes); @@ -2013,11 +2067,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { return nullptr; } - // tensors with "bias" suffix are always used with GGML_OP_ADD + // tensors with "bias" suffix are always used with GGML_OP_ADD or GGML_OP_ADD_ID ggml_op op; bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0; if (bias) { - op = GGML_OP_ADD; + if (info.op == GGML_OP_MUL_MAT_ID) { + op = GGML_OP_ADD_ID; + } else { + op = GGML_OP_ADD; + } } else { op = info.op; } @@ -4426,6 +4484,105 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_GLM4_MOE: + { + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + const int64_t n_expert_shared = hparams.n_expert_shared; + + GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers"); + GGML_ASSERT(hparams.n_expert_used > 0 && "n_expert_used must be > 0 for GLM4_MOE MoE layers"); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + // Load ALL tensors including NextN layer to satisfy total tensor count + // but only PROCESS up to last layer (skipping final NextN layer) in forward pass + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, flags); + + // GLM-style attention with bias terms + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, flags); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, flags); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); + + // K/Q norm tensors (optional for GLM-4.5 355B variant) + layer.attn_q_norm = create_tensor( + tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); + layer.attn_k_norm = create_tensor( + tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags); + + // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead + // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE + const bool use_moe = (static_cast(i) >= hparams.n_layer_dense_lead); + + if (use_moe) { + // MoE layers + layer.ffn_gate_inp = + create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor( + tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); + layer.ffn_down_exps = create_tensor( + tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags); + layer.ffn_up_exps = create_tensor( + tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); + + // Shared expert + if (n_expert_shared > 0) { + const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; + layer.ffn_gate_shexp = create_tensor( + tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); + layer.ffn_down_shexp = create_tensor( + tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags); + layer.ffn_up_shexp = create_tensor( + tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); + } + } else { + // Dense layers (first k layers) - GLM uses separate gate/up projections + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags); + } + } + } + break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5269,6 +5426,46 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_OPENAI_MOE: + { + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); + + layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // bias + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_head * n_rot}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_head_kv * n_rot}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_head_kv * n_rot}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_gate_inp_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); + layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); + layer.ffn_down_exps_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), { n_embd, n_expert}, 0); + layer.ffn_up_exps_b = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); + } + } break; case LLM_ARCH_LFM2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5641,7 +5838,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } - if (arch == LLM_ARCH_QWEN3MOE) { + if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } @@ -13563,6 +13760,165 @@ struct llm_build_glm4 : public llm_graph_context { } }; +struct llm_build_glm4_moe : public llm_graph_context { + llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // Only process up to last layer (skip final NextN layer) + // Final layer tensors are loaded but not processed in forward pass + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_transformer_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // Post-attention norm + cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) + if (static_cast(il) < hparams.n_layer_dense_lead) { + // Dense FFN layer + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); + + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shared_out, "ffn_shexp_out", il); + + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_nemotron : public llm_graph_context { llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -17250,6 +17606,136 @@ struct llm_build_smollm3 : public llm_graph_context { } }; +struct llm_build_openai_moe_iswa : public llm_graph_context { + llm_build_openai_moe_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified_iswa(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn_with_sinks(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].attn_sinks, 1.0f/sqrtf(float(n_rot)), il); + + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = ffn_inp; + cur = build_norm(cur, + model.layers[il].attn_post_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + // MoE branch + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b, + model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b, + model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SWIGLU_OAI_MOE, false, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT, + il); + cb(cur, "ffn_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_lfm2 : public llm_graph_context { const llama_model & model; @@ -17597,6 +18083,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, /* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr, /* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr); } else { @@ -17875,6 +18362,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_GLM4_MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_BITNET: { llm = std::make_unique(*this, params); @@ -17988,6 +18479,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_OPENAI_MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_FALCON_H1: { llm = std::make_unique(*this, params); @@ -18203,9 +18698,11 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MINICPM3: case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: + case LLM_ARCH_OPENAI_MOE: case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: case LLM_ARCH_SMALLTHINKER: + case LLM_ARCH_GLM4_MOE: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/src/llama-model.h b/src/llama-model.h index 094e23808a..6fcd74d57f 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -101,8 +101,10 @@ enum llm_type { LLM_TYPE_A13B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, + LLM_TYPE_106B_A12B, // GLM-4.5-Air LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big + LLM_TYPE_355B_A32B, // GLM-4.5 LLM_TYPE_E2B, LLM_TYPE_E4B, }; @@ -166,6 +168,15 @@ struct llama_layer_shortconv { struct ggml_tensor * out_proj = nullptr; }; +struct llama_layer_nextn { + struct ggml_tensor * eh_proj = nullptr; + struct ggml_tensor * embed_tokens = nullptr; + struct ggml_tensor * enorm = nullptr; + struct ggml_tensor * hnorm = nullptr; + struct ggml_tensor * shared_head_head = nullptr; + struct ggml_tensor * shared_head_norm = nullptr; +}; + struct llama_layer { // normalization struct ggml_tensor * attn_norm = nullptr; @@ -241,10 +252,14 @@ struct llama_layer { struct ggml_tensor * ffn_up_enc = nullptr; // ff MoE - struct ggml_tensor * ffn_gate_inp = nullptr; - struct ggml_tensor * ffn_gate_exps = nullptr; - struct ggml_tensor * ffn_down_exps = nullptr; - struct ggml_tensor * ffn_up_exps = nullptr; + struct ggml_tensor * ffn_gate_inp = nullptr; + struct ggml_tensor * ffn_gate_exps = nullptr; + struct ggml_tensor * ffn_down_exps = nullptr; + struct ggml_tensor * ffn_up_exps = nullptr; + struct ggml_tensor * ffn_gate_inp_b = nullptr; + struct ggml_tensor * ffn_gate_exps_b = nullptr; + struct ggml_tensor * ffn_down_exps_b = nullptr; + struct ggml_tensor * ffn_up_exps_b = nullptr; // ff shared expert (shexp) struct ggml_tensor * ffn_gate_inp_shexp = nullptr; @@ -349,11 +364,16 @@ struct llama_layer { struct ggml_tensor * laurel_r = nullptr; struct ggml_tensor * laurel_post_norm = nullptr; + // openai-moe + struct ggml_tensor * attn_sinks = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; struct llama_layer_shortconv shortconv; + + struct llama_layer_nextn nextn; }; struct llama_model { diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 0756bf09b8..1d0361cc16 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -211,7 +211,10 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t const int64_t nx = tensor->ne[0]; const int64_t qk_k = ggml_blck_size(new_type); - if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) { + if (ftype == LLAMA_FTYPE_MOSTLY_MXFP4_MOE) { + new_type = GGML_TYPE_Q8_0; + } + else if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) { new_type = GGML_TYPE_Q8_0; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || @@ -223,6 +226,14 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t new_type = GGML_TYPE_Q6_K; } } + } else if (ftype == LLAMA_FTYPE_MOSTLY_MXFP4_MOE) { + // MoE tensors -> MXFP4 + // other tensors -> Q8_0 + if (tensor->ne[2] > 1) { + new_type = GGML_TYPE_MXFP4; + } else { + new_type = GGML_TYPE_Q8_0; + } } else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") { if (qs.params->token_embedding_type < GGML_TYPE_COUNT) { new_type = qs.params->token_embedding_type; @@ -533,6 +544,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; + case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: default_type = GGML_TYPE_MXFP4; break; + // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break; @@ -984,6 +997,29 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr; new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); + + // TODO: temporary sanity check that the F16 -> MXFP4 is lossless +#if 0 + if (new_type == GGML_TYPE_MXFP4) { + auto * x = f32_data_03; + + //LLAMA_LOG_INFO("nrows = %d, n_per_row = %d\n", nrows, n_per_row); + std::vector deq(nrows*n_per_row); + const ggml_type_traits * qtype = ggml_get_type_traits(new_type); + qtype->to_float(new_data_03, deq.data(), deq.size()); + + double err = 0.0f; + for (int i = 0; i < (int) deq.size(); ++i) { + err += fabsf(deq[i] - x[i]); + //if (fabsf(deq[i] - x[i]) > 0.00001 && i < 256) { + if (deq[i] != x[i]) { + LLAMA_LOG_INFO("deq[%d] = %f, x[%d] = %f\n", i, deq[i], i, x[i]); + } + } + //LLAMA_LOG_INFO("err = %f\n", err); + GGML_ASSERT(err == 0.00000); + } +#endif } LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); } diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 7b7a935660..f7e03e702e 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1856,7 +1856,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "gigachat" || tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-de" || - tokenizer_pre == "a.x-4.0") { + tokenizer_pre == "a.x-4.0" || + tokenizer_pre == "mellum") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "jina-v1-en" || @@ -2190,6 +2191,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|fim▁begin|>" // DeepSeek || t.first == "
"
                         || t.first == "▁
"          // CodeLlama
+                        || t.first == "<|code_prefix|>" // GLM-4.5
                         ) {
                     special_fim_pre_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2209,6 +2211,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|fim▁hole|>" // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
+                        || t.first == "<|code_suffix|>" // GLM-4.5
                         ) {
                     special_fim_suf_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2228,6 +2231,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|fim▁end|>"  // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
+                        || t.first == "<|code_middle|>" // GLM-4.5
                         ) {
                     special_fim_mid_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2310,6 +2314,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     || t.first == "<|eot_id|>"
                     || t.first == "<|im_end|>"
                     || t.first == "<|end|>"
+                    || t.first == "<|return|>" // o200k_harmony
+                    || t.first == "<|call|>"   // o200k_harmony
                     || t.first == ""
                     || t.first == "<|endoftext|>"
                     || t.first == "<|eom_id|>"
@@ -2333,6 +2339,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             }
         }
 
+        // @ngxson : quick hack for gpt-oss, always render these tokens
+        for (const auto & t : token_to_id) {
+            if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>") {
+                id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
+            }
+        }
+
         // sanity checks
         if (special_eos_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eos_id) == 0) {
             special_eog_ids.insert(special_eos_id);
@@ -2348,6 +2361,36 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             special_eog_ids.insert(special_eom_id);
             LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
         }
+
+        // TODO: workaround for o200k_harmony tokenizer: the "<|end|>" token should not be EOG
+        //       we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens,
+        //       we remove the "<|end|>" token from the EOG list
+        {
+            bool has_return = false;
+            bool has_call   = false;
+            bool has_end    = false;
+
+            llama_token end_id = LLAMA_TOKEN_NULL;
+
+            LLAMA_LOG_INFO("%s: printing all EOG tokens:\n", __func__);
+            for (auto tid : special_eog_ids) {
+                LLAMA_LOG_INFO("%s:   - %d ('%s')\n", __func__, tid, id_to_token[tid].text.c_str());
+
+                if (id_to_token[tid].text == "<|return|>") {
+                    has_return = true;
+                } else if (id_to_token[tid].text == "<|call|>") {
+                    has_call = true;
+                } else if (id_to_token[tid].text == "<|end|>") {
+                    has_end = true;
+                    end_id = tid;
+                }
+            }
+
+            if (has_return && has_call && has_end) {
+                special_eog_ids.erase(end_id);
+                LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
+            }
+        }
     }
 
     // build special tokens cache
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 479b3fad48..d29779cd12 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -296,6 +296,7 @@ static std::string var_to_str(ggml_scale_mode mode) {
 #define VARS_TO_STR10(a, b, c, d, e, f, g, h, i, j) VAR_TO_STR(a) + "," + VARS_TO_STR9(b, c, d, e, f, g, h, i, j)
 #define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k)
 #define VARS_TO_STR12(a, b, c, d, e, f, g, h, i, j, k, l) VAR_TO_STR(a) + "," + VARS_TO_STR11(b, c, d, e, f, g, h, i, j, k, l)
+#define VARS_TO_STR13(a, b, c, d, e, f, g, h, i, j, k, l, m) VAR_TO_STR(a) + "," + VARS_TO_STR12(b, c, d, e, f, g, h, i, j, k, l, m)
 
 #ifdef GGML_USE_SYCL
 static bool inline _isinf(float f) {
@@ -1886,6 +1887,66 @@ struct test_glu_split : public test_case {
     }
 };
 
+struct test_swiglu_oai : public test_case {
+    const ggml_type type;
+    const std::array ne_a;
+    int v; // view (1 : non-contiguous a)
+    float alpha;
+    float limit;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, ne_a, v, alpha, limit);
+    }
+
+    test_swiglu_oai(ggml_type type = GGML_TYPE_F32,
+                    std::array ne_a = {128, 2, 2, 2},
+                    int v = 0,
+                    float alpha = 1.702f,
+                    float limit = 7.0f)
+        : type(type), ne_a(ne_a), v(v), alpha(alpha), limit(limit) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a;
+        ggml_tensor * b;
+        if (v & 1) {
+            auto ne = ne_a; ne[0] *= 3;
+            a = ggml_new_tensor(ctx, type, 4, ne.data());
+            ggml_set_param(a);
+            ggml_set_name(a, "a");
+
+            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
+            ggml_set_name(a, "view_of_a");
+
+            b = ggml_new_tensor(ctx, type, 4, ne.data());
+            ggml_set_param(b);
+            ggml_set_name(b, "b");
+
+            b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
+            ggml_set_name(a, "view_of_b");
+        } else {
+            a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+            ggml_set_param(a);
+            ggml_set_name(a, "a");
+
+            b = ggml_new_tensor(ctx, type, 4, ne_a.data());
+            ggml_set_param(b);
+            ggml_set_name(b, "b");
+        }
+
+        ggml_tensor * out = ggml_swiglu_oai(ctx, a, b, alpha, limit);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            // test extended range of values to check for NaNs in GELU
+            init_tensor_uniform(t, -150.f, 150.f);
+        }
+    }
+};
+
 // GGML_OP_GET_ROWS
 struct test_get_rows : public test_case {
     const ggml_type type;
@@ -2483,6 +2544,68 @@ struct test_bin_bcast : public test_case {
     }
 };
 
+// GGML_OP_ADD_ID
+struct test_add_id : public test_case {
+    const ggml_type type_a;
+    const ggml_type type_b;
+    const int64_t n_embd;
+    const int64_t n_experts;
+    const int64_t n_experts_used;
+    const int64_t n_token;
+
+    std::string vars() override {
+        return VARS_TO_STR6(type_a, type_b, n_embd, n_experts, n_experts_used, n_token);
+    }
+
+    size_t op_size(ggml_tensor * t) override {
+        return ggml_nbytes(t) + ggml_nbytes(t->src[0]) + ggml_nbytes(t->src[2]);
+    }
+
+    test_add_id(ggml_type type_a = GGML_TYPE_F32,
+            ggml_type type_b = GGML_TYPE_F32,
+            int64_t n_embd = 128,
+            int64_t n_experts = 16,
+            int64_t n_experts_used = 8,
+            int64_t n_token = 10)
+        : type_a(type_a), type_b(type_b), n_embd(n_embd),
+          n_experts(n_experts), n_experts_used(n_experts_used), n_token(n_token) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor_3d(ctx, type_a, n_embd, n_experts_used, n_token);
+        ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, n_embd, n_experts);
+        ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_experts, n_token);
+        if (n_experts_used != n_experts) {
+            ids = ggml_view_2d(ctx, ids, n_experts_used, n_token, ids->nb[1], 0);
+            ggml_set_name(ids, "view_of_ids");
+        }
+
+        ggml_tensor * out = ggml_add_id(ctx, a, b, ids);
+        ggml_set_name(out, "out");
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                if (ggml_is_view_op(t->op)) { continue; }
+                std::random_device rd;
+                std::default_random_engine rng(rd());
+                // ids
+                for (int64_t r = 0; r < ggml_nrows(t); r++) {
+                    std::vector data(t->ne[0]);
+                    for (int i = 0; i < t->ne[0]; i++) {
+                        data[i] = i % n_experts;
+                    }
+                    std::shuffle(data.begin(), data.end(), rng);
+                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
+                }
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
 // GGML_OP_ADD1
 struct test_add1 : public test_case {
     const ggml_type type;
@@ -3122,11 +3245,11 @@ struct test_mul_mat_id : public test_case {
     }
 
     void initialize_tensors(ggml_context * ctx) override {
-        std::random_device rd;
-        std::default_random_engine rng(rd());
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
             if (t->type == GGML_TYPE_I32) {
                 if (ggml_is_view_op(t->op)) { continue; }
+                std::random_device rd;
+                std::default_random_engine rng(rd());
                 // ids
                 for (int64_t r = 0; r < ggml_nrows(t); r++) {
                     std::vector data(t->ne[0]);
@@ -3447,13 +3570,14 @@ struct test_soft_max : public test_case {
     const ggml_type type;
     const std::array ne;
     const bool mask;
+    const bool sinks;
     const ggml_type m_prec;
     const std::array nr23; // broadcast only dims 2 and 3
     const float scale;
     const float max_bias;
 
     std::string vars() override {
-        return VARS_TO_STR7(type, ne, mask, m_prec, nr23, scale, max_bias);
+        return VARS_TO_STR8(type, ne, mask, sinks, m_prec, nr23, scale, max_bias);
     }
 
     // the 1024 test with bias occasionally fails:
@@ -3465,11 +3589,12 @@ struct test_soft_max : public test_case {
     test_soft_max(ggml_type type = GGML_TYPE_F32,
             std::array ne = {10, 5, 4, 3},
             bool mask = false,
+            bool sinks = false,
             ggml_type m_prec = GGML_TYPE_F32,
             std::array nr23 = {1, 1},
             float scale = 1.0f,
             float max_bias = 0.0f)
-        : type(type), ne(ne), mask(mask), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {}
+        : type(type), ne(ne), mask(mask), sinks(sinks), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
@@ -3482,7 +3607,14 @@ struct test_soft_max : public test_case {
             ggml_set_name(mask, "mask");
         }
 
+        ggml_tensor * sinks = nullptr;
+        if (this->sinks) {
+            sinks = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[2]*nr23[0]);
+            ggml_set_name(sinks, "sinks");
+        }
+
         ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
+        ggml_soft_max_add_sinks(out, sinks);
         ggml_set_name(out, "out");
 
         return out;
@@ -4433,6 +4565,7 @@ struct test_flash_attn_ext : public test_case {
     const int64_t nb; // batch size
 
     const bool mask; // use mask
+    const bool sinks; // use sinks
 
     const float max_bias; // ALiBi
     const float logit_softcap; // Gemma 2
@@ -4442,7 +4575,7 @@ struct test_flash_attn_ext : public test_case {
     std::array permute;
 
     std::string vars() override {
-        return VARS_TO_STR12(hsk, hsv, nh, nr23, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
+        return VARS_TO_STR13(hsk, hsv, nh, nr23, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, permute);
     }
 
     double max_nmse_err() override {
@@ -4457,9 +4590,9 @@ struct test_flash_attn_ext : public test_case {
     }
 
     test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, std::array nr23 = {1, 1}, int64_t kv = 96, int64_t nb = 8,
-                        bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
+                        bool mask = true, bool sinks = false, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
                         ggml_type type_KV = GGML_TYPE_F16, std::array permute = {0, 1, 2, 3})
-        : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
+        : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), sinks(sinks), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
@@ -4499,13 +4632,31 @@ struct test_flash_attn_ext : public test_case {
             ggml_set_name(m, "m");
         }
 
+        ggml_tensor * s = nullptr;
+        if (sinks) {
+            s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, q->ne[2]);
+            ggml_set_name(s, "s");
+        }
+
         ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hsk), max_bias, logit_softcap);
-        ggml_flash_attn_ext_set_prec(out, prec);
+        ggml_flash_attn_ext_add_sinks(out, s);
+        ggml_flash_attn_ext_set_prec (out, prec);
         ggml_set_name(out, "out");
 
         return out;
     }
 
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (strcmp(t->name, "s") == 0) {
+                // make the sink values more noticable in order to trigger a test failure when the implementation is wrong
+                init_tensor_uniform(t, -10.0f, 10.0f);
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+
     bool grad_precise() override {
         return true;
     }
@@ -5038,6 +5189,7 @@ static const ggml_type all_types[] = {
     GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
     GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
     GGML_TYPE_Q8_0,
+    GGML_TYPE_MXFP4,
     GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
     GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
     GGML_TYPE_Q6_K,
@@ -5053,6 +5205,7 @@ static const ggml_type base_types[] = {
     GGML_TYPE_Q4_0,
     GGML_TYPE_Q4_1, // for I8MM tests
     GGML_TYPE_Q4_K,
+    GGML_TYPE_MXFP4, // TODO: or "other"
     GGML_TYPE_IQ2_XXS
 };
 
@@ -5089,6 +5242,11 @@ static std::vector> make_test_cases_eval() {
     for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
         for (int v : {0, 1}) {
             for (int op = 0; op < GGML_GLU_OP_COUNT; op++) {
+                if (op == GGML_GLU_OP_SWIGLU_OAI) {
+                    // SWIGLU_OAI is handled separately
+                    continue;
+                }
+
                 for (bool swapped : {false, true}) {
                     test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped));
                     test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped));
@@ -5100,6 +5258,14 @@ static std::vector> make_test_cases_eval() {
         }
     }
 
+    for (int v : {0, 1}) {
+        for (float alpha : {.5f, 1.702f}) {
+            for (float limit : {2.0f, 7.0f}) {
+                test_cases.emplace_back(new test_swiglu_oai(GGML_TYPE_F32, { 128, 2, 2, 2 }, v, alpha, limit));
+            }
+        }
+    }
+
     test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
     for (ggml_type type : all_types) {
         for (int b : {1, 7}) {
@@ -5592,13 +5758,15 @@ static std::vector> make_test_cases_eval() {
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1,  1}, {4, 1}, {0, 2, 1, 3}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67,  {1,  1}, {4, 1}, {0, 2, 1, 3}));
 
-    for (auto bs : {1,2,4,8}) {
-        for (auto nr : {1,4}) {
-            for (uint32_t m = 0; m < 2; ++m) {
-                for (uint32_t k = 0; k < 2; ++k) {
-                    for (ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
-                        test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 1056 + m, 1, 128 + k,  {bs,  1}, {nr, 1}, {0, 2, 1, 3}));
-                        test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 128 + m,  1, 1056 + k, {bs,  1}, {nr, 1}, {0, 1, 2, 3}, true));
+    for (auto bs2 : {1,3}) {
+        for (auto bs : {1,2,4,8}) {
+            for (auto nr : {1,4}) {
+                for (uint32_t m = 0; m < 2; ++m) {
+                    for (uint32_t k = 0; k < 2; ++k) {
+                        for (ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
+                            test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 1056 + m, 1, 128 + k,  {bs,  bs2}, {nr, 1}, {0, 2, 1, 3}));
+                            test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 128 + m,  1, 1056 + k, {bs,  bs2}, {nr, 1}, {0, 1, 2, 3}, true));
+                        }
                     }
                 }
             }
@@ -5666,6 +5834,21 @@ static std::vector> make_test_cases_eval() {
         }
     }
 
+    // add_id
+    for (ggml_type type_a : {GGML_TYPE_F32}) {
+        for (ggml_type type_b : {GGML_TYPE_F32}) {
+            for (int n_mats : {4, 8}) {
+                for (int n_used : {1, 2, 4}) {
+                    for (int n_embd : {32, 129}) {
+                        for (int n_token : {1, 32, 129}) {
+                            test_cases.emplace_back(new test_add_id(type_a, type_b, n_embd, n_mats, n_used, n_token));
+                        }
+                    }
+                }
+            }
+        }
+    }
+
     for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
         test_cases.emplace_back(new test_sqr(type));
         test_cases.emplace_back(new test_sqrt(type));
@@ -5695,38 +5878,40 @@ static std::vector> make_test_cases_eval() {
     }
 #endif
     for (bool mask : {false, true}) {
-        for (float max_bias : {0.0f, 8.0f}) {
-            if (!mask && max_bias > 0.0f) continue;
-            for (float scale : {1.0f, 0.1f}) {
-                for (int64_t ne0 : {16, 1024}) {
-                    for (int64_t ne1 : {16, 1024}) {
-                        if (mask) {
-                            for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
-                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
-                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
+        for (bool sinks : {false, true}) {
+            for (float max_bias : {0.0f, 8.0f}) {
+                if (!mask && max_bias > 0.0f) continue;
+                for (float scale : {1.0f, 0.1f}) {
+                    for (int64_t ne0 : {16, 1024}) {
+                        for (int64_t ne1 : {16, 1024}) {
+                            if (mask) {
+                                for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+                                    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, sinks, m_prec, {1, 1}, scale, max_bias));
+                                    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, m_prec, {1, 1}, scale, max_bias));
 
-                                if (ne0 <= 32 && ne1 <= 32) {
-                                    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 3}, mask, m_prec, {3, 1}, scale, max_bias));
-                                    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias));
+                                    if (ne0 <= 32 && ne1 <= 32) {
+                                        test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 3}, mask, sinks, m_prec, {3, 1}, scale, max_bias));
+                                        test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, m_prec, {2, 3}, scale, max_bias));
+                                    }
                                 }
+                            } else {
+                                /* The precision of mask here doesn't matter as boolean mask is false */
+                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, scale, max_bias));
+                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, scale, max_bias));
                             }
-                        } else {
-                            /* The precision of mask here doesn't matter as boolean mask is false */
-                            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias));
-                            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias));
                         }
                     }
                 }
             }
         }
     }
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true,  GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true,  true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true,  false, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  false, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  true,  GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  true,  GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
 
     for (float max_bias : {0.0f, 8.0f}) {
         for (float scale : {1.0f, 0.1f}) {
@@ -5830,27 +6015,29 @@ static std::vector> make_test_cases_eval() {
             if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
 
             for (bool mask : { true, false } ) {
-                for (float max_bias : { 0.0f, 8.0f }) {
-                    if (!mask && max_bias > 0.0f) continue;
-                    for (float logit_softcap : {0.0f, 10.0f}) {
-                        if (hsk != 128 && logit_softcap != 0.0f) continue;
-                        for (int nh : { 4, }) {
-                            for (int nr3 : { 1, 3, }) {
-                                if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
-                                for (int nr2 : { 1, 4, 16 }) {
-                                    if (nr2 == 16 && hsk != 128) continue;
-                                    for (int kv : { 512, 1024, }) {
-                                        if (nr2 != 1 && kv != 512) continue;
-                                        for (int nb : { 1, 3, 32, 35, }) {
-                                            for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
-                                                if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
-                                                for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
-                                                    test_cases.emplace_back(new test_flash_attn_ext(
-                                                                hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
-                                                    // run fewer test cases permuted
-                                                    if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+                for (bool sinks : { true, false } ) {
+                    for (float max_bias : { 0.0f, 8.0f }) {
+                        if (!mask && max_bias > 0.0f) continue;
+                        for (float logit_softcap : {0.0f, 10.0f}) {
+                            if (hsk != 128 && logit_softcap != 0.0f) continue;
+                            for (int nh : { 4, }) {
+                                for (int nr3 : { 1, 3, }) {
+                                    if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
+                                    for (int nr2 : { 1, 4, 16 }) {
+                                        if (nr2 == 16 && hsk != 128) continue;
+                                        for (int kv : { 512, 1024, }) {
+                                            if (nr2 != 1 && kv != 512) continue;
+                                            for (int nb : { 1, 3, 32, 35, }) {
+                                                for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
+                                                    if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
+                                                    for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
                                                         test_cases.emplace_back(new test_flash_attn_ext(
-                                                                    hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+                                                                    hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
+                                                        // run fewer test cases permuted
+                                                        if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+                                                            test_cases.emplace_back(new test_flash_attn_ext(
+                                                                        hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+                                                        }
                                                     }
                                                 }
                                             }
@@ -5935,14 +6122,14 @@ static std::vector> make_test_cases_perf() {
     test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
     test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
 
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
 
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
@@ -5974,7 +6161,7 @@ static std::vector> make_test_cases_perf() {
     for (int kv : { 4096, 8192, 16384, }) {
         for (int hs : { 64, 128, }) {
             for (int nr : { 1, 4, }) {
-                test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+                test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
             }
         }
     }
@@ -5986,6 +6173,12 @@ static std::vector> make_test_cases_perf() {
 
     test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
 
+
+    for (int n_token : {1, 512}) {
+        test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 128, 4, n_token));
+        test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 32, 4, n_token));
+    }
+
     return test_cases;
 }
 
diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp
index a0a50f9881..edfac3b08b 100644
--- a/tests/test-chat-template.cpp
+++ b/tests/test-chat-template.cpp
@@ -61,7 +61,7 @@ int main(void) {
             /* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)",
             /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
             /* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST]   I am an assistant   [INST] Another question [/INST]",
-            /* .expected_output_jinja= */ "",
+            /* .expected_output_jinja= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST]   I am an assistant   [INST] Another question [/INST]",
             /* .bos_token= */ "",
             /* .eos_token= */ "",
         },
@@ -85,7 +85,7 @@ int main(void) {
             /* .name= */ "mlabonne/AlphaMonarch-7B",
             /* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
             /* .expected_output= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n   I am an assistant   \nuser\nAnother question\nassistant\n",
-            /* .expected_output_jinja= */ "",
+            /* .expected_output_jinja= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n   I am an assistant   \nuser\nAnother question\nassistant\n",
             /* .bos_token= */ "",
             /* .eos_token= */ "",
         },
@@ -99,7 +99,7 @@ int main(void) {
             /* .name= */ "OrionStarAI/Orion-14B-Chat",
             /* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
             /* .expected_output= */       "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant:    I am an assistant   Human: Another question\n\nAssistant: ",
-            /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant:    I am an assistant   Human: Another question\n\nAssistant: ",
+            /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant:    I am an assistant   Human: Another question\n\nAssistant: ",
             /* .bos_token= */ "",
             /* .eos_token= */ "",
         },
@@ -277,9 +277,9 @@ int main(void) {
         {
             /* .name= */ "yandex/YandexGPT-5-Lite-8B-instruct",
             /* .template_str= */ "{%- set names = {'assistant': ' Ассистент:', 'user': ' Пользователь:'} %}\n{%- set tools_prefix = 'Тебе доступны следующие функции:' %}\n{%- macro __render_tool(tool) %}\n    {%- set name = tool.function.name %}\n    {%- set description = tool.function.description|default('') %}\n    {%- set parameters = tool.function.parameters|tojson %}\n    {{- '\\n' }}function {{ '{' }}'name':'{{ name }}',\n    {%- if tool.function.description %}'description':'{{ description }}',{% endif %}\n'parameters':{{ parameters }}\n    {{- '}' }}\n{%- endmacro %}\n{%- macro __render_tools(tools) %}\n    {{- tools_prefix }}\n    {%- for tool in tools %}\n        {{- __render_tool(tool) }}\n    {%- endfor %}\n    {{- '\\n\\n' }}\n{%- endmacro %}\n{%- macro __render_tool_message(message) %}\n    {{- '\\n\\nРезультат вызова' }} {{ message.name }}: {{ message.content }} {{ '\\n\\n' }}\n{%- endmacro %}\n{%- if tools -%}\n    {{- __render_tools(tools) }}\n{%- endif -%}\n{%- macro __render_user_message(message) %}\n{{ names.user }} {{ message.content + '\\n\\n' }}\n{%- endmacro %}\n{%- macro __render_assistant_message(message) %}\n    {{- names.assistant }}\n    {%- set call = message['function_call'] %}\n    {%- if call %}\n        {{- '\\n[TOOL_CALL_START]' }}{{ call.name }}{{ '\\n' }}{{ call.arguments|tojson }}\n    {%- else %}\n        {{- ' ' + message.content + '\\n\\n' }}\n    {%- endif %}\n{%- endmacro %}\n{%- if not add_generation_prompt is defined %}\n{%- set add_generation_prompt = false %}\n{%- endif %}\n{%- for message in messages %}\n    {%- if message['role'] == 'user' %}\n        {{- __render_user_message(message) }}\n    {%- endif %}\n    {%- if message.role == 'assistant' and not loop.last %}\n        {{- __render_assistant_message(message) }}\n    {%- endif %}\n    {%- if message.role == 'tool' %}\n        {{- __render_tool_message(message) }}\n    {%- endif %}\n    {%- if loop.last %}\n        {{- ' Ассистент:[SEP]' }}\n    {%- endif %}\n{%- endfor %}\n",
-            /* .expected_output= */ " Пользователь: Hello\n\n Ассистент: Hi there\n\n Пользователь: Who are you\n\n Ассистент:    I am an assistant   \n\n Пользователь: Another question\n\n Ассистент:[SEP]",
+            /* .expected_output= */ " Пользователь: Hello\n\n Ассистент: Hi there\n\n Пользователь: Who are you\n\n Ассистент:    I am an assistant   \n\n Пользователь: Another question\n\n Ассистент:[SEP]",
             /* .expected_output_jinja= */ " Пользователь: You are a helpful assistant\nHello\n\n Ассистент: Hi there\n\n Пользователь: Who are you\n\n Ассистент:    I am an assistant   \n\n Пользователь: Another question\n\n Ассистент:[SEP]",
-            /* .bos_token= */ "",
+            /* .bos_token= */ "",
             /* .eos_token= */ "",
         },
         {
diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp
index 6ebf1464d9..99b4b4d5ba 100644
--- a/tests/test-chat.cpp
+++ b/tests/test-chat.cpp
@@ -953,6 +953,33 @@ static void test_template_output_parsers() {
                 /* is_partial= */ false,
                 {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
 
+        // Test multiple tool calls
+        common_chat_msg message_assist_multiple_calls;
+        message_assist_multiple_calls.role = "assistant";
+        message_assist_multiple_calls.content = "";
+        message_assist_multiple_calls.tool_calls.push_back({"special_function", "{\"arg1\": 1}", ""});
+        message_assist_multiple_calls.tool_calls.push_back({"python", "{\"code\":\"print('hello')\"}", ""});
+
+        assert_msg_equals(
+            message_assist_multiple_calls,
+            common_chat_parse(
+                "\n"
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "\n"
+                "\n"
+                "{\"name\": \"python\", \"arguments\": {\"code\":\"print('hello')\"}}\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+
+        assert_msg_equals(
+            message_assist_multiple_calls,
+            common_chat_parse(
+                "{\"arg1\": 1}\n"
+                "{\"code\":\"print('hello')\"}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+
         assert_msg_equals(
             simple_assist_msg(
                 "This is not a tool call:",
@@ -1039,6 +1066,22 @@ static void test_template_output_parsers() {
                       "\n"
                       "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
                       "");
+
+        // Test multiple tool calls with template
+        common_chat_msg message_assist_multiple_calls_template;
+        message_assist_multiple_calls_template.role = "assistant";
+        message_assist_multiple_calls_template.content = "";
+        message_assist_multiple_calls_template.tool_calls.push_back({"special_function", "{\"arg1\": 1}", ""});
+        message_assist_multiple_calls_template.tool_calls.push_back({"python", "{\"code\":\"print('test')\"}", ""});
+
+        test_templates(tmpls.get(), end_tokens, message_assist_multiple_calls_template, tools,
+                      "\n"
+                      "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                      "\n"
+                      "\n"
+                      "{\"name\": \"python\", \"arguments\": {\"code\":\"print('test')\"}}\n"
+                      "");
+
         test_templates(tmpls.get(), end_tokens, message_assist_call_python_lines, tools,
                       "\n"
                       "{\"name\": \"python\", \"arguments\": {\"code\":\"# This is a program:\\nprint('hey')\"}}\n"
@@ -1343,6 +1386,59 @@ static void test_template_output_parsers() {
                 "{\"arg1\": 1}\n"
                 "```<|tool▁call▁end|><|tool▁calls▁end|>");
     }
+    {
+        auto tmpls = read_templates("models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja");
+        std::vector end_tokens{ "<|end_of_text|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_GRANITE, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+
+        assert_equals(COMMON_CHAT_FORMAT_GRANITE, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+
+        // Test parsing regular content
+        assert_msg_equals(message_assist,
+            common_chat_parse(
+                "Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_GRANITE}));
+
+        // Test parsing content with thinking
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_GRANITE,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_GRANITE,
+                }));
+
+        // Test parsing tool calls
+        assert_msg_equals(message_assist_call,
+            common_chat_parse(
+                "<|tool_call|>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_GRANITE}));
+
+        // Test template generation for regular content
+        test_templates(tmpls.get(), end_tokens, message_assist, tools,
+                      "Hello, world!\nWhat's up?",
+                      /* expect_grammar_triggered= */ false);
+
+        // Test template generation for tool calls
+        test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools,
+                      "{\n"
+                      "  \"tool_calls\": [\n"
+                      "    {\n"
+                      "      \"name\": \"special_function\",\n"
+                      "      \"arguments\": {\n"
+                      "        \"arg1\": 1\n"
+                      "      },\n"
+                      "      \"id\": \"123456789\"\n"
+                      "    }\n"
+                      "  ]\n"
+                      "}",
+                      /* expect_grammar_triggered= */ false
+        );
+    }
 }
 
 static void test_msg_diffs_compute() {
diff --git a/tools/imatrix/README.md b/tools/imatrix/README.md
index 7417a2dec9..4505cb4ce8 100644
--- a/tools/imatrix/README.md
+++ b/tools/imatrix/README.md
@@ -7,7 +7,7 @@ More information is available in =2`, a message is output each time data is collected for any tensor. Default verbosity level is `1`.
 * `-o | --output-file` specifies the name of the file where the computed data will be stored. If missing `imatrix.gguf` is used.
 * `-ofreq | --output-frequency` specifies how often the so far computed result is saved to disk. Default is 10 (i.e., every 10 chunks)
+* `--output-format` specifies the output format of the generated imatrix file. Either "gguf", or "dat" (the legacy format). Defaults to "gguf".
 * `--save-frequency` specifies how often to save a copy of the imatrix in a separate file. Default is 0 (i.e., never)
 * `--process-output` specifies if data will be collected for the `output.weight` tensor. Typically, it is better not to utilize the importance matrix when quantizing `output.weight`, so this is set to `false` by default.
 * `--in-file` one or more existing imatrix files to load and combine. Useful for merging files from multiple runs/datasets.
@@ -45,14 +46,19 @@ Recent versions of `llama-imatrix` store data in GGUF format by default. For the
 
 ```bash
 # generate and save the imatrix using legacy format
-./llama-imatrix -m ggml-model-f16.gguf -f calibration-data.txt -o imatrix-legcy-format.dat -ngl 99
+./llama-imatrix -m ggml-model-f16.gguf -f calibration-data.txt --output-format dat -o imatrix-legcy-format.dat -ngl 99
 ```
 
 ```bash
-# covert legacy (binary) imatrix format to new (GGUF) format
+# convert legacy (binary) imatrix format to new (GGUF) format
 ./llama-imatrix --in-file imatrix-legacy-format.dat -o imatrix-new-format.gguf
 ```
 
+```bash
+# convert new (GGUF) imatrix format to legacy (binary) format
+./llama-imatrix --in-file imatrix-new-format.gguf --output-format dat -o imatrix-legacy-format.dat
+```
+
 ```bash
 # combine existing imatrices
 ./llama-imatrix --in-file imatrix-prev-0.gguf --in-file imatrix-prev-1.gguf -o imatrix-combined.gguf
diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp
index 9aad3711ba..f28a036dee 100644
--- a/tools/imatrix/imatrix.cpp
+++ b/tools/imatrix/imatrix.cpp
@@ -26,7 +26,7 @@
 static void print_usage(int, char ** argv) {
     LOG("\nexample usage:\n");
     LOG("\n    %s \\\n"
-            "       -m model.gguf -f some-text.txt [-o imatrix.gguf] [--no-ppl] \\\n"
+            "       -m model.gguf -f some-text.txt [-o imatrix.gguf] [--output-format {gguf,dat}] [--no-ppl] \\\n"
             "       [--process-output] [--chunk 123] [--save-frequency 0] [--output-frequency 10] \\\n"
             "       [--in-file imatrix-prev-0.gguf --in-file imatrix-prev-1.gguf ...] [--parse-special] \\\n"
             "       [--show-statistics] [...]\n" , argv[0]);
@@ -250,13 +250,6 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
     const char * data = is_host ? (const char *) src1->data : m_src1_data.data();
     GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
 
-    // TODO: 4d? (is that even used in practice?)
-    // the extra dimension would need to be stored somewhere to be reflected in the imatrix file
-    if (ggml_nrows(src1) != src1->ne[1] * src1->ne[2]) {
-        LOG_ERR("%s: tensor has more than 3 dimensions: %s", __func__, wname.c_str());
-        GGML_ASSERT(false);
-    }
-
     // this has been adapted to the new format of storing merged experts in a single 3d tensor
     // ref: https://github.com/ggml-org/llama.cpp/pull/6387
     if (t->op == GGML_OP_MUL_MAT_ID) {
@@ -272,6 +265,12 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
 
         GGML_ASSERT(ids->ne[1] == src1->ne[2]);
 
+        // the extra dimension would need to be stored somewhere to be reflected in the imatrix file
+        if (ggml_nrows(src1) != src1->ne[1] * src1->ne[2]) {
+            LOG_ERR("%s: tensor has more than 3 dimensions: %s", __func__, wname.c_str());
+            GGML_ASSERT(false);
+        }
+
         m_ids.resize(ggml_nbytes(ids));
         ggml_backend_tensor_get(ids, m_ids.data(), 0, ggml_nbytes(ids));
 
@@ -335,29 +334,40 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
         }
     } else {
         auto & e = m_stats[wname];
-        const int64_t n_mat = src1->ne[2] * src1->ne[3];
+        const int64_t n_mat = src0->ne[2] * src0->ne[3];
 
+        // use a single count per dense tensor
+        // (necessary when merging older GGUF-imatrix files with 3d tensors)
+        if (e.counts.size() > 1) {
+            bool all_equal = true;
+            for (size_t i = 1; i < e.counts.size(); ++i) {
+                if (e.counts[0] != e.counts[i]) {
+                    all_equal = false;
+                    break;
+                }
+            }
+            if (all_equal) {
+                e.counts.resize(1);
+            }
+        }
         if (e.values.empty()) {
             e.values.resize(src1->ne[0] * n_mat, 0);
-            e.counts.resize(n_mat, 0);
+            e.counts.resize(1, 0);
         }
         else if (e.values.size() != (size_t)(src1->ne[0] * n_mat)) {
             LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)(src1->ne[0] * n_mat));
             exit(1); //GGML_ABORT("fatal error");
         }
-        else if (e.counts.size() != (size_t)n_mat) {
-            LOG_ERR("%s: inconsistent expert count for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.counts.size(), (int)n_mat);
-            exit(1); //GGML_ABORT("fatal error");
-        }
         LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->ne[2], (int)src1->type);
+
         for (int64_t i3 = 0; i3 < src1->ne[3]; ++i3) {
             for (int64_t i2 = 0; i2 < src1->ne[2]; ++i2) {
-                const int64_t mat_id = i3 * src1->ne[2] + i2;
+                // handle 3D+ tensors, but flatten 3D+ activations when model tensor is 2D
+                const int64_t mat_id = (i3 % src0->ne[3]) * src0->ne[2] + (i2 % src0->ne[2]);
                 const int64_t mat_start = mat_id * src1->ne[0];
 
                 for (int64_t row = 0; row < src1->ne[1]; ++row) {
-                    const float * x = (const float *) (data + row * src1->nb[1] + i2 * src1->nb[2] + i3 * src1->ne[3]);
-                    e.counts[mat_id]++;
+                    const float * x = (const float *) (data + row * src1->nb[1] + i2 * src1->nb[2] + i3 * src1->nb[3]);
                     for (int64_t j = 0; j < src1->ne[0]; ++j) {
                         e.values[mat_start + j] += x[j] * x[j];
                         if (!std::isfinite((float)e.values[j])) {
@@ -366,16 +376,20 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
                         }
                     }
                 }
-                const int32_t n_chunk = e.counts[mat_id] / chunk_size;
-                if (n_chunk > m_last_chunk) {
-                    const int32_t chunk_step = n_chunk - m_last_chunk;
-                    m_last_chunk = n_chunk;
-                    if ((m_last_chunk % m_params.n_out_freq) / chunk_step == 0) {
-                        save_imatrix();
-                    }
-                    if (m_params.n_save_freq > 0 && (m_last_chunk % m_params.n_save_freq) / chunk_step == 0) {
-                        save_imatrix(m_last_chunk);
-                    }
+            }
+        }
+        // only 1 count in practice, except when a tensor is used for both MUL_MAT_ID and MUL_MAT
+        for (size_t i = 0; i < e.counts.size(); ++i) {
+            e.counts[i] += ggml_nrows(src1) / n_mat;
+            const int32_t n_chunk = e.counts[i] / chunk_size;
+            if (n_chunk > m_last_chunk) {
+                const int32_t chunk_step = n_chunk - m_last_chunk;
+                m_last_chunk = n_chunk;
+                if ((m_last_chunk % m_params.n_out_freq) / chunk_step == 0) {
+                    save_imatrix();
+                }
+                if (m_params.n_save_freq > 0 && (m_last_chunk % m_params.n_save_freq) / chunk_step == 0) {
+                    save_imatrix(m_last_chunk);
                 }
             }
         }
@@ -492,13 +506,17 @@ void IMatrixCollector::save_imatrix_legacy(int32_t ncall) const {
 
 void IMatrixCollector::save_imatrix(int32_t n_chunk) const {
     auto fname = m_params.out_file;
+    int8_t use_legacy_format = m_params.imat_dat;
 
-    // TODO: use the new format in more cases
-    if (!string_ends_with(fname, ".gguf")) {
-        LOG_WRN("\n%s: saving to legacy imatrix format because output suffix is not .gguf\n", __func__);
+    if (use_legacy_format > 0) {
         this->save_imatrix_legacy(n_chunk);
         return;
     }
+    // only warn when `--output-format gguf` is not specified
+    if (use_legacy_format == 0 && !string_ends_with(fname, ".gguf")) {
+        LOG_WRN("\n%s: saving imatrix using GGUF format with a different suffix than .gguf\n", __func__);
+        LOG_WRN("%s: if you want the previous imatrix format, use --output-format dat\n", __func__);
+    }
 
     if (n_chunk > 0) {
         fname += ".at_";
diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp
index c56834a2a6..10b48c5568 100644
--- a/tools/llama-bench/llama-bench.cpp
+++ b/tools/llama-bench/llama-bench.cpp
@@ -374,7 +374,7 @@ static void print_usage(int /* argc */, char ** argv) {
     printf("  -embd, --embeddings <0|1>                 (default: %s)\n",
            join(cmd_params_defaults.embeddings, ",").c_str());
     printf("  -ts, --tensor-split           (default: 0)\n");
-    printf("  -ot --override-tensors =;...\n");
+    printf("  -ot --override-tensor =;...\n");
     printf("                                            (default: disabled)\n");
     printf("  -nopo, --no-op-offload <0|1>              (default: 0)\n");
     printf("\n");
@@ -1738,7 +1738,7 @@ struct sql_printer : public printer {
 
     void print_header(const cmd_params & params) override {
         std::vector fields = test::get_fields();
-        fprintf(fout, "CREATE TABLE IF NOT EXISTS test (\n");
+        fprintf(fout, "CREATE TABLE IF NOT EXISTS llama_bench (\n");
         for (size_t i = 0; i < fields.size(); i++) {
             fprintf(fout, "  %s %s%s\n", fields.at(i).c_str(), get_sql_field_type(fields.at(i)).c_str(),
                     i < fields.size() - 1 ? "," : "");
@@ -1749,7 +1749,7 @@ struct sql_printer : public printer {
     }
 
     void print_test(const test & t) override {
-        fprintf(fout, "INSERT INTO test (%s) ", join(test::get_fields(), ", ").c_str());
+        fprintf(fout, "INSERT INTO llama_bench (%s) ", join(test::get_fields(), ", ").c_str());
         fprintf(fout, "VALUES (");
         std::vector values = t.get_values();
         for (size_t i = 0; i < values.size(); i++) {
diff --git a/tools/mtmd/requirements.txt b/tools/mtmd/requirements.txt
index ad069f7745..a9d788f265 100644
--- a/tools/mtmd/requirements.txt
+++ b/tools/mtmd/requirements.txt
@@ -1,5 +1,5 @@
 -r ../../requirements/requirements-convert_legacy_llama.txt
 --extra-index-url https://download.pytorch.org/whl/cpu
 pillow~=11.3.0
-torch~=2.2.1
-torchvision~=0.17.1
+torch~=2.4.0
+torchvision~=0.19.1
diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp
index 0e89a2b81b..470dc3d916 100644
--- a/tools/quantize/quantize.cpp
+++ b/tools/quantize/quantize.cpp
@@ -22,6 +22,7 @@ struct quant_option {
 static const std::vector QUANT_OPTIONS = {
     { "Q4_0",     LLAMA_FTYPE_MOSTLY_Q4_0,     " 4.34G, +0.4685 ppl @ Llama-3-8B",  },
     { "Q4_1",     LLAMA_FTYPE_MOSTLY_Q4_1,     " 4.78G, +0.4511 ppl @ Llama-3-8B",  },
+    { "MXFP4_MOE",LLAMA_FTYPE_MOSTLY_MXFP4_MOE," MXFP4 MoE",  },
     { "Q5_0",     LLAMA_FTYPE_MOSTLY_Q5_0,     " 5.21G, +0.1316 ppl @ Llama-3-8B",  },
     { "Q5_1",     LLAMA_FTYPE_MOSTLY_Q5_1,     " 5.65G, +0.1062 ppl @ Llama-3-8B",  },
     { "IQ2_XXS",  LLAMA_FTYPE_MOSTLY_IQ2_XXS,  " 2.06 bpw quantization",            },
@@ -611,7 +612,7 @@ int main(int argc, char ** argv) {
             return 1;
         }
         if (!try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) {
-            fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[3]);
+            fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[arg_idx]);
             return 1;
         }
         if (ftype_str == "COPY") {
diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz
index 53b71079c1..4e25eef176 100644
Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ
diff --git a/tools/server/server.cpp b/tools/server/server.cpp
index 35d6610428..a255d481a4 100644
--- a/tools/server/server.cpp
+++ b/tools/server/server.cpp
@@ -4249,9 +4249,6 @@ int main(int argc, char ** argv) {
 
             // process prompt
             std::vector inputs;
-            if (oaicompat && !prompt.is_string()) {
-                throw std::runtime_error("prompt must be a string");
-            }
 
             if (oaicompat && has_mtmd) {
                 // multimodal
diff --git a/tools/server/webui/src/components/ChatMessage.tsx b/tools/server/webui/src/components/ChatMessage.tsx
index ee59de450d..1c579f2659 100644
--- a/tools/server/webui/src/components/ChatMessage.tsx
+++ b/tools/server/webui/src/components/ChatMessage.tsx
@@ -61,20 +61,23 @@ export default function ChatMessage({
     if (msg.content === null || msg.role !== 'assistant') {
       return { content: msg.content };
     }
+    const REGEX_THINK_OPEN = /|<\|channel\|>analysis<\|message\|>/;
+    const REGEX_THINK_CLOSE =
+      /<\/think>|<\|start\|>assistant<\|channel\|>final<\|message\|>/;
     let actualContent = '';
     let thought = '';
     let isThinking = false;
-    let thinkSplit = msg.content.split('', 2);
+    let thinkSplit = msg.content.split(REGEX_THINK_OPEN, 2);
     actualContent += thinkSplit[0];
     while (thinkSplit[1] !== undefined) {
       //  tag found
-      thinkSplit = thinkSplit[1].split('', 2);
+      thinkSplit = thinkSplit[1].split(REGEX_THINK_CLOSE, 2);
       thought += thinkSplit[0];
       isThinking = true;
       if (thinkSplit[1] !== undefined) {
         //  closing tag found
         isThinking = false;
-        thinkSplit = thinkSplit[1].split('', 2);
+        thinkSplit = thinkSplit[1].split(REGEX_THINK_OPEN, 2);
         actualContent += thinkSplit[0];
       }
     }
diff --git a/tools/server/webui/src/index.scss b/tools/server/webui/src/index.scss
index 362db6e17d..879cd25855 100644
--- a/tools/server/webui/src/index.scss
+++ b/tools/server/webui/src/index.scss
@@ -31,7 +31,24 @@ html {
   hr {
     @apply my-4 border-base-content/20 border-1;
   }
-  /* TODO: fix markdown table */
+  table {
+    @apply w-full border-collapse text-sm font-sans my-4 text-base-content;
+  }
+  thead {
+    @apply bg-base-200 text-base-content;
+  }
+  th {
+    @apply border border-base-300 px-4 py-2 text-left font-semibold;
+  }
+  td {
+    @apply border border-base-300 px-4 py-2 align-top;
+  }
+  tbody tr:nth-child(even) {
+    @apply bg-base-100;
+  }
+  tbody tr:hover {
+    @apply bg-base-200;
+  }
 }
 
 .btn-mini {
diff --git a/vendor/minja/chat-template.hpp b/vendor/minja/chat-template.hpp
index cf113bf222..d5295b335b 100644
--- a/vendor/minja/chat-template.hpp
+++ b/vendor/minja/chat-template.hpp
@@ -162,8 +162,15 @@ class chat_template {
         }), false);
         caps_.supports_tools = contains(out, "some_tool");
 
-        auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false);
-        auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false);
+        const auto render_with_content = [&](const json & content) {
+            const json assistant_msg {{"role", "assistant"}, {"content", content}};
+            // Render two assistant messages as some templates like QwQ-32B are handling
+            // the content differently depending on whether it's the last message or not
+            // (to remove the  tag in all but the last message).
+            return try_raw_render(json::array({dummy_user_msg, assistant_msg, dummy_user_msg, assistant_msg}), {}, false);
+        };
+        auto out_empty = render_with_content("");
+        auto out_null = render_with_content(json());
         caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle);
 
         json j_null;
@@ -191,12 +198,12 @@ class chat_template {
             dummy_user_msg,
             make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
         }), {}, false);
-        auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
+        auto tool_call_renders_str_arguments = contains(out, "") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
         out = try_raw_render(json::array({
             dummy_user_msg,
             make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
         }), {}, false);
-        auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
+        auto tool_call_renders_obj_arguments = contains(out, "") || contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
 
         caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
         caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
diff --git a/vendor/minja/minja.hpp b/vendor/minja/minja.hpp
index dd107dccda..dad75efbba 100644
--- a/vendor/minja/minja.hpp
+++ b/vendor/minja/minja.hpp
@@ -1291,6 +1291,12 @@ public:
     }
 };
 
+static bool in(const Value & value, const Value & container) {
+  return (((container.is_array() || container.is_object()) && container.contains(value)) ||
+      (value.is_string() && container.is_string() &&
+        container.to_str().find(value.to_str()) != std::string::npos));
+}
+
 class BinaryOpExpr : public Expression {
 public:
     enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot };
@@ -1355,13 +1361,8 @@ public:
               case Op::Gt:        return l > r;
               case Op::Le:        return l <= r;
               case Op::Ge:        return l >= r;
-              case Op::In:        return (((r.is_array() || r.is_object()) && r.contains(l)) ||
-                                          (l.is_string() && r.is_string() &&
-                                            r.to_str().find(l.to_str()) != std::string::npos));
-              case Op::NotIn:
-                                  return !(((r.is_array() || r.is_object()) && r.contains(l)) ||
-                                            (l.is_string() && r.is_string() &&
-                                              r.to_str().find(l.to_str()) != std::string::npos));
+              case Op::In:        return in(l, r);
+              case Op::NotIn:     return !in(l, r);
               default:            break;
           }
           throw std::runtime_error("Unknown binary operator");
@@ -1500,6 +1501,13 @@ public:
           } else if (method->get_name() == "pop") {
             vargs.expectArgs("pop method", {1, 1}, {0, 0});
             return obj.pop(vargs.args[0]);
+          } else if (method->get_name() == "keys") {
+            vargs.expectArgs("keys method", {0, 0}, {0, 0});
+            auto result = Value::array();
+            for (const auto& key : obj.keys()) {
+              result.push_back(Value(key));
+            }
+            return result;
           } else if (method->get_name() == "get") {
             vargs.expectArgs("get method", {1, 2}, {0, 0});
             auto key = vargs.args[0];
@@ -1541,6 +1549,16 @@ public:
           } else if (method->get_name() == "capitalize") {
             vargs.expectArgs("capitalize method", {0, 0}, {0, 0});
             return Value(capitalize(str));
+          } else if (method->get_name() == "upper") {
+            vargs.expectArgs("upper method", {0, 0}, {0, 0});
+            auto result = str;
+            std::transform(result.begin(), result.end(), result.begin(), ::toupper);
+            return Value(result);
+          } else if (method->get_name() == "lower") {
+            vargs.expectArgs("lower method", {0, 0}, {0, 0});
+            auto result = str;
+            std::transform(result.begin(), result.end(), result.begin(), ::tolower);
+            return Value(result);
           } else if (method->get_name() == "endswith") {
             vargs.expectArgs("endswith method", {1, 1}, {0, 0});
             auto suffix = vargs.args[0].get();
@@ -2646,15 +2664,11 @@ inline std::shared_ptr Context::builtins() {
     auto items = Value::array();
     if (args.contains("object")) {
       auto & obj = args.at("object");
-      if (obj.is_string()) {
-        auto json_obj = json::parse(obj.get());
-        for (const auto & kv : json_obj.items()) {
-          items.push_back(Value::array({kv.key(), kv.value()}));
-        }
-      } else if (!obj.is_null()) {
-        for (auto & key : obj.keys()) {
-          items.push_back(Value::array({key, obj.at(key)}));
-        }
+      if (!obj.is_object()) {
+        throw std::runtime_error("Can only get item pairs from a mapping");
+      }
+      for (auto & key : obj.keys()) {
+        items.push_back(Value::array({key, obj.at(key)}));
       }
     }
     return items;
@@ -2782,6 +2796,9 @@ inline std::shared_ptr Context::builtins() {
       if (!items.is_array()) throw std::runtime_error("object is not iterable");
       return items;
   }));
+  globals.set("in", simple_function("in", { "item", "items" }, [](const std::shared_ptr &, Value & args) -> Value {
+      return in(args.at("item"), args.at("items"));
+  }));
   globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr &, Value & args) -> Value {
       auto & items = args.at("items");
       if (!items.is_array()) throw std::runtime_error("object is not iterable");