diff --git a/.devops/tools.sh b/.devops/tools.sh index 41a6b1e55c..8a3a693400 100755 --- a/.devops/tools.sh +++ b/.devops/tools.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -e # Read the first argument into a variable diff --git a/.github/ISSUE_TEMPLATE/010-bug-compilation.yml b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml index b85bf5741e..95a0b5cc75 100644 --- a/.github/ISSUE_TEMPLATE/010-bug-compilation.yml +++ b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml @@ -40,7 +40,7 @@ body: attributes: label: GGML backends description: Which GGML backends do you know to be affected? - options: [AMX, BLAS, CPU, CUDA, HIP, Kompute, Metal, Musa, RPC, SYCL, Vulkan] + options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL] multiple: true validations: required: true diff --git a/.github/ISSUE_TEMPLATE/011-bug-results.yml b/.github/ISSUE_TEMPLATE/011-bug-results.yml index 1ccef0793d..d1034bbb69 100644 --- a/.github/ISSUE_TEMPLATE/011-bug-results.yml +++ b/.github/ISSUE_TEMPLATE/011-bug-results.yml @@ -42,7 +42,7 @@ body: attributes: label: GGML backends description: Which GGML backends do you know to be affected? - options: [AMX, BLAS, CPU, CUDA, HIP, Kompute, Metal, Musa, RPC, SYCL, Vulkan] + options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL] multiple: true validations: required: true diff --git a/.github/labeler.yml b/.github/labeler.yml index 3c2f67707b..df6a7a40ed 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -1,10 +1,4 @@ # https://github.com/actions/labeler -Kompute: - - changed-files: - - any-glob-to-any-file: - - ggml/include/ggml-kompute.h - - ggml/src/ggml-kompute/** - - README-kompute.md Apple Metal: - changed-files: - any-glob-to-any-file: @@ -93,3 +87,8 @@ Ascend NPU: - ggml/include/ggml-cann.h - ggml/src/ggml-cann/** - docs/backend/CANN.md +OpenCL: + - changed-files: + - any-glob-to-any-file: + - ggml/include/ggml-opencl.h + - ggml/src/ggml-opencl/** diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4feccf21e9..788d7a1d10 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -84,7 +84,8 @@ jobs: -DCMAKE_BUILD_RPATH="@loader_path" \ -DLLAMA_FATAL_WARNINGS=ON \ -DGGML_METAL_USE_BF16=ON \ - -DGGML_METAL_EMBED_LIBRARY=ON \ + -DGGML_METAL_EMBED_LIBRARY=OFF \ + -DGGML_METAL_SHADER_DEBUG=ON \ -DGGML_RPC=ON cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) @@ -341,7 +342,7 @@ jobs: cd build export GGML_VK_VISIBLE_DEVICES=0 # This is using llvmpipe and runs slower than other backends - ctest -L main --verbose --timeout 3600 + ctest -L main --verbose --timeout 4200 ubuntu-22-cmake-hip: runs-on: ubuntu-22.04 @@ -664,7 +665,7 @@ jobs: ./build-xcframework.sh windows-msys2: - runs-on: windows-latest + runs-on: windows-2025 strategy: fail-fast: false @@ -714,7 +715,7 @@ jobs: cmake --build build --config ${{ matrix.build }} -j $(nproc) windows-latest-cmake: - runs-on: windows-latest + runs-on: windows-2025 env: OPENBLAS_VERSION: 0.3.23 @@ -725,17 +726,20 @@ jobs: matrix: include: - build: 'cpu-x64 (static)' + arch: 'x64' defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF' - build: 'openblas-x64' + arch: 'x64' defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"' - build: 'vulkan-x64' + arch: 'x64' defines: '-DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_VULKAN=ON' - build: 'llvm-arm64' + arch: 'arm64' defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON' - build: 'llvm-arm64-opencl-adreno' + arch: 'arm64' defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON' - # - build: 'kompute-x64' - # defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_KOMPUTE=ON -DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON' steps: - name: Clone @@ -749,12 +753,6 @@ jobs: variant: ccache evict-old-files: 1d - - name: Clone Kompute submodule - id: clone_kompute - if: ${{ matrix.build == 'kompute-x64' }} - run: | - git submodule update --init ggml/src/ggml-kompute/kompute - - name: Download OpenBLAS id: get_openblas if: ${{ matrix.build == 'openblas-x64' }} @@ -770,7 +768,7 @@ jobs: - name: Install Vulkan SDK id: get_vulkan - if: ${{ matrix.build == 'kompute-x64' || matrix.build == 'vulkan-x64' }} + if: ${{ matrix.build == 'vulkan-x64' }} run: | curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe" & "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install @@ -805,6 +803,8 @@ jobs: - name: libCURL id: get_libcurl uses: ./.github/actions/windows-setup-curl + with: + architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }} - name: Build id: cmake_build @@ -825,7 +825,7 @@ jobs: - name: Test id: cmake_test - if: ${{ matrix.build != 'llvm-arm64' && matrix.build != 'llvm-arm64-opencl-adreno' }} + if: ${{ matrix.arch == 'x64' }} run: | cd build ctest -L main -C Release --verbose --timeout 900 @@ -930,7 +930,7 @@ jobs: cmake --build build --config Release windows-latest-cmake-sycl: - runs-on: windows-latest + runs-on: windows-2022 defaults: run: @@ -964,7 +964,7 @@ jobs: windows-latest-cmake-hip: if: ${{ github.event.inputs.create_release != 'true' }} - runs-on: windows-latest + runs-on: windows-2022 steps: - name: Clone diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 64fff175e2..4ed6126f48 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -49,7 +49,8 @@ jobs: run: | sysctl -a cmake -B build \ - -DCMAKE_BUILD_RPATH="@loader_path" \ + -DCMAKE_INSTALL_RPATH='@loader_path' \ + -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \ -DLLAMA_FATAL_WARNINGS=ON \ -DGGML_METAL_USE_BF16=ON \ -DGGML_METAL_EMBED_LIBRARY=ON \ @@ -103,7 +104,8 @@ jobs: # Metal is disabled due to intermittent failures with Github runners not having a GPU: # https://github.com/ggml-org/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313 cmake -B build \ - -DCMAKE_BUILD_RPATH="@loader_path" \ + -DCMAKE_INSTALL_RPATH='@loader_path' \ + -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \ -DLLAMA_FATAL_WARNINGS=ON \ -DGGML_METAL=OFF \ -DGGML_RPC=ON @@ -160,6 +162,8 @@ jobs: id: cmake_build run: | cmake -B build \ + -DCMAKE_INSTALL_RPATH='$ORIGIN' \ + -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \ -DGGML_BACKEND_DL=ON \ -DGGML_NATIVE=OFF \ -DGGML_CPU_ALL_VARIANTS=ON \ @@ -211,6 +215,8 @@ jobs: id: cmake_build run: | cmake -B build \ + -DCMAKE_INSTALL_RPATH='$ORIGIN' \ + -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \ -DGGML_BACKEND_DL=ON \ -DGGML_NATIVE=OFF \ -DGGML_CPU_ALL_VARIANTS=ON \ @@ -235,7 +241,7 @@ jobs: name: llama-bin-ubuntu-vulkan-x64.zip windows-cpu: - runs-on: windows-latest + runs-on: windows-2025 strategy: matrix: @@ -271,7 +277,7 @@ jobs: env: CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} run: | - call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" ${{ matrix.arch }} + call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" ${{ matrix.arch == 'x64' && 'x64' || 'amd64_arm64' }} cmake -S . -B build -G "Ninja Multi-Config" ^ -D CMAKE_TOOLCHAIN_FILE=cmake/${{ matrix.arch }}-windows-llvm.cmake ^ -DGGML_NATIVE=OFF ^ @@ -288,7 +294,7 @@ jobs: CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} run: | Copy-Item $env:CURL_PATH\bin\libcurl-${{ matrix.arch }}.dll .\build\bin\Release\ - Copy-Item "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Redist\MSVC\14.42.34433\debug_nonredist\${{ matrix.arch }}\Microsoft.VC143.OpenMP.LLVM\libomp140.${{ matrix.arch == 'x64' && 'x86_64' || 'aarch64' }}.dll" .\build\bin\Release\ + Copy-Item "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Redist\MSVC\14.44.35112\debug_nonredist\${{ matrix.arch }}\Microsoft.VC143.OpenMP.LLVM\libomp140.${{ matrix.arch == 'x64' && 'x86_64' || 'aarch64' }}.dll" .\build\bin\Release\ 7z a llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\* - name: Upload artifacts @@ -298,7 +304,7 @@ jobs: name: llama-bin-win-cpu-${{ matrix.arch }}.zip windows: - runs-on: windows-latest + runs-on: windows-2025 env: OPENBLAS_VERSION: 0.3.23 @@ -448,7 +454,7 @@ jobs: name: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip windows-sycl: - runs-on: windows-latest + runs-on: windows-2022 defaults: run: @@ -520,7 +526,7 @@ jobs: name: llama-bin-win-sycl-x64.zip windows-hip: - runs-on: windows-latest + runs-on: windows-2022 strategy: matrix: diff --git a/.gitmodules b/.gitmodules index 23ce5ff059..e69de29bb2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "kompute"] - path = ggml/src/ggml-kompute/kompute - url = https://github.com/nomic-ai/kompute.git diff --git a/CMakeLists.txt b/CMakeLists.txt index d2becb04c6..c79ccd09e0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -120,7 +120,6 @@ endfunction() llama_option_depr(FATAL_ERROR LLAMA_CUBLAS GGML_CUDA) llama_option_depr(WARNING LLAMA_CUDA GGML_CUDA) -llama_option_depr(WARNING LLAMA_KOMPUTE GGML_KOMPUTE) llama_option_depr(WARNING LLAMA_METAL GGML_METAL) llama_option_depr(WARNING LLAMA_METAL_EMBED_LIBRARY GGML_METAL_EMBED_LIBRARY) llama_option_depr(WARNING LLAMA_NATIVE GGML_NATIVE) diff --git a/build-xcframework.sh b/build-xcframework.sh index a08419a801..f813984db9 100755 --- a/build-xcframework.sh +++ b/build-xcframework.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Options IOS_MIN_OS_VERSION=16.4 diff --git a/ci/run.sh b/ci/run.sh index e1b777c304..1146f86b64 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # sample usage: # diff --git a/common/arg.cpp b/common/arg.cpp index c4ad85c47b..56827a6590 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2734,6 +2734,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.public_path = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH")); + add_opt(common_arg( + {"--api-prefix"}, "PREFIX", + string_format("prefix path the server serves from, without the trailing slash (default: %s)", params.api_prefix.c_str()), + [](common_params & params, const std::string & value) { + params.api_prefix = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_API_PREFIX")); add_opt(common_arg( {"--no-webui"}, string_format("Disable the Web UI (default: %s)", params.webui ? "enabled" : "disabled"), @@ -2794,6 +2801,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.ssl_file_cert = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_CERT_FILE")); + add_opt(common_arg( + {"--chat-template-kwargs"}, "STRING", + string_format("sets additional params for the json template parser"), + [](common_params & params, const std::string & value) { + auto parsed = json::parse(value); + for (const auto & item : parsed.items()) { + params.default_template_kwargs[item.key()] = item.value().dump(); + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_CHAT_TEMPLATE_KWARGS")); add_opt(common_arg( {"-to", "--timeout"}, "N", string_format("server read/write timeout in seconds (default: %d)", params.timeout_read), diff --git a/common/chat.cpp b/common/chat.cpp index 7d9aaeb12a..114dbfccdb 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -17,6 +17,8 @@ #include #include +using json = nlohmann::ordered_json; + static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) { auto time = std::chrono::system_clock::to_time_t(now); auto local_time = *std::localtime(&time); @@ -140,6 +142,7 @@ struct templates_params { bool add_generation_prompt = true; bool enable_thinking = true; std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + json extra_context; }; common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { @@ -720,16 +723,23 @@ static void foreach_function(const json & tools, const std::function & messages_override = std::nullopt, + const std::optional & tools_override = std::nullopt, + const std::optional & additional_context = std::nullopt) { minja::chat_template_inputs tmpl_inputs; - tmpl_inputs.messages = messages; - tmpl_inputs.tools = tools; - tmpl_inputs.add_generation_prompt = add_generation_prompt; - tmpl_inputs.extra_context = extra_context; + tmpl_inputs.messages = messages_override ? *messages_override : inputs.messages; + if (tools_override) { + tmpl_inputs.tools = *tools_override; + } else { + tmpl_inputs.tools = inputs.tools.empty() ? json() : inputs.tools; + } + tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt; + tmpl_inputs.extra_context = inputs.extra_context; + if (additional_context) { + tmpl_inputs.extra_context.merge_patch(*additional_context); + } // TODO: add flag to control date/time, if only for testing purposes. // tmpl_inputs.now = std::chrono::system_clock::now(); @@ -828,7 +838,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp inputs.messages, "Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request"); - data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages); data.format = COMMON_CHAT_FORMAT_GENERIC; return data; } @@ -904,7 +914,7 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat data.preserved_tokens = { "[TOOL_CALLS]", }; - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO; return data; } @@ -934,7 +944,7 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ adjusted_messages.push_back(msg); } } - data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {}); + data.prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages); data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; if (string_ends_with(data.prompt, "<|START_THINKING|>")) { if (!inputs.enable_thinking) { @@ -1122,7 +1132,7 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te } else { data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; } - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json { {"date_string", format_time(inputs.now, "%d %b %Y")}, {"tools_in_user_message", false}, {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, @@ -1187,7 +1197,7 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + auto prompt = apply(tmpl, inputs); // Hacks to fix the official (broken) prompt. // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead, @@ -1282,7 +1292,7 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { LOG_DBG("%s\n", __func__); common_chat_params data; - data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ json(), json { {"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, }); @@ -1338,7 +1348,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar // If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code. common_chat_params data; - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; if (inputs.tools.is_array() && !inputs.tools.empty()) { data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -1465,7 +1475,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; } - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs); // TODO: if (has_raw_python) return data; } @@ -1498,14 +1508,15 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - json additional_context = { + json extra_context = json { {"enable_thinking", inputs.enable_thinking}, }; + extra_context.update(inputs.extra_context); - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, additional_context); + data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, extra_context); data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO; if (string_ends_with(data.prompt, "\n")) { - if (!inputs.enable_thinking) { + if (!extra_context["enable_thinking"]) { data.prompt += ""; } else { data.thinking_forced_open = true; @@ -1691,7 +1702,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.prompt = apply(tmpl, inputs); data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; data.grammar_lazy = false; if (!inputs.json_schema.is_null()) { @@ -1722,6 +1733,12 @@ static common_chat_params common_chat_templates_apply_jinja( params.enable_thinking = inputs.enable_thinking; params.grammar = inputs.grammar; params.now = inputs.now; + + params.extra_context = json::object(); + for (auto el : inputs.chat_template_kwargs) { + params.extra_context[el.first] = json::parse(el.second); + } + if (!inputs.json_schema.empty()) { params.json_schema = json::parse(inputs.json_schema); } diff --git a/common/chat.h b/common/chat.h index 9f59e6b087..ca807c145e 100644 --- a/common/chat.h +++ b/common/chat.h @@ -7,6 +7,7 @@ #include #include #include +#include struct common_chat_templates; @@ -125,6 +126,7 @@ struct common_chat_templates_inputs { common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; bool enable_thinking = true; std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + std::map chat_template_kwargs; }; struct common_chat_params { diff --git a/common/common.h b/common/common.h index e08a59eae7..a5abe32859 100644 --- a/common/common.h +++ b/common/common.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #ifdef _WIN32 @@ -369,6 +370,7 @@ struct common_params { std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT + std::string api_prefix = ""; // NOLINT std::string chat_template = ""; // NOLINT bool use_jinja = false; // NOLINT bool enable_chat_template = true; @@ -381,6 +383,8 @@ struct common_params { std::string ssl_file_key = ""; // NOLINT std::string ssl_file_cert = ""; // NOLINT + std::map default_template_kwargs; + // "advanced" endpoints are disabled by default for better security bool webui = true; bool endpoint_slots = false; diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index aed595e259..3f3dfb416c 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -815,6 +815,9 @@ class TextModel(ModelBase): if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 res = "minerva-7b" + if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664": + # ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct + res = "hunyuan" if res is None: logger.warning("\n") @@ -2743,6 +2746,52 @@ class Qwen2Model(TextModel): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Ernie4_5_ForCausalLM") +class Ernie4_5Model(TextModel): + model_arch = gguf.MODEL_ARCH.ERNIE4_5 + + def set_vocab(self): + self._set_vocab_sentencepiece() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + num_heads = self.hparams["num_attention_heads"] + num_kv_heads = self.hparams["num_key_value_heads"] + head_dim = self.hparams["head_dim"] + + if "ernie." in name: + name = name.replace("ernie.", "model.") + # split the qkv weights + # qkv_proj shape: [(num_heads + 2 * num_kv_heads) * head_dim, hidden_size] + if "qkv_proj" in name: + name_q = name.replace("qkv_proj.weight", "q_proj.weight") + name_k = name.replace("qkv_proj.weight", "k_proj.weight") + name_v = name.replace("qkv_proj.weight", "v_proj.weight") + total_q_dim = num_heads * head_dim + total_k_dim = num_kv_heads * head_dim + total_v_dim = num_kv_heads * head_dim + q_proj_weight, k_proj_weight, v_proj_weight = data_torch.split([total_q_dim, total_k_dim, total_v_dim], dim=0) + return [ + (self.map_tensor_name(name_q), q_proj_weight), + (self.map_tensor_name(name_k), k_proj_weight), + (self.map_tensor_name(name_v), v_proj_weight) + ] + # split the up_gate_proj into gate and up + # up_gate_proj shape: [2 * intermediate_size, hidden_size] + if "up_gate_proj" in name: + name_up = name.replace("up_gate_proj.weight", "up_proj.weight") + name_gate = name.replace("up_gate_proj.weight", "gate_proj.weight") + dim_half = data_torch.shape[0] // 2 + gate_proj_weight, up_proj_weight = data_torch.split(dim_half, dim=0) + return [ + (self.map_tensor_name(name_gate), gate_proj_weight), + (self.map_tensor_name(name_up), up_proj_weight) + ] + return [(self.map_tensor_name(name), data_torch)] + + @ModelBase.register( "Qwen2VLModel", "Qwen2VLForConditionalGeneration", @@ -4362,9 +4411,6 @@ class Gemma3NModel(Gemma3Model): ] def set_vocab(self): - with open(self.dir_model / "chat_template.jinja") as f: - # quick hack to make sure chat template is added - self.gguf_writer.add_chat_template(f.read()) super().set_vocab() def set_gguf_parameters(self): @@ -4735,6 +4781,14 @@ class ARwkv7Model(Rwkv7Model): class MambaModel(TextModel): model_arch = gguf.MODEL_ARCH.MAMBA + def __init__(self, dir_model: Path, *args, **kwargs): + # Avoid using AutoConfig for hparams + hparams = kwargs.pop("hparams", None) + if hparams is None: + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + super().__init__(dir_model, *args, hparams=hparams, **kwargs) + def set_vocab(self): vocab_size = self.hparams["vocab_size"] # Round vocab size to next multiple of 8 @@ -4809,6 +4863,100 @@ class MambaModel(TextModel): return [(new_name, data_torch)] +@ModelBase.register("Mamba2ForCausalLM") +class Mamba2Model(TextModel): + model_arch = gguf.MODEL_ARCH.MAMBA2 + + def __init__(self, dir_model: Path, *args, **kwargs): + # Avoid using AutoConfig for hparams + # It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1 + hparams = kwargs.pop("hparams", None) + if hparams is None: + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + super().__init__(dir_model, *args, hparams=hparams, **kwargs) + + def set_vocab(self): + vocab_size = self.hparams["vocab_size"] + # Round vocab size to next multiple of 16 + pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16) + # pad using ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + vocab_size = -(vocab_size // -pad_vocab) * pad_vocab + self.hparams["vocab_size"] = vocab_size + + if (self.dir_model / "tokenizer.model").is_file(): + self._set_vocab_sentencepiece() + elif (self.dir_model / "tokenizer.model.v3").is_file(): + # mamba-codestral + raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}") + elif (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() + else: + # Use the GPT-NeoX tokenizer when no tokenizer files are present + self._set_vocab_builtin("gpt-neox", vocab_size) + + def set_gguf_parameters(self): + d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128 + head_dim = self.find_hparam(["head_dim"], optional=True) or 64 + n_group = self.find_hparam(["n_groups"], optional=True) or 1 + + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + + # Fail early for models which don't have a block expansion factor of 2 + # TODO: does this really matter? + assert d_inner == 2 * d_model + assert d_inner % head_dim == 0 + + self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default + self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading + self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_ssm_conv_kernel(d_conv) + self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_state_size(d_state) + self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim) + self.gguf_writer.add_ssm_group_count(n_group) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + + if name.startswith("model.backbone") or name.startswith("model.lm_head"): + # map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2 + name = name.removeprefix("model.") + + if name.endswith(".dt_bias"): + name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" + + new_name = self.map_tensor_name(name) + + if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid): + data_torch = data_torch.squeeze() + elif any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [ + gguf.MODEL_TENSOR.SSM_A, + gguf.MODEL_TENSOR.SSM_D, + ]): + # unsqueeze A to use similar shape semantics as Mamba-1 + # (D is also unsqueezed, but for more straightforward broadcast internally) + data_torch = data_torch.reshape((*data_torch.shape, 1)) + elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid): + d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + n_group = self.hparams.get("n_groups", 1) + data_torch = data_torch.reshape((n_group, d_inner // n_group)) + + if name.endswith(".A_log"): + logger.debug("A_log --> A ==> " + new_name) + data_torch = -torch.exp(data_torch) + + yield (new_name, data_torch) + + @ModelBase.register("CohereForCausalLM") class CommandR2Model(TextModel): model_arch = gguf.MODEL_ARCH.COMMAND_R @@ -6390,6 +6538,160 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel): super().set_gguf_parameters() self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"]) + +@ModelBase.register("HunYuanMoEV1ForCausalLM") +class HunYuanMoEModel(TextModel): + model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # For handling tied embeddings + self._tok_embd = None + + def set_vocab(self): + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + + # 1. Get the pre-tokenizer identifier hash + tokpre = self.get_vocab_base_pre(tokenizer) + + # 2. Reverse-engineer the merges list from mergeable_ranks + merges = [] + vocab = {} + mergeable_ranks = tokenizer.mergeable_ranks + for token, rank in mergeable_ranks.items(): + vocab[QwenModel.token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) + if len(merged) == 2: # todo this is an assert in Qwen, why? + merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) + + # 3. Generate the tokens and toktypes lists + vocab_size = self.hparams["vocab_size"] + assert tokenizer.vocab_size == vocab_size + special_tokens = tokenizer.special_tokens + reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()} + tokens: list[str] = [] + toktypes: list[int] = [] + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.UNUSED) + else: + token = reverse_vocab[i] + tokens.append(token) + if i in special_tokens.values(): + toktypes.append(gguf.TokenType.CONTROL) + else: + toktypes.append(gguf.TokenType.NORMAL) + + # 4. Write all vocab-related fields to the GGUF writer + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_token_merges(merges) + + # 5. Add special tokens and chat templates + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) + special_vocab.add_to_gguf(self.gguf_writer) + # FIX for BOS token: Overwrite incorrect id read from config.json + self.gguf_writer.add_bos_token_id(127959) # <|bos|> + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + + self.gguf_writer.add_expert_count(hparams["num_experts"]) + self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"]) + + moe_intermediate_size = hparams["moe_intermediate_size"] + assert all(n == moe_intermediate_size[0] for n in moe_intermediate_size) + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0]) + + moe_topk = hparams["moe_topk"] + assert all(topk == moe_topk[0] for topk in moe_topk) + self.gguf_writer.add_expert_used_count(moe_topk[0]) + + moe_shared_expert = hparams["num_shared_expert"] + assert all(n == moe_shared_expert[0] for n in moe_shared_expert) + self.gguf_writer.add_expert_shared_count(moe_shared_expert[0]) + + # Rope + rope_scaling = hparams.get("rope_scaling", {}) + if rope_scaling.get("type") == "dynamic": + # HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + # 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf) + alpha = rope_scaling.get("alpha", 1000) + base = hparams.get("rope_theta", 10000.0) + dim = (hparams["hidden_size"] // hparams["num_attention_heads"]) # 128 + scaled_base = base * (alpha ** (dim / (dim - 2))) # 10000 * (1000 ** (128 / 126)) = 11158839.9251 + self.gguf_writer.add_rope_freq_base(scaled_base) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_rope_scaling_factor(1) + # There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k + self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length + self.gguf_writer.add_context_length(256 * 1024) # 256k context length + + # if any of our assumptions about the values are wrong, something has changed and this may need to be updated + assert alpha == 1000 and base == 10000.0 and dim == 128 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \ + "HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually" + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name == "model.embed_tokens.weight": + self._tok_embd = data_torch.clone() + + if name == "lm_head.weight": + if self.hparams.get("tie_word_embeddings", False): + logger.info("Skipping tied output layer 'lm_head.weight'") + return [] + + if name.find("mlp.experts") != -1: + n_experts = self.hparams["num_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + # merge the experts into a single 3d tensor + tensors: list[tuple[str, Tensor]] = [] + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + if self._experts is not None: + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + +@ModelBase.register("SmolLM3ForCausalLM") +class SmolLM3Model(LlamaModel): + model_arch = gguf.MODEL_ARCH.SMOLLM3 + ###### CONVERSION LOGIC ###### @@ -6569,12 +6871,20 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st # maybe we should fallback to text model's arch in that case, since not many models have both text_config = hparams.get("text_config", {}) vision_config = hparams.get("vision_config", {}) - arch = hparams["architectures"][0] + arch = None + if (arches := hparams.get("architectures")) is not None and len(arches) > 0: + arch = arches[0] + elif "ssm_cfg" in hparams: + # For non-hf Mamba and Mamba2 models + arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM" + # if "architectures" is found in the sub-config, use that instead if model_type == ModelType.TEXT and text_config.get("architectures") is not None: arch = text_config["architectures"][0] elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None: arch = vision_config["architectures"][0] + if arch is None: + raise ValueError("Failed to detect model architecture") return arch diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 2f733f0973..96a2b692a8 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -137,6 +137,7 @@ pre_computed_hashes = [ {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"}, {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, + {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, ] diff --git a/docs/development/HOWTO-add-model.md b/docs/development/HOWTO-add-model.md index 7f71e0247d..51e0b0b20f 100644 --- a/docs/development/HOWTO-add-model.md +++ b/docs/development/HOWTO-add-model.md @@ -83,20 +83,22 @@ NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the conv ### 2. Define the model architecture in `llama.cpp` -The model params and tensors layout must be defined in `llama.cpp`: -1. Define a new `llm_arch` -2. Define the tensors layout in `LLM_TENSOR_NAMES` -3. Add any non-standard metadata in `llm_load_hparams` -4. Create the tensors for inference in `llm_load_tensors` -5. If the model has a RoPE operation, add the rope type in `llama_rope_type` +The model params and tensors layout must be defined in `llama.cpp` source files: +1. Define a new `llm_arch` enum value in `src/llama-arch.h`. +2. In `src/llama-arch.cpp`: + - Add the architecture name to the `LLM_ARCH_NAMES` map. + - Add the tensor mappings to the `LLM_TENSOR_NAMES` map. +3. Add any non-standard metadata loading in the `llama_model_loader` constructor in `src/llama-model-loader.cpp`. +4. If the model has a RoPE operation, add a case for the architecture in `llama_model_rope_type` function in `src/llama-model.cpp`. NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorch` dimensions. ### 3. Build the GGML graph implementation -This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`. - -Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`. +This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `src/llama-model.cpp`. +Create a new struct that inherits from `llm_graph_context` and implement the graph-building logic in its constructor. +Have a look at existing implementations like `llm_build_llama`, `llm_build_dbrx` or `llm_build_bert`. +Then, in the `llama_model::build_graph` method, add a case for your architecture to instantiate your new graph-building struct. Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR. diff --git a/docs/docker.md b/docs/docker.md index f8f0573c17..cbb333ee32 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -25,6 +25,9 @@ Additionally, there the following images, similar to the above: - `ghcr.io/ggml-org/llama.cpp:full-intel`: Same as `full` but compiled with SYCL support. (platforms: `linux/amd64`) - `ghcr.io/ggml-org/llama.cpp:light-intel`: Same as `light` but compiled with SYCL support. (platforms: `linux/amd64`) - `ghcr.io/ggml-org/llama.cpp:server-intel`: Same as `server` but compiled with SYCL support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:full-vulkan`: Same as `full` but compiled with Vulkan support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:light-vulkan`: Same as `light` but compiled with Vulkan support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:server-vulkan`: Same as `server` but compiled with Vulkan support. (platforms: `linux/amd64`) The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](../.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](../.github/workflows/docker.yml). If you need different settings (for example, a different CUDA, ROCm or MUSA library, you'll need to build the images locally for now). diff --git a/examples/Miku.sh b/examples/Miku.sh index 0f6c8c8787..9492bfedc0 100755 --- a/examples/Miku.sh +++ b/examples/Miku.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -e AI_NAME="${AI_NAME:-Miku}" diff --git a/examples/chat-13B.sh b/examples/chat-13B.sh index 1828903c31..f025a47cbf 100755 --- a/examples/chat-13B.sh +++ b/examples/chat-13B.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -e diff --git a/examples/chat-persistent.sh b/examples/chat-persistent.sh index 9d761ebb84..d6b6cb9518 100755 --- a/examples/chat-persistent.sh +++ b/examples/chat-persistent.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -euo pipefail diff --git a/examples/chat-vicuna.sh b/examples/chat-vicuna.sh index ffdd200849..c930962fd3 100755 --- a/examples/chat-vicuna.sh +++ b/examples/chat-vicuna.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -e diff --git a/examples/chat.sh b/examples/chat.sh index 9f85d1e265..5fec46d17b 100755 --- a/examples/chat.sh +++ b/examples/chat.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Temporary script - will be removed in the future diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index fb188f5a9e..4afd80eb45 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -55,6 +55,8 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]); } else if (type == GGML_TYPE_F32) { v = *(float *) &data[i]; + } else if (type == GGML_TYPE_I64) { + v = (float) *(int64_t *) &data[i]; } else if (type == GGML_TYPE_I32) { v = (float) *(int32_t *) &data[i]; } else if (type == GGML_TYPE_I16) { @@ -134,6 +136,11 @@ static bool run(llama_context * ctx, const common_params & params) { std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); + if (tokens.empty()) { + LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__); + return false; + } + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { LOG_ERR("%s : failed to eval\n", __func__); return false; diff --git a/examples/jeopardy/jeopardy.sh b/examples/jeopardy/jeopardy.sh index 07bcb3b8d7..800df2c6ae 100755 --- a/examples/jeopardy/jeopardy.sh +++ b/examples/jeopardy/jeopardy.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -e MODEL=./models/ggml-vicuna-13b-1.1-q4_0.bin diff --git a/examples/reason-act.sh b/examples/reason-act.sh index 06d592799c..3c801920d0 100755 --- a/examples/reason-act.sh +++ b/examples/reason-act.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash cd `dirname $0` cd .. diff --git a/examples/server-llama2-13B.sh b/examples/server-llama2-13B.sh index 4ce79b7fac..fd5a575886 100755 --- a/examples/server-llama2-13B.sh +++ b/examples/server-llama2-13B.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -e diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index cf1178043d..57195df331 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -113,15 +113,16 @@ int main(int argc, char ** argv) { while (true) { // check if we have enough space in the context to evaluate this batch int n_ctx = llama_n_ctx(ctx); - int n_ctx_used = llama_memory_seq_pos_max(llama_get_memory(ctx), 0); + int n_ctx_used = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) + 1; if (n_ctx_used + batch.n_tokens > n_ctx) { printf("\033[0m\n"); fprintf(stderr, "context size exceeded\n"); exit(0); } - if (llama_decode(ctx, batch)) { - GGML_ABORT("failed to decode\n"); + int ret = llama_decode(ctx, batch); + if (ret != 0) { + GGML_ABORT("failed to decode, ret = %d\n", ret); } // sample the next token diff --git a/examples/sycl/build.sh b/examples/sycl/build.sh index e72b2e2612..1993520ebd 100755 --- a/examples/sycl/build.sh +++ b/examples/sycl/build.sh @@ -1,4 +1,4 @@ - +#!/usr/bin/env bash # MIT license # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: MIT diff --git a/examples/sycl/run-llama2.sh b/examples/sycl/run-llama2.sh index 40ce8f5b2b..37195008de 100755 --- a/examples/sycl/run-llama2.sh +++ b/examples/sycl/run-llama2.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # MIT license # Copyright (C) 2024 Intel Corporation diff --git a/examples/sycl/run-llama3.sh b/examples/sycl/run-llama3.sh index 933d1b98bc..8e21b017f4 100755 --- a/examples/sycl/run-llama3.sh +++ b/examples/sycl/run-llama3.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # MIT license # Copyright (C) 2025 Intel Corporation diff --git a/examples/ts-type-to-grammar.sh b/examples/ts-type-to-grammar.sh index 9abba2a3da..9660504078 100755 --- a/examples/ts-type-to-grammar.sh +++ b/examples/ts-type-to-grammar.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # ./examples/ts-type-to-grammar.sh "{a:string,b:string,c?:string}" # python examples/json_schema_to_grammar.py https://json.schemastore.org/tsconfig.json diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 215eb23486..eaba9c7046 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -181,7 +181,6 @@ option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug ou option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF) option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF) option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF) -option(GGML_KOMPUTE "ggml: use Kompute" OFF) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) @@ -266,7 +265,6 @@ set(GGML_PUBLIC_HEADERS include/ggml-cann.h include/ggml-cpp.h include/ggml-cuda.h - include/ggml-kompute.h include/ggml-opt.h include/ggml-metal.h include/ggml-rpc.h @@ -360,6 +358,13 @@ write_basic_package_version_file( VERSION ${GGML_INSTALL_VERSION} COMPATIBILITY SameMajorVersion) +target_compile_definitions(ggml-base PRIVATE + GGML_VERSION="${GGML_INSTALL_VERSION}" + GGML_COMMIT="${GGML_BUILD_COMMIT}" +) +message(STATUS "ggml version: ${GGML_INSTALL_VERSION}") +message(STATUS "ggml commit: ${GGML_BUILD_COMMIT}") + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 778927f682..a2977ea2e5 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -339,7 +339,7 @@ extern "C" { typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data); // Compare the output of two backends - GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data); + GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node); // Tensor initialization GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index e3b79d09bb..be40b10097 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -134,6 +134,7 @@ extern "C" { GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void); + GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t); GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t); GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t); GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t); diff --git a/ggml/include/ggml-kompute.h b/ggml/include/ggml-kompute.h deleted file mode 100644 index 154aa56a74..0000000000 --- a/ggml/include/ggml-kompute.h +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once - -#include "ggml.h" -#include "ggml-backend.h" - -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -#define GGML_KOMPUTE_MAX_DEVICES 16 - -struct ggml_vk_device { - int index; - int type; // same as VkPhysicalDeviceType - size_t heapSize; - const char * name; - const char * vendor; - int subgroupSize; - uint64_t bufferAlignment; - uint64_t maxAlloc; -}; - -struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count); -bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name); -bool ggml_vk_has_vulkan(void); -bool ggml_vk_has_device(void); -struct ggml_vk_device ggml_vk_current_device(void); - -// -// backend API -// - -// forward declaration -typedef struct ggml_backend * ggml_backend_t; - -GGML_BACKEND_API ggml_backend_t ggml_backend_kompute_init(int device); - -GGML_BACKEND_API bool ggml_backend_is_kompute(ggml_backend_t backend); - -GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device); - -GGML_BACKEND_API ggml_backend_reg_t ggml_backend_kompute_reg(void); - -#ifdef __cplusplus -} -#endif diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 236ac52eb3..8a8775be36 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -314,6 +314,13 @@ extern "C" { #endif + // Function type used in fatal error callbacks + typedef void (*ggml_abort_callback_t)(const char * error_message); + + // Set the abort callback (passing null will restore original abort functionality: printing a message to stdout) + // Returns the old callback for chaining + GGML_API ggml_abort_callback_t ggml_set_abort_callback(ggml_abort_callback_t callback); + GGML_NORETURN GGML_ATTRIBUTE_FORMAT(3, 4) GGML_API void ggml_abort(const char * file, int line, const char * fmt, ...); @@ -470,6 +477,7 @@ extern "C" { GGML_OP_TRANSPOSE, GGML_OP_GET_ROWS, GGML_OP_GET_ROWS_BACK, + GGML_OP_SET_ROWS, GGML_OP_DIAG, GGML_OP_DIAG_MASK_INF, GGML_OP_DIAG_MASK_ZERO, @@ -481,12 +489,13 @@ extern "C" { GGML_OP_CONV_TRANSPOSE_1D, GGML_OP_IM2COL, GGML_OP_IM2COL_BACK, + GGML_OP_CONV_2D, GGML_OP_CONV_2D_DW, GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_POOL_1D, GGML_OP_POOL_2D, GGML_OP_POOL_2D_BACK, - GGML_OP_UPSCALE, // nearest interpolate + GGML_OP_UPSCALE, GGML_OP_PAD, GGML_OP_PAD_REFLECT_1D, GGML_OP_ROLL, @@ -519,6 +528,8 @@ extern "C" { GGML_OP_CROSS_ENTROPY_LOSS_BACK, GGML_OP_OPT_STEP_ADAMW, + GGML_OP_GLU, + GGML_OP_COUNT, }; @@ -542,6 +553,16 @@ extern "C" { GGML_UNARY_OP_COUNT, }; + enum ggml_glu_op { + GGML_GLU_OP_REGLU, + GGML_GLU_OP_GEGLU, + GGML_GLU_OP_SWIGLU, + GGML_GLU_OP_GEGLU_ERF, + GGML_GLU_OP_GEGLU_QUICK, + + GGML_GLU_OP_COUNT, + }; + enum ggml_object_type { GGML_OBJECT_TYPE_TENSOR, GGML_OBJECT_TYPE_GRAPH, @@ -627,6 +648,9 @@ extern "C" { // misc + GGML_API const char * ggml_version(void); + GGML_API const char * ggml_commit(void); + GGML_API void ggml_time_init(void); // call this once at the beginning of the program GGML_API int64_t ggml_time_ms(void); GGML_API int64_t ggml_time_us(void); @@ -657,6 +681,7 @@ extern "C" { GGML_API const char * ggml_op_symbol(enum ggml_op op); GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op); + GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op); GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); @@ -687,6 +712,9 @@ extern "C" { // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor); + // true if the elements in dimension 0 are contiguous, or there is just 1 block of elements + GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor); + GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1); GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1); @@ -758,6 +786,7 @@ extern "C" { GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3); GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor); + GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor); GGML_API void * ggml_get_data (const struct ggml_tensor * tensor); GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); @@ -1086,6 +1115,89 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // gated linear unit ops + // A: n columns, r rows, + // result is n / 2 columns, r rows, + // expects gate in second half of row, unless swapped is true + GGML_API struct ggml_tensor * ggml_glu( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_glu_op op, + bool swapped); + + GGML_API struct ggml_tensor * ggml_reglu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_reglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_swiglu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_swiglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu_erf( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu_erf_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu_quick( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_geglu_quick_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // A: n columns, r rows, + // B: n columns, r rows, + GGML_API struct ggml_tensor * ggml_glu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_glu_op op); + + GGML_API struct ggml_tensor * ggml_reglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_geglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_swiglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_geglu_erf_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_geglu_quick_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + // normalize along rows GGML_API struct ggml_tensor * ggml_norm( struct ggml_context * ctx, @@ -1388,6 +1500,23 @@ extern "C" { struct ggml_tensor * b, // row indices struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape + // a TD [n_embd, ne1, ne2, ne3] + // b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3 + // c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1) + // + // undefined behavior if destination rows overlap + // + // broadcast: + // ne2 % ne11 == 0 + // ne3 % ne12 == 0 + // + // return view(a) + GGML_API struct ggml_tensor * ggml_set_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, // destination + struct ggml_tensor * b, // source + struct ggml_tensor * c); // row indices + GGML_API struct ggml_tensor * ggml_diag( struct ggml_context * ctx, struct ggml_tensor * a); @@ -1425,8 +1554,14 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // a [ne0, ne01, ne02, ne03] + // mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional + // + // broadcast: + // ne02 % ne12 == 0 + // ne03 % ne13 == 0 + // // fused soft_max(a*scale + mask*(ALiBi slope)) - // mask is optional // max_bias = 0.0f for no ALiBi GGML_API struct ggml_tensor * ggml_soft_max_ext( struct ggml_context * ctx, @@ -1736,6 +1871,17 @@ extern "C" { struct ggml_tensor * b, int stride); + GGML_API struct ggml_tensor * ggml_conv_2d_direct( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC] + struct ggml_tensor * b, // input data [W, H, C, N] + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1); // dilation dimension 1 + enum ggml_op_pool { GGML_OP_POOL_MAX, GGML_OP_POOL_AVG, @@ -1778,6 +1924,12 @@ extern "C" { enum ggml_scale_mode { GGML_SCALE_MODE_NEAREST = 0, GGML_SCALE_MODE_BILINEAR = 1, + + GGML_SCALE_MODE_COUNT + }; + + enum ggml_scale_flag { + GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8) }; // interpolate @@ -1790,14 +1942,26 @@ extern "C" { // interpolate // interpolate scale to specified dimensions - GGML_API struct ggml_tensor * ggml_upscale_ext( + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_upscale_ext( struct ggml_context * ctx, struct ggml_tensor * a, int ne0, int ne1, int ne2, int ne3, - enum ggml_scale_mode mode); + enum ggml_scale_mode mode), + "use ggml_interpolate instead"); + + // Up- or downsamples the input to the specified size. + // 2D scale modes (eg. bilinear) are applied to the first two dimensions. + GGML_API struct ggml_tensor * ggml_interpolate( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + uint32_t mode); // ggml_scale_mode [ | ggml_scale_flag...] // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0] GGML_API struct ggml_tensor * ggml_pad( @@ -1860,11 +2024,17 @@ extern "C" { #define GGML_KQ_MASK_PAD 64 - // q: [n_embd_k, n_batch, n_head, 1] - // k: [n_embd_k, n_kv, n_head_kv, 1] - // v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !! - // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! - // res: [n_embd_v, n_head, n_batch, 1] !! permuted !! + // q: [n_embd_k, n_batch, n_head, ne3 ] + // k: [n_embd_k, n_kv, n_head_kv, ne3 ] + // v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !! + // mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !! + // + // broadcast: + // n_head % n_head_kv == 0 + // n_head % ne32 == 0 + // ne3 % ne33 == 0 + // GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, @@ -1903,7 +2073,8 @@ extern "C" { struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C); + struct ggml_tensor * C, + struct ggml_tensor * ids); // partition into non-overlapping windows with padding if needed // example: diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 9cb2c228dc..8760c2d35e 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -365,7 +365,6 @@ ggml_add_backend(BLAS) ggml_add_backend(CANN) ggml_add_backend(CUDA) ggml_add_backend(HIP) -ggml_add_backend(Kompute) ggml_add_backend(METAL) ggml_add_backend(MUSA) ggml_add_backend(RPC) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 2d93771fd1..042ea77aca 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -61,10 +61,6 @@ #include "ggml-cann.h" #endif -#ifdef GGML_USE_KOMPUTE -#include "ggml-kompute.h" -#endif - // disable C++17 deprecation warning for std::codecvt_utf8 #if defined(__clang__) # pragma clang diagnostic push @@ -189,9 +185,6 @@ struct ggml_backend_registry { #ifdef GGML_USE_RPC register_backend(ggml_backend_rpc_reg()); #endif -#ifdef GGML_USE_KOMPUTE - register_backend(ggml_backend_kompute_reg()); -#endif #ifdef GGML_USE_CPU register_backend(ggml_backend_cpu_reg()); #endif @@ -575,7 +568,6 @@ void ggml_backend_load_all_from_path(const char * dir_path) { ggml_backend_load_best("cann", silent, dir_path); ggml_backend_load_best("cuda", silent, dir_path); ggml_backend_load_best("hip", silent, dir_path); - ggml_backend_load_best("kompute", silent, dir_path); ggml_backend_load_best("metal", silent, dir_path); ggml_backend_load_best("rpc", silent, dir_path); ggml_backend_load_best("sycl", silent, dir_path); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index b1050ad59c..788861a365 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -817,8 +817,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str } if (sched->debug > 1) { ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node); - GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name, - fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node)); + GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name, + fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node), + graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]); for (int j = 0; j < GGML_MAX_SRC; j++) { struct ggml_tensor * src = node->src[j]; if (src == NULL) { @@ -1826,7 +1827,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) { ggml_free(copy.ctx_unallocated); } -bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) { +bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) { struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph); if (copy.buffer == NULL) { return false; @@ -1837,28 +1838,45 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t assert(g1->n_nodes == g2->n_nodes); - for (int i = 0; i < g1->n_nodes; i++) { - struct ggml_tensor * t1 = g1->nodes[i]; - struct ggml_tensor * t2 = g2->nodes[i]; + if (test_node != nullptr) { + // Compute the whole graph and only test the output for a specific tensor + ggml_backend_graph_compute(backend1, g1); + ggml_backend_graph_compute(backend2, g2); - assert(t1->op == t2->op && ggml_are_same_layout(t1, t2)); - - struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1); - struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1); - - ggml_backend_graph_compute(backend1, &g1v); - ggml_backend_graph_compute(backend2, &g2v); - - if (ggml_is_view_op(t1->op)) { - continue; + int test_node_idx = -1; + for (int i = 0; i < g1->n_nodes; i++) { + struct ggml_tensor * t1 = g1->nodes[i]; + if (t1 == test_node) { + test_node_idx = i; + break; + } } + GGML_ASSERT(test_node_idx != -1); - // compare results, calculate rms etc - if (!callback(i, t1, t2, user_data)) { - break; + callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data); + } else { + for (int i = 0; i < g1->n_nodes; i++) { + struct ggml_tensor * t1 = g1->nodes[i]; + struct ggml_tensor * t2 = g2->nodes[i]; + + assert(t1->op == t2->op && ggml_are_same_layout(t1, t2)); + + struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1); + struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1); + + ggml_backend_graph_compute(backend1, &g1v); + ggml_backend_graph_compute(backend2, &g2v); + + if (ggml_is_view_op(t1->op)) { + continue; + } + + // compare results, calculate rms etc + if (!callback(i, t1, t2, user_data)) { + break; + } } } - ggml_backend_graph_copy_free(copy); return true; diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 437ece2d4a..4d5c2c1825 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -65,8 +65,9 @@ #include #include #include -#include +#include #include +#include #include #include @@ -804,10 +805,11 @@ static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer, nb[i] = nb[i - 1] * ne[i - 1]; } - ggml_cann_async_memset(ctx, buffer, n_bytes, 0); aclTensor* zero = ggml_cann_create_tensor(buffer, type, type_size, ne, nb, dims); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, zero); return zero; + GGML_UNUSED(n_bytes); } /** @@ -2654,6 +2656,67 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* memcpy(ori_src0_nb, cast_nb, sizeof(ori_src0_nb)); } +#ifdef ASCEND_310P + ggml_tensor src0_row = *src0; + ggml_tensor src1_row = *src1; + ggml_tensor dst_row = *dst; + + if (src0->type == GGML_TYPE_F16) { + src0_row.type = GGML_TYPE_F32; + } + + // src0_row [D, M, 1, 1] weight without permute + src0_row.ne[2] = 1; + src0_row.ne[3] = 1; + src0_row.nb[0] = ori_src0_nb[0]; + src0_row.nb[1] = ori_src0_nb[1]; + src0_row.nb[2] = ori_src0_nb[1]; + src0_row.nb[3] = ori_src0_nb[1]; + + // src1_row [D, 1, 1, 1] -> input + src1_row.ne[1] = 1; + src1_row.ne[2] = 1; + src1_row.ne[3] = 1; + src1_row.nb[2] = nb11; + src1_row.nb[3] = nb11; + + // dst_row [M, 1, 1, 1] -> out + dst_row.ne[1] = 1; + dst_row.ne[2] = 1; + dst_row.ne[3] = 1; + dst_row.nb[2] = nb1; + dst_row.nb[3] = nb1; + + //create weight for one row + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + // expert index + int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + // If B = 1 (broadcast), always use 0; otherwise, use id. + int64_t i11 = (ne11 == 1 ? 0 : id); + int64_t i12 = iid1; + + int64_t i1 = id; + int64_t i2 = i12; + + void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2]; + void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12; + void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2; + + src0_row.data = src0_tmp_ptr; + src1_row.data = src1_tmp_ptr; + dst_row.data = dst_tmp_ptr; + dst_row.src[0] = &src0_row; + dst_row.src[1] = &src1_row; + + ggml_cann_mul_mat(ctx, &dst_row); + } + } + return; +#endif + std::vector src0_tensor_vec; std::vector src1_tensor_vec; std::vector dst_tensor_vec; @@ -2701,9 +2764,9 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* } size_t GROUP_SIZE = 128; - // GroupedMatmulV2 required tensor_list.size < 128 + // GroupedMatmulV3 required tensor_list.size < 128 for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) { - // split and call GroupedMatmulV2 + // split and call GroupedMatmulV3 size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size()); std::vector src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end); std::vector src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end); @@ -2713,7 +2776,7 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* aclTensorList* src1_tensor_list = aclCreateTensorList(src1_tensor_vec_split.data(), src1_tensor_vec_split.size()); aclTensorList* dst_tensor_list = aclCreateTensorList(dst_tensor_vec_split.data(), dst_tensor_vec_split.size()); - GGML_CANN_CALL_ACLNN_OP(ctx, GroupedMatmulV2, src1_tensor_list, src0_tensor_list, + GGML_CANN_CALL_ACLNN_OP(ctx, GroupedMatmulV3, src1_tensor_list, src0_tensor_list, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, -1, dst_tensor_list); ggml_cann_release_resources(ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list); diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index ba2cef0c25..8dfe3b061c 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -359,7 +359,7 @@ struct ggml_backend_cann_context { ggml_cann_set_device(device); description = aclrtGetSocName(); - bool async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or("")); + async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or("")); GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, device, async_mode ? "ON" : "OFF"); } diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index d1a0ad374d..eae575cc04 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2086,6 +2086,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, return false; } } break; + case GGML_OP_SET_ROWS: + { + // TODO: add support + // ref: https://github.com/ggml-org/llama.cpp/pull/14274 + return false; + } break; case GGML_OP_CPY: { ggml_tensor *src = op->src[0]; if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) || @@ -2187,7 +2193,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_SQRT: case GGML_OP_CLAMP: case GGML_OP_DIAG_MASK_INF: - case GGML_OP_SOFT_MAX: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: @@ -2205,6 +2210,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: return true; + case GGML_OP_SOFT_MAX: + // TODO: support broadcast + // ref: https://github.com/ggml-org/llama.cpp/pull/14435 + return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); case GGML_OP_FLASH_ATTN_EXT:{ // derived from [ggml-cuda.cu] if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){ @@ -2227,6 +2236,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, // DeepSeek MLA return false; } + // TODO: support broadcast + // ref: https://github.com/ggml-org/llama.cpp/pull/14435 if (op->src[0]->ne[3] != 1) { return false; } diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 671fad4d22..66a5ad8d2e 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -5,7 +5,7 @@ function(ggml_add_cpu_backend_features cpu_name arch) # build, using set_source_files_properties() to set the arch flags is not possible set(GGML_CPU_FEATS_NAME ${cpu_name}-feats) add_library(${GGML_CPU_FEATS_NAME} OBJECT ggml-cpu/arch/${arch}/cpu-feats.cpp) - target_include_directories(${GGML_CPU_FEATS_NAME} PRIVATE . .. ../include) + target_include_directories(${GGML_CPU_FEATS_NAME} PRIVATE . ../include) target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARGN}) target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED) set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) @@ -589,4 +589,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (EMSCRIPTEN) set_target_properties(${GGML_CPU_NAME} PROPERTIES COMPILE_FLAGS "-msimd128") endif() + + if (CMAKE_CXX_COMPILER_ID STREQUAL "IntelLLVM") + # The compiler automatically enables "-ffast-math" which can cause NaNs in tests due to "-fassociative-math" + target_compile_options(${GGML_CPU_NAME} PRIVATE "-fno-associative-math") + endif() endfunction() diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 7cae96f4b4..c5271b7757 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -195,6 +195,7 @@ typedef pthread_t ggml_thread_t; static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = { + .from_float = (ggml_from_float_t) ggml_cpu_fp32_to_fp32, .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, .vec_dot_type = GGML_TYPE_F32, .nrows = 1, @@ -1192,7 +1193,7 @@ static void ggml_compute_forward_mul_mat_one_chunk( } } -static void ggml_compute_forward_mul_mat( +void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -1817,6 +1818,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_get_rows_back(params, tensor); } break; + case GGML_OP_SET_ROWS: + { + ggml_compute_forward_set_rows(params, tensor); + } break; case GGML_OP_DIAG: { ggml_compute_forward_diag(params, tensor); @@ -1861,6 +1866,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_im2col_back_f32(params, tensor); } break; + case GGML_OP_CONV_2D: + { + ggml_compute_forward_conv_2d(params, tensor); + } break; case GGML_OP_CONV_2D_DW: { ggml_compute_forward_conv_2d_dw(params, tensor); @@ -1944,6 +1953,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_unary(params, tensor); } break; + case GGML_OP_GLU: + { + ggml_compute_forward_glu(params, tensor); + } break; case GGML_OP_GET_REL_POS: { ggml_compute_forward_get_rel_pos(params, tensor); @@ -2154,6 +2167,20 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { GGML_ABORT("fatal error"); } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + { + n_tasks = n_threads; + } break; + default: + GGML_ABORT("fatal error"); + } + break; case GGML_OP_SILU_BACK: case GGML_OP_MUL: case GGML_OP_DIV: @@ -2170,6 +2197,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = n_threads; } break; case GGML_OP_GET_ROWS: + case GGML_OP_SET_ROWS: { // FIXME: get_rows can use additional threads, but the cost of launching additional threads // decreases performance with GPU offloading @@ -2206,6 +2234,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_IM2COL: case GGML_OP_IM2COL_BACK: + case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_2D: @@ -2724,6 +2753,10 @@ struct ggml_cplan ggml_graph_plan( GGML_ABORT("fatal error"); } } break; + case GGML_OP_CONV_2D: + { + cur = GGML_IM2COL_WORK_SIZE; + } break; case GGML_OP_CONV_TRANSPOSE_2D: { const int64_t ne00 = node->src[0]->ne[0]; // W @@ -3124,6 +3157,10 @@ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct g return ggml_graph_compute(cgraph, &cplan); } +void ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) { + memcpy(y, x, n * sizeof(float)); +} + void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) { int64_t i = 0; #if defined(__F16C__) diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index a98866a2d8..c9daa4c39e 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -416,6 +416,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st switch (op->op) { case GGML_OP_CPY: + case GGML_OP_SET_ROWS: return op->type != GGML_TYPE_IQ3_XXS && op->type != GGML_TYPE_IQ3_S && diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index bc61080797..a427de404e 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3,6 +3,7 @@ #include "ggml-cpu.h" #include "ggml-impl.h" #include "binary-ops.h" +#include "ggml.h" #include "unary-ops.h" #include "vec.h" @@ -696,24 +697,8 @@ static void ggml_compute_forward_dup_f32( if (ggml_is_contiguous(dst)) { // TODO: simplify if (nb00 == sizeof(float)) { - if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - const size_t rs = ne00 * nb00; - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - memcpy(dst_ptr + id, src0_ptr, rs); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else if (ggml_get_type_traits_cpu(dst->type)->from_float) { - ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float; + if (ggml_get_type_traits_cpu(dst->type)->from_float) { + ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float; size_t id = 0; size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); @@ -724,7 +709,7 @@ static void ggml_compute_forward_dup_f32( id += rs * ir0; for (int i01 = ir0; i01 < ir1; i01++) { const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - quantize_row_q(src0_ptr, dst_ptr + id, ne00); + from_float(src0_ptr, dst_ptr + id, ne00); id += rs; } id += rs * (ne01 - ir1); @@ -2300,6 +2285,12 @@ void ggml_compute_forward_repeat( { ggml_compute_forward_repeat_f32(params, dst); } break; + // TODO: templateify the implemenation and support for I64 + // ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225 + //case GGML_TYPE_I64: + // { + // ggml_compute_forward_repeat_i64(params, dst); + // } break; default: { GGML_ABORT("fatal error"); @@ -3194,6 +3185,721 @@ void ggml_compute_forward_silu_back( } } +// ggml_compute_forward_reglu + +static void ggml_compute_forward_reglu_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_reglu_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_reglu( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_reglu_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_reglu_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_geglu + +static void ggml_compute_forward_geglu_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_geglu_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_geglu( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_geglu_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_geglu_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_swiglu + +static void ggml_compute_forward_swiglu_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_swiglu_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_swiglu( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_swiglu_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_swiglu_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_geglu_erf + +static void ggml_compute_forward_geglu_erf_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_geglu_erf_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_geglu_erf( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_geglu_erf_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_geglu_erf_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_geglu_quick + +static void ggml_compute_forward_geglu_quick_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * src0_p = (float *) (src0_d + i1*src0_o); + float * src1_p = (float *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_geglu_quick_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + char * src0_d = (char *) src0->data; + char * src1_d = (char *) (src1 ? src1->data : src0->data); + const size_t src0_o = src0->nb[1]; + const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + const int nr = ggml_nrows(src0); + + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == nr); + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o); + ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o); + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_geglu_quick( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_geglu_quick_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_geglu_quick_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_norm static void ggml_compute_forward_norm_f32( @@ -4475,6 +5181,74 @@ void ggml_compute_forward_get_rows( //} } +static void ggml_compute_forward_set_rows_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ne01; + + assert(ne0 == nc); + assert(ne2 == ne02); + assert(ne3 == ne03); + assert(src0->type == GGML_TYPE_F32); + assert(ne02 % ne11 == 0); + assert(ne03 % ne12 == 0); + + const int ith = params->ith; + const int nth = params->nth; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = std::min(ir0 + dr, nr); + + ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float; + + for (int64_t i03 = 0; i03 < ne03; ++i03) { + for (int64_t i02 = 0; i02 < ne02; ++i02) { + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i03%ne12; + const int64_t i11 = i02%ne11; + const int64_t i10 = i; + + const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + GGML_ASSERT(i1 >= 0 && i1 < ne1); + + from_float( + (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03), + ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc); + } + } + } +} + +void ggml_compute_forward_set_rows( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_set_rows_f32(params, dst); + } break; + default: + { + GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type)); + } + } +} + // ggml_compute_forward_get_rows_back static void ggml_compute_forward_get_rows_back_f32_f16( @@ -4749,14 +5523,17 @@ static void ggml_compute_forward_soft_max_f32( memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - // TODO: handle transposed/permuted matrices - const int ith = params->ith; const int nth = params->nth; GGML_TENSOR_UNARY_OP_LOCALS - //const int64_t ne11 = src1 ? src1->ne[1] : 1; + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; + + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; // TODO: is this supposed to be ceil instead of floor? // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 @@ -4766,68 +5543,66 @@ static void ggml_compute_forward_soft_max_f32( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; + float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - for (int i1 = ir0; i1 < ir1; i1++) { - // ALiBi - const uint32_t h = (i1/ne01)%ne02; // head - const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const int64_t i11 = i01; + const int64_t i12 = i02%ne12; + const int64_t i13 = i03%ne13; - float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); - float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); + // ALiBi + const uint32_t h = i02; // head + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; - // broadcast the mask across rows - ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - ggml_vec_cpy_f32 (nc, wp, sp); - ggml_vec_scale_f32(nc, wp, scale); - if (mp_f32) { - if (use_f16) { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]); + // broadcast the mask across rows + ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL; + + ggml_vec_cpy_f32 (ne00, wp, sp); + ggml_vec_scale_f32(ne00, wp, scale); + if (mp_f32) { + if (use_f16) { + for (int i = 0; i < ne00; ++i) { + wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]); + } + } else { + for (int i = 0; i < ne00; ++i) { + wp[i] += slope*mp_f32[i]; + } + } } - } else { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*mp_f32[i]; + +#ifndef NDEBUG + for (int i = 0; i < ne00; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(wp[i])); } +#endif + + float max = -INFINITY; + ggml_vec_max_f32(ne00, &max, wp); + + ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max); + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(ne00, dp, sum); + +#ifndef NDEBUG + for (int i = 0; i < ne00; ++i) { + assert(!isnan(dp[i])); + assert(!isinf(dp[i])); + } +#endif } } - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(wp[i])); - } -#endif - - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, wp); - - ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_vec_scale_f32(nc, dp, sum); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(dp[i])); - assert(!isinf(dp[i])); - } -#endif } } @@ -6063,6 +6838,186 @@ void ggml_compute_forward_im2col_back_f32( } } +static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k, + void * a, void * b, float * c) { + const ggml_type_traits * traits = ggml_get_type_traits(type); + struct ggml_tensor src1 = {}; + src1.type = type; + src1.ne[0] = k; + src1.ne[1] = m; + src1.ne[2] = 1; + src1.ne[3] = 1; + src1.nb[0] = traits->type_size; + src1.nb[1] = k * traits->type_size; + src1.nb[2] = src1.nb[1]; + src1.nb[3] = src1.nb[2]; + src1.data = a; + + struct ggml_tensor src0 = {}; + src0.type = type; + src0.ne[0] = k; + src0.ne[1] = n; + src0.ne[2] = 1; + src0.ne[3] = 1; + src0.nb[0] = traits->type_size; + src0.nb[1] = k * traits->type_size; + src0.nb[2] = src0.nb[1]; + src0.nb[3] = src0.nb[2]; + src0.data = b; + + struct ggml_tensor dst = {}; + dst.ne[0] = n; + dst.ne[1] = m; + dst.ne[2] = 1; + dst.ne[3] = 1; + dst.nb[0] = sizeof(float); + dst.nb[1] = n * sizeof(float); + dst.nb[2] = dst.nb[1]; + dst.nb[3] = dst.nb[2]; + dst.data = c; + dst.src[0] = &src0; + dst.src[1] = &src1; + + ggml_compute_forward_mul_mat(params, &dst); +} + +// ggml_compute_forward_conv_2d + +static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params, + const ggml_tensor * kernel, // [KW, KH, IC, OC] + const ggml_tensor * src, // [W, H, C, N] + ggml_tensor * dst, // [OW, OH, OC, N] + ggml_type kernel_type) { + + GGML_ASSERT(ggml_is_contiguous(kernel)); + GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32); + GGML_ASSERT(kernel->type == kernel_type); + + const ggml_type_traits * traits = ggml_get_type_traits(kernel_type); + + const int32_t stride_x = dst->op_params[0]; + const int32_t stride_y = dst->op_params[1]; + const int32_t pad_x = dst->op_params[2]; + const int32_t pad_y = dst->op_params[3]; + const int32_t dilation_x = dst->op_params[4]; + const int32_t dilation_y = dst->op_params[5]; + + const int64_t c_in = src->ne[2]; + const int64_t c_out = kernel->ne[3]; + GGML_ASSERT(c_in == kernel->ne[2]); + + const int64_t src_w = src->ne[0]; + const int64_t src_h = src->ne[1]; + const int64_t knl_w = kernel->ne[0]; + const int64_t knl_h = kernel->ne[1]; + const int64_t dst_w = dst->ne[0]; + const int64_t dst_h = dst->ne[1]; + + const float * src_data = (float *) src->data; + void * knl_data = kernel->data; + float * dst_data = (float *) dst->data; + + const int64_t knl_n = knl_w * knl_h * c_in; + const int64_t patch_total = dst->ne[3] * dst_w * dst_h; + + const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float); + const int64_t batch_size = params->wsize / space_per_patch; + const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size; + const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch; + + GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1); + + void * tmp = params->wdata; + + for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) { + + const int64_t patch_start_batch = batch_i * patches_per_batch; + const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, + patch_total); + const int64_t patch_n = patch_end_batch - patch_start_batch; + + const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth; + const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread; + const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch); + + //im2col for a patch + for (int64_t p = patch_start; p < patch_end; ++p) { + const int64_t batch_n = p / (dst_w * dst_h); + const int64_t src_x = (p / dst_w) % dst_h; + const int64_t src_y = p % dst_w; + + const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]); + char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size; + + for (int64_t ic = 0; ic < c_in; ++ic) { + for (int64_t ky = 0; ky < knl_h; ++ky) { + for (int64_t kx = 0; kx < knl_w; ++kx) { + const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y; + const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x; + + int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx; + + float src_val; + if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) { + src_val = 0.0f; + } else { + const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]); + src_val = *src_ptr; + } + + char * element_ptr = dst_row + dst_idx * traits->type_size; + if (kernel_type == GGML_TYPE_F32) { + *(float *) element_ptr = src_val; + } else if (kernel_type == GGML_TYPE_F16) { + *(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val); + } + } + } + } + } // patches handled by this thread + + ggml_barrier(params->threadpool); + + float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size); + + GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize); + + // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out] + ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output); + + ggml_barrier(params->threadpool); + + + //permute back [OC, N, OH, OW] to [N, OC, OH, OW] + const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth; + const int64_t permute_start = params->ith * permute_per_thread; + const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n); + + for (int64_t i = permute_start; i < permute_end; ++i) { + const int64_t p = patch_start_batch + i; + const int64_t batch_n = p / (dst_w * dst_h); + const int64_t dst_y = (p / dst_w) % dst_h; + const int64_t dst_x = p % dst_w; + + for (int64_t oc = 0; oc < c_out; ++oc) { + const float value = gemm_output[i * c_out + oc]; + float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]); + *dst_ptr = value; + } + } + } +} + +void ggml_compute_forward_conv_2d( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type); +} + // ggml_compute_forward_conv_transpose_2d void ggml_compute_forward_conv_transpose_2d( @@ -6613,12 +7568,13 @@ static void ggml_compute_forward_upscale_f32( GGML_TENSOR_UNARY_OP_LOCALS - const float sf0 = (float)ne0/src0->ne[0]; - const float sf1 = (float)ne1/src0->ne[1]; - const float sf2 = (float)ne2/src0->ne[2]; - const float sf3 = (float)ne3/src0->ne[3]; + float sf0 = (float)ne0/src0->ne[0]; + float sf1 = (float)ne1/src0->ne[1]; + float sf2 = (float)ne2/src0->ne[2]; + float sf3 = (float)ne3/src0->ne[3]; - const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0); + const int32_t mode_flags = ggml_get_op_params_i32(dst, 0); + const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF); if (mode == GGML_SCALE_MODE_NEAREST) { for (int64_t i3 = 0; i3 < ne3; i3++) { @@ -6639,8 +7595,12 @@ static void ggml_compute_forward_upscale_f32( } } } else if (mode == GGML_SCALE_MODE_BILINEAR) { - // setting a pixel offset of 0 would replicate the behavior of pytorch interpolate with align_corners=True - const float pixel_offset = 0.5f; + float pixel_offset = 0.5f; + if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { + pixel_offset = 0.0f; + sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1); + sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1); + } for (int64_t i3 = 0; i3 < ne3; i3++) { const int64_t i03 = i3 / sf3; @@ -7098,7 +8058,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type; + ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type; ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float; ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot; ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float; @@ -7130,7 +8090,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( memset(VKQ32, 0, DV*sizeof(float)); } - const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; + const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL; // k indices const int ik3 = iq3 / rk3; @@ -7668,120 +8628,210 @@ void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_scan_f32( const ggml_compute_params * params, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; // s - const ggml_tensor * src1 = dst->src[1]; // x - const ggml_tensor * src2 = dst->src[2]; // dt - const ggml_tensor * src3 = dst->src[3]; // A - const ggml_tensor * src4 = dst->src[4]; // B - const ggml_tensor * src5 = dst->src[5]; // C + const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+} + const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head} + const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} + const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} + const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs} const int ith = params->ith; const int nth = params->nth; - const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens per sequence - const int64_t n_s = src0->ne[2]; // number of sequences in the batch + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // dim + const int64_t nh = src1->ne[1]; // n_head + const int64_t ng = src4->ne[1]; + const int64_t nt = src1->ne[2]; // number of tokens per sequence + const int64_t ns = src1->ne[3]; // number of sequences in the batch - GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + // can't use ggml_nbytes because src1 is not necessarily contiguous + const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1); + + GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - // required for the dot product between s and C - GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // required for per-sequence offsets for states - GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[3]) - GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); + // allows optimizing the modulo since n_group should be a power of 2 + GGML_ASSERT((ng & -ng) == ng); - // rows per thread - const int dr = (nr + nth - 1)/nth; + // heads per thread + const int dh = (nh + nth - 1)/nth; - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - const int ir = ir1 - ir0; + // head range for this thread + const int ih0 = dh*ith; + const int ih1 = MIN(ih0 + dh, nh); - #ifdef __ARM_FEATURE_SVE - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + const int32_t * ids = (const int32_t *) src6->data; - // use the output as the source for the next token-wise iterations - if (i2 > 0) { s0 = s; } + for (int i3 = 0; i3 < ns; ++i3) { + const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns} + float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt); - svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus); - svfloat32_t r1_vector = GGML_F32_VEC_ZERO; + for (int i2 = 0; i2 < nt; ++i2) { + const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns} + const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns} + const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} + float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} - for (int64_t k = 0; k < nc; k += svcntw()) { - svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]); - svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]); - svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]); - svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]); + if (src3->ne[0] == 1) { + // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop - svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA); - t1 = exp_ps_sve(svptrue_b32(), t1); - svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB); + // n_head + for (int h = ih0; h < ih1; ++h) { + // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 + const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dA = expf(dt_soft_plus * A[h]); - vs0 = GGML_F32_VEC_FMA(vs0, t1, t2); - r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector); + // dim + for (int i1 = 0; i1 < nr; ++i1) { + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; + float sumf = 0.0f; +#if defined(GGML_SIMD) + #if defined(__ARM_FEATURE_SVE) + const int ggml_f32_epr = svcntw(); + const int ggml_f32_step = 1 * ggml_f32_epr; - GGML_F32_VEC_STORE(&s[i1*nc + k], vs0); - } - y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector); - } - } - } + const int np = (nc & ~(ggml_f32_step - 1)); + + GGML_F32_VEC sum = GGML_F32_VEC_ZERO; + + GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA); + GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt); + + for (int i = 0; i < np; i += ggml_f32_step) { + // TODO: maybe unroll more? + for (int j = 0; j < 1; j++) { + GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc); + GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc); + GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc); + + t0 = GGML_F32_VEC_MUL(t0, adA); + t1 = GGML_F32_VEC_MUL(t1, axdt); + + t0 = GGML_F32_VEC_ADD(t0, t1); + + sum = GGML_F32_VEC_FMA(sum, t0, t2); + + GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0); + } + } + + sumf = GGML_F32xt_REDUCE_ONE(sum); #else - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + const int np = (nc & ~(GGML_F32_STEP - 1)); - // use the output as the source for the next token-wise iterations - if (i2 > 0) { s0 = s; } + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; + GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA); + GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + GGML_F32_VEC az[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc); + ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + + ax[j] = GGML_F32_VEC_MUL(ax[j], adA); + ay[j] = GGML_F32_VEC_MUL(ay[j], axdt); + + ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]); + + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]); + + GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); + #endif +#else + const int np = 0; +#endif + // d_state + for (int i0 = np; i0 < nc; ++i0) { + const int i = i0 + ii*nc; + const int ig = i0 + (h & (ng - 1))*nc; + // state = prev_state * dA + dB * x + const float state = (s0[i] * dA) + (B[ig] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[ig]; + s[i] = state; + } + y[ii] = sumf; + } + } + } else { + // Mamba-1 has an element-wise decay factor for the states + + // n_head + for (int h = ih0; h < ih1; ++h) { + // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 + const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + + // dim + for (int i1 = 0; i1 < nr; ++i1) { + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; +#if defined(__ARM_FEATURE_SVE) + svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt); + svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus); + svfloat32_t r1_vector = GGML_F32_VEC_ZERO; + + // d_state + // TODO: what happens when (d_state % svcntw()) != 0? + for (int64_t k = 0; k < nc; k += svcntw()) { + svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]); + svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]); + svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]); + svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]); + + svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA); + t1 = exp_ps_sve(svptrue_b32(), t1); + svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB); + + vs0 = GGML_F32_VEC_FMA(t2, vs0, t1); + r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector); + + GGML_F32_VEC_STORE(&s[ii*nc + k], vs0); + } + y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector); +#else + float sumf = 0.0f; + // NOTE: can't really use GGML_SIMD here because d_state is usually 16 + // and also because expf is used within the loop. + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + const int i = i0 + ii*nc; + const int ig = i0 + (h & (ng - 1))*nc; + // state = prev_state * dA + dB * x + const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[ig]; + s[i] = state; + } + y[ii] = sumf; +#endif } - y[i1] = sumf; } } + // use the output as the source when it's not the first token-wise iteration + s0 = s; } - #endif + } } void ggml_compute_forward_ssm_scan( @@ -7999,6 +9049,42 @@ void ggml_compute_forward_unary( } } +//ggml_compute_forward_glu + +void ggml_compute_forward_glu( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_glu_op op = ggml_get_glu_op(dst); + + switch (op) { + case GGML_GLU_OP_REGLU: + { + ggml_compute_forward_reglu(params, dst); + } break; + case GGML_GLU_OP_GEGLU: + { + ggml_compute_forward_geglu(params, dst); + } break; + case GGML_GLU_OP_SWIGLU: + { + ggml_compute_forward_swiglu(params, dst); + } break; + case GGML_GLU_OP_GEGLU_ERF: + { + ggml_compute_forward_geglu_erf(params, dst); + } break; + case GGML_GLU_OP_GEGLU_QUICK: + { + ggml_compute_forward_geglu_quick(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_get_rel_pos static void ggml_compute_forward_get_rel_pos_f16( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 2d8544d7d3..3a32ec20db 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -20,6 +20,9 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); +// Work buffer size for im2col operations in CONV2D +#define GGML_IM2COL_WORK_SIZE (16 * 1024 * 1024) + #ifdef __cplusplus extern "C" { #endif @@ -53,6 +56,7 @@ void ggml_compute_forward_permute(const struct ggml_compute_params * params, str void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst); @@ -64,6 +68,7 @@ void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struc void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); @@ -93,6 +98,7 @@ void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, st void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_glu(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst); @@ -105,6 +111,7 @@ void ggml_compute_forward_custom(const struct ggml_compute_params * params, stru void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst); #ifdef __cplusplus } diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index b68ac0dd68..b4ad68c9fd 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -189,7 +189,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { #define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b) #define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__) -#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c) +#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a) #define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b) #define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__) diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 5e34d79a16..a8156011eb 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -37,35 +37,35 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G for (int i = 0; i < np; i += ggml_f32_step) { ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); + sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1); ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); - sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2); + sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2); ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); - sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3); + sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3); ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); - sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4); + sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4); ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); - sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5); + sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5); ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); - sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6); + sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6); ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); - sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7); + sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7); ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); - sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8); + sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8); } // leftovers // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop @@ -73,7 +73,7 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G for (int i = np; i < np2; i += ggml_f32_epr) { ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); + sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1); } // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only if (np2 < n) { @@ -254,6 +254,30 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) { } } +void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i))); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i))); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i))); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i))); + } +#endif + for (; i < n; ++i) { + y[i] = ggml_silu_f32(x[i]) * g[i]; + } +} + ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { int i = 0; ggml_float sum = 0; diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 84f6c0e6d2..1f5857a23e 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); + ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx); GGML_F32_VEC_STORE(y + i, ay1); ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); - ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2); + ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx); GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2); ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); - ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3); + ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx); GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3); ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); - ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4); + ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx); GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4); ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); - ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5); + ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx); GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5); ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); - ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6); + ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx); GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6); ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); - ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7); + ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx); GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7); ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); - ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8); + ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx); GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8); } @@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const for (int i = np; i < np2; i += ggml_f32_epr) { ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); + ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx); GGML_F32_VEC_STORE(y + i, ay1); } @@ -905,6 +905,100 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con } } +inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) { + for (int i = 0; i < n; ++i) { + y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f; + } +} + +inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { + for (int i = 0; i < n; ++i) { + float v = GGML_CPU_FP16_TO_FP32(x[i]); + y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v * GGML_CPU_FP16_TO_FP32(g[i]) : 0.f); + } +} + +#ifdef GGML_GELU_FP16 +inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) { + uint16_t t; + for (int i = 0; i < n; ++i) { + if (x[i] <= -10.0f) { + y[i] = 0.0f; + } else if (x[i] >= 10.0f) { + y[i] = x[i] * g[i]; + } else { + ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i]; + } + } +} +#else +inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_f32(x[i]) * g[i]; + } +} +#endif + +inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + float v = GGML_CPU_FP16_TO_FP32(g[i]); + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v); + } +} + +void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g); + +inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { + for (int i = 0; i < n; ++i) { + float v = GGML_CPU_FP16_TO_FP32(x[i]); + float w = GGML_CPU_FP16_TO_FP32(g[i]); + y[i] = GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w); + } +} + +inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) { + for (int i = 0; i < n; ++i) { + float xi = x[i]; + y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i]; + } +} + +inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { + for (int i = 0; i < n; ++i) { + float xi = GGML_CPU_FP16_TO_FP32(x[i]); + float gi = GGML_CPU_FP16_TO_FP32(g[i]); + y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi); + } +} + +#ifdef GGML_GELU_QUICK_FP16 +inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) { + uint16_t t; + for (int i = 0; i < n; ++i) { + ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i]; + } +} +#else +inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_quick_f32(x[i]) * g[i]; + } +} +#endif + +inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + float v = GGML_CPU_FP16_TO_FP32(g[i]); + y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v); + } +} + inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE ggml_float sum = 0.0; diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ea20355023..1a2708ec9d 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -175,6 +175,23 @@ static const char * cu_get_error_str(CUresult err) { #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) #endif +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) +# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \ + do { \ + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \ + const int id = ggml_cuda_get_device(); \ + if (!shared_memory_limit_raised[id]) { \ + CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \ + shared_memory_limit_raised[id] = true; \ + } \ + } while (0) +#else +# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \ + do { \ + GGML_UNUSED(nbytes); \ + } while (0) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) #define GGML_CUDA_ASSUME(x) __builtin_assume(x) #else diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index c6dec4276b..eeaa14bf57 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -728,3 +728,25 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { return nullptr; } } + +to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_cuda; + case GGML_TYPE_F16: + return convert_unary_cuda; + default: + return nullptr; + } +} + +to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_F16: + return convert_unary_cuda; + case GGML_TYPE_BF16: + return convert_unary_cuda; + default: + return nullptr; + } +} diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index b65b98e08e..f04214be17 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -22,5 +22,10 @@ using to_t_nc_cuda_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream); +typedef to_t_nc_cuda_t to_fp32_nc_cuda_t; typedef to_t_nc_cuda_t to_fp16_nc_cuda_t; +typedef to_t_nc_cuda_t to_bf16_nc_cuda_t; + +to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type); to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type); +to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type); diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cu b/ggml/src/ggml-cuda/cross-entropy-loss.cu index 0ce4afbb22..0c8b081972 100644 --- a/ggml/src/ggml-cuda/cross-entropy-loss.cu +++ b/ggml/src/ggml-cuda/cross-entropy-loss.cu @@ -123,13 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_cuda_pool_alloc dst_tmp(pool, blocks_num.x); if (nbytes_shared <= smpbo) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) - static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; - if (!shared_memory_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); - shared_memory_limit_raised[id] = true; - } -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32), smpbo); cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); } else { cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); @@ -175,13 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten const size_t smpbo = ggml_cuda_info().devices[id].smpbo; if (nbytes_shared <= smpbo) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) - static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; - if (!shared_memory_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); - shared_memory_limit_raised[id] = true; - } -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32), smpbo); cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); } else { cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index cfab2b5eba..075f14a49e 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -32,7 +32,9 @@ typedef void (* fattn_kernel_t)( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -851,7 +853,8 @@ void launch_fattn( scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, + mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, Q->nb[1], Q->nb[2], Q->nb[3], nb11, nb12, nb13, nb21, nb22, nb23, diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index e230f6d494..709589854f 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1223,7 +1223,9 @@ static __global__ void flash_attn_ext_f16( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -1288,7 +1290,8 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : + (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); @@ -1327,7 +1330,8 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; + const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : + (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); @@ -1348,8 +1352,8 @@ static __global__ void flash_attn_ext_f16( GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); - GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 9283560d5c..0c967f178e 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -6,7 +6,7 @@ template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(nwarps*WARP_SIZE, 1) +__launch_bounds__(nwarps*WARP_SIZE, 2) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_tile_ext_f16( const char * __restrict__ Q, @@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f16( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -64,7 +66,7 @@ static __global__ void flash_attn_tile_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const int stride_KV2 = nb11 / sizeof(half2); @@ -288,8 +290,8 @@ static __global__ void flash_attn_tile_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 32673adb57..908c76dbdd 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -6,7 +6,7 @@ template // D == head size #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(nwarps*WARP_SIZE, 1) +__launch_bounds__(nwarps*WARP_SIZE, 2) #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_tile_ext_f32( const char * __restrict__ Q, @@ -30,7 +30,9 @@ static __global__ void flash_attn_tile_ext_f32( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -58,8 +60,8 @@ static __global__ void flash_attn_tile_ext_f32( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); @@ -76,7 +78,7 @@ static __global__ void flash_attn_tile_ext_f32( const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const int stride_KV2 = nb11 / sizeof(half2); @@ -297,14 +299,14 @@ static __global__ void flash_attn_tile_ext_f32( GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); - GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); - GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); - GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); - GGML_UNUSED(ne2); GGML_UNUSED(ne3); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); + GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); + GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); + GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; #endif // FLASH_ATTN_AVAILABLE } diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 35e649cb3c..e78fb18191 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f16( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -68,7 +70,7 @@ static __global__ void flash_attn_vec_ext_f16( K += nb12*(blockIdx.z / gqa_ratio); V += nb22*(blockIdx.z / gqa_ratio); - const half * maskh = (const half *) mask + ne11*ic0; + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); @@ -342,8 +344,8 @@ static __global__ void flash_attn_vec_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 9539679177..b2f1724c95 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f32( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -51,8 +53,8 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); @@ -79,7 +81,8 @@ static __global__ void flash_attn_vec_ext_f32( Q += nb02* blockIdx.z + nb01*ic0; K += nb12*(blockIdx.z / gqa_ratio); V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; + + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); @@ -334,13 +337,15 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); - GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); - GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); - GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); - GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); - GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); - GGML_UNUSED(ne2); GGML_UNUSED(ne3); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); + GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); + GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); + GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; #endif // FLASH_ATTN_AVAILABLE } diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index f3b794c364..c95ca7b1f2 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -46,7 +46,9 @@ static __global__ void flash_attn_ext_f16( const int ne12, const int ne13, const int ne31, + const int ne32, const int nb31, + const int nb32, const int nb01, const int nb02, const int nb03, @@ -94,11 +96,11 @@ static __global__ void flash_attn_ext_f16( constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0); - const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; - const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); + const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0); + const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); + const half2 * mask2 = (const half2 *) maskh; const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); @@ -440,7 +442,7 @@ static __global__ void flash_attn_ext_f16( GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 963e4d03dd..f77b2629a1 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type( get_rows_cuda_float((const float *) src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; + case GGML_TYPE_I32: + get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; case GGML_TYPE_BF16: get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); @@ -210,6 +214,10 @@ void get_rows_cuda( ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; + case GGML_TYPE_I32: + ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; case GGML_TYPE_F16: ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b30c13c62f..da1e8f8f4e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1749,7 +1749,7 @@ static void ggml_cuda_op_mul_mat( } static __global__ void k_compute_batched_ptrs( - const half * src0_as_f16, const half * src1_as_f16, char * dst, + const void * src0_as_f16, const void * src1_as_f16, char * dst, const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23, @@ -1772,83 +1772,131 @@ static __global__ void k_compute_batched_ptrs( ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3; } -static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +// Type traits for mapping ggml types to CUDA/cuBLAS types +template +struct batched_mul_mat_traits; + +template<> +struct batched_mul_mat_traits { + using cuda_type = float; + static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; + static inline const cudaDataType_t data_type = CUDA_R_32F; + static inline const ggml_type ggml_type_val = GGML_TYPE_F32; + static inline const float alpha = 1.0f; + static inline const float beta = 0.0f; + static inline const void* get_alpha() { static const float val = alpha; return &val; } + static inline const void* get_beta() { static const float val = beta; return &val; } + static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); } +}; + +template<> +struct batched_mul_mat_traits { + using cuda_type = nv_bfloat16; + static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; + static inline const cudaDataType_t data_type = CUDA_R_16BF; + static inline const ggml_type ggml_type_val = GGML_TYPE_BF16; + static inline const float alpha = 1.0f; + static inline const float beta = 0.0f; + static inline const void* get_alpha() { static const float val = alpha; return &val; } + static inline const void* get_beta() { static const float val = beta; return &val; } + static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); } +}; + +template<> +struct batched_mul_mat_traits { + using cuda_type = half; + static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; + static inline const cudaDataType_t data_type = CUDA_R_16F; + static inline const ggml_type ggml_type_val = GGML_TYPE_F16; + static inline const half alpha = 1.0; + static inline const half beta = 0.0; + static inline const void* get_alpha() { static const half val = alpha; return &val; } + static inline const void* get_beta() { static const half val = beta; return &val; } + static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); } +}; + +template +static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + using traits = batched_mul_mat_traits; + using cuda_t = typename traits::cuda_type; + GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); - GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft)); - GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == src0_type); + GGML_ASSERT(ggml_is_contiguous(dst)); // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst. // As long as dst is contiguous this does not matter though. - GGML_ASSERT(ggml_is_contiguous(dst)); GGML_TENSOR_BINARY_OP_LOCALS const int64_t ne_dst = ggml_nelements(dst); - cudaStream_t main_stream = ctx.stream(); - CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream)); - const half * src0_f16 = (const half *) src0->data; float * dst_ddf = (float *) dst->data; - - const half * src1_f16 = (const half *) src1->data; const size_t ts_src1 = ggml_type_size(src1->type); GGML_ASSERT(nb10 == ts_src1); int64_t s11 = nb11 / ts_src1; int64_t s12 = nb12 / ts_src1; int64_t s13 = nb13 / ts_src1; - ggml_cuda_pool_alloc src1_f16_alloc(ctx.pool()); - // convert src1 to fp16 - if (src1->type != GGML_TYPE_F16) { - const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type); + const cuda_t * src0_ptr = nullptr; + const cuda_t * src1_ptr = nullptr; + + ggml_cuda_pool_alloc src0_alloc(ctx.pool()); + ggml_cuda_pool_alloc src1_alloc(ctx.pool()); + + // Handle src0 + src0_ptr = (const cuda_t *) src0->data; + + // Handle src1 - convert if necessary + if (src1->type == src0_type) { + src1_ptr = (const cuda_t *) src1->data; + } else { + // Convert src1 to target type using traits conversion functions const int64_t ne_src1 = ggml_nelements(src1); - src1_f16_alloc.alloc(ne_src1); - GGML_ASSERT(to_fp16_cuda != nullptr); + src1_alloc.alloc(ne_src1); - to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream); - - src1_f16 = src1_f16_alloc.get(); + const auto convert_func = traits::get_nc_converter(src1->type); + GGML_ASSERT(convert_func != nullptr); + convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream); + src1_ptr = src1_alloc.get(); s11 = ne10; s12 = ne11*s11; s13 = ne12*s12; } - ggml_cuda_pool_alloc dst_f16(ctx.pool()); + // Setup destination buffer + ggml_cuda_pool_alloc dst_temp(ctx.pool()); char * dst_t; - - cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F; - cudaDataType_t cu_data_type = CUDA_R_16F; - - // dst strides size_t nbd2 = dst->nb[2]; size_t nbd3 = dst->nb[3]; - const half alpha_f16 = 1.0f; - const half beta_f16 = 0.0f; - + cublasComputeType_t cu_compute_type = traits::compute_type; + cudaDataType_t cu_data_type = traits::data_type; + cudaDataType_t cu_data_type_a = traits::data_type; + cudaDataType_t cu_data_type_b = traits::data_type; + const void * alpha = traits::get_alpha(); + const void * beta = traits::get_beta(); const float alpha_f32 = 1.0f; - const float beta_f32 = 0.0f; - - const void * alpha = &alpha_f16; - const void * beta = &beta_f16; + const float beta_f32 = 0.0f; if (dst->op_params[0] == GGML_PREC_DEFAULT) { - dst_t = (char *) dst_f16.alloc(ne_dst); - - nbd2 /= sizeof(float) / sizeof(half); - nbd3 /= sizeof(float) / sizeof(half); + if constexpr (src0_type == GGML_TYPE_F32) { + dst_t = (char *) dst_ddf; // Direct F32 output + } else { + dst_t = (char *) dst_temp.alloc(ne_dst); + nbd2 /= sizeof(float) / sizeof(cuda_t); + nbd3 /= sizeof(float) / sizeof(cuda_t); + } } else { dst_t = (char *) dst_ddf; - cu_compute_type = CUBLAS_COMPUTE_32F; - cu_data_type = CUDA_R_32F; - + cu_data_type = CUDA_R_32F; alpha = &alpha_f32; - beta = &beta_f32; + beta = &beta_f32; } int id = ggml_cuda_get_device(); @@ -1856,7 +1904,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { cu_compute_type = CUBLAS_COMPUTE_32F; alpha = &alpha_f32; - beta = &beta_f32; + beta = &beta_f32; } GGML_ASSERT(ne12 % ne02 == 0); @@ -1866,35 +1914,15 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co const int64_t r2 = ne12/ne02; const int64_t r3 = ne13/ne03; -#if 0 - // use cublasGemmEx - { - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - int i03 = i13 / r3; - int i02 = i12 / r2; - - CUBLAS_CHECK( - cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half), - src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11, - beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0, - cu_compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - } - } - } -#else if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 // use cublasGemmStridedBatchedEx CUBLAS_CHECK( cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA - src1_f16, CUDA_R_16F, s11, s12, // strideB - beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC + alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA + src1_ptr, cu_data_type_b, s11, s12, // strideB + beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC ne12*ne13, cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); @@ -1905,34 +1933,55 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co ggml_cuda_pool_alloc ptrs_src(ctx.pool(), 2*ne23); ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); + size_t src1_stride_size = sizeof(cuda_t); + dim3 block_dims(ne13, ne12); k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( - src0_f16, src1_f16, dst_t, + src0_ptr, src1_ptr, dst_t, ptrs_src.get(), ptrs_dst.get(), ne12, ne13, ne23, nb02, nb03, - src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half), - src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half), + (src1->type == src0_type) ? nb12 : s12*src1_stride_size, + (src1->type == src0_type) ? nb13 : s13*src1_stride_size, nbd2, nbd3, r2, r3); + CUDA_CHECK(cudaGetLastError()); CUBLAS_CHECK( cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00, - (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11, - beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0, + alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00, + (const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11, + beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0, ne23, cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } -#endif - if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) { - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream); + // Convert output back to F32 if needed + if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) { + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val); + to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream); + } +} + +static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32); + + switch (src0->type) { + case GGML_TYPE_F32: + ggml_cuda_mul_mat_batched_cublas_impl(ctx, src0, src1, dst); + break; + case GGML_TYPE_BF16: + ggml_cuda_mul_mat_batched_cublas_impl(ctx, src0, src1, dst); + break; + case GGML_TYPE_F16: + ggml_cuda_mul_mat_batched_cublas_impl(ctx, src0, src1, dst); + break; + default: + GGML_ABORT("Unsupported type"); } } @@ -1984,6 +2033,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); + //TODO update for generic tensor parallelism + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16); + bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc); + bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32; + if (!split && use_mul_mat_vec) { // the custom F16 vector kernel can be used over batched cuBLAS GEMM // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) @@ -1992,8 +2047,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst); } else if (!split && use_mul_mat_q) { ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst); - } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) && - !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32) + && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // general KQ + KQV multi-batch without FlashAttention ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else if (use_mul_mat_vec) { @@ -2248,6 +2303,27 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(dst)) { + case GGML_GLU_OP_REGLU: + ggml_cuda_op_reglu(ctx, dst); + break; + case GGML_GLU_OP_GEGLU: + ggml_cuda_op_geglu(ctx, dst); + break; + case GGML_GLU_OP_SWIGLU: + ggml_cuda_op_swiglu(ctx, dst); + break; + case GGML_GLU_OP_GEGLU_ERF: + ggml_cuda_op_geglu_erf(ctx, dst); + break; + case GGML_GLU_OP_GEGLU_QUICK: + ggml_cuda_op_geglu_quick(ctx, dst); + break; + default: + return false; + } + break; case GGML_OP_NORM: ggml_cuda_op_norm(ctx, dst); break; @@ -3041,6 +3117,18 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + return ggml_is_contiguous_1(op->src[0]); + default: + return false; + } + break; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { @@ -3112,6 +3200,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g switch (op->src[0]->type) { case GGML_TYPE_F16: case GGML_TYPE_F32: + case GGML_TYPE_BF16: + case GGML_TYPE_I32: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -3241,12 +3331,26 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_LOG: - case GGML_OP_SSM_SCAN: - case GGML_OP_SSM_CONV: return true; + case GGML_OP_SSM_SCAN: { + if (op->src[3]->ne[0] == 1) { + // Mamba2 + // (kernel only supports d_state == 128 && d_head % 16 == 0) + return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0; + } else { + // Mamba + // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1) + return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1; + } + } + case GGML_OP_SSM_CONV: { + // assumes d_inner % threads == 0 + return op->src[0]->ne[1] % 128 == 0; + } case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; case GGML_OP_DIAG_MASK_INF: + return true; case GGML_OP_SOFT_MAX: return true; case GGML_OP_SOFT_MAX_BACK: { @@ -3271,7 +3375,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_GROUP_NORM: return ggml_is_contiguous(op->src[0]); case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_PAD: case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: @@ -3295,6 +3398,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (op->src[0]->ne[0] == 192) { return false; } + // TODO: support broadcast + // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but + // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505 if (op->src[0]->ne[3] != 1) { return false; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 80baf459c1..9696a32046 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3016,14 +3016,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc); -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) - static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; - if (!shared_memory_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared)); - CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared)); - shared_memory_limit_raised[id] = true; - } -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x; diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 18f691b2d3..d058504cd6 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -50,21 +50,19 @@ static __global__ void rope_norm( const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - if (i0 >= n_dims) { - const int i = row_dst*ne0 + i0; - - dst[i + 0] = x[i + 0]; - dst[i + 1] = x[i + 1]; - - return; - } - const int row_x = row_dst % ne1; const int channel_x = row_dst / ne1; const int idst = row_dst*ne0 + i0; const int ix = channel_x*s2 + row_x*s1 + i0; + if (i0 >= n_dims) { + dst[idst + 0] = x[ix + 0]; + dst[idst + 1] = x[ix + 1]; + + return; + } + const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -94,21 +92,19 @@ static __global__ void rope_neox( const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - if (i0 >= n_dims) { - const int i = row_dst*ne0 + i0; - - dst[i + 0] = x[i + 0]; - dst[i + 1] = x[i + 1]; - - return; - } - const int row_x = row_dst % ne1; const int channel_x = row_dst / ne1; const int idst = row_dst*ne0 + i0/2; const int ix = channel_x*s2 + row_x*s1 + i0/2; + if (i0 >= n_dims) { + dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; + dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; + + return; + } + const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -138,21 +134,19 @@ static __global__ void rope_multi( const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - if (i0 >= n_dims) { - const int i = row_dst*ne0 + i0; - - dst[i + 0] = x[i + 0]; - dst[i + 1] = x[i + 1]; - - return; - } - const int row_x = row_dst % ne1; const int channel_x = row_dst / ne1; const int idst = row_dst*ne0 + i0/2; const int ix = channel_x*s2 + row_x*s1 + i0/2; + if (i0 >= n_dims) { + dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; + dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; + + return; + } + const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; const int sec_w = sections.v[1] + sections.v[0]; const int sector = (i0 / 2) % sect_dims; diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index aac6e09998..14543e978c 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -2,6 +2,7 @@ #include "ggml.h" #include "softmax.cuh" #include +#include template static __device__ __forceinline__ float t2f32(T val) { @@ -13,6 +14,29 @@ __device__ float __forceinline__ t2f32(half val) { return __half2float(val); } +struct soft_max_params { + + int64_t nheads; + uint32_t n_head_log2; + int64_t ncols; + int64_t nrows_x; + int64_t nrows_y; + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + int64_t nb11; + int64_t nb12; + int64_t nb13; + + int64_t ne12; + int64_t ne13; + float scale; + float max_bias; + float m0; + float m1; +}; + // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. // As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. #ifdef __clang__ @@ -21,16 +45,24 @@ __device__ float __forceinline__ t2f32(half val) { #endif // __clang__ template static __global__ void soft_max_f32( - const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, - const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { - const int ncols = ncols_template == 0 ? ncols_par : ncols_template; + const float * x, const T * mask, float * dst, const soft_max_params p) { + const int ncols = ncols_template == 0 ? p.ncols : ncols_template; const int tid = threadIdx.x; - const int rowx = blockIdx.x; - const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension + + const int64_t i03 = blockIdx.z; + const int64_t i02 = blockIdx.y; + const int64_t i01 = blockIdx.x; + + //TODO: noncontigous inputs/outputs + const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y; + + const int64_t i11 = i01; + const int64_t i12 = i02 % p.ne12; + const int64_t i13 = i03 % p.ne13; x += int64_t(rowx)*ncols; - mask += int64_t(rowy)*ncols * (mask != nullptr); + mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr); dst += int64_t(rowx)*ncols; const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; @@ -38,7 +70,7 @@ static __global__ void soft_max_f32( const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; - const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1); + const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1); extern __shared__ float data_soft_max_f32[]; float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication @@ -55,7 +87,7 @@ static __global__ void soft_max_f32( break; } - const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f); + const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -150,64 +182,58 @@ static __global__ void soft_max_back_f32( } } +template +static void launch_soft_max_kernels(const float * x, const T * mask, float * dst, + const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared) +{ + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + + auto launch_kernel = [=](auto I) -> bool { + constexpr int ncols = decltype(I)::value; + constexpr int block = (ncols > 1024 ? 1024 : ncols); + + if (p.ncols == ncols) { + CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32), smpbo); + soft_max_f32<<>> + (x, mask, dst, p); + return true; + } + return false; + }; + + // unary fold over launch_kernel + if ((launch_kernel(std::integral_constant{}) || ...)) { + return; + } + + //default case + CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32), smpbo); + soft_max_f32<<>>(x, mask, dst, p); +} + + template -static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) { int nth = WARP_SIZE; + const int64_t ncols_x = params.ncols; + while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); - const dim3 block_nums(nrows_x, 1, 1); + const dim3 block_nums(params.ne01, params.ne02, params.ne03); const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); - const uint32_t n_head = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; - // FIXME: this limit could be raised by ~2-4x on Ampere or newer - if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { - switch (ncols_x) { - case 32: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - case 64: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - case 128: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - case 256: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - case 512: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - case 1024: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - case 2048: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - case 4096: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - default: - soft_max_f32<<>> - (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); - break; - } + + if (nbytes_shared <= smpbo) { + launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared); } else { const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, params); } } @@ -235,10 +261,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional - const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); const int64_t nrows_y = src0->ne[1]; + const int64_t ne00 = src0->ne[0]; + float scale = 1.0f; float max_bias = 0.0f; @@ -247,10 +274,44 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; + + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + + soft_max_params params = {}; + params.nheads = src0->ne[2]; + params.n_head_log2 = n_head_log2; + params.ncols = ne00; + params.nrows_x = nrows_x; + params.nrows_y = nrows_y; + params.ne00 = src0->ne[0]; + params.ne01 = src0->ne[1]; + params.ne02 = src0->ne[2]; + params.ne03 = src0->ne[3]; + params.nb11 = nb11; + params.nb12 = nb12; + params.nb13 = nb13; + params.ne12 = ne12; + params.ne13 = ne13; + params.scale = scale; + params.max_bias = max_bias; + params.m0 = m0; + params.m1 = m1; + if (use_f16) { - soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream); } else { - soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream); } } diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index 2d34b83605..dc3b1a9a8c 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -4,16 +4,15 @@ template __global__ void __launch_bounds__(splitD, 2) ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, - const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2, - const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, - const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, - float * __restrict__ dst, const int64_t L) { - GGML_UNUSED(src1_nb0); - GGML_UNUSED(src2_nb0); + const int32_t * __restrict__ src6, float * __restrict__ dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, + const int src2_nb1, const int src2_nb2, const int src3_nb1, + const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, + const int64_t s_off, const int64_t d_inner, const int64_t L) { constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - const int bidx = blockIdx.x; // split along B - const int bidy = blockIdx.y; // split along D + const int bidx = blockIdx.x; // split along B (sequences) + const int bidy = blockIdx.y; // split along D (d_inner) const int tid = threadIdx.x; const int wid = tid / 32; const int wtid = tid % 32; @@ -24,23 +23,23 @@ __global__ void __launch_bounds__(splitD, 2) float * smem_A = smem; float * smem_s0 = smem_A + splitD * stride_sA; - const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1); - const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); + const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2); + const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float)); const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float)); const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1); - const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2)); - const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2)); - float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float)); - float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1); + const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3)); + const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3)); + float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float)); + float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2); - const int stride_s0 = src0_nb1 / sizeof(float); - const int stride_x = src1_nb1 / sizeof(float); + const int stride_s0 = src0_nb2 / sizeof(float); + const int stride_x = src1_nb2 / sizeof(float); const int stride_dt = src2_nb1 / sizeof(float); const int stride_A = src3_nb1 / sizeof(float); - const int stride_B = src4_nb1 / sizeof(float); - const int stride_C = src5_nb1 / sizeof(float); + const int stride_B = src4_nb2 / sizeof(float); + const int stride_C = src5_nb2 / sizeof(float); const int stride_s = stride_s0; - const int stride_y = stride_x; + const int stride_y = d_inner; // can N not be 16? for example 32? if (N == 16) { @@ -84,24 +83,156 @@ __global__ void __launch_bounds__(splitD, 2) } } +// assumes as many threads as d_state +template +__global__ void __launch_bounds__(d_state, 1) + ssm_scan_f32_group( + const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, + const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, + const int32_t * __restrict__ src6, float * __restrict__ dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, + const int src2_nb1, const int src2_nb2, const int src3_nb1, + const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, + const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) { + + const int head_idx = (blockIdx.x * splitH) / d_head; + const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float); + const int seq_idx = blockIdx.y; + + const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float); + + const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float)); + const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); + const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1); + const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); + const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); + float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH; + float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + + // strides across n_seq_tokens + const int stride_x = src1_nb2 / sizeof(float); + const int stride_dt = src2_nb1 / sizeof(float); + const int stride_B = src4_nb2 / sizeof(float); + const int stride_C = src5_nb2 / sizeof(float); + const int stride_y = n_head * d_head; + + float state[splitH]; + // for the parallel accumulation + __shared__ float stateC[splitH * d_state]; + +#pragma unroll + for (int j = 0; j < splitH; j++) { + state[j] = s0_block[j * d_state + threadIdx.x]; + } + + for (int64_t i = 0; i < n_tok; i++) { + // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements + // TODO: only calculate B and C once per head group + // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here. + float dt_soft_plus = dt_block[i * stride_dt]; + if (dt_soft_plus <= 20.0f) { + dt_soft_plus = log1pf(expf(dt_soft_plus)); + } + const float dA = expf(dt_soft_plus * A_block[0]); + const float B = B_block[i * stride_B + threadIdx.x]; + const float C = C_block[i * stride_C + threadIdx.x]; + + // across d_head +#pragma unroll + for (int j = 0; j < splitH; j++) { + const float x_dt = x_block[i * stride_x + j] * dt_soft_plus; + + state[j] = (state[j] * dA) + (B * x_dt); + + stateC[j * d_state + threadIdx.x] = state[j] * C; + } + + __syncthreads(); + + // parallel accumulation for stateC + // TODO: simplify + { + static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2"); + static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2"); + + // reduce until w matches the warp size + // TODO: does this work even when the physical warp size is 64? +#pragma unroll + for (int w = d_state; w > WARP_SIZE; w >>= 1) { + // (assuming there are d_state threads) +#pragma unroll + for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) { + // TODO: check for bank conflicts + const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1)); + stateC[k] += stateC[k + (w >> 1)]; + + } + __syncthreads(); + } + + static_assert(splitH >= d_state / WARP_SIZE); + +#pragma unroll + for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) { + float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)]; + y = warp_reduce_sum(y); + + // store the above accumulations + if (threadIdx.x % WARP_SIZE == 0) { + const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE); + y_block[i * stride_y + k] = y; + } + } + } + } + + // write back the state +#pragma unroll + for (int j = 0; j < splitH; j++) { + s_block[j * d_state + threadIdx.x] = state[j]; + } +} + static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3, - const float * src4, const float * src5, const int src0_nb1, const int src0_nb2, - const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3, - const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1, - const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2, - float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B, + const float * src4, const float * src5, const int32_t * src6, float * dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, + const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, + const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, + const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, cudaStream_t stream) { const int threads = 128; - // todo: consider D cannot be divided,does this situation exist? - GGML_ASSERT(D % threads == 0); - const dim3 blocks(B, (D + threads - 1) / threads, 1); - const int smem_size = (threads * (N + 1) * 2) * sizeof(float); - if (N == 16) { - ssm_scan_f32<128, 16><<>>( - src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0, - src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L); + // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! + if (src3_nb1 == sizeof(float)) { + // Mamba-2 + if (d_state == 128) { + GGML_ASSERT(d_state % threads == 0); + // NOTE: can be any power of two between 4 and 64 + const int splitH = 16; + GGML_ASSERT(head_dim % splitH == 0); + const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1); + ssm_scan_f32_group<16, 128><<>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, + src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); + } else { + GGML_ABORT("doesn't support d_state!=128."); + } } else { - GGML_ABORT("doesn't support N!=16."); + // Mamba-1 + GGML_ASSERT(n_head % threads == 0); + GGML_ASSERT(head_dim == 1); + GGML_ASSERT(n_group == 1); + const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1); + const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float); + if (d_state == 16) { + ssm_scan_f32<128, 16><<>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + } else { + GGML_ABORT("doesn't support d_state!=16."); + } } } @@ -112,30 +243,25 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const struct ggml_tensor * src3 = dst->src[3]; // A const struct ggml_tensor * src4 = dst->src[4]; // B const struct ggml_tensor * src5 = dst->src[5]; // C - - // const int64_t d_state = src0->ne[0]; - // const int64_t d_inner = src0->ne[1]; - // const int64_t l = src1->ne[1]; - // const int64_t b = src0->ne[2]; + const struct ggml_tensor * src6 = dst->src[6]; // ids const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens per sequence - const int64_t n_s = src0->ne[2]; // number of sequences in the batch + const int64_t nr = src0->ne[1]; // head_dim or 1 + const int64_t nh = src1->ne[1]; // n_head + const int64_t ng = src4->ne[1]; // n_group + const int64_t n_t = src1->ne[2]; // number of tokens per sequence + const int64_t n_s = src1->ne[3]; // number of sequences in the batch - GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + const int64_t s_off = ggml_nelements(src1) * sizeof(float); + + GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - // required for the dot product between s and C - GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); - // required for per-sequence offsets for states - GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[3]) - GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); const float * src0_d = (const float *) src0->data; const float * src1_d = (const float *) src1->data; @@ -143,13 +269,16 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const float * src3_d = (const float *) src3->data; const float * src4_d = (const float *) src4->data; const float * src5_d = (const float *) src5->data; + const int32_t * src6_d = (const int32_t *) src6->data; float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src6->type == GGML_TYPE_I32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0], - src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1], - src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream); + ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d, + src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2], + src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3], + s_off, nc, nr, nh, ng, n_t, n_s, stream); } diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 2c0375fbe3..f9c7b83c40 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -196,6 +196,103 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } +/* gated ops */ + +template +static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) { + const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + // perform base op and multiply with gate (either offset in same tensor or a separate one) + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + + dst[i] = (T)(op((float)x[j0]) * (float)g[j1]); +} + +template +static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) { + const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; + unary_gated_op_kernel<<>>(x, g, dst, k, n, o0, o1); +} + +template +void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + void * src0_d = src0->data; + void * src1_d = src1 ? src1->data : src0->data; + const int64_t src0_o = src0->nb[1]; + const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + void * dst_d = dst->data; + const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(src0->nb[0] == ggml_element_size(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src1->nb[0] == ggml_element_size(src1)); + GGML_ASSERT(src1->ne[0] == nc); + GGML_ASSERT(src0->type == src1->type); + } + + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; + + if (src0->type == GGML_TYPE_F16) { + half * src0_p = (half *) src0_d; + half * src1_p = (half *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + unary_gated_cuda(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), stream); + } else { + float * src0_p = (float *) src0_d; + float * src1_p = (float *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + unary_gated_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), stream); + } +} + +void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + +void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + +void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + +void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + +void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated(ctx, dst); +} + /* silu_back */ static __device__ __forceinline__ float op_silu_back(float grad, float x) { diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 6686fc17e9..289d690e5c 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -15,6 +15,7 @@ #define CUDA_SQRT_BLOCK_SIZE 256 #define CUDA_SIN_BLOCK_SIZE 256 #define CUDA_COS_BLOCK_SIZE 256 +#define CUDA_GLU_BLOCK_SIZE 256 void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst); @@ -57,3 +58,13 @@ void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/upscale.cu b/ggml/src/ggml-cuda/upscale.cu index 524e979574..ef48aa5f97 100644 --- a/ggml/src/ggml-cuda/upscale.cu +++ b/ggml/src/ggml-cuda/upscale.cu @@ -22,17 +22,88 @@ static __global__ void upscale_f32(const float * x, float * dst, dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) ); } +static __global__ void upscale_f32_bilinear(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + const int64_t index = threadIdx.x + blockIdx.x * blockDim.x; + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; + int y0_src = (int)floorf(y_src_f); + int y1_src = y0_src + 1; + + y0_src = max(0, min(y0_src, ne01_src - 1)); + y1_src = max(0, min(y1_src, ne01_src - 1)); + + float dy = y_src_f - (float)y0_src; + dy = max(0.0f, min(dy, 1.0f)); + + float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset; + int x0_src = (int)floorf(x_src_f); + int x1_src = x0_src + 1; + + x0_src = max(0, min(x0_src, ne00_src - 1)); + x1_src = max(0, min(x1_src, ne00_src - 1)); + + float dx = x_src_f - (float)x0_src; + dx = max(0.0f, min(dx, 1.0f)); + + const float * p_a = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03); + const float * p_b = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03); + const float * p_c = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03); + const float * p_d = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03); + + const float val_a = *p_a; + const float val_b = *p_b; + const float val_c = *p_c; + const float val_d = *p_d; + + float result = val_a * (1.0f - dx) * (1.0f - dy) + + val_b * dx * (1.0f - dy) + + val_c * (1.0f - dx) * dy + + val_d * dx * dy; + + dst[index] = result; +} + static void upscale_f32_cuda(const float * x, float * dst, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int ne13, const float sf0, const float sf1, const float sf2, const float sf3, cudaStream_t stream) { - int dst_size = ne10 * ne11 * ne12 * ne13; - int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; + const int64_t dst_size = ne10 * ne11 * ne12 * ne13; + const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; upscale_f32<<>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3); } +static void upscale_f32_bilinear_cuda(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset, cudaStream_t stream) { + const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; + + upscale_f32_bilinear<<>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); +} + void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -42,10 +113,25 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - const float sf0 = (float)dst->ne[0]/src0->ne[0]; - const float sf1 = (float)dst->ne[1]/src0->ne[1]; - const float sf2 = (float)dst->ne[2]/src0->ne[2]; + const int mode_flags = dst->op_params[0]; + const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF); + + float sf0 = (float)dst->ne[0]/src0->ne[0]; + float sf1 = (float)dst->ne[1]/src0->ne[1]; + float sf2 = (float)dst->ne[2]/src0->ne[2]; const float sf3 = (float)dst->ne[3]/src0->ne[3]; - upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); + if (mode == GGML_SCALE_MODE_NEAREST) { + upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + float pixel_offset = 0.5f; + if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { + sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1); + sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1); + pixel_offset = 0.0f; + } + upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + sf0, sf1, sf2, sf3, pixel_offset, stream); + } } diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 57761644f4..4972558c98 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -301,6 +301,7 @@ struct ggml_cgraph { struct ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes struct ggml_tensor ** grad_accs; // accumulators for node gradients struct ggml_tensor ** leafs; // tensors with constant data + int32_t * use_counts;// number of uses of each tensor, indexed by hash table slot struct ggml_hash_set visited_hash_set; @@ -467,13 +468,76 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) { #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x) #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x) +// return true if the node's results are only used by N other nodes +// and can be fused into their calculations. +static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) { + const struct ggml_tensor * node = cgraph->nodes[node_idx]; + + // check the use count against how many we're replacing + size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); + if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) { + return false; + } + + // if node is a view, some other node might be using the intermediate result + // via the view source. + if (node->view_src) { + return false; + } + + // If the user requested output for the node, can't fuse + if (node->flags & GGML_TENSOR_FLAG_OUTPUT) { + return false; + } + + return true; +} + +// Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[] +// and are fusable. Nodes are considered fusable according to this function if: +// - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses). +// - all nodes except the last are a src of the following node. +// - all nodes are the same shape. +// TODO: Consider allowing GGML_OP_NONE nodes in between +static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) { + if (node_idx + num_ops > cgraph->n_nodes) { + return false; + } + + for (int i = 0; i < num_ops; ++i) { + struct ggml_tensor * node = cgraph->nodes[node_idx + i]; + if (node->op != ops[i]) { + return false; + } + if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idx + i, 1)) { + return false; + } + if (i > 0) { + struct ggml_tensor * prev = cgraph->nodes[node_idx + i - 1]; + if (node->src[0] != prev && node->src[1] != prev) { + return false; + } + if (!ggml_are_same_shape(node, prev)) { + return false; + } + } + } + return true; +} + #ifdef __cplusplus } #endif #ifdef __cplusplus +#include #include +// nicer C++ syntax for ggml_can_fuse +inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) { + return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size()); +} + // expose GGUF internals for test code GGML_API size_t gguf_type_size(enum gguf_type type); GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params); diff --git a/ggml/src/ggml-kompute/CMakeLists.txt b/ggml/src/ggml-kompute/CMakeLists.txt deleted file mode 100644 index c9109d5e8e..0000000000 --- a/ggml/src/ggml-kompute/CMakeLists.txt +++ /dev/null @@ -1,166 +0,0 @@ - -find_package(Vulkan COMPONENTS glslc REQUIRED) -find_program(glslc_executable NAMES glslc HINTS Vulkan::glslc) - -if (NOT glslc_executable) - message(FATAL_ERROR "glslc not found") -endif() - -ggml_add_backend_library(ggml-kompute - ggml-kompute.cpp - ../../include/ggml-kompute.h - ) - -target_link_libraries(ggml-kompute PRIVATE ggml-base kompute) -target_include_directories(ggml-kompute PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) - -add_compile_definitions(VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1) - -function(compile_shader) - set(options) - set(oneValueArgs) - set(multiValueArgs SOURCES) - cmake_parse_arguments(compile_shader "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - foreach(source ${compile_shader_SOURCES}) - get_filename_component(filename ${source} NAME) - set(spv_file ${filename}.spv) - add_custom_command( - OUTPUT ${spv_file} - DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${source} - ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/common.comp - ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_getrows.comp - ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n_pre.comp - ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n.comp - COMMAND ${glslc_executable} --target-env=vulkan1.2 -o ${spv_file} ${CMAKE_CURRENT_SOURCE_DIR}/${source} - COMMENT "Compiling ${source} to ${spv_file}" - ) - - get_filename_component(RAW_FILE_NAME ${spv_file} NAME) - set(FILE_NAME "shader${RAW_FILE_NAME}") - string(REPLACE ".comp.spv" ".h" HEADER_FILE ${FILE_NAME}) - string(TOUPPER ${HEADER_FILE} HEADER_FILE_DEFINE) - string(REPLACE "." "_" HEADER_FILE_DEFINE "${HEADER_FILE_DEFINE}") - set(OUTPUT_HEADER_FILE "${HEADER_FILE}") - message(STATUS "${HEADER_FILE} generating ${HEADER_FILE_DEFINE}") - if(CMAKE_GENERATOR MATCHES "Visual Studio") - add_custom_command( - OUTPUT ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_BINARY_DIR}/bin/$/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} - DEPENDS ${spv_file} xxd - COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/$/xxd" - ) - else() - add_custom_command( - OUTPUT ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_BINARY_DIR}/bin/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} - DEPENDS ${spv_file} xxd - COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/xxd" - ) - endif() - endforeach() -endfunction() - -if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt") - message(STATUS "Kompute found") - set(KOMPUTE_OPT_LOG_LEVEL Error CACHE STRING "Kompute log level") - add_subdirectory(kompute) - - # Compile our shaders - compile_shader(SOURCES - kompute-shaders/op_scale.comp - kompute-shaders/op_scale_8.comp - kompute-shaders/op_add.comp - kompute-shaders/op_addrow.comp - kompute-shaders/op_mul.comp - kompute-shaders/op_silu.comp - kompute-shaders/op_relu.comp - kompute-shaders/op_gelu.comp - kompute-shaders/op_softmax.comp - kompute-shaders/op_norm.comp - kompute-shaders/op_rmsnorm.comp - kompute-shaders/op_diagmask.comp - kompute-shaders/op_mul_mat_mat_f32.comp - kompute-shaders/op_mul_mat_f16.comp - kompute-shaders/op_mul_mat_q8_0.comp - kompute-shaders/op_mul_mat_q4_0.comp - kompute-shaders/op_mul_mat_q4_1.comp - kompute-shaders/op_mul_mat_q4_k.comp - kompute-shaders/op_mul_mat_q6_k.comp - kompute-shaders/op_getrows_f32.comp - kompute-shaders/op_getrows_f16.comp - kompute-shaders/op_getrows_q4_0.comp - kompute-shaders/op_getrows_q4_1.comp - kompute-shaders/op_getrows_q6_k.comp - kompute-shaders/op_rope_norm_f16.comp - kompute-shaders/op_rope_norm_f32.comp - kompute-shaders/op_rope_neox_f16.comp - kompute-shaders/op_rope_neox_f32.comp - kompute-shaders/op_cpy_f16_f16.comp - kompute-shaders/op_cpy_f16_f32.comp - kompute-shaders/op_cpy_f32_f16.comp - kompute-shaders/op_cpy_f32_f32.comp - ) - - # Create a custom target for our generated shaders - add_custom_target(generated_shaders DEPENDS - shaderop_scale.h - shaderop_scale_8.h - shaderop_add.h - shaderop_addrow.h - shaderop_mul.h - shaderop_silu.h - shaderop_relu.h - shaderop_gelu.h - shaderop_softmax.h - shaderop_norm.h - shaderop_rmsnorm.h - shaderop_diagmask.h - shaderop_mul_mat_mat_f32.h - shaderop_mul_mat_f16.h - shaderop_mul_mat_q8_0.h - shaderop_mul_mat_q4_0.h - shaderop_mul_mat_q4_1.h - shaderop_mul_mat_q4_k.h - shaderop_mul_mat_q6_k.h - shaderop_getrows_f32.h - shaderop_getrows_f16.h - shaderop_getrows_q4_0.h - shaderop_getrows_q4_1.h - shaderop_getrows_q6_k.h - shaderop_rope_norm_f16.h - shaderop_rope_norm_f32.h - shaderop_rope_neox_f16.h - shaderop_rope_neox_f32.h - shaderop_cpy_f16_f16.h - shaderop_cpy_f16_f32.h - shaderop_cpy_f32_f16.h - shaderop_cpy_f32_f32.h - ) - - # Create a custom command that depends on the generated_shaders - add_custom_command( - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp - COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp - DEPENDS generated_shaders - COMMENT "Ensuring shaders are generated before compiling ggml-kompute.cpp" - ) - - # Add the stamp to the main sources to ensure dependency tracking - target_sources(ggml-kompute PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp) -else() - message(WARNING "Kompute not found") -endif() diff --git a/ggml/src/ggml-kompute/ggml-kompute.cpp b/ggml/src/ggml-kompute/ggml-kompute.cpp deleted file mode 100644 index 5057922718..0000000000 --- a/ggml/src/ggml-kompute/ggml-kompute.cpp +++ /dev/null @@ -1,2251 +0,0 @@ -#include "ggml-impl.h" -#include "ggml-backend.h" -#include "ggml-backend-impl.h" -#include "ggml-kompute.h" - -// These are generated at build time by cmake custom command -#include "shaderop_scale.h" -#include "shaderop_scale_8.h" -#include "shaderop_add.h" -#include "shaderop_addrow.h" -#include "shaderop_mul.h" -#include "shaderop_silu.h" -#include "shaderop_relu.h" -#include "shaderop_gelu.h" -#include "shaderop_softmax.h" -#include "shaderop_norm.h" -#include "shaderop_rmsnorm.h" -#include "shaderop_diagmask.h" -#include "shaderop_mul_mat_f16.h" -#include "shaderop_mul_mat_q8_0.h" -#include "shaderop_mul_mat_q4_0.h" -#include "shaderop_mul_mat_q4_1.h" -#include "shaderop_mul_mat_q4_k.h" -#include "shaderop_mul_mat_q6_k.h" -#include "shaderop_mul_mat_mat_f32.h" -#include "shaderop_getrows_f32.h" -#include "shaderop_getrows_f16.h" -#include "shaderop_getrows_q4_0.h" -#include "shaderop_getrows_q4_1.h" -#include "shaderop_getrows_q6_k.h" -#include "shaderop_rope_norm_f16.h" -#include "shaderop_rope_norm_f32.h" -#include "shaderop_rope_neox_f16.h" -#include "shaderop_rope_neox_f32.h" -#include "shaderop_cpy_f16_f16.h" -#include "shaderop_cpy_f16_f32.h" -#include "shaderop_cpy_f32_f16.h" -#include "shaderop_cpy_f32_f32.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#ifdef __linux__ -#include // for setenv -#endif - -#define QK4_0 32 -#define QR4_0 2 -#define QK4_1 32 -#define QK_NL 16 - -typedef ggml_fp16_t half; - -static std::string ggml_kompute_format_name(int device) { - return "Kompute" + std::to_string(device); -} - -struct ggml_kompute_context { - int device; - std::string name; - std::shared_ptr pool; - - ggml_kompute_context(int device) - : device(device), name(ggml_kompute_format_name(device)) {} -}; - -// FIXME: It would be good to consolidate the kompute manager and the kompute context into one object -// and consolidate the init functions and simplify object lifetime management. As it currently stands, -// we *have* to have the kompute manager no matter what for device discovery, but the kompute context -// is only created when a device is set and vulkan is explicitly turned on. -static ggml_kompute_context *s_kompute_context = nullptr; - -class kompute_manager { - kp::Manager *s_mgr = nullptr; - -public: - kp::Manager *operator()() { - if (s_mgr && !s_mgr->hasInstance()) { - destroy(); - } - if (!s_mgr) { - s_mgr = new kp::Manager; - } - return s_mgr; - } - - void destroy() { - delete s_mgr; - s_mgr = nullptr; - } -}; - -static kompute_manager komputeManager; - -struct ggml_vk_memory { - void *data = nullptr; - size_t size = 0; - vk::DeviceMemory *primaryMemory = nullptr; - vk::Buffer *primaryBuffer = nullptr; - vk::DeviceMemory *stagingMemory = nullptr; - vk::Buffer *stagingBuffer = nullptr; -}; - -#ifdef __linux__ -__attribute__((constructor)) -static void enable_sam() { - setenv("RADV_PERFTEST", "sam", false); -} -#endif - -static bool ggml_vk_checkPhysicalDeviceFeatures(vk::PhysicalDevice physical_device) { - vk::PhysicalDeviceFeatures availableFeatures; - physical_device.getFeatures(&availableFeatures); - - if (!availableFeatures.shaderInt16) - return false; - - vk::PhysicalDeviceVulkan11Features availableFeatures11; - vk::PhysicalDeviceVulkan12Features availableFeatures12; - - availableFeatures11.pNext = &availableFeatures12; - availableFeatures12.pNext = nullptr; - - vk::PhysicalDeviceFeatures2 features2; - features2.pNext = &availableFeatures11; - - physical_device.getFeatures2(&features2); - - if (!availableFeatures11.uniformAndStorageBuffer16BitAccess || - !availableFeatures11.storageBuffer16BitAccess) { - return false; - } - - if (!availableFeatures12.storageBuffer8BitAccess || - !availableFeatures12.uniformAndStorageBuffer8BitAccess || - !availableFeatures12.shaderFloat16 || - !availableFeatures12.shaderInt8) { - return false; - } - - return true; -} - -static const char * ggml_vk_getVendorName(uint32_t vendorID) { - switch (vendorID) { - case 0x10DE: - return "nvidia"; - case 0x1002: - return "amd"; - case 0x8086: - return "intel"; - default: - return "unknown"; - } -} - -static std::vector ggml_vk_available_devices_internal(size_t memoryRequired) { - std::vector results; - if (!komputeManager()->hasVulkan() || !komputeManager()->hasInstance()) - return results; - - std::vector physical_devices; - try { - physical_devices = komputeManager()->listDevices(); - } catch (vk::SystemError & err) { - std::cerr << __func__ << ": ignoring Vulkan exception: " << err.what() << "\n"; - return results; - } - - uint32_t deviceCount = physical_devices.size(); - if (deviceCount == 0) - return results; - - std::unordered_map count_by_name; - - for (uint32_t i = 0; i < deviceCount; i++) { - const auto & physical_device = physical_devices[i]; - - VkPhysicalDeviceProperties dev_props = physical_device.getProperties(); - VkPhysicalDeviceMemoryProperties memoryProperties = physical_device.getMemoryProperties(); - const uint32_t major = VK_VERSION_MAJOR(dev_props.apiVersion); - const uint32_t minor = VK_VERSION_MINOR(dev_props.apiVersion); - if (major < 1 || minor < 2) - continue; - - if (!ggml_vk_checkPhysicalDeviceFeatures(physical_device)) - continue; - - size_t heapSize = 0; - for (uint32_t j = 0; j < memoryProperties.memoryHeapCount; ++j) { - VkMemoryHeap heap = memoryProperties.memoryHeaps[j]; - if (heap.flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) { - heapSize = heap.size; - break; - } - } - - if (heapSize < memoryRequired) - continue; - - auto ext_props = physical_device.enumerateDeviceExtensionProperties(); - bool has_maintenance4 = false; - - // Check if maintenance4 is supported - for (const auto & properties : ext_props) { - if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { - has_maintenance4 = true; - } - } - - vk::PhysicalDeviceSubgroupProperties subgroup_props; - vk::PhysicalDeviceProperties2 dev_props2; - vk::PhysicalDeviceMaintenance3Properties dev_props3; - vk::PhysicalDeviceMaintenance4Properties dev_props4; - dev_props2.pNext = &dev_props3; - dev_props3.pNext = &subgroup_props; - if (has_maintenance4) { - subgroup_props.pNext = &dev_props4; - } - physical_device.getProperties2(&dev_props2); - - if (subgroup_props.subgroupSize < 32) - continue; - - ggml_vk_device d; - d.index = i; - d.type = dev_props.deviceType; - d.heapSize = heapSize; - d.vendor = strdup(ggml_vk_getVendorName(dev_props.vendorID)); - d.subgroupSize = subgroup_props.subgroupSize; - d.bufferAlignment = dev_props.limits.minStorageBufferOffsetAlignment; - - if (has_maintenance4) { - d.maxAlloc = std::min(dev_props3.maxMemoryAllocationSize, dev_props4.maxBufferSize); - } else { - d.maxAlloc = dev_props3.maxMemoryAllocationSize; - } - - std::string name(dev_props.deviceName); - size_t n_idx = ++count_by_name[name]; - if (n_idx > 1) { - name += " (" + std::to_string(n_idx) + ")"; - } - d.name = strdup(name.c_str()); - - results.push_back(d); - } - - std::stable_sort(results.begin(), results.end(), - [](const ggml_vk_device& lhs, const ggml_vk_device& rhs) -> bool { - if (lhs.type != rhs.type) { - if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return true; - if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_DISCRETE_GPU) return false; - - if (lhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return true; - if (rhs.type == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU) return false; - } - return lhs.heapSize < rhs.heapSize; - } - ); - - return results; -} - -static std::vector& ggml_vk_available_devices() { - static std::vector devices = ggml_vk_available_devices_internal(0); - return devices; -} - -static void ggml_vk_filterByVendor(std::vector& devices, const std::string& targetVendor) { - devices.erase( - std::remove_if(devices.begin(), devices.end(), - [&targetVendor](const ggml_vk_device& device) { - return device.vendor != targetVendor; - }), - devices.end() - ); -} - -static void ggml_vk_filterByName(std::vector& devices, const std::string& targetName) { - devices.erase( - std::remove_if(devices.begin(), devices.end(), - [&targetName](const ggml_vk_device& device) { - return device.name != targetName; - }), - devices.end() - ); -} - -static bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const std::string & name) { - if (name.empty()) - return false; - - auto devices = ggml_vk_available_devices_internal(memoryRequired); - if (name == "amd" || name == "nvidia" || name == "intel") { - ggml_vk_filterByVendor(devices, name); - } else if (name != "gpu") { - ggml_vk_filterByName(devices, name); - } - - if (devices.empty()) - return false; - - *device = devices.front(); - return true; -} - -bool ggml_vk_get_device(ggml_vk_device * device, size_t memoryRequired, const char * name) { - return ggml_vk_get_device(device, memoryRequired, std::string(name)); -} - -bool ggml_vk_has_vulkan() { - return komputeManager()->hasVulkan(); -} - -bool ggml_vk_has_device() { - return komputeManager()->hasDevice(); -} - -ggml_vk_device ggml_vk_current_device() { - if (!komputeManager()->hasDevice()) - return ggml_vk_device(); - - auto devices = ggml_vk_available_devices(); - ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data()); - GGML_ASSERT(!devices.empty()); - return devices.front(); -} - -static -void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t size) { - std::vector descriptorPoolSizes = { - vk::DescriptorPoolSize( - vk::DescriptorType::eStorageBuffer, - 4 * size // Descriptor count is number of possible tensors to pass into an algorithm - ) - }; - - vk::DescriptorPoolCreateInfo descriptorPoolInfo( - vk::DescriptorPoolCreateFlags(), - size, // Max sets - static_cast(descriptorPoolSizes.size()), - descriptorPoolSizes.data()); - - ctx->pool = std::make_shared(); - vk::Result r = komputeManager()->device()->createDescriptorPool( - &descriptorPoolInfo, nullptr, ctx->pool.get()); - if (r != vk::Result::eSuccess) - std::cerr << "Error allocating descriptor pool" << vk::to_string(r); -} - -static -void ggml_vk_free_descriptor_pool(struct ggml_kompute_context * ctx) { - if (ctx->pool) { - komputeManager()->device()->destroy( - *ctx->pool, - (vk::Optional)nullptr); - ctx->pool = nullptr; - } -} - -static -vk::Buffer *ggml_vk_allocate_buffer(size_t size) { - vk::BufferCreateInfo bufferCreateInfo; - bufferCreateInfo.size = size; - bufferCreateInfo.usage = vk::BufferUsageFlagBits::eStorageBuffer | - vk::BufferUsageFlagBits::eTransferSrc | - vk::BufferUsageFlagBits::eTransferDst; - bufferCreateInfo.sharingMode = vk::SharingMode::eExclusive; - - vk::Buffer *vkBuffer = new vk::Buffer; - vk::Result r = komputeManager()->device()->createBuffer(&bufferCreateInfo, nullptr, vkBuffer); - if (r != vk::Result::eSuccess) - std::cerr << "Error allocating buffer " << vk::to_string(r) << std::endl; - return vkBuffer; -} - -static -vk::DeviceMemory *ggml_vk_allocate(size_t size, vk::MemoryPropertyFlags flags, vk::MemoryRequirements requirements, bool *isHostVisible) { - - uint32_t memoryTypeIndex = -1; - bool memoryTypeIndexFound = false; - vk::PhysicalDeviceMemoryProperties memoryProperties = komputeManager()->physicalDevice()->getMemoryProperties(); - for (uint32_t i = 0; i < memoryProperties.memoryTypeCount; i++) { - const vk::MemoryType &memoryType = memoryProperties.memoryTypes[i]; - const vk::MemoryHeap &memoryHeap = memoryProperties.memoryHeaps[memoryType.heapIndex]; - if (memoryHeap.size < size) { - continue; - } - - if (requirements.memoryTypeBits & (1 << i)) { - if (((memoryProperties.memoryTypes[i]).propertyFlags & - flags) == flags) { - memoryTypeIndex = i; - memoryTypeIndexFound = true; - if (isHostVisible && (memoryProperties.memoryTypes[i].propertyFlags & vk::MemoryPropertyFlagBits::eHostVisible)) { - *isHostVisible = true; - } - break; - } - } - } - if (!memoryTypeIndexFound) { - throw std::runtime_error( - "Memory type index for buffer creation not found"); - } - - vk::MemoryAllocateInfo allocInfo; - allocInfo.allocationSize = size; - allocInfo.memoryTypeIndex = memoryTypeIndex; - vk::DeviceMemory *vkDeviceMemory = new vk::DeviceMemory; - vk::Result r = komputeManager()->device()->allocateMemory(&allocInfo, nullptr, vkDeviceMemory); - if (r != vk::Result::eSuccess) { - std::cerr << "Error allocating memory " << vk::to_string(r) << std::endl; - throw std::runtime_error("Error allocating vulkan memory."); - } - return vkDeviceMemory; -} - -static size_t ggml_vk_aligned_offset(ggml_backend_buffer_t buffer, size_t offset) { - size_t minStorageBufferOffsetAlignment = ggml_backend_buffer_get_alignment(buffer); - - // If offset is already aligned, return it directly - if (offset % minStorageBufferOffsetAlignment == 0) { - return offset; - } - - // Otherwise, return the largest multiple of minStorageBufferOffsetAlignment less than offset - return (offset / minStorageBufferOffsetAlignment) * minStorageBufferOffsetAlignment; -} - -static ggml_vk_memory ggml_vk_allocate(size_t size) { - ggml_vk_memory memory; - bool isHostVisible = false; - { - memory.primaryBuffer = ggml_vk_allocate_buffer(size); - vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.primaryBuffer); - vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eDeviceLocal; - memory.primaryMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible); - komputeManager()->device()->bindBufferMemory(*memory.primaryBuffer, *memory.primaryMemory, 0); - if (isHostVisible) { - vk::Result r = komputeManager()->device()->mapMemory(*memory.primaryMemory, 0, size, vk::MemoryMapFlags(), &memory.data); - if (r != vk::Result::eSuccess) - std::cerr << "Error mapping memory" << vk::to_string(r); - } - } - - if (!isHostVisible) { - memory.stagingBuffer = ggml_vk_allocate_buffer(size); - vk::MemoryRequirements memoryRequirements = komputeManager()->device()->getBufferMemoryRequirements(*memory.stagingBuffer); - vk::MemoryPropertyFlags memoryPropertyFlags = vk::MemoryPropertyFlagBits::eHostVisible | - vk::MemoryPropertyFlagBits::eHostCoherent | - vk::MemoryPropertyFlagBits::eHostCached; - memory.stagingMemory = ggml_vk_allocate(size, memoryPropertyFlags, memoryRequirements, &isHostVisible); - komputeManager()->device()->bindBufferMemory(*memory.stagingBuffer, *memory.stagingMemory, 0); - vk::Result r = komputeManager()->device()->mapMemory(*memory.stagingMemory, 0, size, vk::MemoryMapFlags(), &memory.data); - if (r != vk::Result::eSuccess) - std::cerr << "Error mapping memory" << vk::to_string(r); - } - - memory.size = size; - return memory; -} - -static void ggml_vk_free_memory(ggml_vk_memory &memory) -{ - komputeManager()->device()->destroy( - *memory.primaryBuffer, - (vk::Optional)nullptr); - if (memory.stagingBuffer) { - komputeManager()->device()->destroy( - *memory.stagingBuffer, - (vk::Optional)nullptr); - } - komputeManager()->device()->freeMemory( - *memory.primaryMemory, - (vk::Optional)nullptr); - if (memory.stagingMemory) { - komputeManager()->device()->freeMemory( - *memory.stagingMemory, - (vk::Optional)nullptr); - } -} - -static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft); - -static -ggml_vk_memory * ggml_vk_find_tensor(const struct ggml_tensor * t, uint64_t & offset) { - ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer; - - // compatibility with ggml-backend - GGML_ASSERT(buffer && buffer->buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name); - - ggml_vk_memory * buf_ctx = static_cast(buffer->context); - - const intptr_t ioffs = intptr_t(t->data) - intptr_t(buf_ctx->data); - - GGML_ASSERT(ioffs >= 0 && ioffs + int64_t(ggml_nbytes(t)) <= int64_t(buffer->size)); - - offset = uint64_t(ioffs); - return buf_ctx; -} - -static -const std::shared_ptr ggml_vk_get_tensor(const struct ggml_tensor * t, uint32_t * alignedOffset = nullptr) { - uint64_t originalOffset = 0; - auto * res = ggml_vk_find_tensor(t, originalOffset); - if (!res) { - static std::shared_ptr nullTensor = nullptr; - return nullTensor; - } - - // Create a tensor whose memory will be composed of our buffers at the correct offset - const size_t nelements = ggml_nelements(t); - size_t nbytes = ggml_nbytes(t); - - size_t vulkanOffset = ggml_vk_aligned_offset(t->buffer, originalOffset); - if (alignedOffset) { - *alignedOffset = originalOffset - vulkanOffset; - nbytes += *alignedOffset; - } - - return komputeManager()->tensor( - t->data, - nelements, - nbytes, kp::Tensor::TensorDataTypes::eFloat, - res->primaryMemory, res->primaryBuffer, - res->stagingMemory, res->stagingBuffer, - vulkanOffset); -} - -static std::vector getSpirvShader(const unsigned char* rawData, size_t size) { - if (size % sizeof(uint32_t) != 0) { - throw std::runtime_error("Invalid size: must be divisible by sizeof(uint32_t)"); - } - - const uint32_t* data_ptr = reinterpret_cast(rawData); - size_t count = size / sizeof(uint32_t); - return std::vector(data_ptr, data_ptr + count); -} - -inline static -uint32_t safe_divide(uint32_t a, uint32_t b) { - if (b <= 1) { - return a; - } - if ((a % b) != 0) { - fprintf(stderr, "((%u %% %u) == %u) != 0\n", a, b, a % b); - GGML_ABORT("safe_divide result would've had remainder"); - } - return a / b; -} - -static void ggml_vk_add( - kp::Sequence& seq, - const std::shared_ptr& inA, - const std::shared_ptr& inB, - const std::shared_ptr& out, - uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03, - int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03, - int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13, - int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13, - int32_t ne0, - int32_t nb0, int32_t nb1, int32_t nb2, int32_t nb3 -) { - const static auto spirv = getSpirvShader(kp::shader_data::op_add_comp_spv, - kp::shader_data::op_add_comp_spv_len); - - struct PushConstants { - uint32_t inAOff, inBOff, outOff; - int32_t ne00; - int32_t nb00, nb01, nb02, nb03; - int32_t ne10, ne11, ne12, ne13; - int32_t nb10, nb11, nb12, nb13; - int32_t ne0; - int32_t nb0, nb1, nb2, nb3; - } const pushConsts { - safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4), - ne00, - nb00, nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb10, nb11, nb12, nb13, - ne0, - nb0, nb1, nb2, nb3 - }; - - std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(__func__)) { - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}); - } else { - s_algo = komputeManager()->getAlgorithm(__func__); - s_algo->setTensors({inA, inB, out}); - s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)}); - s_algo->setPushConstants({pushConsts}); - s_algo->updateDescriptors(s_kompute_context->pool.get()); - } - seq.record(s_algo); -} - -static void ggml_vk_addrow(kp::Sequence& seq, - const std::shared_ptr& inA, - const std::shared_ptr& inB, - const std::shared_ptr& out, - uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - uint32_t size, uint32_t row = 0) { - - const static auto spirv = getSpirvShader(kp::shader_data::op_addrow_comp_spv, - kp::shader_data::op_addrow_comp_spv_len); - - struct PushConstants { - uint32_t inAOff, inBOff, outOff; - uint32_t row; - } const pushConsts { - safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4), - row - }; - - std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(__func__)) - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts}); - else { - s_algo = komputeManager()->getAlgorithm(__func__); - s_algo->setTensors({inA, inB, out}); - s_algo->setWorkgroup({size}); - s_algo->setPushConstants({pushConsts}); - s_algo->updateDescriptors(s_kompute_context->pool.get()); - } - seq.record(s_algo); -} - -static void ggml_vk_mul( - kp::Sequence& seq, - const std::shared_ptr& inA, - const std::shared_ptr& inB, - const std::shared_ptr& out, - uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03, - int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03, - int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13, - int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13, - int32_t ne0, - int32_t nb0, int32_t nb1, int32_t nb2, int32_t nb3 -) { - const static auto spirv = getSpirvShader(kp::shader_data::op_mul_comp_spv, - kp::shader_data::op_mul_comp_spv_len); - - struct PushConstants { - uint32_t inAOff, inBOff, outOff; - int32_t ne00; - int32_t nb00, nb01, nb02, nb03; - int32_t ne10, ne11, ne12, ne13; - int32_t nb10, nb11, nb12, nb13; - int32_t ne0; - int32_t nb0, nb1, nb2, nb3; - } const pushConsts { - safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4), - ne00, - nb00, nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb10, nb11, nb12, nb13, - ne0, - nb0, nb1, nb2, nb3 - }; - - std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(__func__)) { - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}); - } else { - s_algo = komputeManager()->getAlgorithm(__func__); - s_algo->setTensors({inA, inB, out}); - s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)}); - s_algo->setPushConstants({pushConsts}); - s_algo->updateDescriptors(s_kompute_context->pool.get()); - } - seq.record(s_algo); -} - -static void ggml_vk_scale(kp::Sequence& seq, - const std::shared_ptr& in, - const std::shared_ptr& out, - uint32_t inOff, uint32_t outOff, - uint32_t size, float scale) { - const static auto spirv_1 = getSpirvShader( - kp::shader_data::op_scale_comp_spv, kp::shader_data::op_scale_comp_spv_len - ); - const static auto spirv_8 = getSpirvShader( - kp::shader_data::op_scale_8_comp_spv, kp::shader_data::op_scale_8_comp_spv_len - ); - - struct PushConstants { - uint32_t inOff, outOff; - float scale; - } const pushConsts { - safe_divide(inOff, 4), safe_divide(outOff, 4), - scale - }; - - const auto * spirv = &spirv_1; - std::string name(__func__); - if (size % 8 == 0) { - size /= 8; - name += "_8"; - spirv = &spirv_8; - } - - std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(name)) { - s_algo = komputeManager()->algorithm(name, s_kompute_context->pool.get(), {in, out}, *spirv, {size}, {}, {pushConsts}); - } else { - s_algo = komputeManager()->getAlgorithm(name); - s_algo->setTensors({in, out}); - s_algo->setWorkgroup({size}); - s_algo->setPushConstants({pushConsts}); - s_algo->updateDescriptors(s_kompute_context->pool.get()); - } - seq.record(s_algo); -} - -static void ggml_vk_xxlu( - const std::vector& spirv, const char * suffix, kp::Sequence& seq, - const std::shared_ptr& in, - const std::shared_ptr& out, - uint32_t inOff, uint32_t outOff, - uint32_t size -) { - struct PushConstants { - uint32_t inOff, outOff; - } const pushConsts { - safe_divide(inOff, 4), safe_divide(outOff, 4), - }; - - auto name = std::string(__func__) + "_" + suffix; - std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(name)) { - s_algo = komputeManager()->algorithm(name, s_kompute_context->pool.get(), {in, out}, spirv, {size}, {}, {pushConsts}); - } else { - s_algo = komputeManager()->getAlgorithm(name); - s_algo->setTensors({in, out}); - s_algo->setWorkgroup({size}); - s_algo->setPushConstants({pushConsts}); - s_algo->updateDescriptors(s_kompute_context->pool.get()); - } - seq.record(s_algo); -} - -template -static void ggml_vk_silu(Args&&... args) { - const static auto spirv = getSpirvShader(kp::shader_data::op_silu_comp_spv, - kp::shader_data::op_silu_comp_spv_len); - - ggml_vk_xxlu(spirv, "silu", std::forward(args)...); -} - -template -static void ggml_vk_relu(Args&&... args) { - const static auto spirv = getSpirvShader(kp::shader_data::op_relu_comp_spv, - kp::shader_data::op_relu_comp_spv_len); - - ggml_vk_xxlu(spirv, "relu", std::forward(args)...); -} - -template -static void ggml_vk_gelu(Args&&... args) { - const static auto spirv = getSpirvShader(kp::shader_data::op_gelu_comp_spv, - kp::shader_data::op_gelu_comp_spv_len); - - ggml_vk_xxlu(spirv, "gelu", std::forward(args)...); -} - -static void ggml_vk_soft_max( - kp::Sequence& seq, - const std::shared_ptr& inA, - const std::shared_ptr& inB, - const std::shared_ptr& out, - uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03, - float scale, float max_bias, float m0, float m1, - uint32_t n_head_log2 -) { - const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv, - kp::shader_data::op_softmax_comp_spv_len); - - struct PushConstants { - uint32_t inAOff, inBOff, outOff; - int32_t ne00, ne01, ne02; - float scale, max_bias, m0, m1; - uint32_t n_head_log2; - int32_t mask; - } pushConsts { - safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4), - ne00, ne01, ne02, - scale, max_bias, m0, m1, - n_head_log2, - bool(inB) - }; - - auto & inB_ = inB ? inB : inA; - - std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(__func__)) { - // FIXME: The softmax kernel needs to be fixed to use the subgroupsize which can vary by device - const uint32_t local_x = 32; - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB_, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts}); - } else { - s_algo = komputeManager()->getAlgorithm(__func__); - s_algo->setTensors({inA, inB_, out}); - s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)}); - s_algo->setPushConstants({pushConsts}); - s_algo->updateDescriptors(s_kompute_context->pool.get()); - } - seq.record(s_algo); -} - -static void ggml_vk_norm_( - const std::vector& spirv, const char * suffix, kp::Sequence& seq, - const std::shared_ptr& in, - const std::shared_ptr& out, - uint32_t inOff, uint32_t outOff, - int32_t ne00, int32_t nb01, - int32_t nrows, float epsilon -) { - GGML_ASSERT(nb01%sizeof(float) == 0); - GGML_ASSERT(ne00%sizeof(float) == 0); - - struct PushConstants { - uint32_t inOff, outOff; - uint32_t ne00, nb01; - float eps; - } pushConsts { - safe_divide(inOff, 4), safe_divide(outOff, 4), - (uint32_t)ne00, (uint32_t)nb01, epsilon - }; - - auto name = std::string(__func__) + "_" + suffix; - std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(name)) { - s_algo = komputeManager()->algorithm(name, s_kompute_context->pool.get(), {in, out}, spirv, {(uint32_t)nrows}, {}, {pushConsts}); - } else { - s_algo = komputeManager()->getAlgorithm(name); - s_algo->setTensors({in, out}); - s_algo->setWorkgroup({(uint32_t)nrows}); - s_algo->setPushConstants({pushConsts}); - s_algo->updateDescriptors(s_kompute_context->pool.get()); - } - seq.record(s_algo); -} - -template -static void ggml_vk_norm(Args&&... args) { - const static auto spirv = getSpirvShader(kp::shader_data::op_norm_comp_spv, - kp::shader_data::op_norm_comp_spv_len); - - ggml_vk_norm_(spirv, "norm", std::forward(args)...); -} - -template -static void ggml_vk_rms_norm(Args&&... args) { - const static auto spirv = getSpirvShader(kp::shader_data::op_rmsnorm_comp_spv, - kp::shader_data::op_rmsnorm_comp_spv_len); - - ggml_vk_norm_(spirv, "rms", std::forward(args)...); -} - -static void ggml_vk_diag_mask_inf(kp::Sequence& seq, - const std::shared_ptr& in, - const std::shared_ptr& out, - uint32_t inOff, uint32_t outOff, - uint32_t n_past, - int32_t ne00, int32_t ne01, int32_t ne02) { - const static auto spirv = getSpirvShader(kp::shader_data::op_diagmask_comp_spv, - kp::shader_data::op_diagmask_comp_spv_len); - - struct PushConstants { - uint32_t inOff, outOff; - uint32_t n_past; - int32_t ne00, ne01; - } pushConsts { - safe_divide(inOff, 4), safe_divide(outOff, 4), - n_past, - ne00, ne01 - }; - - std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(__func__)) - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne00), unsigned(ne01), unsigned(ne02)}, {}, {pushConsts}); - else { - s_algo = komputeManager()->getAlgorithm(__func__); - s_algo->setTensors({in, out}); - s_algo->setWorkgroup({unsigned(ne00), unsigned(ne01), unsigned(ne02)}); - s_algo->setPushConstants({pushConsts}); - s_algo->updateDescriptors(s_kompute_context->pool.get()); - } - seq.record(s_algo); -} - -static void ggml_vk_mul_mat_f16( - kp::Sequence& seq, - const std::shared_ptr& inA, - const std::shared_ptr& inB, - const std::shared_ptr& out, - uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - int32_t ne00, int32_t ne01, int32_t ne02, - uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03, - int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13, - uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13, - int32_t ne0, int32_t ne1, - uint32_t r2, uint32_t r3 -) { - const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_f16_comp_spv, - kp::shader_data::op_mul_mat_f16_comp_spv_len); - - struct PushConstants { - uint32_t inAOff, inBOff, outOff; - int32_t ne00, ne01, ne02; - uint32_t nb00, nb01, nb02, nb03; - int32_t ne10, ne11, ne12; - uint32_t nb10, nb11, nb12, nb13; - int32_t ne0, ne1; - uint32_t r2, r3; - } pushConsts { - safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4), - ne00, ne01, ne02, - nb00, nb01, nb02, nb03, - ne10, ne11, ne12, - nb10, nb11, nb12, nb13, - ne0, ne1, - r2, r3 - }; - - const unsigned ny = unsigned((ne11 + 4 - 1)/4); - - std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(__func__)) { - const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2; - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), ny, unsigned(ne12*ne13)}, {local_x}, {pushConsts}); - } else { - s_algo = komputeManager()->getAlgorithm(__func__); - s_algo->setTensors({inA, inB, out}); - s_algo->setWorkgroup({unsigned(ne01), ny, unsigned(ne12*ne13)}); - s_algo->setPushConstants({pushConsts}); - s_algo->updateDescriptors(s_kompute_context->pool.get()); - } - seq.record(s_algo); -} - -static void ggml_vk_mul_mat_mat_f32(kp::Sequence& seq, - const std::shared_ptr& inA, - const std::shared_ptr& inB, - const std::shared_ptr& out, - uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - int32_t ne00, int32_t ne01, int32_t ne02, - uint32_t nb01, uint32_t nb02, - int32_t ne11, int32_t ne12, - uint32_t nb11, uint32_t nb12, - uint32_t nb1, uint32_t nb2) { - const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_mat_f32_comp_spv, - kp::shader_data::op_mul_mat_mat_f32_comp_spv_len); - - struct PushConstants { - uint32_t inAOff, inBOff, outOff; - int32_t ne00, ne01, ne02, ne11, ne12; - uint32_t nb01, nb02; - uint32_t nb11, nb12; - uint32_t nb1, nb2; - } pushConsts { - safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4), - ne00, ne01, ne02, ne11, ne12, - nb01, nb02, nb11, nb12, - nb1, nb2 - }; - - const uint32_t local_x = ggml_vk_current_device().subgroupSize; - std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(__func__)) { - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), - {inA, inB, out}, spirv, - {unsigned(ne01), - unsigned(ne11), - unsigned(std::max(ne12, ne02)) - }, - {local_x}, - {pushConsts}); - } else { - s_algo = komputeManager()->getAlgorithm(__func__); - s_algo->setTensors({inA, inB, out}); - s_algo->setWorkgroup({unsigned(ne01), - unsigned(ne11), - unsigned(std::max(ne12, ne02)), - }); - s_algo->setPushConstants({pushConsts}); - s_algo->updateDescriptors(s_kompute_context->pool.get()); - } - seq.record(s_algo); -} - -static void ggml_vk_mul_mat_impl( - const std::vector& spirv, const char * suffix, uint32_t block_size, kp::Sequence& seq, - const std::shared_ptr& inA, - const std::shared_ptr& inB, - const std::shared_ptr& out, - uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - int32_t ne00, int32_t ne01, int32_t ne02, - int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13, - int32_t ne0, int32_t ne1, - uint32_t nb01, uint32_t nb02, uint32_t nb03, - uint32_t nb11, uint32_t nb12, uint32_t nb13, - uint32_t r2, uint32_t r3 -) { - struct PushConstants { - uint32_t inAOff, inBOff, outOff; - int32_t ne00, ne01, ne02; - int32_t ne10, ne12; - int32_t ne0, ne1; - uint32_t nb01, nb02, nb03; - uint32_t nb11, nb12, nb13; - uint32_t r2, r3; - } pushConsts { - safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4), - ne00, ne01, ne02, - ne10, ne12, - ne0, ne1, - nb01, nb02, nb03, - nb11, nb12, nb13, - r2, r3 - }; - - auto name = std::string(__func__) + "_" + suffix; - std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(name)) { - const uint32_t local_x = (ggml_vk_current_device().subgroupSize * 2) / 8; - s_algo = komputeManager()->algorithm(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts}); - } else { - s_algo = komputeManager()->getAlgorithm(name); - s_algo->setTensors({inA, inB, out}); - s_algo->setWorkgroup({unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}); - s_algo->setPushConstants({pushConsts}); - s_algo->updateDescriptors(s_kompute_context->pool.get()); - } - seq.record(s_algo); -} - -template -static void ggml_vk_mul_mat_q4_0(Args&&... args) { - const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_0_comp_spv, - kp::shader_data::op_mul_mat_q4_0_comp_spv_len); - - ggml_vk_mul_mat_impl(spirv, "q4_0", 1/*We access blocks unaligned*/, std::forward(args)...); -} - -template -static void ggml_vk_mul_mat_q4_1(Args&&... args) { - const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_1_comp_spv, - kp::shader_data::op_mul_mat_q4_1_comp_spv_len); - - ggml_vk_mul_mat_impl(spirv, "q4_1", 1/*We access blocks unaligned*/, std::forward(args)...); -} - -template -static void ggml_vk_mul_mat_q8_0(Args&&... args) { - const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q8_0_comp_spv, - kp::shader_data::op_mul_mat_q8_0_comp_spv_len); - - ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward(args)...); -} - -static void ggml_vk_mul_mat_q4_k( - kp::Sequence& seq, - const std::shared_ptr& inA, - const std::shared_ptr& inB, - const std::shared_ptr& out, - uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - int32_t ne00, int32_t ne01, int32_t ne02, - int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13, - int32_t ne0, int32_t ne1, - uint32_t nb01, uint32_t nb02, uint32_t nb03, - uint32_t nb11, uint32_t nb12, uint32_t nb13, - uint32_t r2, uint32_t r3 -) { - const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv, - kp::shader_data::op_mul_mat_q4_k_comp_spv_len); - - struct PushConstants { - uint32_t inAOff, inBOff, outOff; - int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12; - uint32_t nb01, nb02, nb03, nb11, nb12, nb13; - uint32_t r2, r3; - } pushConsts { - inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4), - ne00, ne10, ne0, ne1, ne01, ne02, ne12, - nb01, nb02, nb03, nb11, nb12, nb13, - r2, r3 - }; - - std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(__func__)) { - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts}); - } else { - s_algo = komputeManager()->getAlgorithm(__func__); - s_algo->setTensors({inA, inB, out}); - s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}); - s_algo->setPushConstants({pushConsts}); - s_algo->updateDescriptors(s_kompute_context->pool.get()); - } - seq.record(s_algo); -} - -static void ggml_vk_mul_mat_q6_k( - kp::Sequence& seq, - const std::shared_ptr& inA, - const std::shared_ptr& inB, - const std::shared_ptr& out, - uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - int32_t ne00, int32_t ne01, int32_t ne02, - int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13, - int32_t ne0, int32_t ne1, - uint32_t nb01, uint32_t nb02, uint32_t nb03, - uint32_t nb11, uint32_t nb12, uint32_t nb13, - uint32_t r2, uint32_t r3 -) { - const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv, - kp::shader_data::op_mul_mat_q6_k_comp_spv_len); - - struct PushConstants { - uint32_t inAOff, inBOff, outOff; - int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12; - uint32_t nb01, nb02, nb03, nb11, nb12, nb13; - uint32_t r2, r3; - } pushConsts { - inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4), - ne00, ne10, ne0, ne1, ne01, ne02, ne12, - nb01, nb02, nb03, nb11, nb12, nb13, - r2, r3 - }; - - std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(__func__)) { - const uint32_t local_x = 2; - const uint32_t local_y = ggml_vk_current_device().subgroupSize; - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)}, {local_x, local_y}, {pushConsts}); - } else { - s_algo = komputeManager()->getAlgorithm(__func__); - s_algo->setTensors({inA, inB, out}); - s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)}); - s_algo->setPushConstants({pushConsts}); - s_algo->updateDescriptors(s_kompute_context->pool.get()); - } - seq.record(s_algo); -} - -static void ggml_vk_get_rows( - const std::vector& spirv, - const char * suffix, - unsigned element_size, unsigned qk, - kp::Sequence& seq, - const std::shared_ptr& inA, - const std::shared_ptr& inB, - const std::shared_ptr& out, - uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - int32_t ne00, int32_t nb01, int32_t nb1, - uint32_t size -) { - GGML_ASSERT(nb01%element_size == 0); - GGML_ASSERT(nb1%sizeof(float) == 0); - if (qk) GGML_ASSERT(ne00%qk == 0); - - struct PushConstants { - uint32_t inAOff, inBOff, outOff; - int32_t ne00, nb01, nb1; - } pushConsts { - safe_divide(inAOff, element_size), safe_divide(inBOff, 4), safe_divide(outOff, 4), - ne00, nb01, nb1 - }; - - auto name = std::string(__func__) + "_" + suffix; - std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(name)) { - s_algo = komputeManager()->algorithm(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts}); - } else { - s_algo = komputeManager()->getAlgorithm(name); - s_algo->setTensors({inA, inB, out}); - s_algo->setWorkgroup({size}); - s_algo->setPushConstants({pushConsts}); - s_algo->updateDescriptors(s_kompute_context->pool.get()); - } - seq.record(s_algo); -} - -template -static void ggml_vk_get_rows_f32(Args&&... args) { - const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv, - kp::shader_data::op_getrows_f32_comp_spv_len); - - ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward(args)...); -} - -template -static void ggml_vk_get_rows_f16(Args&&... args) { - const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv, - kp::shader_data::op_getrows_f16_comp_spv_len); - - ggml_vk_get_rows(spirv, "f16", sizeof(half), 0, std::forward(args)...); -} - -template -static void ggml_vk_get_rows_q4_0(Args&&... args) { - const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_0_comp_spv, - kp::shader_data::op_getrows_q4_0_comp_spv_len); - - ggml_vk_get_rows(spirv, "q4_0", 1/*We access blocks unaligned*/, QK4_0, std::forward(args)...); -} - -template -static void ggml_vk_get_rows_q4_1(Args&&... args) { - const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q4_1_comp_spv, - kp::shader_data::op_getrows_q4_1_comp_spv_len); - - ggml_vk_get_rows(spirv, "q4_1", 1/*We access blocks unaligned*/, QK4_1, std::forward(args)...); -} - -template -static void ggml_vk_get_rows_q6_k(Args&&... args) { - const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_q6_k_comp_spv, - kp::shader_data::op_getrows_q6_k_comp_spv_len); - ggml_vk_get_rows(spirv, "q6_k", 1/*We access blocks unaligned*/, QK_NL, std::forward(args)...); -} - -static void ggml_vk_rope( - kp::Sequence& seq, - const std::shared_ptr& inA, - const std::shared_ptr& inB, - const std::shared_ptr& inC, - const std::shared_ptr& out, - uint32_t inAOff, uint32_t inBOff, uint32_t inCOff, uint32_t outOff, - ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig, - float freq_base, float freq_scale, bool has_freq_factors, float ext_factor, float attn_factor, float beta_fast, float beta_slow, - int32_t ne01, int32_t ne02, int32_t ne03, - uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03, - int32_t ne0, - uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3 -) { - GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32); - - static const auto spirv_norm_f16 = getSpirvShader( - kp::shader_data::op_rope_norm_f16_comp_spv, kp::shader_data::op_rope_norm_f16_comp_spv_len - ); - static const auto spirv_norm_f32 = getSpirvShader( - kp::shader_data::op_rope_norm_f32_comp_spv, kp::shader_data::op_rope_norm_f32_comp_spv_len - ); - static const auto spirv_neox_f16 = getSpirvShader( - kp::shader_data::op_rope_neox_f16_comp_spv, kp::shader_data::op_rope_neox_f16_comp_spv_len - ); - static const auto spirv_neox_f32 = getSpirvShader( - kp::shader_data::op_rope_neox_f32_comp_spv, kp::shader_data::op_rope_neox_f32_comp_spv_len - ); - - int type_size = src0t == GGML_TYPE_F16 ? 2 : 4; - - GGML_ASSERT(nb03 % type_size == 0); - GGML_ASSERT(nb02 % type_size == 0); - GGML_ASSERT(nb01 % type_size == 0); - GGML_ASSERT(nb00 % type_size == 0); - GGML_ASSERT(nb3 % type_size == 0); - GGML_ASSERT(nb2 % type_size == 0); - GGML_ASSERT(nb1 % type_size == 0); - GGML_ASSERT(nb0 % type_size == 0); - - struct PushConstants { - uint32_t inAOff, inBOff, inCOff, outOff; - int32_t n_dims, mode, n_ctx_orig; - float freq_base, freq_scale; - bool has_freq_factors; - float ext_factor, attn_factor, beta_fast, beta_slow; - uint32_t nb00, nb01, nb02, nb03; - int32_t ne0; - uint32_t nb0, nb1, nb2, nb3; - } pushConsts { - safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(inCOff, type_size), safe_divide(outOff, type_size), - n_dims, mode, n_ctx_orig, - freq_base, freq_scale, - has_freq_factors, - ext_factor, attn_factor, beta_fast, beta_slow, - nb00, nb01, nb02, nb03, - ne0, - nb0, nb1, nb2, nb3 - }; - - auto & inC_ = inC ? inC : inA; - const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; - const bool is_f16 = src0t == GGML_TYPE_F16; - - auto name = std::string(__func__) + (is_neox ? "_neox" : "_norm") + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32"); - std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(name)) { - auto & spirv = is_neox ? is_f16 ? spirv_neox_f16 : spirv_neox_f32 : is_f16 ? spirv_norm_f16 : spirv_norm_f32; - s_algo = komputeManager()->algorithm( - name, s_kompute_context->pool.get(), {inA, inB, inC_, out}, spirv, - {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts} - ); - } else { - s_algo = komputeManager()->getAlgorithm(name); - s_algo->setTensors({inA, inB, inC_, out}); - s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)}); - s_algo->setPushConstants({pushConsts}); - s_algo->updateDescriptors(s_kompute_context->pool.get()); - } - seq.record(s_algo); -} - -static void ggml_vk_cpy( - const std::vector& spirv, - uint32_t in_element_size, uint32_t out_element_size, - kp::Sequence& seq, - const std::shared_ptr& in, - const std::shared_ptr& out, - uint32_t inOff, uint32_t outOff, - int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03, - uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03, - int32_t ne0, int32_t ne1, int32_t ne2, - uint32_t nb0, uint32_t nb1, uint32_t nb2, uint32_t nb3 -) { - struct PushConstants { - uint32_t inOff, outOff; - int32_t ne00, ne01, ne02; - uint32_t nb00, nb01, nb02, nb03; - int32_t ne0, ne1, ne2; - uint32_t nb0, nb1, nb2, nb3; - } pushConsts { - safe_divide(inOff, in_element_size), safe_divide(outOff, out_element_size), - ne00, ne01, ne02, - nb00, nb01, nb02, nb03, - ne0, ne1, ne2, - nb0, nb1, nb2, nb3 - }; - - std::string name = std::string(__func__) - + "_i_" + std::to_string(in_element_size) - + "_o_" + std::to_string(out_element_size); - std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(name)) - s_algo = komputeManager()->algorithm(name, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}); - else { - s_algo = komputeManager()->getAlgorithm(name); - s_algo->setTensors({in, out}); - s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)}); - s_algo->setPushConstants({pushConsts}); - s_algo->updateDescriptors(s_kompute_context->pool.get()); - } - seq.record(s_algo); -} - -template -static void ggml_vk_cpy_f32_f16(Args&&... args) { - const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f16_comp_spv, - kp::shader_data::op_cpy_f32_f16_comp_spv_len); - ggml_vk_cpy(spirv, 4, 2, std::forward(args)...); -} - -template -static void ggml_vk_cpy_f32_f32(Args&&... args) { - const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f32_f32_comp_spv, - kp::shader_data::op_cpy_f32_f32_comp_spv_len); - ggml_vk_cpy(spirv, 4, 4, std::forward(args)...); -} - -template -static void ggml_vk_cpy_f16_f16(Args&&... args) { - const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f16_comp_spv, - kp::shader_data::op_cpy_f16_f16_comp_spv_len); - ggml_vk_cpy(spirv, 2, 2, std::forward(args)...); -} - -template -static void ggml_vk_cpy_f16_f32(Args&&... args) { - const static auto spirv = getSpirvShader(kp::shader_data::op_cpy_f16_f32_comp_spv, - kp::shader_data::op_cpy_f16_f32_comp_spv_len); - ggml_vk_cpy(spirv, 2, 4, std::forward(args)...); -} - -static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { - int64_t n = ggml_nelements(op); - switch (op->op) { - case GGML_OP_UNARY: - if (n % 4 != 0) return false; - switch (ggml_get_unary_op(op)) { - case GGML_UNARY_OP_GELU: - if (n % 8 != 0) return false; - // fall through - case GGML_UNARY_OP_RELU: - case GGML_UNARY_OP_SILU: - return ggml_is_contiguous(op->src[0]); - default: - ; - } - break; - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: - case GGML_OP_PERMUTE: - case GGML_OP_ADD: - case GGML_OP_MUL: - case GGML_OP_SCALE: - case GGML_OP_SOFT_MAX: - case GGML_OP_RMS_NORM: - case GGML_OP_NORM: - return true; - case GGML_OP_ROPE: - { - const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { - return false; - } - if (mode & GGML_ROPE_TYPE_VISION) { - return false; - } - return true; - } - case GGML_OP_DUP: - case GGML_OP_CPY: - case GGML_OP_CONT: - switch (op->src[0]->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - break; - default: - return false; - } - switch (op->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - break; - default: - return false; - } - return true; - case GGML_OP_DIAG_MASK_INF: - return op->ne[3] == 1; - case GGML_OP_GET_ROWS: - switch (op->src[0]->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q6_K: - return op->ne[2] == 1 && op->ne[3] == 1; - default: - ; - } - return false; - case GGML_OP_MUL_MAT: - if (op->src[1]->type != GGML_TYPE_F32 || ggml_is_transposed(op->src[0]) || ggml_is_transposed(op->src[1])) - return false; - - switch (op->src[0]->type) { - case GGML_TYPE_F32: - return op->ne[3] == 1; - case GGML_TYPE_Q6_K: - case GGML_TYPE_F16: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q4_K: - return true; - default: - ; - } - default: - ; - } - return false; - - GGML_UNUSED(dev); -} - -static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) { - const int n_seq = 8; - - // FIXME: Figure out if we can somehow optimize the size of the pool... right now we're setting - // it to the size of the graph, but I think it can be made smaller? - ggml_vk_allocate_descriptor_pool(ctx, gf->n_nodes); - - std::vector> sequences(n_seq); - - for (auto& sequence : sequences) { - sequence = komputeManager()->sequence(); - } - for (int seq_idx = 0; seq_idx < n_seq; ++seq_idx) { - const int n_nodes_per_seq = (gf->n_nodes + n_seq - 1) / n_seq; - - auto& seq = *sequences[seq_idx]; - - const int node_start = (seq_idx + 0) * n_nodes_per_seq; - const int node_end = std::min((seq_idx == n_seq - 1) ? gf->n_nodes : (seq_idx + 1) * n_nodes_per_seq, gf->n_nodes); - - bool any_commands_recorded = false; - - for (int i = node_start; i < node_end; ++i) { - struct ggml_tensor * src0 = gf->nodes[i]->src[0]; - struct ggml_tensor * src1 = gf->nodes[i]->src[1]; - struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2); - struct ggml_tensor * dst = gf->nodes[i]; - GGML_ASSERT(dst->data != nullptr); - - if (ggml_is_empty(dst)) { - continue; - } - - switch (dst->op) { - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: - case GGML_OP_PERMUTE: - continue; // noop -> next node - default: - break; - } - - any_commands_recorded = true; - - const int32_t ne00 = src0 ? src0->ne[0] : 0; - const int32_t ne01 = src0 ? src0->ne[1] : 0; - const int32_t ne02 = src0 ? src0->ne[2] : 0; - const int32_t ne03 = src0 ? src0->ne[3] : 0; - - const uint32_t nb00 = src0 ? src0->nb[0] : 0; - const uint32_t nb01 = src0 ? src0->nb[1] : 0; - const uint32_t nb02 = src0 ? src0->nb[2] : 0; - const uint32_t nb03 = src0 ? src0->nb[3] : 0; - - const int32_t ne10 = src1 ? src1->ne[0] : 0; - const int32_t ne11 = src1 ? src1->ne[1] : 0; - const int32_t ne12 = src1 ? src1->ne[2] : 0; - const int32_t ne13 = src1 ? src1->ne[3] : 0; - - const uint32_t nb10 = src1 ? src1->nb[0] : 0; - const uint32_t nb11 = src1 ? src1->nb[1] : 0; - const uint32_t nb12 = src1 ? src1->nb[2] : 0; - const uint32_t nb13 = src1 ? src1->nb[3] : 0; - - const int32_t ne0 = dst ? dst->ne[0] : 0; - const int32_t ne1 = dst ? dst->ne[1] : 0; - const int32_t ne2 = dst ? dst->ne[2] : 0; -// const int32_t ne3 = dst ? dst->ne[3] : 0; - - const uint32_t nb0 = dst ? dst->nb[0] : 0; - const uint32_t nb1 = dst ? dst->nb[1] : 0; - const uint32_t nb2 = dst ? dst->nb[2] : 0; - const uint32_t nb3 = dst ? dst->nb[3] : 0; - - const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; - const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; - const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; - - const static std::shared_ptr nullTensor = nullptr; - uint32_t off_src0 = 0; - uint32_t off_src1 = 0; - uint32_t off_src2 = 0; - uint32_t off_dst = 0; - const std::shared_ptr& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor; - const std::shared_ptr& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor; - const std::shared_ptr& id_src2 = src2 ? ggml_vk_get_tensor(src2, &off_src2) : nullTensor; - const std::shared_ptr& id_dst = dst ? ggml_vk_get_tensor(dst, &off_dst) : nullTensor; - - switch (dst->op) { - case GGML_OP_ADD: - { - if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { - // src1 is a row - ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00); - } else { - ggml_vk_add( - seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, - ne00, ne01, ne02, ne03, - nb00, nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb10, nb11, nb12, nb13, - ne0, - nb0, nb1, nb2, nb3 - ); - } - } break; - case GGML_OP_MUL: - { - ggml_vk_mul( - seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, - ne00, ne01, ne02, ne03, - nb00, nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb10, nb11, nb12, nb13, - ne0, - nb0, nb1, nb2, nb3 - ); - } break; - case GGML_OP_SCALE: - { - float scale; memcpy(&scale, dst->op_params, sizeof(float)); - - ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst), scale); - } break; - case GGML_OP_UNARY: - { - int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - switch (ggml_get_unary_op(gf->nodes[i])) { - case GGML_UNARY_OP_SILU: - { - ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, n/4); - } break; - case GGML_UNARY_OP_RELU: - { - ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, n/4); - } break; - case GGML_UNARY_OP_GELU: - { - GGML_ASSERT(n % 8 == 0); - ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, n/8); - } break; - default: - { - fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ABORT("fatal error"); - } - } - } break; - case GGML_OP_SOFT_MAX: - { - float scale; - float max_bias; - - memcpy(&scale, (float *)dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float)); - -#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support") -#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") - GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32); - - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; - - const uint32_t n_head = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale, max_bias, m0, m1, n_head_log2); - } break; - case GGML_OP_DIAG_MASK_INF: - { - const int n_past = ((int32_t *)(dst->op_params))[0]; - ggml_vk_diag_mask_inf(seq, id_src0, id_dst, off_src0, off_dst, n_past, ne00, ne01, ne02); - } break; - case GGML_OP_NORM: - { - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps); - } break; - case GGML_OP_RMS_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps); - } break; - case GGML_OP_MUL_MAT: - { - GGML_ASSERT(ne00 == ne10); - - GGML_ASSERT(ne12 % ne02 == 0); - GGML_ASSERT(ne13 % ne03 == 0); - - const uint32_t r2 = ne12/ne02; - const uint32_t r3 = ne13/ne03; - - if (src1t != GGML_TYPE_F32) { - fprintf(stderr, "%s: %s: Unsupported src1 type: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t); - goto not_implemented; - } - - if (ggml_is_transposed(src0) || - ggml_is_transposed(src1)) { - fprintf(stderr, "%s: %s: matmul on tranposed tensor not supported: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t); - goto not_implemented; - } - - switch (src0t) { - case GGML_TYPE_F32: - ggml_vk_mul_mat_mat_f32( - seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, - ne00, ne01, ne02, nb01, nb02, ne11, ne12, nb11, nb12, nb1, nb2 - ); - break; - case GGML_TYPE_F16: - ggml_vk_mul_mat_f16( - seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, - ne00, ne01, ne02, nb00, nb01, nb02, nb03, - ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, - ne0, ne1, r2, r3 - ); - break; - case GGML_TYPE_Q8_0: - ggml_vk_mul_mat_q8_0( - seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, - ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, - nb01, nb02, nb03, nb11, nb12, nb13, r2, r3 - ); - break; - case GGML_TYPE_Q4_0: - ggml_vk_mul_mat_q4_0( - seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, - ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, - nb01, nb02, nb03, nb11, nb12, nb13, r2, r3 - ); - break; - case GGML_TYPE_Q4_1: - ggml_vk_mul_mat_q4_1( - seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, - ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, - nb01, nb02, nb03, nb11, nb12, nb13, r2, r3 - ); - break; - case GGML_TYPE_Q4_K: - ggml_vk_mul_mat_q4_k( - seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, - ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, - nb01, nb02, nb03, nb11, nb12, nb13, r2, r3 - ); - break; - case GGML_TYPE_Q6_K: - ggml_vk_mul_mat_q6_k( - seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, - ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, - nb01, nb02, nb03, nb11, nb12, nb13, r2, r3 - ); - break; - default: { - fprintf(stderr, "%s: %s: Unsupported quantization: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t); - goto not_implemented; - } - } - - } break; - case GGML_OP_GET_ROWS: - { - if (src0t == GGML_TYPE_F32) { - ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1)); - } else if (src0t == GGML_TYPE_F16) { - ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1)); - } else if (src0t == GGML_TYPE_Q4_0) { - ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1)); - } else if (src0t == GGML_TYPE_Q4_1) { - ggml_vk_get_rows_q4_1(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1)); - } else if (src0t == GGML_TYPE_Q6_K) { - ggml_vk_get_rows_q6_k(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1)); - } else { - fprintf(stderr, "%s: %s: Unsupported quantization: %u\n", __func__, ggml_op_name(dst->op), src0t); - goto not_implemented; - } - } break; - case GGML_OP_ROPE: - { - GGML_ASSERT(ne10 == ne02); - GGML_ASSERT(src0t == dstt); - // const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan - const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; - - const bool has_freq_factors = dst->src[2] != nullptr; - - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - ggml_vk_rope( - seq, id_src0, id_src1, id_src2, id_dst, off_src0, off_src1, off_src2, off_dst, src0t, n_dims, mode, n_ctx_orig, - freq_base, freq_scale, has_freq_factors, ext_factor, attn_factor, beta_fast, beta_slow, - ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3 - ); - } break; - case GGML_OP_DUP: - case GGML_OP_CPY: - case GGML_OP_CONT: - { - switch (src0t) { - case GGML_TYPE_F32: - { - switch (dstt) { - case GGML_TYPE_F16: ggml_vk_cpy_f32_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break; - case GGML_TYPE_F32: ggml_vk_cpy_f32_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break; - default: goto not_implemented; - } - } break; - case GGML_TYPE_F16: - { - switch (dstt) { - case GGML_TYPE_F16: ggml_vk_cpy_f16_f16(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break; - case GGML_TYPE_F32: ggml_vk_cpy_f16_f32(seq, id_src0, id_dst, off_src0, off_dst, ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, ne1, ne2, nb0, nb1, nb2, nb3); break; - default: goto not_implemented; - } break; - default: goto not_implemented; - } - } - } break; - default: goto not_implemented; - } - continue; - not_implemented: {} - fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - //GGML_ABORT("fatal error"); - } - - // Evaluate sequence - if (any_commands_recorded) { - seq.evalAsync(); - } - } - - // Wait for all sequences to finish - for (auto& sequence : sequences) { - if (sequence->isRunning()) - sequence->evalAwait(); - } - - ggml_vk_free_descriptor_pool(ctx); -} - -template<> -kp::Tensor::TensorDataTypes -kp::TensorT::dataType() -{ - return TensorDataTypes::eFloat; -} - -template<> -kp::Tensor::TensorDataTypes -kp::TensorT::dataType() -{ - return TensorDataTypes::eUnsignedInt; -} - -//////////////////////////////////////////////////////////////////////////////// - -// backend interface - -struct ggml_backend_kompute_buffer_type_context { - int device; - int device_ref = 0; - uint64_t buffer_alignment; - uint64_t max_alloc; - std::string name; - - ggml_backend_kompute_buffer_type_context(int device, uint64_t buffer_alignment, uint64_t max_alloc) - : device(device), buffer_alignment(buffer_alignment), max_alloc(max_alloc), name(ggml_kompute_format_name(device)) {} -}; - -static void ggml_backend_kompute_device_ref(ggml_backend_buffer_type_t buft) { - auto * ctx = static_cast(buft->context); - - if (!ctx->device_ref) { - komputeManager()->initializeDevice( - ctx->device, {}, { - "VK_KHR_shader_float16_int8", "VK_KHR_8bit_storage", - "VK_KHR_16bit_storage", "VK_KHR_shader_non_semantic_info" - } - ); - } - - assert(ggml_vk_has_device()); - ctx->device_ref++; -} - -static void ggml_backend_kompute_device_unref(ggml_backend_buffer_type_t buft) { - auto * ctx = static_cast(buft->context); - - assert(ctx->device_ref > 0); - - ctx->device_ref--; - - if (!ctx->device_ref) { - komputeManager.destroy(); - } -} - -static void ggml_backend_kompute_buffer_free_buffer(ggml_backend_buffer_t buffer) { - auto * memory = (ggml_vk_memory *)buffer->context; - if (ggml_vk_has_device()) { - ggml_vk_free_memory(*memory); - } - delete memory; -} - -static void * ggml_backend_kompute_buffer_get_base(ggml_backend_buffer_t buffer) { - return ((ggml_vk_memory *)buffer->context)->data; -} - -static void ggml_backend_kompute_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - GGML_UNUSED(buffer); - - const auto res = ggml_vk_get_tensor(tensor); - GGML_ASSERT(res); - - memcpy((char *)tensor->data + offset, data, size); - - komputeManager()->sequence()->eval({res}); -} - -static void ggml_backend_kompute_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { - GGML_UNUSED(buffer); - - const auto res = ggml_vk_get_tensor(tensor); - GGML_ASSERT(res); - - komputeManager()->sequence()->eval({res}); - - memcpy(data, (const char *)tensor->data + offset, size); -} - -static void ggml_backend_kompute_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - auto * memory = (ggml_vk_memory *)buffer->context; - memset(memory->data, value, buffer->size); - - if (memory->stagingBuffer) - komputeManager()->sequence()->eval(memory->primaryBuffer, memory->stagingBuffer, memory->size); -} - -static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = { - /* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer, - /* .get_base = */ ggml_backend_kompute_buffer_get_base, - /* .init_tensor = */ NULL, - /* .memset_tensor = */ NULL, - /* .set_tensor = */ ggml_backend_kompute_buffer_set_tensor, - /* .get_tensor = */ ggml_backend_kompute_buffer_get_tensor, - /* .cpy_tensor = */ NULL, - /* .clear = */ ggml_backend_kompute_buffer_clear, - /* .reset = */ NULL, -}; - -// default buffer type - -static const char * ggml_backend_kompute_buffer_type_get_name(ggml_backend_buffer_type_t buft) { - auto * ctx = static_cast(buft->context); - return ctx->name.c_str(); -} - -static ggml_backend_buffer_t ggml_backend_kompute_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - ggml_backend_kompute_device_ref(buft); - auto * ctx = new ggml_vk_memory(ggml_vk_allocate(size)); - return ggml_backend_buffer_init(buft, ggml_backend_kompute_buffer_i, ctx, size); -} - -static size_t ggml_backend_kompute_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - auto * ctx = static_cast(buft->context); - return ctx->buffer_alignment; -} - -static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { - auto * ctx = static_cast(buft->context); - return ctx->max_alloc; -} - -static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = { - /* .get_name = */ ggml_backend_kompute_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_kompute_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_kompute_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ NULL, -}; - -ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) { - static std::mutex mutex; - std::lock_guard lock(mutex); - - auto devices = ggml_vk_available_devices(); - int32_t device_count = (int32_t) devices.size(); - GGML_ASSERT(device < device_count); - GGML_ASSERT(devices.size() <= GGML_KOMPUTE_MAX_DEVICES); - - static ggml_backend_buffer_type - ggml_backend_kompute_buffer_types[GGML_KOMPUTE_MAX_DEVICES]; - - static bool ggml_backend_kompute_buffer_type_initialized = false; - - if (!ggml_backend_kompute_buffer_type_initialized) { - for (int32_t i = 0; i < device_count; i++) { - ggml_backend_kompute_buffer_types[i] = { - /* .iface = */ ggml_backend_kompute_buffer_type_interface, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), i), - /* .context = */ new ggml_backend_kompute_buffer_type_context{ i, devices[i].bufferAlignment, devices[i].maxAlloc }, - }; - } - ggml_backend_kompute_buffer_type_initialized = true; - } - - return &ggml_backend_kompute_buffer_types[device]; -} - -// backend - -static const char * ggml_backend_kompute_name(ggml_backend_t backend) { - auto * ctx = static_cast(backend->context); - return ctx->name.c_str(); -} - -static void ggml_backend_kompute_free(ggml_backend_t backend) { - auto * ctx = static_cast(backend->context); - - assert(ctx == s_kompute_context); - s_kompute_context = nullptr; - if (ctx != nullptr) { - delete ctx; - } - - delete backend; -} - -static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - auto * ctx = static_cast(backend->context); - ggml_vk_graph_compute(ctx, cgraph); - return GGML_STATUS_SUCCESS; -} - -static struct ggml_backend_i kompute_backend_i = { - /* .get_name = */ ggml_backend_kompute_name, - /* .free = */ ggml_backend_kompute_free, - /* .set_tensor_async = */ NULL, - /* .get_tensor_async = */ NULL, - /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, - /* .graph_plan_create = */ NULL, - /* .graph_plan_free = */ NULL, - /* .graph_plan_update = */ NULL, - /* .graph_plan_compute = */ NULL, - /* .graph_compute = */ ggml_backend_kompute_graph_compute, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, -}; - -static ggml_guid_t ggml_backend_kompute_guid() { - static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49, 0xfb, 0x35, 0xfa, 0x9b, 0x18, 0x31, 0x1d, 0xca }; - return &guid; -} - -ggml_backend_t ggml_backend_kompute_init(int device) { - GGML_ASSERT(s_kompute_context == nullptr); - s_kompute_context = new ggml_kompute_context(device); - - ggml_backend_t kompute_backend = new ggml_backend { - /* .guid = */ ggml_backend_kompute_guid(), - /* .interface = */ kompute_backend_i, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), device), - /* .context = */ s_kompute_context, - }; - - return kompute_backend; -} - -bool ggml_backend_is_kompute(ggml_backend_t backend) { - return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid()); -} - -static size_t ggml_backend_kompute_get_device_count() { - auto devices = ggml_vk_available_devices(); - return devices.size(); -} - -static void ggml_backend_kompute_get_device_description(int device, char * description, size_t description_size) { - auto devices = ggml_vk_available_devices(); - GGML_ASSERT((size_t) device < devices.size()); - snprintf(description, description_size, "%s", devices[device].name); -} - -static void ggml_backend_kompute_get_device_memory(int device, size_t * free, size_t * total) { - auto devices = ggml_vk_available_devices(); - GGML_ASSERT((size_t) device < devices.size()); - *total = devices[device].heapSize; - *free = devices[device].heapSize; -} - -////////////////////////// - -struct ggml_backend_kompute_device_context { - int device; - std::string name; - std::string description; -}; - -static const char * ggml_backend_kompute_device_get_name(ggml_backend_dev_t dev) { - ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context; - return ctx->name.c_str(); -} - -static const char * ggml_backend_kompute_device_get_description(ggml_backend_dev_t dev) { - ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context; - return ctx->description.c_str(); -} - -static void ggml_backend_kompute_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context; - ggml_backend_kompute_get_device_memory(ctx->device, free, total); -} - -static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_buffer_type(ggml_backend_dev_t dev) { - ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context; - return ggml_backend_kompute_buffer_type(ctx->device); -} - -static bool ggml_backend_kompute_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { - if (buft->iface.get_name != ggml_backend_kompute_buffer_type_get_name) { - return false; - } - - ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context; - ggml_backend_kompute_buffer_type_context * buft_ctx = (ggml_backend_kompute_buffer_type_context *)buft->context; - - return buft_ctx->device == ctx->device; -} - -static enum ggml_backend_dev_type ggml_backend_kompute_device_get_type(ggml_backend_dev_t dev) { - GGML_UNUSED(dev); - return GGML_BACKEND_DEVICE_TYPE_GPU; -} - -static void ggml_backend_kompute_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { - props->name = ggml_backend_kompute_device_get_name(dev); - props->description = ggml_backend_kompute_device_get_description(dev); - props->type = ggml_backend_kompute_device_get_type(dev); - ggml_backend_kompute_device_get_memory(dev, &props->memory_free, &props->memory_total); - props->caps = { - /* async = */ false, - /* host_buffer = */ false, - /* .buffer_from_host_ptr = */ false, - /* events = */ false, - }; -} - -static ggml_backend_t ggml_backend_kompute_device_init(ggml_backend_dev_t dev, const char * params) { - GGML_UNUSED(params); - ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context; - return ggml_backend_kompute_init(ctx->device); -} - -static bool ggml_backend_kompute_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - const int min_batch_size = 32; - - return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || - (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); - - GGML_UNUSED(dev); -} - -static const struct ggml_backend_device_i ggml_backend_kompute_device_i = { - /* .get_name = */ ggml_backend_kompute_device_get_name, - /* .get_description = */ ggml_backend_kompute_device_get_description, - /* .get_memory = */ ggml_backend_kompute_device_get_memory, - /* .get_type = */ ggml_backend_kompute_device_get_type, - /* .get_props = */ ggml_backend_kompute_device_get_props, - /* .init_backend = */ ggml_backend_kompute_device_init, - /* .get_buffer_type = */ ggml_backend_kompute_device_get_buffer_type, - /* .get_host_buffer_type = */ NULL, - /* .buffer_from_host_ptr = */ NULL, - /* .supports_op = */ ggml_backend_kompute_device_supports_op, - /* .supports_buft = */ ggml_backend_kompute_device_supports_buft, - /* .offload_op = */ ggml_backend_kompute_device_offload_op, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_synchronize = */ NULL, -}; - -static const char * ggml_backend_kompute_reg_get_name(ggml_backend_reg_t reg) { - GGML_UNUSED(reg); - return "Kompute"; -} - -static size_t ggml_backend_kompute_reg_get_device_count(ggml_backend_reg_t reg) { - GGML_UNUSED(reg); - return ggml_backend_kompute_get_device_count(); -} - -static ggml_backend_dev_t ggml_backend_kompute_reg_get_device(ggml_backend_reg_t reg, size_t device) { - static std::vector devices; - - static bool initialized = false; - - { - static std::mutex mutex; - std::lock_guard lock(mutex); - if (!initialized) { - for (size_t i = 0; i < ggml_backend_kompute_get_device_count(); i++) { - ggml_backend_kompute_device_context * ctx = new ggml_backend_kompute_device_context; - char desc[256]; - ggml_backend_kompute_get_device_description(i, desc, sizeof(desc)); - ctx->device = i; - ctx->name = "Kompute" + std::to_string(i); - ctx->description = desc; - devices.push_back(new ggml_backend_device { - /* .iface = */ ggml_backend_kompute_device_i, - /* .reg = */ reg, - /* .context = */ ctx, - }); - } - initialized = true; - } - } - - GGML_ASSERT(device < devices.size()); - return devices[device]; -} - -static const struct ggml_backend_reg_i ggml_backend_kompute_reg_i = { - /* .get_name = */ ggml_backend_kompute_reg_get_name, - /* .get_device_count = */ ggml_backend_kompute_reg_get_device_count, - /* .get_device = */ ggml_backend_kompute_reg_get_device, - /* .get_proc_address = */ NULL, -}; - -ggml_backend_reg_t ggml_backend_kompute_reg() { - static ggml_backend_reg reg = { - /* .api_version = */ GGML_BACKEND_API_VERSION, - /* .iface = */ ggml_backend_kompute_reg_i, - /* .context = */ nullptr, - }; - - return ® -} - -GGML_BACKEND_DL_IMPL(ggml_backend_kompute_reg) diff --git a/ggml/src/ggml-kompute/kompute b/ggml/src/ggml-kompute/kompute deleted file mode 160000 index 4565194ed7..0000000000 --- a/ggml/src/ggml-kompute/kompute +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 4565194ed7c32d1d2efa32ceab4d3c6cae006306 diff --git a/ggml/src/ggml-kompute/kompute-shaders/common.comp b/ggml/src/ggml-kompute/kompute-shaders/common.comp deleted file mode 100644 index dbe4cf804e..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/common.comp +++ /dev/null @@ -1,112 +0,0 @@ -#extension GL_EXT_shader_16bit_storage: require -#extension GL_EXT_shader_8bit_storage: require -#extension GL_EXT_shader_explicit_arithmetic_types_float16: require -#extension GL_EXT_shader_explicit_arithmetic_types_int8: require -#extension GL_EXT_shader_explicit_arithmetic_types_int16: require -#extension GL_EXT_shader_explicit_arithmetic_types_int64: require -#extension GL_EXT_control_flow_attributes: enable -#extension GL_KHR_shader_subgroup_arithmetic : require -#extension GL_EXT_debug_printf : enable - -#define QK4_0 32 -#define QK4_1 32 - -#define GELU_COEF_A 0.044715 -#define SQRT_2_OVER_PI 0.79788456080286535587989211986876 -#define TWOPI_F 6.283185307179586f - -#define QK_K 256 -#define K_SCALE_SIZE 12 - -#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx]) -#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx) -#define u8BufToU32(buf, idx) (((uint32_t u8BufToU16(buf, idx + 2) << 8 | buf[idx + 1]) << 8) | buf[idx]) -#define u8BufToFloat(buf, idx) uintBitsToFloat u8BufToU32(buf, idx) - -#define sizeof_block_q4_0 0x12 -struct block_q4_0 { - float16_t d; - uint8_t qs[QK4_0 / 2]; -}; -mat4 dequantize_q4_0(const block_q4_0 xb, uint il) { - const float d1 = il != 0 ? (xb.d / 16.f) : xb.d; - const float d2 = d1 / 256.f; - const float md = -8.f * xb.d; - const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F); - const uint16_t mask1 = mask0 << 8; - - mat4 reg; - for (int i=0;i<8;i++) { - uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]); - reg[i/2][2*(i%2)+0] = d1 * (b & mask0) + md; - reg[i/2][2*(i%2)+1] = d2 * (b & mask1) + md; - } - return reg; -} - -#define sizeof_block_q4_1 0x14 -struct block_q4_1 { - float16_t d; - float16_t m; - uint8_t qs[QK4_1 / 2]; -}; -mat4 dequantize_q4_1(const block_q4_1 xb, uint il) { - const float d1 = il != 0 ? (xb.d / 16.f) : xb.d; - const float d2 = d1 / 256.f; - const float m = xb.m; - const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F); - const uint16_t mask1 = mask0 << 8; - - mat4 reg; - for (int i=0;i<8;i++) { - uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]); - reg[i/2][2*(i%2)+0] = ((b & mask0) * d1) + m; - reg[i/2][2*(i%2)+1] = ((b & mask1) * d2) + m; - } - return reg; -} - -#define sizeof_block_q4_k 144 -struct block_q4_k { - float16_t d; - float16_t dmin; - uint8_t scales[K_SCALE_SIZE]; - uint8_t qs[QK_K/2]; -}; - -#define sizeof_block_q6_k 210 -struct block_q6_k { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales, quantized with 8 bits - float16_t d; // super-block scale -}; -mat4 dequantize_q6_k(const block_q6_k xb, uint il) { - const float16_t d_all = xb.d; - - const uint qlIndex = 64*(il/8) + 32*((il/2)&1) + 16*(il&1); - const uint qhIndex = 32*(il/8) + 16*(il&1); - float16_t sc = xb.scales[(il%2) + 2 * ((il/2))]; - il = (il/2) & 3; - - const uint16_t kmask1 = il>1 ? uint16_t(il>2 ? 192 : 48) : uint16_t(il>0 ? 12 : 3); - const uint16_t kmask2 = il>1 ? uint8_t(0xF0) : uint8_t(0x0F); - const float16_t coef = il>1 ? float16_t(1.f/16.f) : float16_t(1.f); - const float16_t ml = float16_t(d_all * sc * 32.f); - const float16_t dl = float16_t(d_all * sc * coef); - mat4 reg; - for (int i = 0; i < 16; ++i) { - const float16_t q = (il&1) != 0 ? ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 2)) - : ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 4)); - reg[i/4][i%4] = dl * q - ml; - } - return reg; -} - - -#define QK8_0 32 -// struct block_q8_0 { -// float16_t d; // delta -// int8_t qs[QK8_0]; // quants -// }; -#define sizeof_block_q8_0 34 diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_add.comp b/ggml/src/ggml-kompute/kompute-shaders/op_add.comp deleted file mode 100644 index b7b76a79db..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +++ /dev/null @@ -1,58 +0,0 @@ -#version 450 - -#include "common.comp" - -layout(local_size_x = 1024) in; - -layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; -layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; }; -layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; }; - -layout(push_constant) uniform PushConstants { - uint inAOff; - uint inBOff; - uint outOff; - int ne00; - int nb00; - int nb01; - int nb02; - int nb03; - int ne10; - int ne11; - int ne12; - int ne13; - int nb10; - int nb11; - int nb12; - int nb13; - int ne0; - int nb0; - int nb1; - int nb2; - int nb3; - //int offs; // TODO: needed for GGML_OP_ACC, see metal code -} pcs; - -// general-purpose kernel for addition of two tensors -// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 -// cons: not very efficient -void main() { - const uint i03 = gl_WorkGroupID.z; - const uint i02 = gl_WorkGroupID.y; - const uint i01 = gl_WorkGroupID.x; - - const uint i13 = i03 % pcs.ne13; - const uint i12 = i02 % pcs.ne12; - const uint i11 = i01 % pcs.ne11; - - int offs = 0; // TMP (see above) - - uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + offs) / 4); - uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11 ) / 4); - uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1 + offs) / 4); - - for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) { - const uint i10 = i0 % pcs.ne10; - out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] + inB[pcs.inBOff + src1_off + i10]; - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp b/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp deleted file mode 100644 index 2376a6b8f0..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +++ /dev/null @@ -1,25 +0,0 @@ -#version 450 - -#include "common.comp" - -layout(local_size_x = 1) in; - -layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; -layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; }; -layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; }; - -layout(push_constant) uniform PushConstants { - uint inAOff; - uint inBOff; - uint outOff; - uint row; -} pcs; - -void main() { - const uint baseIndex = gl_WorkGroupID.x * 4; - - for (uint x = 0; x < 4; x++) { - const uint i = baseIndex + x; - out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i % pcs.row) + pcs.inBOff]; - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp deleted file mode 100644 index d57247d2dc..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +++ /dev/null @@ -1,52 +0,0 @@ -#version 450 - -#include "common.comp" - -#define IN_TYPE float16_t -#define IN_TYPE_SIZE 2 -#define OUT_TYPE float16_t -#define OUT_TYPE_SIZE 2 - -layout(local_size_x = 1024) in; - -layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; }; -layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; }; - -layout (push_constant) uniform parameter { - uint inOff; - uint outOff; - int ne00; - int ne01; - int ne02; - uint nb00; - uint nb01; - uint nb02; - uint nb03; - int ne0; - int ne1; - int ne2; - uint nb0; - uint nb1; - uint nb2; - uint nb3; -} pcs; - -void main() { - const uint i03 = gl_WorkGroupID.z; - const uint i02 = gl_WorkGroupID.y; - const uint i01 = gl_WorkGroupID.x; - - const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00; - - const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0); - const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0); - const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0; - const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0); - - const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_ - - for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) { - const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_ - out_[dst_data+i00] = OUT_TYPE(in_[src]); - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp deleted file mode 100644 index b568bcd7b2..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +++ /dev/null @@ -1,52 +0,0 @@ -#version 450 - -#include "common.comp" - -#define IN_TYPE float16_t -#define IN_TYPE_SIZE 2 -#define OUT_TYPE float -#define OUT_TYPE_SIZE 4 - -layout(local_size_x = 1024) in; - -layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; }; -layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; }; - -layout (push_constant) uniform parameter { - uint inOff; - uint outOff; - int ne00; - int ne01; - int ne02; - uint nb00; - uint nb01; - uint nb02; - uint nb03; - int ne0; - int ne1; - int ne2; - uint nb0; - uint nb1; - uint nb2; - uint nb3; -} pcs; - -void main() { - const uint i03 = gl_WorkGroupID.z; - const uint i02 = gl_WorkGroupID.y; - const uint i01 = gl_WorkGroupID.x; - - const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00; - - const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0); - const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0); - const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0; - const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0); - - const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_ - - for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) { - const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_ - out_[dst_data+i00] = OUT_TYPE(in_[src]); - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp deleted file mode 100644 index 99b2283430..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +++ /dev/null @@ -1,52 +0,0 @@ -#version 450 - -#include "common.comp" - -#define IN_TYPE float -#define IN_TYPE_SIZE 4 -#define OUT_TYPE float16_t -#define OUT_TYPE_SIZE 2 - -layout(local_size_x = 1024) in; - -layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; }; -layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; }; - -layout (push_constant) uniform parameter { - uint inOff; - uint outOff; - int ne00; - int ne01; - int ne02; - uint nb00; - uint nb01; - uint nb02; - uint nb03; - int ne0; - int ne1; - int ne2; - uint nb0; - uint nb1; - uint nb2; - uint nb3; -} pcs; - -void main() { - const uint i03 = gl_WorkGroupID.z; - const uint i02 = gl_WorkGroupID.y; - const uint i01 = gl_WorkGroupID.x; - - const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00; - - const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0); - const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0); - const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0; - const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0); - - const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_ - - for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) { - const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_ - out_[dst_data+i00] = OUT_TYPE(in_[src]); - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp deleted file mode 100644 index 2fc998492b..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +++ /dev/null @@ -1,52 +0,0 @@ -#version 450 - -#include "common.comp" - -#define IN_TYPE float -#define IN_TYPE_SIZE 4 -#define OUT_TYPE float -#define OUT_TYPE_SIZE 4 - -layout(local_size_x = 1024) in; - -layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; }; -layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; }; - -layout (push_constant) uniform parameter { - uint inOff; - uint outOff; - int ne00; - int ne01; - int ne02; - uint nb00; - uint nb01; - uint nb02; - uint nb03; - int ne0; - int ne1; - int ne2; - uint nb0; - uint nb1; - uint nb2; - uint nb3; -} pcs; - -void main() { - const uint i03 = gl_WorkGroupID.z; - const uint i02 = gl_WorkGroupID.y; - const uint i01 = gl_WorkGroupID.x; - - const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00; - - const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0); - const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0); - const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0; - const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0); - - const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_ - - for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) { - const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_ - out_[dst_data+i00] = OUT_TYPE(in_[src]); - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp b/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp deleted file mode 100644 index 291c3fc189..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +++ /dev/null @@ -1,30 +0,0 @@ -#version 450 - -#include "common.comp" - -layout(local_size_x = 1) in; - -layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; }; -layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; }; - -layout(push_constant) uniform PushConstants { - uint inOff; - uint outOff; - uint n_past; - int ne00; - int ne01; -} pcs; - -void main() { - const uint i02 = gl_WorkGroupID.z; - const uint i01 = gl_WorkGroupID.y; - const uint i00 = gl_WorkGroupID.x; - - const uint index = i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00; - - if (i00 > pcs.n_past + i01) { - out_[index + pcs.outOff] = uintBitsToFloat(0xFF800000); - } else { - out_[index + pcs.outOff] = in_[index + pcs.inOff]; - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp b/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp deleted file mode 100644 index 9d8c53710a..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +++ /dev/null @@ -1,22 +0,0 @@ -#version 450 - -#include "common.comp" - -layout(local_size_x = 1) in; - -layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; }; -layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; }; -layout(push_constant) uniform PushConstants { - uint inOff; - uint outOff; -} pcs; - -void main() { - const uint baseIndex = gl_WorkGroupID.x * 8; - - for (uint x = 0; x < 8; x++) { - const uint i = baseIndex + x; - const float y = in_[i + pcs.inOff]; - out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(clamp(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y), -15.0, 15.0))); - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp deleted file mode 100644 index 1a5581b23a..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +++ /dev/null @@ -1,17 +0,0 @@ -void main() { - const uint i = gl_WorkGroupID.x; - const int r = inB[i + pcs.inBOff]; - - int z = 0; - for (uint ind = gl_LocalInvocationID.x; ind < pcs.ne00/16; ind += gl_WorkGroupSize.x) { - const uint inIndex = (r * pcs.nb01 + pcs.inAOff) + ind/NL * SIZE_OF_BLOCK; - const mat4 result = dequantize_block(inIndex, ind%NL); - for (uint j = 0; j < 4; ++j) { - for (uint k = 0; k < 4; ++k) { - const uint outIndex = i * pcs.nb1/BYTES_FOR_TYPE + pcs.outOff + z; - out_[outIndex] = result[j][k]; - ++z; - } - } - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp deleted file mode 100644 index 48c9361081..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +++ /dev/null @@ -1,31 +0,0 @@ -#version 450 - -#include "common.comp" - -layout(local_size_x = 1) in; - -layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; }; -layout (binding = 1) readonly buffer tensorInB { int inB[]; }; -layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; - -layout (push_constant) uniform parameter { - uint inAOff; - uint inBOff; - uint outOff; - int ne00; - int nb01; - int nb1; -} pcs; - -void dequantize_row_f16(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) { - for (int j = 0; j < k; j++) { - out_[y + j] = inA[x + j]; - } -} - -void main() { - const uint i = gl_WorkGroupID.x; - const int r = inB[i + pcs.inBOff]; - - dequantize_row_f16(r*pcs.nb01/2/*bytes for float16*/ + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00); -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp deleted file mode 100644 index 9d7acdaf8a..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +++ /dev/null @@ -1,31 +0,0 @@ -#version 450 - -#include "common.comp" - -layout(local_size_x = 1) in; - -layout (binding = 0) readonly buffer tensorInA { float inA[]; }; -layout (binding = 1) readonly buffer tensorInB { int inB[]; }; -layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; - -layout (push_constant) uniform parameter { - uint inAOff; - uint inBOff; - uint outOff; - int ne00; - int nb01; - int nb1; -} pcs; - -void dequantize_row_f32(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) { - for (int j = 0; j < k; j++) { - out_[y + j] = inA[x + j]; - } -} - -void main() { - const uint i = gl_WorkGroupID.x; - const int r = inB[i + pcs.inBOff]; - - dequantize_row_f32(r*pcs.nb01/4 + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00); -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp deleted file mode 100644 index 32b2e891e8..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +++ /dev/null @@ -1,38 +0,0 @@ -#version 450 - -#include "common.comp" - -#define NL 2 -#define BYTES_FOR_TYPE 4 /*bytes for float*/ -#define SIZE_OF_BLOCK sizeof_block_q4_0 - -layout(local_size_x = 1) in; - -layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; -layout (binding = 1) readonly buffer tensorInB { int inB[]; }; -layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; - -layout (push_constant) uniform parameter { - uint inAOff; - uint inBOff; - uint outOff; - int ne00; - int nb01; - int nb1; -} pcs; - -block_q4_0 get_unaligned_block_q4_0(uint index) { - block_q4_0 fres; - fres.d = u8BufToFloat16(inA, index); - [[unroll]] for (uint it = 0; it != QK4_0 / 2; it++) { - fres.qs[it] = inA[index+2+it]; - } - return fres; -} - -mat4 dequantize_block(uint index, uint il) { - const block_q4_0 block = get_unaligned_block_q4_0(index); - return dequantize_q4_0(block, il); -} - -#include "op_getrows.comp" diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp deleted file mode 100644 index 87f2fbe17b..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +++ /dev/null @@ -1,39 +0,0 @@ -#version 450 - -#include "common.comp" - -#define NL 2 -#define BYTES_FOR_TYPE 4 /*bytes for float*/ -#define SIZE_OF_BLOCK sizeof_block_q4_1 - -layout(local_size_x = 1) in; - -layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; -layout (binding = 1) readonly buffer tensorInB { int inB[]; }; -layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; - -layout (push_constant) uniform parameter { - uint inAOff; - uint inBOff; - uint outOff; - int ne00; - int nb01; - int nb1; -} pcs; - -block_q4_1 get_unaligned_block_q4_1(uint index) { - block_q4_1 fres; - fres.d = u8BufToFloat16(inA, index); - fres.m = u8BufToFloat16(inA, index+2); - [[unroll]] for (uint it = 0; it != QK4_1 / 2; it++) { - fres.qs[it] = inA[index+4+it]; - } - return fres; -} - -mat4 dequantize_block(uint index, uint il) { - const block_q4_1 block = get_unaligned_block_q4_1(index); - return dequantize_q4_1(block, il); -} - -#include "op_getrows.comp" diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp deleted file mode 100644 index 9ce3545d1e..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +++ /dev/null @@ -1,44 +0,0 @@ -#version 450 - -#include "common.comp" - -#define NL 16 -#define BYTES_FOR_TYPE 4 /*bytes for float*/ -#define SIZE_OF_BLOCK sizeof_block_q6_k - -layout(local_size_x = 1) in; - -layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; -layout (binding = 1) readonly buffer tensorInB { int inB[]; }; -layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; - -layout (push_constant) uniform parameter { - uint inAOff; - uint inBOff; - uint outOff; - int ne00; - int nb01; - int nb1; -} pcs; - -block_q6_k get_unaligned_block_q6_k(uint index) { - block_q6_k fres; - [[unroll]] for (uint it = 0; it != QK_K / 2; it++) { - fres.ql[it] = inA[index + it]; - } - [[unroll]] for (uint it = 0; it != QK_K / 4; it++) { - fres.qh[it] = inA[index + QK_K/2 + it]; - } - [[unroll]] for (uint it = 0; it != QK_K / 16; it++) { - fres.scales[it] = int8_t(inA[index + QK_K/2 + QK_K/4 + it]); - } - fres.d = u8BufToFloat16(inA, index + QK_K/2 + QK_K/4 + QK_K/16); - return fres; -} - -mat4 dequantize_block(uint index, uint il) { - const block_q6_k block = get_unaligned_block_q6_k(index); - return dequantize_q6_k(block, il); -} - -#include "op_getrows.comp" diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp deleted file mode 100644 index c92647c4db..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +++ /dev/null @@ -1,52 +0,0 @@ -#version 450 - -#include "common.comp" - -layout(local_size_x = 1024) in; - -layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; -layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; }; -layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; }; - -layout(push_constant) uniform PushConstants { - uint inAOff; - uint inBOff; - uint outOff; - int ne00; - int nb00; - int nb01; - int nb02; - int nb03; - int ne10; - int ne11; - int ne12; - int ne13; - int nb10; - int nb11; - int nb12; - int nb13; - int ne0; - int nb0; - int nb1; - int nb2; - int nb3; -} pcs; - -void main() { - const uint i03 = gl_WorkGroupID.z; - const uint i02 = gl_WorkGroupID.y; - const uint i01 = gl_WorkGroupID.x; - - const uint i13 = i03 % pcs.ne13; - const uint i12 = i02 % pcs.ne12; - const uint i11 = i01 % pcs.ne11; - - uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01) / 4); - uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11) / 4); - uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1) / 4); - - for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) { - const uint i10 = i0 % pcs.ne10; - out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] * inB[pcs.inBOff + src1_off + i10]; - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp deleted file mode 100644 index 0ab1b2fc20..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +++ /dev/null @@ -1,69 +0,0 @@ -#version 450 - -#include "common.comp" - -#extension GL_KHR_shader_subgroup_arithmetic : require - -layout(local_size_x_id = 0) in; - -layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; }; -layout (binding = 1) readonly buffer tensorInB { float inB[]; }; -layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; - -layout (push_constant) uniform parameter { - uint inAOff; - uint inBOff; - uint outOff; - int ne00; - int ne01; - int ne02; - uint nb00; - uint nb01; - uint nb02; - uint nb03; - int ne10; - int ne11; - int ne12; - uint nb10; - uint nb11; - uint nb12; - uint nb13; - int ne0; - int ne1; - uint r2; - uint r3; -} pcs; - -#define N_F16_F32 4 - -void main() { - const uint r0 = gl_WorkGroupID.x; - const uint rb = gl_WorkGroupID.y*N_F16_F32; - const uint im = gl_WorkGroupID.z; - - const uint i12 = im%pcs.ne12; - const uint i13 = im/pcs.ne12; - - const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb03; - - const uint x = offset0 / 2 + pcs.inAOff; // Based from inA - - for (uint row = 0; row < N_F16_F32; ++row) { - uint r1 = rb + row; - if (r1 >= pcs.ne11) { - break; - } - - const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff; - - float sumf = 0; - for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) { - sumf += float(inA[x+i]) * float(inB[y+i]); - } - - const float all_sum = subgroupAdd(sumf); - if (subgroupElect()) { - out_[im*pcs.ne1*pcs.ne0 + r1*pcs.ne0 + r0 + pcs.outOff] = all_sum; - } - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp deleted file mode 100644 index d1ca4ad6c2..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +++ /dev/null @@ -1,51 +0,0 @@ -#version 450 - -#include "common.comp" - -#extension GL_KHR_shader_subgroup_arithmetic : require -#extension GL_EXT_debug_printf : enable - -// device subgroup size -layout (local_size_x_id = 0) in; - -layout(binding = 0) readonly buffer tensorInA { float inA[]; }; -layout(binding = 1) readonly buffer tensorInB { float inB[]; }; -layout(binding = 2) writeonly buffer tensorOut { float out_[]; }; - -layout(push_constant) uniform parameter { - uint inAOff; - uint inBOff; - uint outOff; - int ne00; - int ne01; - int ne02; - int ne11; - int ne12; - uint nb01; - uint nb02; - uint nb11; - uint nb12; - uint nb1; - uint nb2; -} -pcs; - - -void main() { - uvec3 gid = gl_WorkGroupID; - - uint bc_ab = pcs.ne12 > pcs.ne02 ? gid.z / (pcs.ne12 / pcs.ne02) : gid.z; - uint bc_ba = pcs.ne02 > pcs.ne12 ? gid.z / (pcs.ne02 / pcs.ne12) : gid.z; - - const uint x = (gid.x*pcs.nb01 + bc_ab*pcs.nb02) / 4 + pcs.inAOff; // Based from inA - const uint y = (gid.y*pcs.nb11 + bc_ba*pcs.nb12) / 4 + pcs.inBOff; // based from inB - float sum = 0.0f; - for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) { - sum += float(inA[x+i]) * float(inB[y+i]); - } - - const float all_sum = subgroupAdd(sum); - if (subgroupElect()) { - out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = all_sum; - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp deleted file mode 100644 index b0cea8bbe6..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +++ /dev/null @@ -1,33 +0,0 @@ -#version 450 - -#include "common.comp" - -#define BLOCKS_IN_QUANT QK4_0 -#define SIZE_OF_BLOCK sizeof_block_q4_0 -#define N_ROWS 4 - -#include "op_mul_mv_q_n_pre.comp" - -// The q4_0 version of this function -float block_q_n_dot_y(uint block_index, uint yb, uint il) { - vec2 acc = vec2(0.0, 0.0); - const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff; - float d = float(u8BufToFloat16(inA, index)); - float sumy = 0.0f; - for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) { - const uint16_t b = u8BufToU16(inA, index + 2 + il + i); - - const float yl0 = inB[yb + i]; - const float yl1 = inB[yb + i + 1]; - const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2]; - const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1]; - - sumy += yl0 + yl1 + yl8 + yl9; - - acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00); - acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000); - } - return d * (sumy * -8.f + acc[0] + acc[1]); -} - -#include "op_mul_mv_q_n.comp" diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp deleted file mode 100644 index 8582c61a3b..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +++ /dev/null @@ -1,35 +0,0 @@ -#version 450 - -#include "common.comp" - -#define BLOCKS_IN_QUANT QK4_1 -#define SIZE_OF_BLOCK sizeof_block_q4_1 -#define N_ROWS 4 - -#include "op_mul_mv_q_n_pre.comp" - -// The q4_1 version of this function -float block_q_n_dot_y(uint block_index, uint yb, uint il) { - vec2 acc = vec2(0.0, 0.0); - const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff; - float d = float(u8BufToFloat16(inA, index)); - float m = float(u8BufToFloat16(inA, index+2)); - - float sumy = 0.0f; - for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) { - const uint16_t b = u8BufToU16(inA, index + 4 + il + i); - - const float yl0 = inB[yb + i]; - const float yl1 = inB[yb + i + 1]; - const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2]; - const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1]; - - sumy += yl0 + yl1 + yl8 + yl9; - - acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00); - acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000); - } - return d * (acc[0] + acc[1]) + sumy * m; -} - -#include "op_mul_mv_q_n.comp" diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp deleted file mode 100644 index a5752a3a00..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +++ /dev/null @@ -1,140 +0,0 @@ -#version 450 - -#include "common.comp" - -#define N_DST 4 -#define SIZE_OF_BLOCK sizeof_block_q4_k - -layout(local_size_x = 4) in; -layout(local_size_y = 8) in; -layout(local_size_z = 1) in; - -layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; }; -layout (binding = 1) readonly buffer tensorInB { float inB[]; }; -layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; - -layout (push_constant) uniform parameter { - uint inAOff; - uint inBOff; - uint outOff; - int ne00; - int ne10; - int ne0; - int ne1; - int ne01; - int ne02; - int ne12; - uint nb01; - uint nb02; - uint nb03; - uint nb11; - uint nb12; - uint nb13; - uint r2; - uint r3; -} pcs; - -void main() { - const uint16_t kmask1 = uint16_t(0x3f3f); - const uint16_t kmask2 = uint16_t(0x0f0f); - const uint16_t kmask3 = uint16_t(0xc0c0); - - const uint ix = gl_SubgroupInvocationID/8; // 0...3 - const uint it = gl_SubgroupInvocationID%8; // 0...7 - const uint iq = it/4; // 0 or 1 - const uint ir = it%4; // 0...3 - - const uint nb = pcs.ne00/QK_K; - - const uint r0 = gl_WorkGroupID.x; - const uint r1 = gl_WorkGroupID.y; - const uint im = gl_WorkGroupID.z; - - const uint first_row = r0 * N_DST; - const uint ib_row = first_row * nb; - - const uint i12 = im%pcs.ne12; - const uint i13 = im/pcs.ne12; - - const uint offset0 = first_row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK); - const uint offset1 = r1*pcs.nb11 + (i12 )*pcs.nb12 + (i13 )*pcs.nb13; - - const uint xblk = offset0 + pcs.inAOff; - const uint y = (offset1 / 4) + pcs.inBOff; - - float yl[16]; - float yh[16]; - float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f}; - float all_sum = 0.f; - - uint y4 = y + ix * QK_K + 64 * iq + 8 * ir; - - for (uint ib = ix; ib < nb; ib += 4) { - const uint blk_idx = ib + xblk; - - float sumy[4] = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+0] = inB[y4+i+ 0]; sumy[0] += yl[i+0]; - yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8]; - yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0]; - yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8]; - } - - for (int row = 0; row < N_DST; row++) { - uint row_idx = row * (pcs.nb01 / SIZE_OF_BLOCK); - - uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0); - uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2); - uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4); - uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6); - uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8); - - uint16_t sc16[4]; - sc16[0] = sc_0 & kmask1; - sc16[1] = sc_2 & kmask1; - sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2); - sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2); - - float acc1[4] = {0.f, 0.f, 0.f, 0.f}; - float acc2[4] = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i); - uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i); - acc1[0] += yl[i+0] * (q1 & 0x000F); - acc1[1] += yl[i+1] * (q1 & 0x0F00); - acc1[2] += yl[i+8] * (q1 & 0x00F0); - acc1[3] += yl[i+9] * (q1 & 0xF000); - acc2[0] += yh[i+0] * (q2 & 0x000F); - acc2[1] += yh[i+1] * (q2 & 0x0F00); - acc2[2] += yh[i+8] * (q2 & 0x00F0); - acc2[3] += yh[i+9] * (q2 & 0xF000); - } - - uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF); - uint8_t sc8_1 = uint8_t(sc16[0] >> 8 ); - uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF); - uint8_t sc8_3 = uint8_t(sc16[1] >> 8 ); - uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF); - uint8_t sc8_5 = uint8_t(sc16[2] >> 8 ); - uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF); - uint8_t sc8_7 = uint8_t(sc16[3] >> 8 ); - - float dall = float(inA[blk_idx + row_idx].d); - float dmin = float(inA[blk_idx + row_idx].dmin); - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 + - (acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f + - (acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 + - (acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) - - dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7); - } - - y4 += 4 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = subgroupAdd(sumf[row]); - if (subgroupElect()) { - out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum; - } - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp deleted file mode 100644 index d331d1a705..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +++ /dev/null @@ -1,106 +0,0 @@ -#version 450 - -#include "common.comp" - -#define SIZE_OF_BLOCK sizeof_block_q6_k - -layout(local_size_x_id = 0) in; -layout(local_size_y_id = 1) in; -layout(local_size_z = 1) in; - -layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; -layout (binding = 1) readonly buffer tensorInB { float inB[]; }; -layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; - -layout (push_constant) uniform parameter { - uint inAOff; - uint inBOff; - uint outOff; - int ne00; - int ne10; - int ne0; - int ne1; - int ne01; - int ne02; - int ne12; - uint nb01; - uint nb02; - uint nb03; - uint nb11; - uint nb12; - uint nb13; - uint r2; - uint r3; -} pcs; - -void main() { - const uint8_t kmask1 = uint8_t(0x03); - const uint8_t kmask2 = uint8_t(0x0C); - const uint8_t kmask3 = uint8_t(0x30); - const uint8_t kmask4 = uint8_t(0xC0); - - const uint nb = pcs.ne00/QK_K; - - const uint r0 = gl_WorkGroupID.x; - const uint r1 = gl_WorkGroupID.y; - const uint im = gl_WorkGroupID.z; - - const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID); - - const uint i12 = im%pcs.ne12; - const uint i13 = im/pcs.ne12; - - const uint x = row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK); - const uint yy = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff; - - float sumf = 0; - - // bits of invocation ID for gl_SubgroupSize=32: - // x x x x x - // 4 3 2 1 0 - // ( tid ) ix - // ip ( il ) - - const uint block_stride = gl_SubgroupSize / 16; // number of blocks each subgroup processes - const uint tid = gl_SubgroupInvocationID/block_stride; // first block_stride groups have tid=0 - const uint ix = gl_SubgroupInvocationID%block_stride; // first block is 0..block_stride-1 - const uint ip = tid/8; // first or second half of block (0 or 1) - const uint il = tid%8; // each half has 8 parts, one per scale - const uint n = 4; // 4 scales at a time (and 4 sums) - const uint l0 = n*il; // offset into half-block, 0..28 - const uint is = 8*ip + l0/16; // 0, 1, 8, 9 - - const uint y_offset = 128*ip + l0; - const uint q_offset_l = 64*ip + l0; - const uint q_offset_h = 32*ip + l0; - - for (uint i = ix; i < nb; i += block_stride) { - - const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff; - - const uint qlIndex = q_offset_l; - const uint q2Index = qlIndex + QK_K/8; - const uint qhIndex = q_offset_h; - const uint y = yy + i * QK_K + y_offset; - - float sums[4] = {0.0f, 0.0f, 0.0f, 0.0f}; - for (uint l = 0; l < n; ++l) { - const uint8_t currentQ1 = inA[baseIndex + qlIndex + l]; - const uint8_t currentQ2 = inA[baseIndex + q2Index + l]; - const uint8_t currentQh = inA[baseIndex + QK_K/2 + qhIndex + l]; - - sums[0] += inB[y+l+ 0] * (int8_t((currentQ1 & 0xF) | ((currentQh & kmask1) << 4)) - 32); - sums[1] += inB[y+l+32] * (int8_t((currentQ2 & 0xF) | ((currentQh & kmask2) << 2)) - 32); - sums[2] += inB[y+l+64] * (int8_t((currentQ1 >> 4) | ((currentQh & kmask3) << 0)) - 32); - sums[3] += inB[y+l+96] * (int8_t((currentQ2 >> 4) | ((currentQh & kmask4) >> 2)) - 32); - } - - float d = u8BufToFloat16(inA, baseIndex + QK_K/2 + QK_K/4 + QK_K/16); - sumf += d * (sums[0] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + is]) + sums[1] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 2 + is]) + sums[2] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 4 + is]) + sums[3] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 6 + is])); - } - - const float tot = subgroupAdd(sumf); - if (subgroupElect()) { - out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot; - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp deleted file mode 100644 index 34d015e90b..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +++ /dev/null @@ -1,73 +0,0 @@ -#version 450 - -#include "common.comp" - -#include "op_mul_mv_q_n_pre.comp" - -#define SIZE_OF_D 2 - -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 - -#define NB_Q8_0 8 - -void main() { - // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64 - if (gl_SubgroupInvocationID > 31) - return; - - const int nr = N_DST; - const int nsg = N_SIMDGROUP; - const int nw = N_SIMDWIDTH; - - const int nb = pcs.ne00/QK8_0; - const uint r0 = gl_WorkGroupID.x; - const uint r1 = gl_WorkGroupID.y; - const uint im = gl_WorkGroupID.z; - - const uint first_row = (r0 * nsg + gl_SubgroupID) * nr; - - const uint i12 = im%pcs.ne12; - const uint i13 = im/pcs.ne12; - - const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02); - - const uint x = offset0*sizeof_block_q8_0 + pcs.inAOff; // Based from inA - const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff; // based from inB - - float yl[NB_Q8_0]; - float sumf[N_DST]={0.f, 0.f, 0.f, 0.f}; - - const uint ix = gl_SubgroupInvocationID.x/4; - const uint il = gl_SubgroupInvocationID.x%4; - - uint yb = y + ix * QK8_0 + NB_Q8_0*il; - - // each thread in a SIMD group deals with NB_Q8_0 quants at a time - for (uint ib = ix; ib < nb; ib += nw/4) { - for (int i = 0; i < NB_Q8_0; ++i) { - yl[i] = inB[yb + i]; - } - - for (int row = 0; row < nr; row++) { - const uint block_offset = (ib+row*nb) * sizeof_block_q8_0; - float sumq = 0.f; - for (int iq = 0; iq < NB_Q8_0; ++iq) { - const int8_t qs_iq = int8_t(inA[x + block_offset + SIZE_OF_D + NB_Q8_0*il + iq]); - sumq += qs_iq * yl[iq]; - } - const float16_t d = u8BufToFloat16(inA, x + block_offset); - sumf[row] += sumq*d; - } - - yb += NB_Q8_0 * nw; - } - - for (int row = 0; row < nr; ++row) { - const float tot = subgroupAdd(sumf[row]); - if (subgroupElect() && first_row + row < pcs.ne01) { - out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row] = tot; - } - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp deleted file mode 100644 index a6517cc1f1..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +++ /dev/null @@ -1,52 +0,0 @@ -void main() { - // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64 - if (gl_SubgroupInvocationID > 31) - return; - - const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT); - - const uint r0 = gl_WorkGroupID.x; - const uint r1 = gl_WorkGroupID.y; - const uint im = gl_WorkGroupID.z; - - const uint first_row = (r0 * gl_NumSubgroups + gl_SubgroupID) * N_ROWS; - - const uint i12 = im%pcs.ne12; - const uint i13 = im/pcs.ne12; - - // pointers to src0 rows - uint ax[N_ROWS]; - for (int row = 0; row < N_ROWS; ++row) { - const uint offset0 = (first_row + row)*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK); - - ax[row] = offset0 + pcs.inAOff; - } - - const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff; - - float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f}; - - const uint ix = gl_SubgroupInvocationID/2; - const uint il = (BLOCKS_IN_QUANT/4)*(gl_SubgroupInvocationID%2); - - uint yb = y + ix * BLOCKS_IN_QUANT + il; - - //debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n", - // gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize, - // gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z); - - for (uint ib = ix; ib < nb; ib += 16) { - for (int row = 0; row < N_ROWS; row++) { - sumf[row] += block_q_n_dot_y(ax[row] + ib, yb, il); - } - - yb += BLOCKS_IN_QUANT * 16; - } - - for (int row = 0; row < N_ROWS; ++row) { - const float tot = subgroupAdd(sumf[row]); - if (first_row + row < pcs.ne01 && subgroupElect()) { - out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = tot; - } - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp deleted file mode 100644 index a9a2f22180..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +++ /dev/null @@ -1,28 +0,0 @@ -layout(local_size_x_id = 0) in; -layout(local_size_y = 8) in; -layout(local_size_z = 1) in; - -layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; -layout (binding = 1) readonly buffer tensorInB { float inB[]; }; -layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; - -layout (push_constant) uniform parameter { - uint inAOff; - uint inBOff; - uint outOff; - int ne00; - int ne01; - int ne02; - int ne10; - int ne12; - int ne0; - int ne1; - uint nb01; - uint nb02; - uint nb03; - uint nb11; - uint nb12; - uint nb13; - uint r2; - uint r3; -} pcs; diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp b/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp deleted file mode 100644 index ad0c3c01b9..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +++ /dev/null @@ -1,84 +0,0 @@ -#version 450 - -#include "common.comp" - -layout(local_size_x = 256) in; - -layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; }; -layout(binding = 1) buffer restrict tensorOut { float out_[]; }; - -layout(push_constant) uniform PushConstants { - uint inOff; - uint outOff; - uint ne00; - uint nb01; - float eps; -} pcs; - -shared float sum[gl_WorkGroupSize.x]; - -void main() { - const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_ - // MEAN - // parallel sum - sum[gl_LocalInvocationID.x] = 0.0; - for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) { - sum[gl_LocalInvocationID.x] += in_[x+i00]; - } - - // reduce - barrier(); - memoryBarrierShared(); - [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) { - if (gl_LocalInvocationID.x < i) { - sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i]; - } - barrier(); - memoryBarrierShared(); - } - - // broadcast - if (gl_LocalInvocationID.x == 0) { - sum[0] /= float(pcs.ne00); - } - barrier(); - memoryBarrierShared(); - const float mean = sum[0]; - - // recenter - const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_ - for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) { - out_[y+i00] = in_[x+i00] - mean; - } - - // VARIANCE - // parallel sum - sum[gl_LocalInvocationID.x] = 0.0; - for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) { - sum[gl_LocalInvocationID.x] += out_[y+i00] * out_[y+i00]; - } - - // reduce - barrier(); - memoryBarrierShared(); - [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) { - if (gl_LocalInvocationID.x < i) { - sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i]; - } - barrier(); - memoryBarrierShared(); - } - - // broadcast - if (gl_LocalInvocationID.x == 0) { - sum[0] /= float(pcs.ne00); - } - barrier(); - memoryBarrierShared(); - const float variance = sum[0]; - - const float scale = 1.0f/sqrt(variance + pcs.eps); - for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) { - out_[y+i00] *= scale; - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp b/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp deleted file mode 100644 index 52a601fe6d..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +++ /dev/null @@ -1,21 +0,0 @@ -#version 450 - -#include "common.comp" - -layout(local_size_x = 1) in; - -layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; }; -layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; }; -layout(push_constant) uniform PushConstants { - uint inOff; - uint outOff; -} pcs; - -void main() { - const uint baseIndex = gl_WorkGroupID.x * 4; - - for (uint x = 0; x < 4; x++) { - const uint i = baseIndex + x; - out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]); - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp deleted file mode 100644 index da658c1601..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +++ /dev/null @@ -1,53 +0,0 @@ -#version 450 - -#include "common.comp" - -layout(local_size_x = 512) in; - -layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; }; -layout(binding = 1) buffer restrict tensorOut { float out_[]; }; - -layout(push_constant) uniform PushConstants { - uint inOff; - uint outOff; - uint ne00; - uint nb01; - float eps; -} pcs; - -shared float sum[gl_WorkGroupSize.x]; - -void main() { - const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_ - - // parallel sum - sum[gl_LocalInvocationID.x] = 0.0; - for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) { - sum[gl_LocalInvocationID.x] += in_[x+i00] * in_[x+i00]; - } - - // reduce - barrier(); - memoryBarrierShared(); - [[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) { - if (gl_LocalInvocationID.x < i) { - sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i]; - } - barrier(); - memoryBarrierShared(); - } - - // broadcast - if (gl_LocalInvocationID.x == 0) { - sum[0] /= float(pcs.ne00); - } - barrier(); - memoryBarrierShared(); - - const float scale = 1.0f/sqrt(sum[0] + pcs.eps); - - const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_ - for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) { - out_[y+i00] = in_[x+i00] * scale; - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp deleted file mode 100644 index 63659cbfe5..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +++ /dev/null @@ -1,52 +0,0 @@ -#version 450 - -#include "rope_common.comp" - -layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; }; -layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; }; -layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; }; -layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; }; - -void main() { - const uint i3 = gl_WorkGroupID.z; - const uint i2 = gl_WorkGroupID.y; - const uint i1 = gl_WorkGroupID.x; - - float corr_dims[2]; - rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); - - const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); - - float theta_base = float(inB[pcs.inBOff + i2]); - float inv_ndims = -1.f/pcs.n_dims; - - float cos_theta; - float sin_theta; - - for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) { - if (i0 < pcs.n_dims) { - uint ic = i0/2; - - float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0); - - const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f; - - rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); - - const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 2) + pcs.inAOff; // Based from in - const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + ic*pcs.nb0) / 2) + pcs.outOff; // Based from out_ - - const float x0 = float(inA[src]); - const float x1 = float(inA[src+pcs.n_dims/2]); - - out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta); - out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta); - } else { - const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in - const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_ - - out_[dst_data] = inA[src]; - out_[dst_data+1] = inA[src+1]; - } - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp deleted file mode 100644 index 4df56204d7..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +++ /dev/null @@ -1,52 +0,0 @@ -#version 450 - -#include "rope_common.comp" - -layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; -layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; }; -layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; }; -layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; }; - -void main() { - const uint i3 = gl_WorkGroupID.z; - const uint i2 = gl_WorkGroupID.y; - const uint i1 = gl_WorkGroupID.x; - - float corr_dims[2]; - rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); - - const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); - - float theta_base = float(inB[pcs.inBOff + i2]); - float inv_ndims = -1.f/pcs.n_dims; - - float cos_theta; - float sin_theta; - - for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) { - if (i0 < pcs.n_dims) { - uint ic = i0/2; - - float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0); - - const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f; - - rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); - - const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 4) + pcs.inAOff; // Based from in - const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + ic*pcs.nb0) / 4) + pcs.outOff; // Based from out_ - - const float x0 = inA[src]; - const float x1 = inA[src+pcs.n_dims/2]; - - out_[dst_data] = x0*cos_theta - x1*sin_theta; - out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta; - } else { - const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in - const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_ - - out_[dst_data] = inA[src]; - out_[dst_data+1] = inA[src+1]; - } - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp deleted file mode 100644 index a3c0eda8bd..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +++ /dev/null @@ -1,52 +0,0 @@ -#version 450 - -#include "rope_common.comp" - -layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; }; -layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; }; -layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; }; -layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; }; - -void main() { - const uint i3 = gl_WorkGroupID.z; - const uint i2 = gl_WorkGroupID.y; - const uint i1 = gl_WorkGroupID.x; - - float corr_dims[2]; - rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); - - const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); - - float theta_base = float(inB[pcs.inBOff + i2]); - float inv_ndims = -1.f/pcs.n_dims; - - float cos_theta; - float sin_theta; - - for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) { - if (i0 < pcs.n_dims) { - uint ic = i0/2; - - float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0); - - const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f; - - rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); - - const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in - const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_ - - const float x0 = float(inA[src]); - const float x1 = float(inA[src+1]); - - out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta); - out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta); - } else { - const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in - const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_ - - out_[dst_data] = inA[src]; - out_[dst_data+1] = inA[src+1]; - } - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp deleted file mode 100644 index b7963ae725..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +++ /dev/null @@ -1,52 +0,0 @@ -#version 450 - -#include "rope_common.comp" - -layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; -layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; }; -layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; }; -layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; }; - -void main() { - const uint i3 = gl_WorkGroupID.z; - const uint i2 = gl_WorkGroupID.y; - const uint i1 = gl_WorkGroupID.x; - - float corr_dims[2]; - rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); - - const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); - - float theta_base = float(inB[pcs.inBOff + i2]); - float inv_ndims = -1.f/pcs.n_dims; - - float cos_theta; - float sin_theta; - - for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) { - if (i0 < pcs.n_dims) { - uint ic = i0/2; - - float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0); - - const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f; - - rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); - - const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in - const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_ - - const float x0 = inA[src]; - const float x1 = inA[src+1]; - - out_[dst_data] = x0*cos_theta - x1*sin_theta; - out_[dst_data+1] = x0*sin_theta + x1*cos_theta; - } else { - const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in - const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_ - - out_[dst_data] = inA[src]; - out_[dst_data+1] = inA[src+1]; - } - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp b/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp deleted file mode 100644 index bdae267382..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +++ /dev/null @@ -1,19 +0,0 @@ -#version 450 - -#include "common.comp" - -layout(local_size_x = 1) in; - -layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; }; -layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; }; - -layout(push_constant) uniform PushConstants { - uint inOff; - uint outOff; - float scale; -} pcs; - -void main() { - const uint i = gl_WorkGroupID.x; - out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale; -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp b/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp deleted file mode 100644 index ada69754b2..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +++ /dev/null @@ -1,23 +0,0 @@ -#version 450 - -#include "common.comp" - -layout(local_size_x = 1) in; - -layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; }; -layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; }; - -layout(push_constant) uniform PushConstants { - uint inOff; - uint outOff; - float scale; -} pcs; - -void main() { - const uint baseIndex = gl_WorkGroupID.x * 8; - - for (uint x = 0; x < 8; x++) { - const uint i = baseIndex + x; - out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale; - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp b/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp deleted file mode 100644 index 0fb8e4b740..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +++ /dev/null @@ -1,22 +0,0 @@ -#version 450 - -#include "common.comp" - -layout(local_size_x = 1) in; - -layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; }; -layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; }; -layout(push_constant) uniform PushConstants { - uint inOff; - uint outOff; -} pcs; - -void main() { - const uint baseIndex = gl_WorkGroupID.x * 4; - - for (uint x = 0; x < 4; x++) { - const uint i = baseIndex + x; - const float y = in_[i + pcs.inOff]; - out_[i + pcs.outOff] = y / (1.0 + exp(-y)); - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp b/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp deleted file mode 100644 index 4165295bf4..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +++ /dev/null @@ -1,72 +0,0 @@ -// TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4) - -#version 450 - -#include "common.comp" - -layout(local_size_x_id = 0) in; - -layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; -layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; }; -layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; }; - -layout(push_constant) uniform PushConstants { - uint inAOff; - uint inBOff; - uint outOff; - int ne00; - int ne01; - int ne02; - float scale; - float max_bias; - float m0; - float m1; - uint n_head_log2; - int mask; -} pcs; - -void main() { - if (gl_SubgroupInvocationID > 31) - return; - - const uint i03 = gl_WorkGroupID.z; - const uint i02 = gl_WorkGroupID.y; - const uint i01 = gl_WorkGroupID.x; - - const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00; - const uint psrc0 = extra_off + pcs.inAOff; // Based from inA - const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB - const uint pdst = extra_off + pcs.outOff; // Based from out_ - - float slope = 1.0f; - - // ALiBi - if (pcs.max_bias > 0.0f) { - int64_t h = i02; - - float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1; - int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1; - - slope = pow(base, float(exp)); - } - - // parallel max - float localMax = uintBitsToFloat(0xFF800000); - for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { - localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f)); - } - float max_ = subgroupMax(localMax); - - // parallel sum - float localSum = 0.0f; - for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { - const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_); - localSum += exp_psrc0; - out_[pdst + i00] = exp_psrc0; - } - - const float sum = subgroupAdd(localSum); - for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { - out_[pdst + i00] /= sum; - } -} diff --git a/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp b/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp deleted file mode 100644 index 0fca640dcc..0000000000 --- a/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +++ /dev/null @@ -1,71 +0,0 @@ -#include "common.comp" - -#define GGML_ROPE_TYPE_NEOX 2 - -// TODO: use a local size of 32 or more (Metal uses 1024) -layout(local_size_x = 1) in; - -layout (push_constant) uniform parameter { - uint inAOff; - uint inBOff; - uint inCOff; - uint outOff; - int n_dims; - int mode; - int n_ctx_orig; - float freq_base; - float freq_scale; - bool has_freq_factors; - float ext_factor; - float attn_factor; - float beta_fast; - float beta_slow; - uint nb00; - uint nb01; - uint nb02; - uint nb03; - int ne0; - uint nb0; - uint nb1; - uint nb2; - uint nb3; -} pcs; - -float rope_yarn_ramp(const float low, const float high, const float i0) { - const float y = (i0 / 2 - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); -} - -// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn -// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -void rope_yarn( - float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale, - out float cos_theta, out float sin_theta -) { - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = freq_scale * theta_extrap; - float theta = theta_interp; - if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; - theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - - // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); - } - cos_theta = cos(theta) * mscale; - sin_theta = sin(theta) * mscale; -} - -// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get -// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { - return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base)); -} - -void rope_yarn_corr_dims( - int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2] -) { - // start and end correction dims - dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))); - dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))); -} diff --git a/ggml/src/ggml-metal/CMakeLists.txt b/ggml/src/ggml-metal/CMakeLists.txt index 77187efc17..0ca8a3c55e 100644 --- a/ggml/src/ggml-metal/CMakeLists.txt +++ b/ggml/src/ggml-metal/CMakeLists.txt @@ -71,7 +71,9 @@ else() # note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1 # note: unfortunately, we have to call it default.metallib instead of ggml.metallib # ref: https://github.com/ggerganov/whisper.cpp/issues/1720 - set(XC_FLAGS -fno-fast-math -fno-inline -g) + # note: adding -g causes segmentation fault during compile + #set(XC_FLAGS -fno-fast-math -fno-inline -g) + set(XC_FLAGS -fno-fast-math -fno-inline) else() set(XC_FLAGS -O3) endif() @@ -90,7 +92,7 @@ else() add_custom_command( OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - | - xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib + xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal DEPENDS ggml-metal.metal ${METALLIB_COMMON} diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 17eab976f3..752d55c216 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -229,7 +229,11 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne32; + int32_t ne33; uint64_t nb31; + uint64_t nb32; + uint64_t nb33; int32_t ne1; int32_t ne2; float scale; @@ -422,6 +426,17 @@ typedef struct { int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources } ggml_metal_kargs_im2col; +typedef struct{ + int32_t ne00; + uint64_t nb01; + int32_t ne10; + uint64_t nb11; + int32_t ne0; + uint64_t nb1; + int32_t i00; + int32_t i10; +} ggml_metal_kargs_glu; + typedef struct { int64_t ne00; int64_t ne01; @@ -450,9 +465,21 @@ typedef struct { } ggml_metal_kargs_sum_rows; typedef struct { - int64_t ne00; - int64_t ne01; - int64_t ne02; + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; float scale; float max_bias; float m0; @@ -488,26 +515,25 @@ typedef struct { typedef struct { int64_t d_state; int64_t d_inner; + int64_t n_head; + int64_t n_group; int64_t n_seq_tokens; int64_t n_seqs; - uint64_t nb00; uint64_t nb01; uint64_t nb02; - uint64_t nb10; + uint64_t nb03; uint64_t nb11; uint64_t nb12; uint64_t nb13; - uint64_t nb20; uint64_t nb21; uint64_t nb22; - uint64_t nb30; uint64_t nb31; - uint64_t nb40; uint64_t nb41; uint64_t nb42; - uint64_t nb50; + uint64_t nb43; uint64_t nb51; uint64_t nb52; + uint64_t nb53; } ggml_metal_kargs_ssm_scan; typedef struct { @@ -521,6 +547,22 @@ typedef struct { uint64_t nb2; } ggml_metal_kargs_get_rows; +typedef struct { + int32_t nk0; + int32_t ne01; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_set_rows; + typedef struct { int64_t ne00; int64_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 69b8a268bf..65093628fc 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -202,12 +202,22 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, + GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, + GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, + GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, + GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, + GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, + GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, + GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, + GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, + GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, GGML_METAL_KERNEL_TYPE_RMS_NORM, GGML_METAL_KERNEL_TYPE_L2_NORM, GGML_METAL_KERNEL_TYPE_GROUP_NORM, GGML_METAL_KERNEL_TYPE_NORM, GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, + GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, @@ -517,6 +527,11 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_SIN, GGML_METAL_KERNEL_TYPE_COS, GGML_METAL_KERNEL_TYPE_NEG, + GGML_METAL_KERNEL_TYPE_REGLU, + GGML_METAL_KERNEL_TYPE_GEGLU, + GGML_METAL_KERNEL_TYPE_SWIGLU, + GGML_METAL_KERNEL_TYPE_GEGLU_ERF, + GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_MEAN, GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, @@ -1169,12 +1184,22 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); @@ -1484,6 +1509,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); @@ -1635,6 +1665,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex const bool use_bfloat = ctx_dev->use_bfloat; if (!use_bfloat) { + if (op->type == GGML_TYPE_BF16) { + return false; + } + for (size_t i = 0, n = 3; i < n; ++i) { if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) { return false; @@ -1658,6 +1692,17 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex default: return false; } + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + default: + return false; + } case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -1688,7 +1733,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: case GGML_OP_GROUP_NORM: - return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); + return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); @@ -1804,6 +1849,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex { return op->ne[3] == 1; } + case GGML_OP_SET_ROWS: + { + if (op->src[0]->type != GGML_TYPE_F32) { + return false; + } + + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_IQ4_NL: + return true; + default: + return false; + }; + } default: return false; } @@ -2377,6 +2443,68 @@ static bool ggml_metal_encode_node( GGML_ABORT("fatal error"); } } break; + case GGML_OP_GLU: + { + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + if (src1) { + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + } + + id pipeline = nil; + + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_REGLU: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline; + break; + case GGML_GLU_OP_GEGLU: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline; + break; + case GGML_GLU_OP_SWIGLU: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; + break; + case GGML_GLU_OP_GEGLU_ERF: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline; + break; + case GGML_GLU_OP_GEGLU_QUICK: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline; + break; + default: + GGML_ABORT("fatal error"); + } + + const int32_t swp = ((const int32_t *) dst->op_params)[1]; + + const int32_t i00 = swp ? ne0 : 0; + const int32_t i10 = swp ? 0 : ne0; + + ggml_metal_kargs_glu args = { + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.ne10 =*/ src1 ? ne10 : ne00, + /*.nb11 =*/ src1 ? nb11 : nb01, + /*.ne0 =*/ ne0, + /*.nb1 =*/ nb1, + /*.i00 =*/ src1 ? 0 : i00, + /*.i10 =*/ src1 ? 0 : i10, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (src1) { + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; + + const int64_t nrows = ggml_nrows(src0); + + const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_SQR: { GGML_ASSERT(ggml_is_contiguous(src0)); @@ -2531,10 +2659,7 @@ static bool ggml_metal_encode_node( memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; - - const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head = src0->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); @@ -2594,6 +2719,18 @@ static bool ggml_metal_encode_node( /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, /*.scale =*/ scale, /*.max_bias =*/ max_bias, /*.m0 =*/ m0, @@ -2613,7 +2750,7 @@ static bool ggml_metal_encode_node( [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_DIAG_MASK_INF: { @@ -2687,71 +2824,91 @@ static bool ggml_metal_encode_node( struct ggml_tensor * src3 = node->src[3]; struct ggml_tensor * src4 = node->src[4]; struct ggml_tensor * src5 = node->src[5]; + struct ggml_tensor * src6 = node->src[6]; GGML_ASSERT(src3); GGML_ASSERT(src4); GGML_ASSERT(src5); + GGML_ASSERT(src6); size_t offs_src3 = 0; size_t offs_src4 = 0; size_t offs_src5 = 0; + size_t offs_src6 = 0; id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; + id id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil; - const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); + const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); - const uint64_t nb30 = src3->nb[0]; + const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30); const uint64_t nb31 = src3->nb[1]; const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); - const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); + const int64_t ne41 = src4->ne[1]; const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); + const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43); - const uint64_t nb40 = src4->nb[0]; + const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40); const uint64_t nb41 = src4->nb[1]; const uint64_t nb42 = src4->nb[2]; + const uint64_t nb43 = src4->nb[3]; const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); + const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53); - const uint64_t nb50 = src5->nb[0]; + const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50); const uint64_t nb51 = src5->nb[1]; const uint64_t nb52 = src5->nb[2]; + const uint64_t nb53 = src5->nb[3]; + + const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60); + + const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60); const int64_t d_state = ne00; const int64_t d_inner = ne01; - const int64_t n_seq_tokens = ne11; - const int64_t n_seqs = ne02; + const int64_t n_head = ne02; + const int64_t n_group = ne41; + const int64_t n_seq_tokens = ne12; + const int64_t n_seqs = ne13; - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + id pipeline = nil; + + if (ne30 == 1) { + // Mamba-2 + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + } ggml_metal_kargs_ssm_scan args = { - /*.d_state =*/ d_state, - /*.d_inner =*/ d_inner, + /*.d_state =*/ d_state, + /*.d_inner =*/ d_inner, + /*.n_head =*/ n_head, + /*.n_group =*/ n_group, /*.n_seq_tokens =*/ n_seq_tokens, - /*.n_seqs =*/ n_seqs, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb20 =*/ nb20, - /*.nb21 =*/ nb21, - /*.nb22 =*/ nb22, - /*.nb30 =*/ nb30, - /*.nb31 =*/ nb31, - /*.nb40 =*/ nb40, - /*.nb41 =*/ nb41, - /*.nb42 =*/ nb42, - /*.nb50 =*/ nb50, - /*.nb51 =*/ nb51, - /*.nb52 =*/ nb52, + /*.n_seqs =*/ n_seqs, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb31 =*/ nb31, + /*.nb41 =*/ nb41, + /*.nb42 =*/ nb42, + /*.nb43 =*/ nb43, + /*.nb51 =*/ nb51, + /*.nb52 =*/ nb52, + /*.nb53 =*/ nb53, }; [encoder setComputePipelineState:pipeline]; @@ -2761,10 +2918,17 @@ static bool ggml_metal_encode_node( [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - [encoder setBytes:&args length:sizeof(args) atIndex:7]; + [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; + [encoder setBytes:&args length:sizeof(args) atIndex:8]; - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + if (ne30 == 1) { + // Mamba-2 + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + GGML_ASSERT(d_inner == 1); + [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } } break; case GGML_OP_RWKV_WKV6: { @@ -3778,13 +3942,74 @@ static bool ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&args length:sizeof(args) atIndex:3]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; + case GGML_OP_SET_ROWS: + { + id pipeline = nil; + + switch (dst->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break; + default: GGML_ABORT("not implemented"); + } + + const int32_t nk0 = ne0/ggml_blck_size(dst->type); + + int nth = 32; // SIMD width + + while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } + + int nrptg = 1; + if (nth > nk0) { + nrptg = (nth + nk0 - 1)/nk0; + nth = nk0; + + if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) { + nrptg--; + } + } + + nth = MIN(nth, nk0); + + ggml_metal_kargs_set_rows args = { + /*.nk0 =*/ nk0, + /*.ne01 =*/ ne01, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)]; + } break; case GGML_OP_RMS_NORM: { GGML_ASSERT(ne00 % 4 == 0); @@ -4805,7 +5030,11 @@ static bool ggml_metal_encode_node( /*.nb21 =*/ nb21, /*.nb22 =*/ nb22, /*.nb23 =*/ nb23, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, /*.ne1 =*/ ne1, /*.ne2 =*/ ne2, /*.scale =*/ scale, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index ae012b1c79..239ec31fbc 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -35,6 +35,17 @@ constexpr constant static float kvalues_iq4nl_f[16] = { -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f }; +static inline int best_index_int8(int n, constant float * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} + // NOTE: this is not dequantizing - we are simply fitting the template template void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { @@ -97,6 +108,178 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r } } +void quantize_q4_0(device const float * src, device block_q4_0 & dst) { +#pragma METAL fp math_mode(safe) + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f/d : 0.0f; + + dst.d = d; + + for (int j = 0; j < QK4_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_0/2 + j]*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); + + dst.qs[j] = xi0; + dst.qs[j] |= xi1 << 4; + } +} + +void quantize_q4_1(device const float * src, device block_q4_1 & dst) { +#pragma METAL fp math_mode(safe) + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < QK4_1; j++) { + const float v = src[j]; + if (min > v) min = v; + if (max < v) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst.d = d; + dst.m = min; + + for (int j = 0; j < QK4_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK4_1/2 + j] - min)*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); + + dst.qs[j] = xi0; + dst.qs[j] |= xi1 << 4; + } +} + +void quantize_q5_0(device const float * src, device block_q5_0 & dst) { +#pragma METAL fp math_mode(safe) + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK5_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + dst.d = d; + + uint32_t qh = 0; + for (int j = 0; j < QK5_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK5_0/2 + j]*id; + + const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); + + dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + + for (int j = 0; j < 4; ++j) { + dst.qh[j] = qh8[j]; + } +} + +void quantize_q5_1(device const float * src, device block_q5_1 & dst) { +#pragma METAL fp math_mode(safe) + float max = src[0]; + float min = src[0]; + + for (int j = 1; j < QK5_1; j++) { + const float v = src[j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = d ? 1.0f/d : 0.0f; + + dst.d = d; + dst.m = min; + + uint32_t qh = 0; + for (int j = 0; j < QK5_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK5_1/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); + } + + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + + for (int j = 0; j < 4; ++j) { + dst.qh[j] = qh8[j]; + } +} + +void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) { +#pragma METAL fp math_mode(safe) + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_NL; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / kvalues_iq4nl_f[0]; + const float id = d ? 1.0f/d : 0.0f; + + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_NL/2 + j]*id; + + const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0); + const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1); + + dst.qs[j] = xi0 | (xi1 << 4); + + const float v0 = kvalues_iq4nl_f[xi0]; + const float v1 = kvalues_iq4nl_f[xi1]; + const float w0 = src[0 + j]*src[0 + j]; + const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j]; + sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + + } + + dst.d = sumq2 > 0 ? sumqx/sumq2 : d; +} + template void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 2); @@ -279,6 +462,27 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re } } +void quantize_q8_0(device const float * src, device block_q8_0 & dst) { +#pragma METAL fp math_mode(safe) + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = src[j]; + amax = MAX(amax, fabs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst.d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = src[j]*id; + + dst.qs[j] = round(x0); + } +} + template void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { const float d = xb->d; @@ -995,6 +1199,114 @@ kernel void kernel_neg( dst[tpig] = -src0[tpig]; } +kernel void kernel_reglu( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + dst_row[i0] = x0*x1*(x0 > 0.0f); + } +} + +kernel void kernel_geglu( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0))); + + dst_row[i0] = gelu*x1; + } +} + +kernel void kernel_swiglu( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float silu = x0 / (1.0f + exp(-x0)); + + dst_row[i0] = silu*x1; + } +} + +kernel void kernel_geglu_erf( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float gelu_erf = 0.5f*x0*(1.0f+erf_approx(x0*SQRT_2_INV)); + + dst_row[i0] = gelu_erf*x1; + } +} + +kernel void kernel_geglu_quick( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0))); + + dst_row[i0] = gelu_quick*x1; + } +} + template kernel void kernel_sum_rows( constant ggml_metal_kargs_sum_rows & args, @@ -1057,24 +1369,28 @@ kernel void kernel_soft_max( device char * dst, constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (args.ne02*args.ne01); - const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; - const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); + uint3 tptg[[threads_per_threadgroup]]) { + const int32_t i03 = tgpig.z; + const int32_t i02 = tgpig.y; + const int32_t i01 = tgpig.x; - device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr; - device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); + const int32_t i13 = i03%args.ne13; + const int32_t i12 = i02%args.ne12; + const int32_t i11 = i01; + + device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); + device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr; + device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3); float slope = 1.0f; // ALiBi if (args.max_bias > 0.0f) { - const int64_t h = i02; + const int32_t h = i02; const float base = h < args.n_head_log2 ? args.m0 : args.m1; const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; @@ -1085,13 +1401,13 @@ kernel void kernel_soft_max( // parallel max float lmax = -INFINITY; - for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) { lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)); } // find the max value in the block float max_val = simd_max(lmax); - if (ntg > N_SIMDWIDTH) { + if (tptg.x > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = -INFINITY; } @@ -1110,7 +1426,7 @@ kernel void kernel_soft_max( // parallel sum float lsum = 0.0f; - for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) { const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); lsum += exp_psrc0; pdst[i00] = exp_psrc0; @@ -1122,7 +1438,7 @@ kernel void kernel_soft_max( float sum = simd_sum(lsum); - if (ntg > N_SIMDWIDTH) { + if (tptg.x > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = 0.0f; } @@ -1141,7 +1457,7 @@ kernel void kernel_soft_max( const float inv_sum = 1.0f/sum; - for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) { pdst[i00] *= inv_sum; } } @@ -1153,23 +1469,27 @@ kernel void kernel_soft_max_4( device char * dst, constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (args.ne02*args.ne01); - const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; - const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); + uint3 tptg[[threads_per_threadgroup]]) { + const int32_t i03 = tgpig.z; + const int32_t i02 = tgpig.y; + const int32_t i01 = tgpig.x; - device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr; - device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; + const int32_t i13 = i03%args.ne13; + const int32_t i12 = i02%args.ne12; + const int32_t i11 = i01; + + device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); + device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr; + device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3); float slope = 1.0f; if (args.max_bias > 0.0f) { - const int64_t h = i02; + const int32_t h = i02; const float base = h < args.n_head_log2 ? args.m0 : args.m1; const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; @@ -1180,14 +1500,14 @@ kernel void kernel_soft_max_4( // parallel max float4 lmax4 = -INFINITY; - for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) { lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); float max_val = simd_max(lmax); - if (ntg > N_SIMDWIDTH) { + if (tptg.x > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = -INFINITY; } @@ -1206,7 +1526,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; - for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) { const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; @@ -1220,7 +1540,7 @@ kernel void kernel_soft_max_4( float sum = simd_sum(lsum); - if (ntg > N_SIMDWIDTH) { + if (tptg.x > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = 0.0f; } @@ -1239,7 +1559,7 @@ kernel void kernel_soft_max_4( const float inv_sum = 1.0f/sum; - for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) { pdst4[i00] *= inv_sum; } } @@ -1325,7 +1645,7 @@ kernel void kernel_ssm_conv_f32( x[0] = sumf; } -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part kernel void kernel_ssm_scan_f32( device const void * src0, device const void * src1, @@ -1333,46 +1653,119 @@ kernel void kernel_ssm_scan_f32( device const void * src3, device const void * src4, device const void * src5, + device const void * src6, device float * dst, constant ggml_metal_kargs_ssm_scan & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { - const int64_t ir = tgpig.x; - const int64_t i3 = tgpig.y; + const int64_t i1 = 0; + const int64_t ir = tgpig.x; // current head + const int64_t i3 = tgpig.y; // current seq + + const uint64_t nb00 = sizeof(float); + const uint64_t nb10 = sizeof(float); + const uint64_t nb20 = sizeof(float); const int64_t nc = args.d_state; - // const int64_t nr = args.d_inner; + const int64_t nr = args.d_inner; + const int64_t nh = args.n_head; + const int64_t ng = args.n_group; const int64_t n_t = args.n_seq_tokens; - // const int64_t n_s = args.n_seqs; + + const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src6; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02); - device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); - device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); - device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42); - device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52); - device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides - device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13); + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} - if (i2 > 0) { - s0 = s; - } - - // i1 == 0 - float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - float x_dt = x[0] * dt_soft_plus; + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { - int64_t i = i0; - float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); sumf += state * C[i0]; s[i] = state; } y[0] = sumf; + + // recurse + s0 = s; + } +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part +// TODO: optimize (e.g. by parallelizing over d_state) +kernel void kernel_ssm_scan_f32_group( + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device const void * src6, + device float * dst, + constant ggml_metal_kargs_ssm_scan & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i1 = tgpig.x; + const int64_t ir = tgpig.y; // current head + const int64_t i3 = tgpig.z; // current seq + + const uint64_t nb00 = sizeof(float); + const uint64_t nb10 = sizeof(float); + const uint64_t nb20 = sizeof(float); + + const int64_t nc = args.d_state; + const int64_t nr = args.d_inner; + const int64_t nh = args.n_head; + const int64_t ng = args.n_group; + const int64_t n_t = args.n_seq_tokens; + + const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src6; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + + for (int64_t i2 = 0; i2 < n_t; ++i2) { + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} + + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; + const float dA = exp(dt_soft_plus * A[0]); + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * dA) + (B[i0] * x_dt); + sumf += state * C[i0]; + s[i] = state; + } + + y[0] = sumf; + + // recurse + s0 = s; } } @@ -3513,7 +3906,7 @@ kernel void kernel_flash_attn_ext( // load the mask in shared memory #pragma unroll(Q) for (short j = 0; j < Q; ++j) { - device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); + device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); const float m = pm[ic + tiisg]; @@ -3999,7 +4392,7 @@ kernel void kernel_flash_attn_ext_vec( const bool has_mask = mask != q; // pointer to the mask - device const half * pm = (device const half *) (mask + iq1*args.nb31); + device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); float slope = 1.0f; @@ -4412,6 +4805,7 @@ template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy; #endif +// TODO: templetify these kernels kernel void kernel_cpy_f32_q8_0( constant ggml_metal_kargs_cpy & args, device const char * src0, @@ -4435,23 +4829,7 @@ kernel void kernel_cpy_f32_q8_0( for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) { device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - const float v = src[j]; - amax = MAX(amax, fabs(v)); - } - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK8_0].d = d; - - for (int j = 0; j < QK8_0; ++j) { - const float x0 = src[j]*id; - - dst_data[i00/QK8_0].qs[j] = round(x0); - } + quantize_q8_0(src, dst_data[i00/QK8_0]); } } @@ -4478,32 +4856,7 @@ kernel void kernel_cpy_f32_q4_0( for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) { device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < QK4_0; j++) { - const float v = src[j]; - if (amax < fabs(v)) { - amax = fabs(v); - max = v; - } - } - - const float d = max / -8; - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK4_0].d = d; - - for (int j = 0; j < QK4_0/2; ++j) { - const float x0 = src[0 + j]*id; - const float x1 = src[QK4_0/2 + j]*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); - - dst_data[i00/QK4_0].qs[j] = xi0; - dst_data[i00/QK4_0].qs[j] |= xi1 << 4; - } + quantize_q4_0(src, dst_data[i00/QK4_0]); } } @@ -4530,31 +4883,7 @@ kernel void kernel_cpy_f32_q4_1( for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) { device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - float min = FLT_MAX; - float max = -FLT_MAX; - - for (int j = 0; j < QK4_1; j++) { - const float v = src[j]; - if (min > v) min = v; - if (max < v) max = v; - } - - const float d = (max - min) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK4_1].d = d; - dst_data[i00/QK4_1].m = min; - - for (int j = 0; j < QK4_1/2; ++j) { - const float x0 = (src[0 + j] - min)*id; - const float x1 = (src[QK4_1/2 + j] - min)*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); - - dst_data[i00/QK4_1].qs[j] = xi0; - dst_data[i00/QK4_1].qs[j] |= xi1 << 4; - } + quantize_q4_1(src, dst_data[i00/QK4_1]); } } @@ -4581,38 +4910,7 @@ kernel void kernel_cpy_f32_q5_0( for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) { device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < QK5_0; j++) { - const float v = src[j]; - if (amax < fabs(v)) { - amax = fabs(v); - max = v; - } - } - - const float d = max / -16; - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK5_0].d = d; - - uint32_t qh = 0; - for (int j = 0; j < QK5_0/2; ++j) { - const float x0 = src[0 + j]*id; - const float x1 = src[QK5_0/2 + j]*id; - - const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); - const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); - - dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); - } - thread const uint8_t * qh8 = (thread const uint8_t *)&qh; - for (int j = 0; j < 4; ++j) { - dst_data[i00/QK5_0].qh[j] = qh8[j]; - } + quantize_q5_0(src, dst_data[i00/QK5_0]); } } @@ -4639,51 +4937,10 @@ kernel void kernel_cpy_f32_q5_1( for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) { device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - float max = src[0]; - float min = src[0]; - - for (int j = 1; j < QK5_1; j++) { - const float v = src[j]; - min = v < min ? v : min; - max = v > max ? v : max; - } - - const float d = (max - min) / 31; - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK5_1].d = d; - dst_data[i00/QK5_1].m = min; - - uint32_t qh = 0; - for (int j = 0; j < QK5_1/2; ++j) { - const float x0 = (src[0 + j] - min)*id; - const float x1 = (src[QK5_1/2 + j] - min)*id; - - const uint8_t xi0 = (uint8_t)(x0 + 0.5f); - const uint8_t xi1 = (uint8_t)(x1 + 0.5f); - - dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); - } - thread const uint8_t * qh8 = (thread const uint8_t *)&qh; - for (int j = 0; j < 4; ++j) { - dst_data[i00/QK5_1].qh[j] = qh8[j]; - } + quantize_q5_1(src, dst_data[i00/QK5_1]); } } -static inline int best_index_int8(int n, constant float * val, float x) { - if (x <= val[0]) return 0; - if (x >= val[n-1]) return n-1; - int ml = 0, mu = n-1; - while (mu-ml > 1) { - int mav = (ml+mu)/2; - if (x < val[mav]) mu = mav; else ml = mav; - } - return x - val[mu-1] < val[mu] - x ? mu-1 : mu; -} - kernel void kernel_cpy_f32_iq4_nl( constant ggml_metal_kargs_cpy & args, device const char * src0, @@ -4707,40 +4964,7 @@ kernel void kernel_cpy_f32_iq4_nl( for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) { device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < QK4_NL; j++) { - const float v = src[j]; - if (amax < fabs(v)) { - amax = fabs(v); - max = v; - } - } - - const float d = max / kvalues_iq4nl_f[0]; - const float id = d ? 1.0f/d : 0.0f; - - float sumqx = 0, sumq2 = 0; - for (int j = 0; j < QK4_NL/2; ++j) { - const float x0 = src[0 + j]*id; - const float x1 = src[QK4_NL/2 + j]*id; - - const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0); - const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1); - - dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4); - - const float v0 = kvalues_iq4nl_f[xi0]; - const float v1 = kvalues_iq4nl_f[xi1]; - const float w0 = src[0 + j]*src[0 + j]; - const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j]; - sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j]; - sumq2 += w0*v0*v0 + w1*v1*v1; - - } - - dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d; + quantize_iq4_nl(src, dst_data[i00/QK4_NL]); } } @@ -6421,10 +6645,10 @@ kernel void kernel_mul_mv_iq4_xs_f32( template kernel void kernel_get_rows_q( + constant ggml_metal_kargs_get_rows & args, device const void * src0, device const void * src1, device float * dst, - constant ggml_metal_kargs_get_rows & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint3 tptg [[threads_per_threadgroup]]) { @@ -6444,10 +6668,10 @@ kernel void kernel_get_rows_q( template kernel void kernel_get_rows_f( + constant ggml_metal_kargs_get_rows & args, device const void * src0, device const void * src1, device float * dst, - constant ggml_metal_kargs_get_rows & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint3 tptg [[threads_per_threadgroup]]) { @@ -6465,10 +6689,10 @@ kernel void kernel_get_rows_f( } kernel void kernel_get_rows_i32( + constant ggml_metal_kargs_get_rows & args, device const void * src0, device const void * src1, device int32_t * dst, - constant ggml_metal_kargs_get_rows & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint3 tptg [[threads_per_threadgroup]]) { @@ -6485,6 +6709,67 @@ kernel void kernel_get_rows_i32( } } +template +kernel void kernel_set_rows_q32( + constant ggml_metal_kargs_set_rows & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int32_t i03 = tgpig.z; + const int32_t i02 = tgpig.y; + + const int32_t i12 = i03%args.ne12; + const int32_t i11 = i02%args.ne11; + + const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x; + if (i01 >= args.ne01) { + return; + } + + const int32_t i10 = i01; + const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; + + device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); + const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); + + for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) { + quantize_func(src_row + 32*ind, dst_row[ind]); + } +} + +template +kernel void kernel_set_rows_f( + constant ggml_metal_kargs_set_rows & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int32_t i03 = tgpig.z; + const int32_t i02 = tgpig.y; + + const int32_t i12 = i03%args.ne12; + const int32_t i11 = i02%args.ne11; + + const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x; + if (i01 >= args.ne01) { + return; + } + + const int32_t i10 = i01; + const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; + + device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); + const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); + + for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) { + dst_row[ind] = (T) src_row[ind]; + } +} #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B @@ -6908,6 +7193,27 @@ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q; +// +// set rows +// + +typedef decltype(kernel_set_rows_f) set_rows_f_t; + +template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f; +template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f; +#endif + +typedef decltype(kernel_set_rows_q32) set_rows_q32_t; + +template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32; + // // matrix-matrix multiplication // diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 0e2a419649..45a4883348 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -65,6 +65,7 @@ set(GGML_OPENCL_KERNELS gemv_noshuffle_general gemv_noshuffle get_rows + glu group_norm im2col_f32 im2col_f16 diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 96e8a8588d..a9fc039038 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -351,6 +351,7 @@ struct ggml_backend_opencl_context { cl_program program_gemv_noshuffle_general; cl_program program_gemv_noshuffle; cl_program program_get_rows; + cl_program program_glu; cl_program program_im2col_f16; cl_program program_im2col_f32; cl_program program_mul_mat_Ab_Bi_8x4; @@ -397,10 +398,13 @@ struct ggml_backend_opencl_context { cl_kernel kernel_scale; cl_kernel kernel_silu, kernel_silu_4; cl_kernel kernel_gelu, kernel_gelu_4; + cl_kernel kernel_gelu_erf, kernel_gelu_erf_4; cl_kernel kernel_gelu_quick, kernel_gelu_quick_4; cl_kernel kernel_relu; cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16; cl_kernel kernel_clamp; + cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick, + kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16; cl_kernel kernel_norm; cl_kernel kernel_rms_norm; cl_kernel kernel_group_norm; @@ -733,11 +737,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu", &err), err)); CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_4", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_erf = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_erf", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_erf_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_erf_4", &err), err)); CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick", &err), err)); CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick_4", &err), err)); GGML_LOG_CONT("."); } + // glu + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "glu.cl.h" + }; +#else + const std::string kernel_src = read_file("glu.cl"); +#endif + backend_ctx->program_glu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err)); + CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err)); + CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err)); + CL_CHECK((backend_ctx->kernel_geglu_erf = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err)); + CL_CHECK((backend_ctx->kernel_geglu_quick = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err)); + CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_geglu_erf_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_geglu_quick_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick_f16", &err), err)); + GGML_LOG_CONT("."); + } + // get_rows { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2163,7 +2194,7 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm // dependencies. sync_with_other_backends(backend); - if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; } @@ -2198,6 +2229,12 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te default: return false; } + case GGML_OP_SET_ROWS: + { + // TODO: add support + // ref: https://github.com/ggml-org/llama.cpp/pull/14274 + return false; + } break; case GGML_OP_CPY: case GGML_OP_DUP: case GGML_OP_CONT: @@ -2232,6 +2269,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_UNARY_OP_SIGMOID: @@ -2242,6 +2280,17 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te default: return false; } + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); + default: + return false; + } case GGML_OP_CLAMP: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SOFT_MAX: @@ -3166,7 +3215,7 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso // Open file and dump. char fname[512]; - sprintf(fname, "./tensor-dumps/%s.txt", tensor->name); + snprintf(fname, sizeof(fname), "./tensor-dumps/%s.txt", tensor->name); FILE * f = fopen(fname, "w"); if (!f) { printf("Failed to open %s\n", fname); @@ -3825,6 +3874,44 @@ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_cl_gelu_erf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + int n = ggml_nelements(dst); + + if (n % 4 == 0) { + kernel = backend_ctx->kernel_gelu_erf_4; + n /= 4; + } else { + kernel = backend_ctx->kernel_gelu_erf; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -4420,7 +4507,8 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0); + const int mode_flags = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0); + const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF); cl_kernel kernel = nullptr; if (mode == GGML_SCALE_MODE_NEAREST) { @@ -4451,18 +4539,22 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3]; - const int ne00_src = src0->ne[0]; - const int ne01_src = src0->ne[1]; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; - const int ne10_dst = dst->ne[0]; - const int ne11_dst = dst->ne[1]; - const int ne12_dst = dst->ne[2]; - const int ne13_dst = dst->ne[3]; + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; - const float sf0 = (float)dst->ne[0] / src0->ne[0]; - const float sf1 = (float)dst->ne[1] / src0->ne[1]; - const float sf2 = (float)dst->ne[2] / src0->ne[2]; - const float sf3 = (float)dst->ne[3] / src0->ne[3]; + float sf0 = (float)ne0 / ne00; + float sf1 = (float)ne1 / ne01; + float sf2 = (float)ne2 / ne02; + float sf3 = (float)ne3 / ne03; + + float pixel_offset = 0.5f; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); @@ -4474,29 +4566,36 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb03)); if (mode == GGML_SCALE_MODE_NEAREST) { - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne10_dst)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11_dst)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12_dst)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13_dst)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne3)); CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &sf0)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &sf1)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf2)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3)); } else if (mode == GGML_SCALE_MODE_BILINEAR) { - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00_src)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01_src)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10_dst)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11_dst)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12_dst)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13_dst)); + if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { + sf0 = (float)(ne0 - 1) / (ne00 - 1); + sf1 = (float)(ne1 - 1) / (ne01 - 1); + pixel_offset = 0.0f; + } + + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne3)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf0)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf1)); CL_CHECK(clSetKernelArg(kernel, 16, sizeof(float), &sf2)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(float), &sf3)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &pixel_offset)); } - size_t dst_total_elements = (size_t)ne10_dst * ne11_dst * ne12_dst * ne13_dst; + size_t dst_total_elements = (size_t)ne0 * ne1 * ne2 * ne3; if (dst_total_elements == 0) { return; } @@ -5712,19 +5811,31 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0; - const int ne00 = src0 ? src0->ne[0] : 0; - const int ne01 = src0 ? src0->ne[1] : 0; - const int ne02 = src0 ? src0->ne[2] : 0; - const int ne03 = src0 ? src0->ne[3] : 0; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_long nb01 = src0->nb[1]; + const cl_long nb02 = src0->nb[2]; + const cl_long nb03 = src0->nb[3]; + + const int ne12 = src1 ? src1->ne[2] : 0; + const int ne13 = src1 ? src1->ne[3] : 0; + + const cl_long nb11 = src1 ? src1->nb[1] : 0; + const cl_long nb12 = src1 ? src1->nb[2] : 0; + const cl_long nb13 = src1 ? src1->nb[3] : 0; + + const cl_long nb1 = dst->nb[1]; + const cl_long nb2 = dst->nb[2]; + const cl_long nb3 = dst->nb[3]; float scale, max_bias; memcpy(&scale, dst->op_params + 0, sizeof(float)); memcpy(&max_bias, dst->op_params + 1, sizeof(float)); - const int nrows_x = ggml_nrows(src0); - const int nrows_y = src0->ne[1]; - - const int n_head = nrows_x/nrows_y; + const int n_head = src0->ne[2]; const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); @@ -5769,13 +5880,22 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(float), &scale)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &max_bias)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &m0)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &m1)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &n_head_log2)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2)); size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)nth, 1, 1}; @@ -6143,6 +6263,105 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + if (src1) { + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + } + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + cl_kernel kernel; + switch (ggml_get_glu_op(dst)) { + case GGML_GLU_OP_GEGLU: + if (dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_geglu; + } else { + kernel = backend_ctx->kernel_geglu_f16; + } + break; + case GGML_GLU_OP_REGLU: + if (dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_reglu; + } else { + kernel = backend_ctx->kernel_reglu_f16; + } + break; + case GGML_GLU_OP_SWIGLU: + if (dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_swiglu; + } else { + kernel = backend_ctx->kernel_swiglu_f16; + } + break; + case GGML_GLU_OP_GEGLU_ERF: + if (dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_geglu_erf; + } else { + kernel = backend_ctx->kernel_geglu_erf_f16; + } + break; + case GGML_GLU_OP_GEGLU_QUICK: + if (dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_geglu_quick; + } else { + kernel = backend_ctx->kernel_geglu_quick_f16; + } + break; + default: + GGML_ABORT("Unsupported glu op"); + } + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0; + + const int ne0 = dst->ne[0]; + + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb11 = src1 ? src1->nb[1] : nb01; + + const cl_ulong nb1 = dst->nb[1]; + + const int swp = ((const int32_t *) dst->op_params)[1]; + const int ne00_off = src1 ? 0 : (swp ? ne0 : 0); + const int ne10_off = src1 ? 0 : (swp ? 0 : ne0); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), src1 ? &extra1->data_device : &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne00_off)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10_off)); + + const size_t nrows = ggml_nrows(src0); + size_t nth = 512; + size_t global_work_size[] = {nrows*nth, 1, 1}; + size_t local_work_size[] = {nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + //------------------------------------------------------------------------------ // Op offloading //------------------------------------------------------------------------------ @@ -6211,6 +6430,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_gelu; break; + case GGML_UNARY_OP_GELU_ERF: + if (!any_on_device) { + return false; + } + func = ggml_cl_gelu_erf; + break; case GGML_UNARY_OP_GELU_QUICK: if (!any_on_device) { return false; @@ -6244,6 +6469,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor default: return false; } break; + case GGML_OP_GLU: + if (!any_on_device) { + return false; + } + func = ggml_cl_glu; + break; case GGML_OP_CLAMP: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/gelu.cl b/ggml/src/ggml-opencl/kernels/gelu.cl index 71c310cc9f..1ab426c774 100644 --- a/ggml/src/ggml-opencl/kernels/gelu.cl +++ b/ggml/src/ggml-opencl/kernels/gelu.cl @@ -6,6 +6,7 @@ #define GELU_COEF_A 0.044715f #define GELU_QUICK_COEF -1.702f #define SQRT_2_OVER_PI 0.79788456080286535587989211986876f +#define SQRT_2_INV 0.70710678118654752440084436210484f kernel void kernel_gelu( global float * src0, @@ -35,6 +36,32 @@ kernel void kernel_gelu_4( dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } +kernel void kernel_gelu_erf( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV)); +} + +kernel void kernel_gelu_erf_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV)); +} + kernel void kernel_gelu_quick( global float * src0, ulong offset0, diff --git a/ggml/src/ggml-opencl/kernels/glu.cl b/ggml/src/ggml-opencl/kernels/glu.cl new file mode 100644 index 0000000000..7cca16e6a9 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/glu.cl @@ -0,0 +1,337 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define GELU_COEF_A 0.044715f +#define GELU_QUICK_COEF -1.702f +#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f +#define SQRT_2_INV 0.70710678118654752440084436210484f + +//------------------------------------------------------------------------------ +// geglu +//------------------------------------------------------------------------------ +kernel void kernel_geglu( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0))); + + dst_row[i0] = gelu*x1; + } +} + +kernel void kernel_geglu_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const half x0 = src0_row[i0]; + const half x1 = src1_row[i0]; + + const half gelu = 0.5f*x0*(1.0f + tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0))); + + dst_row[i0] = gelu*x1; + } +} + +//------------------------------------------------------------------------------ +// reglu +//------------------------------------------------------------------------------ +kernel void kernel_reglu( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + dst_row[i0] = x0*x1*(x0 > 0.0f); + } +} + +kernel void kernel_reglu_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const half x0 = src0_row[i0]; + const half x1 = src1_row[i0]; + + dst_row[i0] = x0*x1*(x0 > 0.0f); + } +} + +//------------------------------------------------------------------------------ +// swiglu +//------------------------------------------------------------------------------ +kernel void kernel_swiglu( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float silu = x0 / (1.0f + exp(-x0)); + + dst_row[i0] = silu*x1; + } +} + +kernel void kernel_swiglu_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const half x0 = src0_row[i0]; + const half x1 = src1_row[i0]; + + const half silu = x0 / (1.0f + exp(-x0)); + + dst_row[i0] = silu*x1; + } +} + +//------------------------------------------------------------------------------ +// geglu_erf +//------------------------------------------------------------------------------ +kernel void kernel_geglu_erf( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV)); + + dst_row[i0] = gelu_erf*x1; + } +} + +kernel void kernel_geglu_erf_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const half x0 = src0_row[i0]; + const half x1 = src1_row[i0]; + + const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV)); + + dst_row[i0] = gelu_erf*x1; + } +} + +//------------------------------------------------------------------------------ +// geglu_quick +//------------------------------------------------------------------------------ +kernel void kernel_geglu_quick( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const float x0 = src0_row[i0]; + const float x1 = src1_row[i0]; + + const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0))); + + dst_row[i0] = gelu_quick*x1; + } +} + +kernel void kernel_geglu_quick_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb11, + int ne0, + ulong nb1, + int ne00_off, + int ne10_off +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global char*)((global char*)dst + offsetd); + + global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off; + global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off; + global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const half x0 = src0_row[i0]; + const half x1 = src1_row[i0]; + + const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0))); + + dst_row[i0] = gelu_quick*x1; + } +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl index 62c05369a8..a6d8ede670 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl @@ -22,32 +22,45 @@ REQD_SUBGROUP_SIZE_64 #endif kernel void kernel_soft_max_4_f16( - global float * src0, + global char * src0, ulong offset0, - global half * src1, + global char * src1, ulong offset1, - global float * dst, + global char * dst, ulong offsetd, int ne00, - int ne01, - int ne02, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + int ne13, + ulong nb11, + ulong nb12, + ulong nb13, + ulong nb1, + ulong nb2, + ulong nb3, float scale, float max_bias, float m0, float m1, int n_head_log2 ) { - src0 = (global float *)((global char *)src0 + offset0); - src1 = (global half *)((global char *)src1 + offset1); - dst = (global float *)((global char *)dst + offsetd); + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; int i03 = get_group_id(2); int i02 = get_group_id(1); int i01 = get_group_id(0); - global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0; - global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + int i13 = i03%ne13; + int i12 = i02%ne12; + int i11 = i01; + + global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); + global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl index d562774eab..35b5573b46 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl @@ -22,32 +22,45 @@ REQD_SUBGROUP_SIZE_64 #endif kernel void kernel_soft_max_4( - global float * src0, + global char * src0, ulong offset0, - global float * src1, + global char * src1, ulong offset1, - global float * dst, + global char * dst, ulong offsetd, int ne00, - int ne01, - int ne02, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + int ne13, + ulong nb11, + ulong nb12, + ulong nb13, + ulong nb1, + ulong nb2, + ulong nb3, float scale, float max_bias, float m0, float m1, int n_head_log2 ) { - src0 = (global float*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; int i03 = get_group_id(2); int i02 = get_group_id(1); int i01 = get_group_id(0); - global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0; - global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + int i13 = i03%ne13; + int i12 = i02%ne12; + int i11 = i01; + + global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); + global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; diff --git a/ggml/src/ggml-opencl/kernels/softmax_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_f16.cl index d38d099671..9d292b5746 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_f16.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_f16.cl @@ -22,32 +22,45 @@ REQD_SUBGROUP_SIZE_64 #endif kernel void kernel_soft_max_f16( - global float * src0, + global char * src0, ulong offset0, - global half * src1, + global char * src1, ulong offset1, - global float * dst, + global char * dst, ulong offsetd, int ne00, - int ne01, - int ne02, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + int ne13, + ulong nb11, + ulong nb12, + ulong nb13, + ulong nb1, + ulong nb2, + ulong nb3, float scale, float max_bias, float m0, float m1, int n_head_log2 ) { - src0 = (global float *)((global char *)src0 + offset0); - src1 = (global half *)((global char *)src1 + offset1); - dst = (global float *)((global char *)dst + offsetd); + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; int i03 = get_group_id(2); int i02 = get_group_id(1); int i01 = get_group_id(0); - global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0; - global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + int i13 = i03%ne13; + int i12 = i02%ne12; + int i11 = i01; + + global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); + global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; diff --git a/ggml/src/ggml-opencl/kernels/softmax_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_f32.cl index 001b587abe..7c53dfbe5a 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_f32.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_f32.cl @@ -22,32 +22,45 @@ REQD_SUBGROUP_SIZE_64 #endif kernel void kernel_soft_max( - global float * src0, + global char * src0, ulong offset0, - global float * src1, + global char * src1, ulong offset1, - global float * dst, + global char * dst, ulong offsetd, int ne00, - int ne01, - int ne02, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + int ne13, + ulong nb11, + ulong nb12, + ulong nb13, + ulong nb1, + ulong nb2, + ulong nb3, float scale, float max_bias, float m0, float m1, int n_head_log2 ) { - src0 = (global float*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; int i03 = get_group_id(2); int i02 = get_group_id(1); int i01 = get_group_id(0); - global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0; - global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + int i13 = i03%ne13; + int i12 = i02%ne12; + int i11 = i01; + + global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); + global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; diff --git a/ggml/src/ggml-opencl/kernels/upscale.cl b/ggml/src/ggml-opencl/kernels/upscale.cl index 219d31dbb9..25c68351ba 100644 --- a/ggml/src/ggml-opencl/kernels/upscale.cl +++ b/ggml/src/ggml-opencl/kernels/upscale.cl @@ -60,7 +60,8 @@ kernel void kernel_upscale_bilinear( float sf0, float sf1, float sf2, - float sf3 + float sf3, + float pixel_offset ) { global const char * src_base = (global const char *)p_src0 + off_src0; global float * dst_base = (global float *)((global char *)p_dst + off_dst); @@ -80,8 +81,6 @@ kernel void kernel_upscale_bilinear( int i02_src = (int)(i12_dst / sf2); int i03_src = (int)(i13_dst / sf3); - const float pixel_offset = 0.5f; - float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; long y0_src = (long)floor(y_src_f); long y1_src = y0_src + 1; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index e389a46dbe..9a7d1b22d7 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -568,14 +568,14 @@ static float make_qkx2_quants(int n, int nmax, const float * GGML_RESTRICT x, co } float iscale = nmax/(max - min); float scale = 1/iscale; - float best_mad = 0; + float best_error = 0; for (int i = 0; i < n; ++i) { int l = nearest_int(iscale*(x[i] - min)); L[i] = MAX(0, MIN(nmax, l)); float diff = scale * L[i] + min - x[i]; diff = use_mad ? fabsf(diff) : diff * diff; float w = weights[i]; - best_mad += w * diff; + best_error += w * diff; } if (nstep < 1) { *the_min = -min; @@ -601,18 +601,18 @@ static float make_qkx2_quants(int n, int nmax, const float * GGML_RESTRICT x, co this_min = 0; this_scale = sum_xl / sum_l2; } - float mad = 0; + float cur_error = 0; for (int i = 0; i < n; ++i) { float diff = this_scale * Laux[i] + this_min - x[i]; diff = use_mad ? fabsf(diff) : diff * diff; float w = weights[i]; - mad += w * diff; + cur_error += w * diff; } - if (mad < best_mad) { + if (cur_error < best_error) { for (int i = 0; i < n; ++i) { L[i] = Laux[i]; } - best_mad = mad; + best_error = cur_error; scale = this_scale; min = this_min; } diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index c56924ce83..0363b06a3e 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -1,12 +1,19 @@ #include "common.hpp" +#include "ggml-sycl/presets.hpp" #include "ggml.h" #include "element_wise.hpp" +#define SYCL_GLOBAL_ID_LOOP(K, ITEM) \ + for (auto i = ITEM.get_global_id(0); i < (size_t)K; i += ITEM.get_global_range(0)) + +#define SYCL_LOCAL_ID_CALC(ITEM, IDX) \ + (ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX)) + + static void acc_f32(const float * x, const float * y, float * dst, const int ne, const int ne10, const int ne11, const int ne12, - const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); + const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) { + const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0); if (i >= ne) { return; } @@ -21,248 +28,280 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne, } } +/* Unary OP funcs */ template -static void sgn(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { - for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { - dst[i] = x[i] > static_cast(0.f) ? static_cast(1.f) : ((x[i] < static_cast(0.f) ? static_cast(-1.f) : static_cast(0.f))); - } +static __dpct_inline__ T op_sgn(T x) { + return x > static_cast(0.f) ? static_cast(1.f) : ((x < static_cast(0.f) ? static_cast(-1.f) : static_cast(0.f))); } template -static void abs_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { - for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { - dst[i] = sycl::fabs(x[i]); - } +static __dpct_inline__ T op_abs(T x) { + return sycl::fabs(x); } template -static void elu_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { - for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { - dst[i] = (x[i] > static_cast(0.f)) ? x[i] : sycl::expm1(x[i]); - } +static __dpct_inline__ T op_elu(T x) { + return (x > static_cast(0.f)) ? x : sycl::expm1(x); } template -static void gelu(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { +static __dpct_inline__ T op_gelu(T x) { const T GELU_COEF_A = static_cast(0.044715f); const T SQRT_2_OVER_PI = static_cast(0.79788456080286535587989211986876f); - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - - float xi = x[i]; - dst[i] = static_cast(0.5f) * xi * - (static_cast(1.0f) + - sycl::tanh(SQRT_2_OVER_PI * xi * (static_cast(1.0f) + GELU_COEF_A * xi * xi))); + return static_cast(0.5f) * x * + (static_cast(1.0f) + + sycl::tanh(SQRT_2_OVER_PI * x * (static_cast(1.0f) + GELU_COEF_A * x * x))); } template -static void silu(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = x[i] / (static_cast(1.0f) + sycl::native::exp(-x[i])); +static __dpct_inline__ T op_silu(T x) { + return x / (static_cast(1.0f) + sycl::native::exp(-x)); } template -static void gelu_quick(const T *x, T *dst, int k, - const sycl::nd_item<3> &item_ct1) { - const float GELU_QUICK_COEF = -1.702f; - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - if (i >= k) { - return; - } - dst[i] = x[i] * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i]))); +static __dpct_inline__ T op_gelu_quick(T x) { + const T GELU_QUICK_COEF_LOCAL = static_cast(-1.702f); + return x * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x))); } template -static void gelu_erf(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { +static __dpct_inline__ T op_gelu_erf(T x) { const T SQRT_2_INV = static_cast(0.70710678118654752440084436210484f); - for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { - auto x_i = x[i]; - dst[i] = static_cast(0.5f) * x_i * (static_cast(1.0f) + sycl::erf(x_i * SQRT_2_INV)); + return static_cast(0.5f) * x * (static_cast(1.0f) + sycl::erf(x * SQRT_2_INV)); +} + +template +static __dpct_inline__ T op_tanh(T x) { + return sycl::tanh(x); +} + +template +static __dpct_inline__ T op_relu(T x) { + return sycl::fmax(x, static_cast(0)); +} + +template +static __dpct_inline__ T op_sigmoid(T x) { + return static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(-x)); +} + +template +static __dpct_inline__ T op_sqrt(T x) { + return sycl::sqrt(x); +} + +template +static __dpct_inline__ T op_sin(T x) { + return sycl::sin(x); +} + +template +static __dpct_inline__ T op_cos(T x) { + return sycl::cos(x); +} + +template +static __dpct_inline__ T op_hardsigmoid(T x) { + return sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); +} + +template +static __dpct_inline__ T op_hardswish(T x) { + return x * sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); +} + +template +static __dpct_inline__ T op_exp(T x) { + return sycl::exp(x); +} + +template +static __dpct_inline__ T op_log(T x) { + if (x <= static_cast(0)) { + return neg_infinity(); + } + return sycl::log(x); +} + +template +static __dpct_inline__ T op_neg(T x) { + return -x; +} + +template +static __dpct_inline__ T op_step(T x) { + return (x > static_cast(0.0f)) ? static_cast(1.0f) : static_cast(0.0f); +} + +template +static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) { + T neg_slope_T = static_cast(negative_slope); + return sycl::fmax(x, static_cast(0)) + + sycl::fmin(x, static_cast(0.0f)) * neg_slope_T; +} + +template +static __dpct_inline__ T op_sqr(T x) { + return x * x; +} + +template +static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) { + return x < static_cast(min_val) ? static_cast(min_val) : (x > static_cast(max_val) ? static_cast(max_val) : x); +} + +template +static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_sgn(x[i]); } } template -static void tanh(const T *x, T *dst, int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - if (i >= k) { - return; - } - dst[i] = sycl::tanh((x[i])); -} - -template -static void relu(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = sycl::fmax((x[i]), static_cast(0)); -} - -template -static void sigmoid(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = 1.0f / (static_cast(1.0f) + sycl::native::exp(-x[i])); -} - -template -static void sqrt(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = sycl::sqrt(x[i]); -} - -template -static void sin(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = sycl::sin(x[i]); -} - -template -static void cos(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = sycl::cos(x[i]); -} - -template -static void hardsigmoid(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x[i] + static_cast(3.0f)) / static_cast(6.0f))); -} - -template -static void hardswish(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = x[i] * sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x[i] + static_cast(3.0f)) / static_cast(6.0f))); -} - -template -static void exp(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = sycl::exp(x[i]); -} - -template -static void log(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - T xi = x[i]; - if (xi <= 0) { - dst[i] = neg_infinity(); - } else { - dst[i] = sycl::log(xi); +static void unary_op_abs_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_abs(x[i]); } } template -static void neg(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; +static void unary_op_elu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_elu(x[i]); } - dst[i] = -x[i]; } template -static void step(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; +static void unary_op_gelu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_gelu(x[i]); } - dst[i] = x[i] > static_cast(0.0f); } template -static void leaky_relu(const T *x, T *dst, const int k, const float negative_slope, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - if (i >= k) { - return; +static void unary_op_silu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_silu(x[i]); } - dst[i] = sycl::fmax((x[i]), static_cast(0)) + - sycl::fmin((x[i]), static_cast(0.0f)) * negative_slope; } template -static void sqr(const T * x, T * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; +static void unary_op_gelu_quick_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_gelu_quick(x[i]); + } +} + +template +static void unary_op_gelu_erf_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_gelu_erf(x[i]); + } +} + +template +static void unary_op_tanh_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_tanh(x[i]); + } +} + +template +static void unary_op_relu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_relu(x[i]); + } +} + +template +static void unary_op_sigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_sigmoid(x[i]); + } +} + +template +static void unary_op_sqrt_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_sqrt(x[i]); + } +} + +template +static void unary_op_sin_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_sin(x[i]); + } +} + +template +static void unary_op_cos_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_cos(x[i]); + } +} + +template +static void unary_op_hardsigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_hardsigmoid(x[i]); + } +} + +template +static void unary_op_hardswish_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_hardswish(x[i]); + } +} + +template +static void unary_op_exp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_exp(x[i]); + } +} + +template +static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_log(x[i]); + } +} + +template +static void unary_op_neg_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_neg(x[i]); + } +} + +template +static void unary_op_step_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_step(x[i]); + } +} + +template +static void unary_op_leaky_relu_kernel(const T * x, T * dst, const int k, float negative_slope, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_leaky_relu(x[i], negative_slope); + } +} + +template +static void unary_op_sqr_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_sqr(x[i]); + } +} + +template +static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1, float min_val, float max_val) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = op_clamp(x[i], min_val, max_val); } - dst[i] = x[i] * x[i]; } template @@ -281,10 +320,10 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01, int i12 = (index / (ne10 * ne11)) % ne12; int i13 = (index / (ne10 * ne11 * ne12)) % ne13; - int i00 = i10 / sf0; - int i01 = i11 / sf1; - int i02 = i12 / sf2; - int i03 = i13 / sf3; + int i00 = static_cast(i10 / sf0); + int i01 = static_cast(i11 / sf1); + int i02 = static_cast(i12 / sf2); + int i03 = static_cast(i13 / sf3); dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); } @@ -292,8 +331,7 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01, template static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02, const sycl::nd_item<3> &item_ct1) { - int nidx = item_ct1.get_local_id(2) + - item_ct1.get_group(2) * item_ct1.get_local_range(2); + int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2); if (nidx >= ne0) { return; } @@ -310,246 +348,73 @@ static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne } } - template static void clamp(const T * x, T * dst, const float min, const float max, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; + const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + dst[i] = x[i] < static_cast(min) ? static_cast(min) : (x[i] > static_cast(max) ? static_cast(max) : x[i]); } - - dst[i] = x[i] < static_cast(min) ? static_cast(min) : (x[i] > static_cast(max) ? static_cast(max) : x[i]); } +template +static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + dst[i] = op_gelu(x[j0]) * g[j1]; + } +} + +template +static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + dst[i] = op_relu(x[j0]) * g[j1]; + } +} + +template +static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + dst[i] = op_silu(x[j0]) * g[j1]; + } +} + +template +static void gated_op_fused_geglu_erf(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + dst[i] = op_gelu_erf(x[j0]) * g[j1]; + } +} + +template +static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) { + SYCL_GLOBAL_ID_LOOP(k, item_ct1) { + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + dst[i] = op_gelu_quick(x[j0]) * g[j1]; + } +} + +namespace ggml_sycl_detail { static void acc_f32_sycl(const float *x, const float *y, float *dst, const int n_elements, const int ne10, const int ne11, const int ne12, const int nb1, const int nb2, const int offset, queue_ptr stream) { - int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; + int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE); sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, item_ct1); - }); -} - -template -static void gelu_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { gelu(x, dst, k, item_ct1); }); -} - -template -static void silu_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE; - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { silu(x, dst, k, item_ct1); }); -} - -template -static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) { - // hard code for now - const int num_blocks = ceil_div(k, 256); - sycl_parallel_for( - stream, sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)), - [=](sycl::nd_item<3> item_ct1) { sgn(x, dst, k, item_ct1); }); -} - -template -static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) { - // hard code for now - const int num_blocks = ceil_div(k, 256); - sycl_parallel_for( - stream, - sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item_ct1) { abs_op(x, dst, k, item_ct1); }); -} - - -template -static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) { - // hard code for now - const int num_blocks = ceil_div(k, 256); - sycl_parallel_for( - stream, - sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), - [=](sycl::nd_item<3> item_ct1) { elu_op(x, dst, k, item_ct1); }); -} - -template -static void gelu_quick_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { gelu_quick(x, dst, k, item_ct1); }); -} - - -template -static void gelu_erf_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { gelu_erf(x, dst, k, item_ct1); }); -} - -template -static void tanh_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE; - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { tanh(x, dst, k, item_ct1); }); -} - -template -static void relu_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { relu(x, dst, k, item_ct1); }); -} - -template -static void hardsigmoid_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE; - sycl_parallel_for( - stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { hardsigmoid(x, dst, k, item_ct1); }); -} - -template -static void hardswish_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE; - sycl_parallel_for( - stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { hardswish(x, dst, k, item_ct1); }); -} - -template -static void exp_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { exp(x, dst, k, item_ct1); }); -} - -template -static void log_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { log(x, dst, k, item_ct1); }); -} - -template -static void neg_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { neg(x, dst, k, item_ct1); }); -} - -template -static void step_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { step(x, dst, k, item_ct1); }); -} - -template -static void sigmoid_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE; - sycl_parallel_for( - stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { sigmoid(x, dst, k, item_ct1); }); -} - -template -static void sqrt_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE; - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { sqrt(x, dst, k, item_ct1); }); -} - -template -static void sin_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { sin(x, dst, k, item_ct1); }); -} - -template -static void cos_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { cos(x, dst, k, item_ct1); }); -} - -template -static void leaky_relu_sycl(const T *x, T *dst, const int k, - const float negative_slope, - queue_ptr stream) { - const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { leaky_relu(x, dst, k, negative_slope, item_ct1); }); -} - -template -static void sqr_sycl(const T *x, T *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE; - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { sqr(x, dst, k, item_ct1); }); + sycl::nd_range<1>(sycl::range<1>(num_blocks) * + sycl::range<1>(SYCL_ACC_BLOCK_SIZE), + sycl::range<1>(SYCL_ACC_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, + item_ct1); + }); } template @@ -558,7 +423,7 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01, const int ne12, const int ne13, const float sf0, const float sf1, const float sf2, const float sf3, queue_ptr stream) { int dst_size = ne10 * ne11 * ne12 * ne13; - int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE); sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE); sycl_parallel_for<1>( stream, sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { @@ -570,7 +435,7 @@ template static void pad_sycl(const T *x, T *dst, const int ne00, const int ne01, const int ne02, const int ne0, const int ne1, const int ne2, queue_ptr stream) { - int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; + int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE); sycl::range<3> gridDim(ne2, ne1, num_blocks); sycl_parallel_for(stream, sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), @@ -578,115 +443,8 @@ static void pad_sycl(const T *x, T *dst, const int ne00, [=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); }); } -template -static void clamp_sycl(const T *x, T *dst, const float min, - const float max, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE; - sycl_parallel_for(stream, - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { clamp(x, dst, min, max, k, item_ct1); }); -} - -inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - - -inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +template +static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { #if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); @@ -702,14 +460,14 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst case GGML_TYPE_F16: { auto data_pts = cast_data(dst); - silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward(args)...); break; } #endif case GGML_TYPE_F32: { auto data_pts = cast_data(dst); - silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward(args)...); break; } default: @@ -717,7 +475,8 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst } } -inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +template +static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { #if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); @@ -728,19 +487,66 @@ inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;; + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_is_contiguous_1(dst->src[0])); + GGML_ASSERT(ggml_is_contiguous(dst)); + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; + void * src0_d = src0->data; + void * src1_d = src1 ? src1->data : src0->data; + const int64_t src0_o = src0->nb[1]; + const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + void * dst_d = dst->data; + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src1->nb[0] == ggml_element_size(src1)); + GGML_ASSERT(src1->ne[0] == nc); + GGML_ASSERT(src0->type == src1->type); + } switch (dst->type) { #if defined (GGML_SYCL_F16) case GGML_TYPE_F16: { - auto data_pts = cast_data(dst); - gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + sycl::half * src0_p = (sycl::half *) src0_d; + sycl::half * src1_p = (sycl::half *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + kernel_invoker(src0_p, + src1_p, + (sycl::half *) dst_d, + ggml_nelements(dst), + nc, + src0_o / sizeof(sycl::half), + src1_o / sizeof(sycl::half), + main_stream, + std::forward(args)...); break; } #endif case GGML_TYPE_F32: { - auto data_pts = cast_data(dst); - gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + float * src0_p = (float *) src0_d; + float * src1_p = (float *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + kernel_invoker(src0_p, + src1_p, + (float *) dst_d, + ggml_nelements(dst), + nc, + src0_o / sizeof(float), + src1_o / sizeof(float), + main_stream, + std::forward(args)...); break; } default: @@ -748,511 +554,8 @@ inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst } } -inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - - -inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - - GGML_ASSERT(dst->src[0]->type == dst->type); - float negative_slope; - memcpy(&negative_slope, dst->op_params, sizeof(float)); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - #if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - -inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +template +static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { #if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); @@ -1274,18 +577,18 @@ inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * case GGML_TYPE_F16: { auto data_pts = cast_data(dst); - upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], - dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, - main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2], + (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3, + main_stream, std::forward(args)...); break; } #endif case GGML_TYPE_F32: { auto data_pts = cast_data(dst); - upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], - dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, - main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2], + (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3, + main_stream, std::forward(args)...); break; } default: @@ -1293,7 +596,8 @@ inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * } } -inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +template +static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { #if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); @@ -1302,7 +606,7 @@ inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) GGML_ASSERT(dst->type == GGML_TYPE_F32); #endif GGML_ASSERT(dst->src[0]->type == dst->type); - GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors + GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); switch (dst->type) { @@ -1310,16 +614,16 @@ inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) case GGML_TYPE_F16: { auto data_pts = cast_data(dst); - pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0], - dst->ne[1], dst->ne[2], main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0], + (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward(args)...); break; } #endif case GGML_TYPE_F32: { auto data_pts = cast_data(dst); - pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0], - dst->ne[1], dst->ne[2], main_stream); + kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0], + (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward(args)...); break; } default: @@ -1327,45 +631,320 @@ inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } } -inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { -#if defined(GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else +} // namespace ggml_sycl_detail - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - float min; - float max; - memcpy(&min, dst->op_params, sizeof(float)); - memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); - switch (dst->type) { -#if defined(GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } + +static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, 256); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), + sycl::range<1>(256)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_sgn_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); } -inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { +static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, 256); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), + sycl::range<1>(256)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_abs_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} +static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, 256); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), + sycl::range<1>(256)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_elu_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE), + sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_silu_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_gelu_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_gelu_quick_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_gelu_erf_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE), + sycl::range<1>(SYCL_TANH_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_tanh_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE), + sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_relu_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE), + sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_hardsigmoid_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE), + sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_hardswish_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE), + sycl::range<1>(SYCL_EXP_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_exp_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); // Using EXP block size + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE), + sycl::range<1>(SYCL_EXP_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE), + sycl::range<1>(SYCL_NEG_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_neg_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE), + sycl::range<1>(SYCL_NEG_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_step_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE), + sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_sigmoid_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SQRT_BLOCK_SIZE); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE), + sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE), + sycl::range<1>(SYCL_SIN_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); // Using SIN block size + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE), + sycl::range<1>(SYCL_SIN_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + float negative_slope; + memcpy(&negative_slope, dst->op_params, sizeof(float)); + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float slope) { + const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE), + sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1); + }); + }, negative_slope); +} + +static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { + const int num_blocks = ceil_div(k_elements, SYCL_SQR_BLOCK_SIZE); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE), + sycl::range<1>(SYCL_SQR_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_upscale(ctx, dst, + [](const auto* src, auto* dst_ptr, int nb00, int nb01, int nb02, int nb03, + int ne10, int ne11, int ne12, int ne13, float sf0, float sf1, float sf2, float sf3, + queue_ptr stream) { + ggml_sycl_detail::upscale_sycl(src, dst_ptr, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, stream); + }); +} + +static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst, + [](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, + queue_ptr stream) { + ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream); + }); +} + +static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + float min_val; + float max_val; + memcpy(&min_val, dst->op_params, sizeof(float)); + memcpy(&max_val, (float *) dst->op_params + 1, sizeof(float)); + ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, + [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float min_arg, float max_arg) { + const int num_blocks = ceil_div(k_elements, SYCL_CLAMP_BLOCK_SIZE); + sycl_parallel_for(stream, + sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE), + sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1); + }); + }, min_val, max_val); +} + +static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -1381,7 +960,62 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused int offset = dst->op_params[3] / 4; // offset in bytes - acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream); + ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream); +} + +static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, + [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { + const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); + sycl_parallel_for(main_stream, + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, + [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { + const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu + sycl_parallel_for(main_stream, + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, + [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { + const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu + sycl_parallel_for(main_stream, + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, + [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { + const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); + sycl_parallel_for(main_stream, + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); + }); + }); +} + +static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst, + [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { + const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); + sycl_parallel_for(main_stream, + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); + }); + }); } @@ -1509,3 +1143,28 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_elu(ctx, dst); } + +void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_geglu(ctx, dst); +} + +void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_reglu(ctx, dst); +} + +void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_swiglu(ctx, dst); +} + +void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_geglu_erf(ctx, dst); +} + +void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_geglu_quick(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index bd40113f09..50749e87d7 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -3,27 +3,30 @@ #include "common.hpp" #include "ggml.h" -#include +#include // For std::numeric_limits template T neg_infinity() { return -std::numeric_limits::infinity(); } -template +template struct typed_data { - const T * src; - T * dst; + const T_Src * src; + T_Dst * dst; }; -template -typed_data cast_data(ggml_tensor * dst) { +template +typed_data cast_data(ggml_tensor * dst) { return { - /* .src = */ static_cast(dst->src[0]->data), - /* .dst = */ static_cast(dst->data) + /* .src = */ static_cast(dst->src[0]->data), + /* .dst = */ static_cast(dst->data) }; } +const float GELU_QUICK_COEF = -1.702f; + + void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst); @@ -73,5 +76,11 @@ void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); -#endif // GGML_SYCL_ELEMENTWISE_HPP +void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_ELEMENTWISE_HPP diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 9cb36ae99e..21c81e99a1 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -83,7 +83,7 @@ static ggml_sycl_device_info ggml_sycl_init() { info.devices[i].cc = 100 * prop.get_major_version() + 10 * prop.get_minor_version(); - info.devices[i].opt_feature.reorder = !device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu); + info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu); info.max_work_group_sizes[i] = prop.get_max_work_group_size(); } @@ -3676,6 +3676,27 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(dst)) { + case GGML_GLU_OP_REGLU: + ggml_sycl_reglu(ctx, dst); + break; + case GGML_GLU_OP_GEGLU: + ggml_sycl_geglu(ctx, dst); + break; + case GGML_GLU_OP_SWIGLU: + ggml_sycl_swiglu(ctx, dst); + break; + case GGML_GLU_OP_GEGLU_ERF: + ggml_sycl_geglu_erf(ctx, dst); + break; + case GGML_GLU_OP_GEGLU_QUICK: + ggml_sycl_geglu_quick(ctx, dst); + break; + default: + return false; + } + break; case GGML_OP_NORM: ggml_sycl_norm(ctx, dst); break; @@ -4212,6 +4233,18 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g default: return false; } + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + return ggml_is_contiguous_1(op->src[0]); + default: + return false; + } + break; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { @@ -4260,6 +4293,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return false; } } + case GGML_OP_SET_ROWS: + { + // TODO: add support + // ref: https://github.com/ggml-org/llama.cpp/pull/14274 + return false; + } break; case GGML_OP_CPY: { ggml_type src0_type = op->src[0]->type; @@ -4370,9 +4409,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return true; case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; - case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: - return true; + // TODO: support batching + if (op->src[0]->ne[3] != 1) { + return false; + } + // TODO: support broadcast + // ref: https://github.com/ggml-org/llama.cpp/pull/14435 + return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); + case GGML_OP_DIAG_MASK_INF: case GGML_OP_ROPE: case GGML_OP_IM2COL: return true; diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp index e44c6b6ef8..1b60226dcd 100644 --- a/ggml/src/ggml-sycl/rope.cpp +++ b/ggml/src/ggml-sycl/rope.cpp @@ -47,18 +47,17 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); - if (i0 >= n_dims) { - const int i = row * ne0 + i0; - *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); - return; - } - const int row0 = row % ne1; const int channel0 = row / ne1; const int i = row * ne0 + i0; const int i2 = channel0 * s2 + row0 * s1 + i0; + if (i0 >= n_dims) { + *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i2); + return; + } + const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; @@ -88,18 +87,17 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); - if (i0 >= n_dims) { - const int i = row * ne0 + i0; - *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); - return; - } - const int row0 = row % ne1; const int channel0 = row / ne1; const int i = row * ne0 + i0 / 2; const int i2 = channel0 * s2 + row0 * s1 + i0 / 2; + if (i0 >= n_dims) { + *reinterpret_cast *>(dst + i + i0 / 2) = *reinterpret_cast *>(x + i2 + i0 / 2); + return; + } + const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; @@ -129,17 +127,16 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const } const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); - if (i0 >= n_dims) { - const int i = row_dst*ne0 + i0; - *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); - return; - } - const int row_x = row_dst % ne1; const int channel_x = row_dst / ne1; const int idst = (row_dst * ne0) + (i0 / 2); const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2); + if (i0 >= n_dims) { + *reinterpret_cast *>(dst + idst + i0 / 2) = *reinterpret_cast *>(x + i0 / 2 + ix); + return; + } + const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; const int sec_w = sections.v[1] + sections.v[0]; const int sector = (i0 / 2) % sect_dims; diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 0bf4cb14f8..b97e7bf995 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -99,6 +99,7 @@ if (Vulkan_FOUND) if (GGML_VULKAN_SHADER_DEBUG_INFO) add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO) + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DGGML_VULKAN_SHADER_DEBUG_INFO=ON) endif() if (GGML_VULKAN_VALIDATE) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 99be5e45b2..2245a65549 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -224,6 +224,21 @@ enum vk_device_architecture { INTEL_XE2, }; +// HSK x HSV +enum FaHeadSizes { + FA_HEAD_SIZE_64, + FA_HEAD_SIZE_80, + FA_HEAD_SIZE_96, + FA_HEAD_SIZE_112, + FA_HEAD_SIZE_128, + FA_HEAD_SIZE_192, + FA_HEAD_SIZE_192_128, + FA_HEAD_SIZE_256, + FA_HEAD_SIZE_576_512, + FA_HEAD_SIZE_UNSUPPORTED, + FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED, +}; + static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { vk::PhysicalDeviceProperties props = device.getProperties(); @@ -305,7 +320,7 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& } struct vk_device_struct { - std::mutex mutex; + std::recursive_mutex mutex; vk::PhysicalDevice physical_device; vk::PhysicalDeviceProperties properties; @@ -425,17 +440,25 @@ struct vk_device_struct { vk_pipeline pipeline_norm_f32; vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_rms_norm_f32; + vk_pipeline pipeline_rms_norm_mul_f32; vk_pipeline pipeline_rms_norm_back_f32; vk_pipeline pipeline_l2_norm_f32; // [src/dst 0=fp32,1=fp16] vk_pipeline pipeline_gelu[2]; + vk_pipeline pipeline_gelu_erf[2]; vk_pipeline pipeline_gelu_quick[2]; vk_pipeline pipeline_silu[2]; vk_pipeline pipeline_relu[2]; vk_pipeline pipeline_tanh[2]; vk_pipeline pipeline_sigmoid[2]; + vk_pipeline pipeline_geglu[2]; + vk_pipeline pipeline_reglu[2]; + vk_pipeline pipeline_swiglu[2]; + vk_pipeline pipeline_geglu_erf[2]; + vk_pipeline pipeline_geglu_quick[2]; + vk_pipeline pipeline_leaky_relu_f32; vk_pipeline pipeline_silu_back_f32; vk_pipeline pipeline_diag_mask_inf_f32; @@ -461,26 +484,11 @@ struct vk_device_struct { vk_pipeline pipeline_conv2d_dw_cwhn_f32; // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} - vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; - vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_split_k_reduce; @@ -493,6 +501,8 @@ struct vk_device_struct { ggml_backend_buffer_type buffer_type; + bool disable_fusion; + #ifdef GGML_VULKAN_MEMORY_DEBUG std::unique_ptr memory_logger; #endif @@ -627,6 +637,8 @@ struct vk_flash_attn_push_constants { uint32_t nev2; uint32_t nev3; uint32_t nem1; + uint32_t nem2; + uint32_t nem3; uint32_t nb01; uint32_t nb02; @@ -637,14 +649,12 @@ struct vk_flash_attn_push_constants { uint32_t nb21; uint32_t nb22; uint32_t nb23; - uint32_t nb31; float scale; float max_bias; float logit_softcap; - uint32_t mask; - uint32_t n_head_log2; + uint32_t mask_n_head_log2; float m0; float m1; @@ -652,6 +662,7 @@ struct vk_flash_attn_push_constants { uint32_t split_kv; uint32_t k_num; }; +static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128"); struct vk_op_push_constants { uint32_t KX; @@ -660,6 +671,13 @@ struct vk_op_push_constants { float param2; }; +struct vk_op_glu_push_constants { + uint32_t N; + uint32_t ne00; + uint32_t ne20; + uint32_t mode; // 0: default, 1: swapped, 2: split +}; + struct vk_op_unary_push_constants { uint32_t ne; uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; @@ -743,6 +761,14 @@ struct vk_op_rope_push_constants { struct vk_op_soft_max_push_constants { uint32_t KX; uint32_t KY; + uint32_t ne00; + uint32_t ne01; + uint32_t ne02; + uint32_t ne12; + uint32_t ne13; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; float scale; float max_bias; float m0; @@ -978,6 +1004,10 @@ struct ggml_backend_vk_context { vk_command_pool compute_cmd_pool; vk_command_pool transfer_cmd_pool; + + // number of additional consecutive nodes that are being fused with the + // node currently being processed + int num_additional_fused_ops {}; }; static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT @@ -1063,8 +1093,8 @@ static size_t vk_skip_checks; static size_t vk_output_tensor; static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name); -static void ggml_vk_check_results_0(ggml_tensor * tensor); -static void ggml_vk_check_results_1(ggml_tensor * tensor); +static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx); +static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx); #endif typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); @@ -1197,7 +1227,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin } { - std::lock_guard guard(device->mutex); + std::lock_guard guard(device->mutex); device->pipelines.insert({ pipeline->name, pipeline }); } @@ -1411,7 +1441,7 @@ static uint32_t ggml_vk_find_queue_family_index(std::vector guard(device->mutex); + std::lock_guard guard(device->mutex); q.queue_family_index = queue_family_index; q.transfer_only = transfer_only; @@ -1673,6 +1703,35 @@ enum FaCodePath { FA_COOPMAT2, }; +static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) { + if (hsk != 192 && hsk != 576 && hsk != hsv) { + return FA_HEAD_SIZE_UNSUPPORTED; + } + switch (hsk) { + case 64: return FA_HEAD_SIZE_64; + case 80: return FA_HEAD_SIZE_80; + case 96: return FA_HEAD_SIZE_96; + case 112: return FA_HEAD_SIZE_112; + case 128: return FA_HEAD_SIZE_128; + case 192: + if (hsv == 192) { + return FA_HEAD_SIZE_192; + } else if (hsv == 128) { + return FA_HEAD_SIZE_192_128; + } else { + return FA_HEAD_SIZE_UNSUPPORTED; + } + case 256: return FA_HEAD_SIZE_256; + case 576: + if (hsv == 512) { + return FA_HEAD_SIZE_576_512; + } else { + return FA_HEAD_SIZE_UNSUPPORTED; + } + default: return FA_HEAD_SIZE_UNSUPPORTED; + } +} + // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; @@ -1693,8 +1752,9 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) { } } -static std::array fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { +static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) { GGML_UNUSED(clamp); + GGML_UNUSED(hsv); if (path == FA_SCALAR) { if (small_rows) { @@ -1718,7 +1778,7 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t D, uint32_ } // small cols to reduce register count - if (ggml_is_quantized(type) || D == 256) { + if (ggml_is_quantized(type) || hsk >= 256) { return {64, 32}; } return {64, 64}; @@ -2011,19 +2071,21 @@ static void ggml_vk_load_shaders(vk_device& device) { parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; - auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { - return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1}; + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { + return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1}; }; - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // can't use 256 for D==80. // For scalar, use 128 (arbitrary) + // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs. + const uint32_t D = (hsk|hsv); uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) ? scalar_flash_attention_workgroup_size : ((small_rows && (D % 32) == 0) ? 256 : 128); - auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows); + auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows); // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. @@ -2032,26 +2094,29 @@ static void ggml_vk_load_shaders(vk_device& device) { // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); - return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split}; + return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split}; }; -#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ +#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \ - CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256) + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) @@ -2641,7 +2706,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { @@ -2655,7 +2720,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -2744,6 +2810,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); CREATE_UNARY(gelu) + CREATE_UNARY(gelu_erf) CREATE_UNARY(gelu_quick) CREATE_UNARY(silu) CREATE_UNARY(relu) @@ -2751,6 +2818,17 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY(sigmoid) #undef CREATE_UNARY +#define CREATE_GLU(name) \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); + + CREATE_GLU(geglu) + CREATE_GLU(reglu) + CREATE_GLU(swiglu) + CREATE_GLU(geglu_erf) + CREATE_GLU(geglu_quick) +#undef CREATE_GLU + ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -3431,6 +3509,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->idx = idx; + device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr; + return device; } @@ -3651,7 +3731,6 @@ static void ggml_vk_instance_init() { } - size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan @@ -4124,6 +4203,7 @@ static void * ggml_vk_host_malloc(vk_device& device, size_t size) { return nullptr; } + std::lock_guard guard(device->mutex); device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf)); return buf->ptr; @@ -4134,6 +4214,8 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) { return; } VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")"); + std::lock_guard guard(device->mutex); + vk_buffer buf; size_t index; for (size_t i = 0; i < device->pinned_memory.size(); i++) { @@ -4156,6 +4238,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) { } static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { + std::lock_guard guard(device->mutex); buf = nullptr; buf_offset = 0; for (size_t i = 0; i < device->pinned_memory.size(); i++) { @@ -4457,7 +4540,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width); } } else { - std::lock_guard guard(dst->device->mutex); + std::lock_guard guard(dst->device->mutex); vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(dst->device, subctx); @@ -4548,7 +4631,7 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ memcpy(dst, (uint8_t *) src->ptr + offset, size); } else { - std::lock_guard guard(src->device->mutex); + std::lock_guard guard(src->device->mutex); vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(src->device, subctx); @@ -4578,7 +4661,7 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { if (src->device == dst->device) { - std::lock_guard guard(src->device->mutex); + std::lock_guard guard(src->device->mutex); VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")"); // Copy within the device vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); @@ -4613,7 +4696,7 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")"); - std::lock_guard guard(dst->device->mutex); + std::lock_guard guard(dst->device->mutex); vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(dst->device, subctx); subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); @@ -4840,9 +4923,17 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const // type size must be exactly 2 or 4. GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4); if ((ggml_type_size(src->type) % 4) == 0) { - return ctx->device->pipeline_contig_cpy_f32_f32; + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_f32; + } else { + return ctx->device->pipeline_cpy_f32_f32; + } } else { - return ctx->device->pipeline_contig_cpy_f16_f16; + if (contig) { + return ctx->device->pipeline_contig_cpy_f16_f16; + } else { + return ctx->device->pipeline_cpy_f16_f16; + } } } @@ -4903,7 +4994,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); - GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT const uint64_t ne00 = src0->ne[0]; @@ -5131,7 +5222,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)"); - GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT const uint64_t ne00 = src0->ne[0]; @@ -5732,7 +5823,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); - GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT GGML_ASSERT(ids->type == GGML_TYPE_I32); @@ -5926,28 +6017,74 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); } else { - ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); + // Split based on number of ids, to fit in shared memory + const uint32_t nei0 = (uint32_t)src2->ne[0]; + const uint32_t nei1 = (uint32_t)src2->ne[1]; + + GGML_ASSERT(nei0 <= 4096); + const uint32_t split_size = std::min(nei1, 4096u / nei0); + + ggml_tensor src1_copy = *src1; + ggml_tensor src2_copy = *src2; + ggml_tensor dst_copy = *dst; + + for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) { + const uint32_t n_tokens = std::min(split_size, nei1 - token_start); + + src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2]; + src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1]; + dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2]; + + src1_copy.ne[2] = n_tokens; + src2_copy.ne[1] = n_tokens; + dst_copy.ne[2] = n_tokens; + + ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun); + } } } -static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) { // Needs to be kept up to date on shader changes + GGML_UNUSED(hsv); const uint32_t wg_size = scalar_flash_attention_workgroup_size; const uint32_t Br = scalar_flash_attention_num_large_rows; const uint32_t Bc = scalar_flash_attention_Bc; + const uint32_t tmpsh = wg_size * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * sizeof(float); + + const uint32_t masksh = Bc * Br * sizeof(float); + + const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float); + + const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); + + return supported; +} + +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) { + // Needs to be kept up to date on shader changes + GGML_UNUSED(hsv); + const uint32_t wg_size = scalar_flash_attention_workgroup_size; + const uint32_t Br = coopmat1_flash_attention_num_large_rows; + const uint32_t Bc = scalar_flash_attention_Bc; + const uint32_t acctype = f32acc ? 4 : 2; const uint32_t f16vec4 = 8; const uint32_t tmpsh = wg_size * sizeof(float); const uint32_t tmpshv4 = wg_size * 4 * acctype; - const uint32_t Qf = Br * (D / 4 + 2) * f16vec4; + const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4; - const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; + const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br; const uint32_t sfsh = Bc * sfshstride * acctype; - const uint32_t kshstride = D / 4 + 2; + const uint32_t kshstride = hsk / 4 + 2; const uint32_t ksh = Bc * kshstride * f16vec4; const uint32_t slope = Br * sizeof(float); @@ -5955,7 +6092,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; - VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); return supported; } @@ -5977,13 +6114,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_TENSOR_LOCALS(size_t, nb, dst, nb) const uint32_t nem1 = mask ? mask->ne[1] : 0; - const uint32_t nbm1 = mask ? mask->nb[1] : 0; + const uint32_t nem2 = mask ? mask->ne[2] : 0; + const uint32_t nem3 = mask ? mask->ne[3] : 0; - const uint32_t D = neq0; + const uint32_t HSK = nek0; + const uint32_t HSV = nev0; uint32_t N = neq1; const uint32_t KV = nek1; - GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne0 == HSV); GGML_ASSERT(ne2 == N); // input tensor rows must be contiguous @@ -5991,12 +6130,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(nbk0 == ggml_type_size(k->type)); GGML_ASSERT(nbv0 == ggml_type_size(v->type)); - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev0 == D); + GGML_ASSERT(neq0 == HSK); GGML_ASSERT(neq1 == N); - GGML_ASSERT(nev0 == D); GGML_ASSERT(nev1 == nek1); @@ -6017,7 +6153,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); - const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32); + const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32); if (!coopmat_shape_supported || !coopmat_shmem_supported) { path = FA_SCALAR; @@ -6047,7 +6183,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && - qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { + qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) { // grouped query attention - make the N dimension equal to gqa_ratio, reduce // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 // and change addressing calculations to index Q's dimension 2. @@ -6070,47 +6206,25 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx path = FA_SCALAR; } + // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory + if (path == FA_SCALAR && + !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) { + small_rows = true; + } + bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; + FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]); + switch (path) { case FA_SCALAR: - switch (D) { - case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; - case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; - case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; - case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; - case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; - case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; - default: - GGML_ASSERT(!"unsupported D value"); - return; - } + pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0]; break; case FA_COOPMAT1: - switch (D) { - case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break; - case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break; - case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break; - case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break; - case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break; - case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break; - default: - GGML_ASSERT(!"unsupported D value"); - return; - } + pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0]; break; case FA_COOPMAT2: - switch (D) { - case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break; - case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break; - case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break; - case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break; - case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break; - case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break; - default: - GGML_ASSERT(!"unsupported D value"); - return; - } + pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0]; break; default: GGML_ASSERT(0); @@ -6138,21 +6252,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; // Try to use split_k when KV is large enough to be worth the overhead - if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) { + if (workgroups_x == 1 && shader_core_count > 0) { // Try to run two workgroups per SM. - split_k = ctx->device->shader_core_count * 2 / workgroups_y; + split_k = shader_core_count * 2 / (workgroups_y * workgroups_z); if (split_k > 1) { // Try to evenly split KV into split_k chunks, but it needs to be a multiple // of "align", so recompute split_k based on that. - split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align); + split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align); split_k = CEIL_DIV(KV, split_kv); workgroups_x = split_k; } } - // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1) - // and the per-row m and L values (ne1 rows). - const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0; + // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) + // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. + const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0; if (split_k_size > ctx->device->max_memory_allocation_size) { GGML_ABORT("Requested preallocation size is too large"); } @@ -6239,18 +6353,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } } + uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2; + const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, - nem1, + nem1, nem2, nem3, q_stride, (uint32_t)nbq2, (uint32_t)nbq3, k_stride, (uint32_t)nbk2, (uint32_t)nbk3, v_stride, (uint32_t)nbv2, (uint32_t)nbv3, - nbm1, scale, max_bias, logit_softcap, - mask != nullptr, n_head_log2, m0, m1, + mask_n_head_log2, m0, m1, gqa_ratio, split_kv, split_k }; ggml_vk_sync_buffers(subctx); @@ -6271,13 +6386,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); ggml_vk_sync_buffers(subctx); - const std::array pc2 = { D, (uint32_t)ne1, split_k }; + const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, { vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, }, - pc2, { (uint32_t)ne1, 1, 1 }); + pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); } else { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { @@ -6418,7 +6533,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_RMS_NORM: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_rms_norm_f32; + return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32; } return nullptr; case GGML_OP_RMS_NORM_BACK: @@ -6443,6 +6558,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_GELU: return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_GELU_ERF: + return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_GELU_QUICK: return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_RELU: @@ -6455,6 +6572,28 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const break; } return nullptr; + case GGML_OP_GLU: + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) || + (src0->type != dst->type)) { + return nullptr; + } + + switch (ggml_get_glu_op(dst)) { + case GGML_GLU_OP_GEGLU: + return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_REGLU: + return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_SWIGLU: + return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_GEGLU_ERF: + return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_GEGLU_QUICK: + return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16]; + default: + break; + } + return nullptr; case GGML_OP_DIAG_MASK_INF: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_diag_mask_inf_f32; @@ -6915,6 +7054,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_UNARY: + case GGML_OP_GLU: case GGML_OP_CONV_2D_DW: { uint32_t ne = ggml_nelements(dst); @@ -6955,7 +7095,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } - if (op == GGML_OP_SOFT_MAX) { + if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) { // Empty src1 is possible in soft_max, but the shader needs a buffer vk_subbuffer subbuf_y; if (use_src1) { @@ -7518,18 +7658,18 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); } -static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - float * op_params = (float *)dst->op_params; +static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)ggml_nelements(src0), - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - op_params[0], 0.0f, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + op_params[0], 0.0f, 0, }, dryrun); } @@ -7547,6 +7687,25 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } +static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const bool swapped = (bool)dst->op_params[1]; + const bool split = src1 != nullptr; + + GGML_ASSERT(ggml_is_contiguous(src0)); + + if (!split) { + GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]); + } else { + GGML_ASSERT(src0->ne[0] == src1->ne[0]); + GGML_ASSERT(src0->ne[0] == dst->ne[0]); + GGML_ASSERT(src0->type == src1->type); + } + + const uint32_t mode = split ? 2 : (swapped ? 1 : 0); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun); +} + static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { int32_t * op_params = (int32_t *)dst->op_params; ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); @@ -7562,7 +7721,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const uint32_t nrows_x = (uint32_t)ggml_nrows(src0); const uint32_t nrows_y = (uint32_t)src0->ne[1]; - const uint32_t n_head_kv = nrows_x/nrows_y; + const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u; + const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u; + const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u; + const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u; + const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u; + + const uint32_t n_head_kv = src0->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); @@ -7571,6 +7736,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, { ncols, src1 != nullptr ? nrows_y : (uint32_t)0, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], + ne12, ne13, + nb11, nb12, nb13, scale, max_bias, m0, m1, n_head_log2, @@ -8720,11 +8888,12 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } } -static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); +static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); // Returns true if node has enqueued work into the queue, false otherwise // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. -static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){ +static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){ + ggml_tensor * node = cgraph->nodes[node_idx]; if (ggml_is_empty(node) || !node->buffer) { return false; } @@ -8749,6 +8918,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod switch (ggml_get_unary_op(node)) { case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: @@ -8758,6 +8928,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + break; + default: + return false; + } + break; case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_GET_ROWS: @@ -8850,6 +9032,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_RMS_NORM_BACK: case GGML_OP_L2_NORM: case GGML_OP_UNARY: + case GGML_OP_GLU: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: @@ -8962,8 +9145,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod break; case GGML_OP_RMS_NORM: - ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun); - + if (ctx->num_additional_fused_ops > 0) { + // fused rms_norm + mul + ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0]; + ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun); + } else { + ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun); + } break; case GGML_OP_RMS_NORM_BACK: ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun); @@ -8977,6 +9166,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod switch (ggml_get_unary_op(node)) { case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: @@ -8987,6 +9177,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun); + break; + default: + return false; + } + break; case GGML_OP_DIAG_MASK_INF: ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun); @@ -9108,12 +9311,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod ctx->compute_ctx.reset(); - bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready); + bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready); if (!ok) { if (node->op == GGML_OP_UNARY) { std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; - } - else { + } else if (node->op == GGML_OP_GLU) { + std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast(node->op_params[0])) << ")" << std::endl; + } else { std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl; } } @@ -9122,7 +9326,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod return true; } -static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) { +static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) { + GGML_UNUSED(cgraph); ggml_backend_buffer * buf = nullptr; switch (tensor->op) { @@ -9182,6 +9387,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: @@ -9192,6 +9398,19 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(tensor)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + buf = tensor->buffer; + break; + default: + return false; + } + break; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: case GGML_OP_FLASH_ATTN_EXT: @@ -9218,7 +9437,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * // Only run if ctx hasn't been submitted yet if (!subctx->seqs.empty()) { #ifdef GGML_VULKAN_CHECK_RESULTS - ggml_vk_check_results_0(tensor); + ggml_vk_check_results_0(ctx, cgraph, tensor_idx); use_fence = true; #endif @@ -9238,7 +9457,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * ggml_vk_wait_for_fence(ctx); } #ifdef GGML_VULKAN_CHECK_RESULTS - ggml_vk_check_results_1(tensor); + ggml_vk_check_results_1(ctx, cgraph, tensor_idx); #endif } @@ -9685,6 +9904,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) { return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; } +static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) { + if (!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + // additional constraints specific to this fusion + const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + // rms_norm only supports f32 + if (mul->src[0]->type != GGML_TYPE_F32 || + mul->src[1]->type != GGML_TYPE_F32 || + mul->type != GGML_TYPE_F32) { + return false; + } + // if rms_norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && + mul->src[0]->ne[1] != rms_norm->ne[1]) { + return false; + } + // rms_norm shader assumes contiguous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + } + return true; +} + static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; @@ -9698,10 +9948,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg uint64_t total_mat_mul_bytes = 0; for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false); + if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; + } + ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); } + i += ctx->num_additional_fused_ops; + ctx->num_additional_fused_ops = 0; } if (ctx->device->need_compiles) { ggml_vk_load_shaders(ctx->device); @@ -9763,14 +10018,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); } + if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; + } + // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; bool submit = (submitted_nodes >= nodes_per_submit) || (mul_mat_bytes >= mul_mat_bytes_per_submit) || - (i == last_node) || + (i + ctx->num_additional_fused_ops == last_node) || (almost_ready && !ctx->almost_ready_fence_pending); - bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit); + bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit); if (vk_perf_logger_enabled) { if (ctx->compute_ctx.expired()) { @@ -9780,7 +10039,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg } else { compute_ctx = ctx->compute_ctx.lock(); } - compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1); + // If there are fused ops, just write out timestamps for all nodes to keep the accounting simple + for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) { + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1); + } } if (enqueued) { @@ -9802,6 +10064,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg } submit_count++; } + i += ctx->num_additional_fused_ops; + ctx->num_additional_fused_ops = 0; } if (vk_perf_logger_enabled) { @@ -9963,6 +10227,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: @@ -9976,15 +10241,36 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + return ggml_is_contiguous(op->src[0]) && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (op->src[0]->type == op->type); + default: + return false; + } + break; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { ggml_type src0_type = op->src[0]->type; ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; const vk_device& device = ggml_vk_get_device(ctx->device); - if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) { - // If there's not enough shared memory for row_ids and the result tile, fallback to CPU - return false; + if (op->op == GGML_OP_MUL_MAT_ID) { + if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) { + // If there's not enough shared memory for row_ids and the result tile, fallback to CPU + return false; + } + // Check against size of shared memory variable + if (op->src[2]->ne[0] > 4096) { + return false; + } } switch (src0_type) { case GGML_TYPE_F32: @@ -10042,19 +10328,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; auto device = ggml_vk_get_device(ctx->device); bool coopmat2 = device->coopmat2; - switch (op->src[0]->ne[0]) { - case 64: - case 80: - case 96: - case 112: - case 128: - case 256: - break; - default: - return false; - } - if (op->src[1]->ne[0] != op->src[2]->ne[0]) { - // different head sizes of K and V are not supported yet + FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]); + if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) { return false; } if (op->src[0]->type != GGML_TYPE_F32) { @@ -10134,6 +10409,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } } break; + case GGML_OP_SET_ROWS: + { + // TODO: add support + // ref: https://github.com/ggml-org/llama.cpp/pull/14274 + return false; + } break; case GGML_OP_CONT: case GGML_OP_CPY: case GGML_OP_DUP: @@ -10224,6 +10505,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_SCALE: case GGML_OP_PAD: case GGML_OP_DIAG_MASK_INF: + return true; case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ARGSORT: @@ -10513,11 +10795,21 @@ void * comp_result; size_t comp_size; size_t comp_nb[GGML_MAX_DIMS]; size_t check_counter = 0; -static void ggml_vk_check_results_0(ggml_tensor * tensor) { +static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { + ggml_tensor * tensor = cgraph->nodes[tensor_idx]; if (tensor->op == GGML_OP_TRANSPOSE) { return; } + bool fused_rms_norm_mul = false; + int rms_norm_idx = -1; + if (ctx->num_additional_fused_ops == 1 && + tensor->op == GGML_OP_RMS_NORM && + cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) { + fused_rms_norm_mul = true; + tensor = cgraph->nodes[tensor_idx + 1]; + } + check_counter++; if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { return; @@ -10545,6 +10837,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { for (int i = 0; i < 6; i++) { ggml_tensor * srci = tensor->src[i]; + if (fused_rms_norm_mul) { + rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1; + ggml_tensor *rms_norm = tensor->src[rms_norm_idx]; + switch (i) { + case 0: srci = rms_norm->src[0]; break; + case 1: srci = tensor->src[1 - rms_norm_idx]; break; + default: continue; + } + } if (srci == nullptr) { continue; } @@ -10602,7 +10903,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_SUB) { tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_MUL) { - tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); + if (fused_rms_norm_mul) { + tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params); + tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]); + } else { + tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); + } } else if (tensor->op == GGML_OP_DIV) { tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_CONCAT) { @@ -10690,6 +10996,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { case GGML_UNARY_OP_GELU: tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]); break; + case GGML_UNARY_OP_GELU_ERF: + tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]); + break; case GGML_UNARY_OP_GELU_QUICK: tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]); break; @@ -10706,6 +11015,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); } + } else if (tensor->op == GGML_OP_GLU) { + if (src_clone[1] == nullptr) { + tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]); + } else { + tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]); + } } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { if (src1 == nullptr) { tensor_clone = ggml_dup(ggml_ctx, src_clone[0]); @@ -10784,10 +11099,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { GGML_ABORT("fatal error"); } - ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); - ggml_build_forward_expand(cgraph, tensor_clone); + ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph_cpu, tensor_clone); - ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8); + ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8); if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { ggml_vk_print_tensor(tensor_clone, "tensor_clone"); @@ -10810,10 +11125,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")"); } -static void ggml_vk_check_results_1(ggml_tensor * tensor) { +static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { + ggml_tensor * tensor = cgraph->nodes[tensor_idx]; if (tensor->op == GGML_OP_TRANSPOSE) { return; } + bool fused_rms_norm_mul = false; + if (ctx->num_additional_fused_ops == 1 && + tensor->op == GGML_OP_RMS_NORM && + cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) { + fused_rms_norm_mul = true; + tensor = cgraph->nodes[tensor_idx + 1]; + } + if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index 14e9daaa01..e1f613fb4f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -19,6 +19,10 @@ if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) message(STATUS "Enabling bfloat16 glslc support") endif() +if (GGML_VULKAN_SHADER_DEBUG_INFO) + add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO) + message(STATUS "Enabling shader debug info") +endif() set(TARGET vulkan-shaders-gen) add_executable(${TARGET} vulkan-shaders-gen.cpp) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index ce230a8f7d..45c6e7736a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -11,7 +11,8 @@ #include "types.comp" #include "flash_attn_base.comp" -const uint32_t D_per_thread = D / D_split; +const uint32_t HSK_per_thread = HSK / D_split; +const uint32_t HSV_per_thread = HSV / D_split; const uint32_t cols_per_iter = WorkGroupSize / D_split; const uint32_t cols_per_thread = Bc / cols_per_iter; @@ -29,7 +30,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];}; // Rows index by Q's dimension 2, and the first N rows are valid. D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) { - uint32_t offset = (iq2 + r) * D + c; + uint32_t offset = (iq2 + r) * HSV + c; data_o[o_offset + offset] = D_TYPE(elem); return elem; } @@ -38,7 +39,7 @@ shared FLOAT_TYPE tmpsh[WorkGroupSize]; shared vec4 tmpshv4[WorkGroupSize]; shared float masksh[Bc][Br]; -shared vec4 Qf[Br][D / 4]; +shared vec4 Qf[Br][HSK / 4]; void main() { #ifdef NEEDS_INIT_IQ_SHMEM @@ -53,18 +54,18 @@ void main() { uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; - [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { - uint32_t d = (idx + tid) % (D / 4); - uint32_t r = (idx + tid) / (D / 4); - if (r < Br && d < D / 4 && + [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t r = (idx + tid) / (HSK / 4); + if (r < Br && d < HSK / 4 && i * Br + r < N) { Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; } } barrier(); - vec4 Of[Br][D_per_thread / 4]; - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + vec4 Of[Br][HSV_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { Of[r][d] = vec4(0.0); } @@ -99,6 +100,10 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + } [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { @@ -112,7 +117,7 @@ void main() { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint ib = coord / BLOCK_SIZE; @@ -144,13 +149,13 @@ void main() { } } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br) { - masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]); + masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); } } barrier(); @@ -191,14 +196,14 @@ void main() { Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; } - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { Of[r][d] = eMf[r] * Of[r][d]; } } [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint ib = coord / BLOCK_SIZE; @@ -255,7 +260,7 @@ void main() { Lf[r] = tmpsh[d_tid]; barrier(); - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { Of[r][d] = eMf * Of[r][d]; tmpshv4[tid] = Of[r][d]; @@ -277,11 +282,11 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = D * p.ne1 * split_k_index; + uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); } @@ -289,7 +294,7 @@ void main() { } } - o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); @@ -305,18 +310,18 @@ void main() { Lfrcp[r] = 1.0 / Lf[r]; } - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { Of[r][d] *= Lfrcp[r]; } } - uint32_t o_offset = iq3*p.ne2*p.ne1; + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); } @@ -326,9 +331,9 @@ void main() { } else { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (i * Br + r < N) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp index 61d90e2d8e..7defe72b40 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -4,10 +4,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (constant_id = 0) const uint32_t WorkGroupSize = 128; layout (constant_id = 1) const uint32_t Br = 1; layout (constant_id = 2) const uint32_t Bc = 32; -layout (constant_id = 3) const uint32_t D = 32; -layout (constant_id = 4) const uint32_t Clamp = 0; -layout (constant_id = 5) const uint32_t D_split = 16; - +layout (constant_id = 3) const uint32_t HSK = 32; +layout (constant_id = 4) const uint32_t HSV = 32; +layout (constant_id = 5) const uint32_t Clamp = 0; +layout (constant_id = 6) const uint32_t D_split = 16; layout (push_constant) uniform parameter { uint32_t N; @@ -24,6 +24,8 @@ layout (push_constant) uniform parameter { uint32_t nev2; uint32_t nev3; uint32_t nem1; + uint32_t nem2; + uint32_t nem3; uint32_t nb01; uint32_t nb02; @@ -34,14 +36,12 @@ layout (push_constant) uniform parameter { uint32_t nb21; uint32_t nb22; uint32_t nb23; - uint32_t nb31; float scale; float max_bias; float logit_softcap; - uint32_t mask; - uint32_t n_head_log2; + uint32_t mask_n_head_log2; float m0; float m1; @@ -50,6 +50,9 @@ layout (push_constant) uniform parameter { uint32_t k_num; } p; +#define MASK_ENABLE_BIT (1<<16) +#define N_LOG2_MASK 0xFFFF + layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; #if defined(A_TYPE_PACKED16) @@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i { const uint32_t h = iq2 + (r % p.gqa_ratio); - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK; + + const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1); return ACC_TYPE(pow(base, ACC_TYPE(exph))); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index da478be24f..486735fe8b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -13,7 +13,9 @@ #include "types.comp" #include "flash_attn_base.comp" -const uint32_t D_per_thread = D / D_split; +const uint32_t HSK_per_thread = HSK / D_split; +const uint32_t HSV_per_thread = HSV / D_split; + const uint32_t row_split = 4; const uint32_t rows_per_thread = Br / row_split; const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; @@ -32,7 +34,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];}; // Rows index by Q's dimension 2, and the first N rows are valid. D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) { - uint32_t offset = (iq2 + r) * D + c; + uint32_t offset = (iq2 + r) * HSV + c; data_o[o_offset + offset] = D_TYPE(elem); return elem; } @@ -44,14 +46,14 @@ const uint32_t MatBc = 16; shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; -const uint32_t qstride = D / 4 + 2; // in units of f16vec4 +const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4 shared f16vec4 Qf[Br * qstride]; -// Avoid padding for D==256 to make it fit in 48KB shmem. -const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; +// Avoid padding for hsk==256 to make it fit in 48KB shmem. +const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br; shared ACC_TYPE sfsh[Bc * sfshstride]; -const uint32_t kshstride = D / 4 + 2; // in units of f16vec4 +const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4 shared f16vec4 ksh[Bc * kshstride]; shared float slope[Br]; @@ -74,18 +76,18 @@ void main() { uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; - [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { - uint32_t d = (idx + tid) % (D / 4); - uint32_t r = (idx + tid) / (D / 4); - if (r < Br && d < D / 4 && + [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t r = (idx + tid) / (HSK / 4); + if (r < Br && d < HSK / 4 && i * Br + r < N) { Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } } barrier(); - ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4]; - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] = ACC_TYPEV4(0.0); } @@ -123,14 +125,18 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + } [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - [[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) { - uint32_t d = (idx + tid) % (D / 4); - uint32_t c = (idx + tid) / (D / 4); - if (c < Bc && d < D / 4) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t c = (idx + tid) / (HSK / 4); + if (c < Bc && d < HSK / 4) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; uint ib = coord / BLOCK_SIZE; @@ -145,14 +151,14 @@ void main() { } barrier(); - // K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br - // Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 + // K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br + // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 // This is written transposed in order to allow for N being 8 if implementations need it coopmat SfMat = coopmat(0); coopmat KMat; coopmat QMat; - for (uint32_t d = 0; d < D / 16; ++d) { + for (uint32_t d = 0; d < HSK / 16; ++d) { coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; @@ -176,12 +182,12 @@ void main() { barrier(); } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)])); + sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)])); } } barrier(); @@ -202,7 +208,7 @@ void main() { eMf[r] = exp(Moldf - Mf[r]); } - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] = float16_t(eMf[r]) * Of[r][d]; } @@ -217,7 +223,7 @@ void main() { Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); Lf[r] += Pf[r]; } - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint ib = coord / BLOCK_SIZE; @@ -280,7 +286,7 @@ void main() { } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { Of[r][d] = float16_t(eMf[r]) * Of[r][d]; tmpshv4[tid] = Of[r][d]; @@ -300,11 +306,11 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = D * p.ne1 * split_k_index; + uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); } @@ -312,7 +318,7 @@ void main() { } } - o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); @@ -328,18 +334,18 @@ void main() { Lfrcp[r] = 1.0 / Lf[r]; } - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] *= float16_t(Lfrcp[r]); } } - uint32_t o_offset = iq3*p.ne2*p.ne1; + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); } @@ -349,9 +355,9 @@ void main() { } else { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (i * Br + tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 6acf67a03a..274f48fcab 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -61,8 +61,8 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele // Rows index by Q's dimension 2, and the first N rows are valid. D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) { - if (r < N && c < D) { - uint32_t offset = (iq2 + r) * D + c; + if (r < N && c < HSV) { + uint32_t offset = (iq2 + r) * HSV + c; data_o[o_offset + offset] = D_TYPE(elem); } return elem; @@ -86,9 +86,9 @@ void main() { tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); #endif - tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D); - tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); - tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); + tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK); + tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK); + tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV); // hint to the compiler that strides are aligned for the aligned variant of the shader if (Clamp != gl_CooperativeMatrixClampModeConstantNV) @@ -104,16 +104,16 @@ void main() { tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); - coopmat Q; - coopmat Qf16; + coopmat Q; + coopmat Qf16; uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; - coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D)); + coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK)); - Qf16 = coopmat(Q); + Qf16 = coopmat(Q); Qf16 *= float16_t(p.scale); - coopmat O = coopmat(0); + coopmat O = coopmat(0); coopmat L, M; @@ -130,15 +130,20 @@ void main() { coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + } + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { coopmat S = coopmat(0); - coopmat K_T; + coopmat K_T; uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; - coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC); + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC); S = coopMatMulAdd(Qf16, K_T, S); if (p.logit_softcap != 0.0f) { @@ -148,14 +153,14 @@ void main() { } } - if (p.mask != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); coopmat mv; - coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); S += slopeMat*coopmat(mv); } @@ -203,42 +208,42 @@ void main() { rowsum = coopmat(0.0); rowsum = coopMatMulAdd(P_A, One, rowsum); - coopmat V; + coopmat V; uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; - coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC); + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC); L = eM*L + rowsum; // This is the "diagonal" matrix in the paper, but since we do componentwise // multiply rather than matrix multiply it has the diagonal element smeared // across the row - coopmat eMdiag; + coopmat eMdiag; // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); // multiply with fp16 accumulation, then add to O. - coopmat PV = coopmat(0); + coopmat PV = coopmat(0); PV = coopMatMulAdd(P_A, V, PV); - O = eMdiag * O + coopmat(PV); + O = eMdiag * O + coopmat(PV); } // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - coopmat O_D = coopmat(O); + coopmat O_D = coopmat(O); - uint32_t o_offset = D * p.ne1 * split_k_index; + uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); - o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); return; } - coopmat Ldiag; + coopmat Ldiag; // resize L by using smear/reduce coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); @@ -250,18 +255,18 @@ void main() { O = Ldiag*O; - uint32_t o_offset = iq3*p.ne2*p.ne1; + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; - coopmat O_D = coopmat(O); + coopmat O_D = coopmat(O); if (p.gqa_ratio > 1) { coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); } else { tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); - tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV); // permute dimensions tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); - coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute); + coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index a7e3956854..0a17a9df23 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -2,9 +2,9 @@ #extension GL_EXT_control_flow_attributes : enable -#define BLOCK_SIZE 32 +layout(constant_id = 0) const uint BLOCK_SIZE = 32; -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {float data_a[];}; layout (binding = 1) writeonly buffer D {float data_d[];}; @@ -12,48 +12,80 @@ layout (binding = 1) writeonly buffer D {float data_d[];}; layout (push_constant) uniform parameter { uint D; uint N; + uint ne3; uint k_num; } p; +shared float tmpsh[BLOCK_SIZE]; + void main() { // Each workgroup handles a row const uint n = gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; + const uint iq3 = gl_WorkGroupID.z; uint D = p.D; uint N = p.N; uint k_num = p.k_num; - uint l_offset = D * N * k_num + n; - uint m_offset = D * N * k_num + N + n; + uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n; + uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n; uint lm_stride = N * 2; // Compute the max m value for the row float m_max = -1.0/0.0; - [[unroll]] for (uint k = 0; k < k_num; ++k) { - float m = data_a[m_offset + k * lm_stride]; + for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) { + float m = data_a[m_offset + (k + tid) * lm_stride]; m_max = max(m_max, m); } + // reduce across the workgroup + tmpsh[tid] = m_max; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + m_max = max(m_max, tmpsh[tid + s]); + tmpsh[tid] = m_max; + } + barrier(); + } + m_max = tmpsh[0]; + + barrier(); + // Compute L based on m_max float L = 0; - [[unroll]] for (uint k = 0; k < k_num; ++k) { - float l = data_a[l_offset + k * lm_stride]; - float m = data_a[m_offset + k * lm_stride]; + for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) { + float l = data_a[l_offset + (k + tid) * lm_stride]; + float m = data_a[m_offset + (k + tid) * lm_stride]; L += exp(m - m_max) * l; } + // reduce across the workgroup + tmpsh[tid] = L; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + L += tmpsh[tid + s]; + tmpsh[tid] = L; + } + barrier(); + } + L = tmpsh[0]; + L = 1.0 / L; + // D dimension is split across workgroups in the y dimension + uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE; // Scale and sum the O contributions based on m_max and store the result to memory - for (uint d = tid; d < D; d += BLOCK_SIZE) { + if (d < D) { float O = 0.0; [[unroll]] for (uint k = 0; k < k_num; ++k) { - uint o_offset = D * N * k + D * n + d; + uint o_offset = D * N * (k + iq3 * k_num) + D * n + d; float m = data_a[m_offset + k * lm_stride]; O += exp(m - m_max) * data_a[o_offset]; } O *= L; - data_d[D * n + d] = O; + data_d[iq3 * D * N + D * n + d] = O; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp new file mode 100644 index 0000000000..f4268ed24f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp @@ -0,0 +1,13 @@ +#version 450 + +#include "glu_head.comp" + +const float GELU_COEF_A = 0.044715f; +const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +float op(float a, float b) { + const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a); + return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b; +} + +#include "glu_main.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp new file mode 100644 index 0000000000..cbd4cb36bf --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "glu_head.comp" + +// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation +// ref: https://www.johndcook.com/blog/python_erf/ +const float p_erf = 0.3275911f; +const float a1_erf = 0.254829592f; +const float a2_erf = -0.284496736f; +const float a3_erf = 1.421413741f; +const float a4_erf = -1.453152027f; +const float a5_erf = 1.061405429f; + +const float SQRT_2_INV = 0.70710678118654752440084436210484f; + +float op(float a, float b) { + const float a_div_sqr2 = a * SQRT_2_INV; + const float sign_x = sign(a_div_sqr2); + const float x = abs(a_div_sqr2); + const float t = 1.0f / (1.0f + p_erf * x); + const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + const float erf_approx = sign_x * y; + + return 0.5f * a * (1.0f + erf_approx) * b; +} + +#include "glu_main.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp new file mode 100644 index 0000000000..3a2a6897bf --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp @@ -0,0 +1,11 @@ +#version 450 + +#include "glu_head.comp" + +const float GELU_QUICK_COEF = -1.702f; + +float op(float a, float b) { + return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b; +} + +#include "glu_main.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp new file mode 100644 index 0000000000..5fd5a5e703 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp @@ -0,0 +1,39 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation + // ref: https://www.johndcook.com/blog/python_erf/ + const float p_erf = 0.3275911f; + const float a1_erf = 0.254829592f; + const float a2_erf = -0.284496736f; + const float a3_erf = 1.421413741f; + const float a4_erf = -1.453152027f; + const float a5_erf = 1.061405429f; + + const float SQRT_2_INV = 0.70710678118654752440084436210484f; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float a = float(data_a[i]); + const float a_div_sqr2 = a * SQRT_2_INV; + const float sign_x = sign(a_div_sqr2); + const float x = abs(a_div_sqr2); + const float t = 1.0f / (1.0f + p_erf * x); + const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + const float erf_approx = sign_x * y; + + data_d[i] = D_TYPE(0.5f * a * (1.0f + erf_approx)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp new file mode 100644 index 0000000000..41a2988907 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp @@ -0,0 +1,15 @@ +#extension GL_EXT_shader_16bit_storage : require + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {A_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +layout (push_constant) uniform parameter +{ + uint N; + uint ne00; + uint ne20; + uint mode; +} p; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp new file mode 100644 index 0000000000..85cf65a9ec --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp @@ -0,0 +1,29 @@ +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.N) { + return; + } + + const uint row = i / p.ne20; + const uint col = i - row * p.ne20; + + if (p.mode == 0) { + // Default + const uint offset = p.ne00 / 2; + const uint idx = row * p.ne00 + col; + + data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); + } else if (p.mode == 1) { + // Swapped + const uint offset = p.ne00 / 2; + const uint idx = row * p.ne00 + col; + + data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); + } else { + // Split + const uint idx = row * p.ne00 + col; + + data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 26163b167c..888ce79f6e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -500,10 +500,9 @@ void main() { const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx % 128) / 4; - const int i8 = 2 * int(idx % 4); + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 32; const float d = float(data_a[ib].d); const uint qh = data_a[ib].qh[ib32]; @@ -512,22 +511,16 @@ void main() { const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); - const ivec2 gvec = ivec2( - bitfieldExtract(grid, 2 * (i8), 2), - bitfieldExtract(grid, 2 * (i8 + 1), 2) - ); - const vec2 v = dl * (vec2(gvec) + delta); - - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + [[unroll]] for (int k = 0; k < 8; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); + } #elif defined(DATA_A_IQ1_M) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib8 = (idx % 128) / 4; + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; const uint ib16 = ib8 / 2; - const int i8 = 2 * int(idx % 4); const uint16_t[4] scales = data_a[ib].scales; const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; @@ -538,21 +531,17 @@ void main() { const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); - const ivec2 gvec = ivec2( - bitfieldExtract(grid, 2 * (i8), 2), - bitfieldExtract(grid, 2 * (i8 + 1), 2) - ); - const vec2 v = dl * (vec2(gvec) + delta); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + [[unroll]] for (int k = 0; k < 8; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta)); + } #elif defined(DATA_A_IQ2_XXS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx / 4) % 4; + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; const float d = float(data_a[ib].d); const uint qs = data_a[ib].qs[8 * ib32 + ib8]; @@ -562,63 +551,81 @@ void main() { data_a[ib].qs[8*ib32 + 6], data_a[ib].qs[8*ib32 + 7] )); - const float db = d * 0.25 * (0.5 + (signs >> 28)); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28))); const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xxs_grid[qs]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_XS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint ib8 = (idx / 4) % 4; // 0..3 + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; // 0..3 const float d = float(data_a[ib].d); const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; - const float db = d * 0.25 * (0.5 + scale); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); const uint qs = data_a[ib].qs[4 * ib32 + ib8]; const uint sign7 = qs >> 9; - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xs_grid[qs & 511]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_S) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint ib8 = (idx % 128) / 4; // 0..31 - const uint ib32 = ib8 / 4; // 0..7 + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; // 0..31 + const uint ib32 = ib8 / 4; // 0..7 const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; const uint qs = data_a[ib].qs[ib8]; const uint qh = data_a[ib].qh[ib32]; const uint qhshift = 2 * (ib8 % 4); - const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)); + const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8]; const float d = float(data_a[ib].d); - const float db = d * 0.25 * (0.5 + scale); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; - const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147 + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); + const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x); + buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z); + buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x); + buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z); + buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ3_XXS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = (idx % 128) / 2; // 0..63 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values const float d = float(data_a[ib].d); @@ -631,33 +638,36 @@ void main() { )); const float db = d * 0.5 * (0.5 + (signs >> 28)); const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); - const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); - const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); - const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2)); + const uint grid = iq3xxs_grid[qs]; + const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z); + buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ3_S) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = (idx % 128) / 2; // 0..63 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 const uint iqh = iqs / 8; const float d = float(data_a[ib].d); const uint qs = data_a[ib].qs[iqs]; const uint qh = data_a[ib].qh[iqh]; - const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4))); + const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2))); const uint scale = data_a[ib].scales[iqs / 16]; const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); - const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; + const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z); + buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ4_XS) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp new file mode 100644 index 0000000000..0073d8f766 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp @@ -0,0 +1,9 @@ +#version 450 + +#include "glu_head.comp" + +float op(float a, float b) { + return max(a, 0.0f) * b; +} + +#include "glu_main.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index deb8ee9960..6428ca7ba3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -1,11 +1,13 @@ #version 450 -#include "generic_unary_head.comp" +#include "generic_binary_head.comp" #include "types.comp" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 +layout (constant_id = 1) const bool do_multiply = false; + layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; shared FLOAT_TYPE sum[BLOCK_SIZE]; @@ -25,6 +27,7 @@ void main() { const uint stride_sample = p.nb03; uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); + uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp @@ -46,7 +49,13 @@ void main() { const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols); const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { - data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + if (do_multiply) { + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + } + } else { + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index 4f5b1a0eca..5808710ccf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -14,21 +14,19 @@ void main() { const uint row_dst = gl_GlobalInvocationID.x; - if (i0 >= p.n_dims) { - const uint i = row_dst*ne0 + i0; - - data_d[i + 0] = data_a[i + 0]; - data_d[i + 1] = data_a[i + 1]; - - return; - } - const uint row_x = row_dst % ne1; const uint channel_x = row_dst / ne1; const uint idst = row_dst*ne0 + i0/2; const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + if (i0 >= p.n_dims) { + data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0]; + data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1]; + + return; + } + const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3]; const int sec_w = p.sections[1] + p.sections[0]; const uint sector = (i0 / 2) % sect_dims; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index db775c456c..366a7b1c47 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -13,21 +13,19 @@ void main() { const uint row_dst = gl_GlobalInvocationID.x; - if (i0 >= p.n_dims) { - const uint i = row_dst*ne0 + i0; - - data_d[i + 0] = data_a[i + 0]; - data_d[i + 1] = data_a[i + 1]; - - return; - } - const uint row_x = row_dst % ne1; const uint channel_x = row_dst / ne1; const uint idst = row_dst*ne0 + i0/2; const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + if (i0 >= p.n_dims) { + data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0]; + data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1]; + + return; + } + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index 4ad35e549d..9643bca96a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -13,21 +13,19 @@ void main() { const uint row_dst = gl_GlobalInvocationID.x; - if (i0 >= p.n_dims) { - const uint i = row_dst*ne0 + i0; - - data_d[i + 0] = data_a[i + 0]; - data_d[i + 1] = data_a[i + 1]; - - return; - } - const uint row_x = row_dst % ne1; const uint channel_x = row_dst / ne1; const uint idst = row_dst*ne0 + i0; const uint ix = channel_x*p.s2 + row_x*p.s1 + i0; + if (i0 >= p.n_dims) { + data_d[idst + 0] = data_a[ix + 0]; + data_d[idst + 1] = data_a[ix + 1]; + + return; + } + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp index 51fc2dc7ed..5bcd3b1e3d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp @@ -6,6 +6,14 @@ layout (push_constant) uniform parameter { uint KX; uint KY; + uint ne00; + uint ne01; + uint ne02; + uint ne12; + uint ne13; + uint nb11; + uint nb12; + uint nb13; float scale; float max_bias; float m0; @@ -31,7 +39,15 @@ shared FLOAT_TYPE vals[BLOCK_SIZE]; void soft_max(uint num_iters) { const uint tid = gl_LocalInvocationID.x; const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; - const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } if (rowx >= p.nrows_x) { return; @@ -41,7 +57,7 @@ void soft_max(uint num_iters) { // ALiBi if (p.max_bias > 0.0f) { - const uint h = rowx/p.KY; // head index + const uint h = (rowx / p.ne01) % p.ne02; // head index const float base = h < p.n_head_log2 ? p.m0 : p.m1; const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; @@ -67,7 +83,7 @@ void soft_max(uint num_iters) { FLOAT_TYPE b = FLOAT_TYPE(0); if (p.KY > 0 && col < p.KX) { - b = data_b[rowy * p.KX + col]; + b = data_b[rowy_start + col]; } FLOAT_TYPE v = a * p.scale + slope * b; @@ -111,7 +127,7 @@ void soft_max(uint num_iters) { if (idx < DATA_CACHE_SIZE) { val = exp(data_cache[idx] - max_val); } else { - val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val); } sum += val; if (idx < DATA_CACHE_SIZE) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp new file mode 100644 index 0000000000..a28e7c6cc8 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp @@ -0,0 +1,9 @@ +#version 450 + +#include "glu_head.comp" + +float op(float a, float b) { + return a / (1.0f + exp(-a)) * b; +} + +#include "glu_main.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index c63345ec8b..30f78fabb3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -360,9 +360,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool for (const auto& tname : type_names) { std::string load_vec_quant = "2"; - if ((tname == "q4_0") || (tname == "q4_1")) + if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl")) + else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl")) load_vec_quant = "4"; if (tname == "bf16") { @@ -497,7 +497,7 @@ void process_shaders() { // Norms string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -574,6 +574,8 @@ void process_shaders() { string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_erf_f32", "gelu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); @@ -585,6 +587,17 @@ void process_shaders() { string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e77d33fc7a..5ae1c527df 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -202,19 +202,34 @@ void ggml_print_backtrace(void) { } #endif +static ggml_abort_callback_t g_abort_callback = NULL; + +// Set the abort callback (passing null will restore original abort functionality: printing a message to stdout) +GGML_API ggml_abort_callback_t ggml_set_abort_callback(ggml_abort_callback_t callback) { + ggml_abort_callback_t ret_val = g_abort_callback; + g_abort_callback = callback; + return ret_val; +} + void ggml_abort(const char * file, int line, const char * fmt, ...) { fflush(stdout); - fprintf(stderr, "%s:%d: ", file, line); + char message[2048]; + int offset = snprintf(message, sizeof(message), "%s:%d: ", file, line); va_list args; va_start(args, fmt); - vfprintf(stderr, fmt, args); + vsnprintf(message + offset, sizeof(message) - offset, fmt, args); va_end(args); - fprintf(stderr, "\n"); + if (g_abort_callback) { + g_abort_callback(message); + } else { + // default: print error and backtrace to stderr + fprintf(stderr, "%s\n", message); + ggml_print_backtrace(); + } - ggml_print_backtrace(); abort(); } @@ -458,6 +473,14 @@ bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) { return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0; } +const char * ggml_version(void) { + return GGML_VERSION; +} + +const char * ggml_commit(void) { + return GGML_COMMIT; +} + // // timing // @@ -933,6 +956,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "TRANSPOSE", "GET_ROWS", "GET_ROWS_BACK", + "SET_ROWS", "DIAG", "DIAG_MASK_INF", "DIAG_MASK_ZERO", @@ -944,6 +968,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CONV_TRANSPOSE_1D", "IM2COL", "IM2COL_BACK", + "CONV_2D", "CONV_2D_DW", "CONV_TRANSPOSE_2D", "POOL_1D", @@ -981,9 +1006,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", "OPT_STEP_ADAMW", + + "GLU", }; -static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); +static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1029,6 +1056,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "transpose(x)", "get_rows(x)", "get_rows_back(x)", + "set_rows(x)", "diag(x)", "diag_mask_inf(x)", "diag_mask_zero(x)", @@ -1040,6 +1068,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "conv_transpose_1d(x)", "im2col(x)", "im2col_back(x)", + "conv_2d(x)", "conv_2d_dw(x)", "conv_transpose_2d(x)", "pool_1d(x)", @@ -1077,9 +1106,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", "adamw(x)", + + "glu(x)", }; -static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); +static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1105,6 +1136,17 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15"); +static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { + "REGLU", + "GEGLU", + "SWIGLU", + "GEGLU_ERF", + "GEGLU_QUICK", +}; + +static_assert(GGML_GLU_OP_COUNT == 5, "GGML_GLU_OP_COUNT != 5"); + + static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -1207,11 +1249,19 @@ const char * ggml_unary_op_name(enum ggml_unary_op op) { return GGML_UNARY_OP_NAME[op]; } +const char * ggml_glu_op_name(enum ggml_glu_op op) { + return GGML_GLU_OP_NAME[op]; +} + const char * ggml_op_desc(const struct ggml_tensor * t) { if (t->op == GGML_OP_UNARY) { enum ggml_unary_op uop = ggml_get_unary_op(t); return ggml_unary_op_name(uop); } + if (t->op == GGML_OP_GLU) { + enum ggml_glu_op gop = ggml_get_glu_op(t); + return ggml_glu_op_name(gop); + } return ggml_op_name(t->op); } @@ -1348,6 +1398,12 @@ bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) { tensor->nb[2] == ggml_type_size(tensor->type); } +bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor) { + return + tensor->ne[0] == ggml_blck_size(tensor->type) || + tensor->nb[0] == ggml_type_size(tensor->type); +} + static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); @@ -1722,6 +1778,11 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) { return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0); } +enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->op == GGML_OP_GLU); + return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0); +} + const char * ggml_get_name(const struct ggml_tensor * tensor) { return tensor->name; } @@ -2601,6 +2662,156 @@ struct ggml_tensor * ggml_exp_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP); } +// ggml_glu + +static struct ggml_tensor * ggml_glu_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_glu_op op, + bool swapped) { + GGML_ASSERT(ggml_is_contiguous_1(a)); + + if (b) { + GGML_ASSERT(ggml_is_contiguous_1(b)); + GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(a->type == b->type); + } + + int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i]; + struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0); + + ggml_set_op_params_i32(result, 0, (int32_t) op); + ggml_set_op_params_i32(result, 1, (int32_t) swapped); + + result->op = GGML_OP_GLU; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_glu( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_glu_op op, + bool swapped) { + return ggml_glu_impl(ctx, a, NULL, op, swapped); +} + +struct ggml_tensor * ggml_glu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_glu_op op) { + return ggml_glu_impl(ctx, a, b, op, false); +} + +// ggml_reglu + +struct ggml_tensor * ggml_reglu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, false); +} + +struct ggml_tensor * ggml_reglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, true); +} + +struct ggml_tensor * ggml_reglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_REGLU, false); +} + +// ggml_geglu + +struct ggml_tensor * ggml_geglu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, false); +} + +struct ggml_tensor * ggml_geglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, true); +} + +struct ggml_tensor * ggml_geglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU, false); +} + +// ggml_swiglu + +struct ggml_tensor * ggml_swiglu( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, false); +} + +struct ggml_tensor * ggml_swiglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, true); +} + +struct ggml_tensor * ggml_swiglu_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false); +} + +// ggml_geglu_erf + +struct ggml_tensor * ggml_geglu_erf( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, false); +} + +struct ggml_tensor * ggml_geglu_erf_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, true); +} + +struct ggml_tensor * ggml_geglu_erf_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_ERF, false); +} + +// ggml_geglu_quick + +struct ggml_tensor * ggml_geglu_quick( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, false); +} + +struct ggml_tensor * ggml_geglu_quick_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, true); +} + +struct ggml_tensor * ggml_geglu_quick_split( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false); +} + // ggml_norm static struct ggml_tensor * ggml_norm_impl( @@ -3402,6 +3613,35 @@ struct ggml_tensor * ggml_get_rows_back( return result; } +// ggml_set_rows + +struct ggml_tensor * ggml_set_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c) { + GGML_ASSERT(a->ne[0] == b->ne[0]); + GGML_ASSERT(a->ne[2] == b->ne[2]); + GGML_ASSERT(a->ne[3] == b->ne[3]); + GGML_ASSERT(b->ne[1] == c->ne[0]); + GGML_ASSERT(b->ne[2] % c->ne[1] == 0); + GGML_ASSERT(b->ne[3] % c->ne[2] == 0); + GGML_ASSERT(c->ne[3] == 1); + GGML_ASSERT(b->type == GGML_TYPE_F32); + GGML_ASSERT(c->type == GGML_TYPE_I64); + + GGML_ASSERT(ggml_is_contiguous_rows(a)); + GGML_ASSERT(ggml_is_contiguous_rows(b)); + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + result->op = GGML_OP_SET_ROWS; + result->src[0] = b; + result->src[1] = c; + + return result; +} + // ggml_diag struct ggml_tensor * ggml_diag( @@ -3496,9 +3736,10 @@ static struct ggml_tensor * ggml_soft_max_impl( if (mask) { GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(mask)); - GGML_ASSERT(ggml_is_matrix(mask)); GGML_ASSERT(mask->ne[0] == a->ne[0]); GGML_ASSERT(mask->ne[1] >= a->ne[1]); + GGML_ASSERT(a->ne[2]%mask->ne[2] == 0); + GGML_ASSERT(a->ne[3]%mask->ne[3] == 0); } if (max_bias > 0.0f) { @@ -4138,6 +4379,44 @@ struct ggml_tensor * ggml_conv_2d_dw_direct( return result; } +// ggml_conv_2d_direct + +struct ggml_tensor * ggml_conv_2d_direct( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC] + struct ggml_tensor * b, // input data [W, H, C, N] + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1) {// dilation dimension 1 + + GGML_ASSERT(a->ne[2] == b->ne[2]); + //GGML_ASSERT(a->type == b->type); + + int64_t ne[4]; + ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1); + ne[2] = a->ne[3]; + ne[3] = b->ne[3]; + + struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne); + + ggml_set_op_params_i32(result, 0, s0); + ggml_set_op_params_i32(result, 1, s1); + ggml_set_op_params_i32(result, 2, p0); + ggml_set_op_params_i32(result, 3, p1); + ggml_set_op_params_i32(result, 4, d0); + ggml_set_op_params_i32(result, 5, d1); + + result->op = GGML_OP_CONV_2D; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // ggml_conv_transpose_2d_p0 static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { @@ -4254,24 +4533,21 @@ struct ggml_tensor * ggml_pool_2d_back( return result; } -// ggml_upscale +// ggml_upscale / ggml_interpolate -static struct ggml_tensor * ggml_upscale_impl( +static struct ggml_tensor * ggml_interpolate_impl( struct ggml_context * ctx, struct ggml_tensor * a, - int ne0, - int ne1, - int ne2, - int ne3, - enum ggml_scale_mode mode) { - GGML_ASSERT(a->ne[0] <= ne0); - GGML_ASSERT(a->ne[1] <= ne1); - GGML_ASSERT(a->ne[2] <= ne2); - GGML_ASSERT(a->ne[3] <= ne3); + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + uint32_t mode) { + GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT); struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); - ggml_set_op_params_i32(result, 0, mode); + ggml_set_op_params_i32(result, 0, (int32_t)mode); result->op = GGML_OP_UPSCALE; result->src[0] = a; @@ -4284,7 +4560,8 @@ struct ggml_tensor * ggml_upscale( struct ggml_tensor * a, int scale_factor, enum ggml_scale_mode mode) { - return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode); + GGML_ASSERT(scale_factor > 1); + return ggml_interpolate_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode); } struct ggml_tensor * ggml_upscale_ext( @@ -4295,7 +4572,18 @@ struct ggml_tensor * ggml_upscale_ext( int ne2, int ne3, enum ggml_scale_mode mode) { - return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode); + return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode); +} + +struct ggml_tensor * ggml_interpolate( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + uint32_t mode) { + return ggml_interpolate_impl(ctx, a, ne0, ne1, ne2, ne3, mode); } // ggml_pad @@ -4472,13 +4760,17 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) + GGML_ASSERT(q->ne[3] == k->ne[3]); + GGML_ASSERT(q->ne[3] == v->ne[3]); + if (mask) { GGML_ASSERT(ggml_is_contiguous(mask)); - GGML_ASSERT(mask->ne[2] == 1); - GGML_ASSERT(mask->ne[3] == 1); GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); + + GGML_ASSERT(q->ne[2] % mask->ne[2] == 0); + GGML_ASSERT(q->ne[3] % mask->ne[3] == 0); } if (max_bias > 0.0f) { @@ -4606,7 +4898,6 @@ struct ggml_tensor * ggml_ssm_conv( const int64_t n_s = sx->ne[2]; // TODO: maybe support other strides than 1? - // FIXME: this is always true? GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); GGML_ASSERT(sx->ne[1] == d_inner); GGML_ASSERT(n_t >= 0); @@ -4629,36 +4920,49 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C) { + struct ggml_tensor * C, + struct ggml_tensor * ids) { GGML_ASSERT(ggml_is_contiguous(s)); - GGML_ASSERT(ggml_is_contiguous(x)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); - GGML_ASSERT(ggml_is_matrix(A)); - GGML_ASSERT(ggml_is_3d(B)); - GGML_ASSERT(ggml_is_3d(s)); + GGML_ASSERT(x->nb[0] == ggml_type_size(x->type)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); - GGML_ASSERT(ggml_are_same_shape(x, dt)); + GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]); + GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]); + GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]); GGML_ASSERT(ggml_are_same_shape(B, C)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); { const int64_t d_state = s->ne[0]; - const int64_t d_inner = s->ne[1]; - const int64_t n_seq_tokens = x->ne[1]; - const int64_t n_seqs = x->ne[2]; + const int64_t head_dim = x->ne[0]; + const int64_t n_head = x->ne[1]; + const int64_t n_seq_tokens = x->ne[2]; + const int64_t n_seqs = x->ne[3]; - GGML_ASSERT(s->ne[2] == n_seqs); - GGML_ASSERT(x->ne[0] == d_inner); - GGML_ASSERT(A->ne[0] == d_state); - GGML_ASSERT(A->ne[1] == d_inner); + GGML_ASSERT(dt->ne[0] == n_head); + GGML_ASSERT(dt->ne[1] == n_seq_tokens); + GGML_ASSERT(dt->ne[2] == n_seqs); + GGML_ASSERT(ggml_is_3d(dt)); + GGML_ASSERT(s->ne[1] == head_dim); + GGML_ASSERT(s->ne[2] == n_head); GGML_ASSERT(B->ne[0] == d_state); - GGML_ASSERT(B->ne[1] == n_seq_tokens); - GGML_ASSERT(B->ne[2] == n_seqs); + GGML_ASSERT(B->ne[2] == n_seq_tokens); + GGML_ASSERT(B->ne[3] == n_seqs); + GGML_ASSERT(ids->ne[0] == n_seqs); + GGML_ASSERT(ggml_is_vector(ids)); + GGML_ASSERT(A->ne[1] == n_head); + GGML_ASSERT(ggml_is_matrix(A)); + + if (A->ne[0] != 1) { + // Mamba-1 has more granular decay factors + GGML_ASSERT(A->ne[0] == d_state); + } } // concatenated y + ssm_states - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]); result->op = GGML_OP_SSM_SCAN; result->src[0] = s; @@ -4667,6 +4971,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; + result->src[6] = ids; return result; } @@ -5807,13 +6112,28 @@ static void ggml_compute_backward( } GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); } break; + case GGML_OP_GLU: { + switch (ggml_get_glu_op(tensor)) { + case GGML_GLU_OP_SWIGLU: { + if (src0_needs_grads) { + GGML_ASSERT(src1 && "backward pass only implemented for split swiglu"); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0)); + } + if (src1_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad)); + } + } break; + default: { + GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor))); + } //break; + } + } break; case GGML_OP_NONE: { // noop } break; case GGML_OP_COUNT: default: { - fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); - GGML_ABORT("fatal error"); + GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); } //break; } @@ -5822,19 +6142,32 @@ static void ggml_compute_backward( GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2])); } -static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { +static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { // check if already visited - if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) { - return; + size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); + GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL); + if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) { + // This is the first time we see this node in the current graph. + cgraph->visited_hash_set.keys[node_hash_pos] = node; + ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos); + cgraph->use_counts[node_hash_pos] = 0; + } else { + // already visited + return node_hash_pos; } for (int i = 0; i < GGML_MAX_SRC; ++i) { const int k = (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i : (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) : - /* unknown order, just fall back to using i*/ i; - if (node->src[k]) { - ggml_visit_parents(cgraph, node->src[k]); + /* unknown order, just fall back to using i */ i; + + struct ggml_tensor * src = node->src[k]; + if (src) { + size_t src_hash_pos = ggml_visit_parents(cgraph, src); + + // Update the use count for this operand. + cgraph->use_counts[src_hash_pos]++; } } @@ -5858,6 +6191,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * cgraph->nodes[cgraph->n_nodes] = node; cgraph->n_nodes++; } + + return node_hash_pos; } static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { @@ -5995,6 +6330,7 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) { incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1); incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs + incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys if (grads) { incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads @@ -6024,11 +6360,12 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz void * p = cgraph + 1; - struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); - struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); - struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); - struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; - struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + int32_t * use_counts_ptr = incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); + struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t)); @@ -6043,6 +6380,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz /*.grads =*/ grads_ptr, /*.grad_accs =*/ grad_accs_ptr, /*.leafs =*/ leafs_ptr, + /*.use_counts =*/ use_counts_ptr, /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr }, /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, }; @@ -6069,7 +6407,8 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) /*.grads =*/ NULL, // gradients would need visited_hash_set /*.grad_accs =*/ NULL, /*.leafs =*/ NULL, - /*.visited_hash_set =*/ { 0, NULL, NULL }, + /*.use_counts =*/ cgraph0->use_counts, + /*.visited_hash_set =*/ cgraph0->visited_hash_set, /*.order =*/ cgraph0->order, }; @@ -6096,7 +6435,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) { for (size_t i = 0; i < src->visited_hash_set.size; ++i) { // copy all hashset keys (tensors) that are in use if (ggml_bitset_get(src->visited_hash_set.used, i)) { - ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]); + size_t new_hash_pos = ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]); + dst->use_counts[new_hash_pos] = src->use_counts[i]; } } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index fb75143b0b..e938f8fa66 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -170,6 +170,7 @@ class Keys: INNER_SIZE = "{arch}.ssm.inner_size" STATE_SIZE = "{arch}.ssm.state_size" TIME_STEP_RANK = "{arch}.ssm.time_step_rank" + GROUP_COUNT = "{arch}.ssm.group_count" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" class WKV: @@ -327,6 +328,7 @@ class MODEL_ARCH(IntEnum): RWKV7 = auto() ARWKV7 = auto() MAMBA = auto() + MAMBA2 = auto() XVERSE = auto() COMMAND_R = auto() COHERE2 = auto() @@ -354,6 +356,9 @@ class MODEL_ARCH(IntEnum): BAILINGMOE = auto() DOTS1 = auto() ARCEE = auto() + ERNIE4_5 = auto() + HUNYUAN_MOE = auto() + SMOLLM3 = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -428,6 +433,7 @@ class MODEL_TENSOR(IntEnum): SSM_DT = auto() SSM_A = auto() SSM_D = auto() + SSM_NORM = auto() SSM_OUT = auto() TIME_MIX_W0 = auto() TIME_MIX_W1 = auto() @@ -627,6 +633,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.RWKV7: "rwkv7", MODEL_ARCH.ARWKV7: "arwkv7", MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.MAMBA2: "mamba2", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.COHERE2: "cohere2", @@ -654,6 +661,9 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.BAILINGMOE: "bailingmoe", MODEL_ARCH.DOTS1: "dots1", MODEL_ARCH.ARCEE: "arcee", + MODEL_ARCH.ERNIE4_5: "ernie4_5", + MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", + MODEL_ARCH.SMOLLM3: "smollm3", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -728,6 +738,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", + MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", @@ -1712,6 +1723,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_OUT, ], + MODEL_ARCH.MAMBA2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_OUT, + ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -2177,6 +2201,57 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.ERNIE4_5: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.HUNYUAN_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], + MODEL_ARCH.SMOLLM3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } @@ -2481,6 +2556,7 @@ KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK +KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS # tokenization diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index d32cd479ad..a7ecf3d312 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -714,8 +714,8 @@ class GGUFWriter: def add_clamp_kqv(self, value: float) -> None: self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value) - def add_shared_kv_layers(self, value: float) -> None: - self.add_float32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value) + def add_shared_kv_layers(self, value: int) -> None: + self.add_uint32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value) def add_sliding_window_pattern(self, value: Sequence[bool]) -> None: self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value) @@ -861,6 +861,9 @@ class GGUFWriter: def add_ssm_time_step_rank(self, value: int) -> None: self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value) + def add_ssm_group_count(self, value: int) -> None: + self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value) + def add_ssm_dt_b_c_rms(self, value: bool) -> None: self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index b30f77dbe3..7c2877f56c 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -303,6 +303,7 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe "model.layers.{bid}.feed_forward.router", # llama4 "encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe + "model.layers.{bid}.mlp.gate.wg", # hunyuan ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -362,6 +363,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2 "model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4 + "model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan ), # AWQ-activation gate @@ -398,6 +400,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2 "model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4 + "model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan ), # Feed-forward down @@ -447,11 +450,13 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 "model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 "model.layers.{bid}.shared_mlp.output_linear", # granitemoe + "model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan ), MODEL_TENSOR.ATTN_Q_NORM: ( "language_model.encoder.layers.{bid}.self_attention.q_layernorm", "model.layers.{bid}.self_attn.q_layernorm", # persimmon + "model.layers.{bid}.self_attn.query_layernorm", # hunyuan "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2 "transformer.blocks.{bid}.attn.q_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 @@ -461,6 +466,7 @@ class TensorNameMap: MODEL_TENSOR.ATTN_K_NORM: ( "language_model.encoder.layers.{bid}.self_attention.k_layernorm", "model.layers.{bid}.self_attn.k_layernorm", # persimmon + "model.layers.{bid}.self_attn.key_layernorm", # hunyuan "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2 "transformer.blocks.{bid}.attn.k_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 @@ -477,7 +483,7 @@ class TensorNameMap: "encoder.layers.{bid}.norm2", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_3", # Grok "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 - "encoder.layer.{bid}.layer_norm_2" # jina-v2-code + "encoder.layer.{bid}.layer_norm_2", # jina-v2-code ), MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: ( @@ -574,6 +580,10 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.D", ), + MODEL_TENSOR.SSM_NORM: ( + "backbone.layers.{bid}.mixer.norm", # mamba2 + ), + MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index 3f541b0c02..635fcef35e 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -245,9 +245,18 @@ class SpecialVocab: if not tokenizer_config: return True chat_template_alt = None - chat_template_file = path / 'chat_template.json' - if chat_template_file.is_file(): - with open(chat_template_file, encoding = 'utf-8') as f: + chat_template_json = path / 'chat_template.json' + chat_template_jinja = path / 'chat_template.jinja' + if chat_template_jinja.is_file(): + with open(chat_template_jinja, encoding = 'utf-8') as f: + chat_template_alt = f.read() + if additional_templates := list((path / 'additional_chat_templates').glob('*.jinja')): + chat_template_alt = [{'name': 'default', 'template': chat_template_alt}] + for template_path in additional_templates: + with open(template_path, encoding = 'utf-8') as fp: + chat_template_alt.append({'name': template_path.stem, 'template': fp.read()}) + elif chat_template_json.is_file(): + with open(chat_template_json, encoding = 'utf-8') as f: chat_template_alt = json.load(f).get('chat_template') chat_template = tokenizer_config.get('chat_template', chat_template_alt) if chat_template is None or isinstance(chat_template, (str, list)): diff --git a/include/llama.h b/include/llama.h index 3eda9bc686..dc86aea41d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -117,6 +117,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, }; enum llama_rope_type { diff --git a/scripts/apple/validate-apps.sh b/scripts/apple/validate-apps.sh index a571aa6fcf..f0475758c3 100755 --- a/scripts/apple/validate-apps.sh +++ b/scripts/apple/validate-apps.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash ./scripts/apple/validate-ios.sh ./scripts/apple/validate-macos.sh ./scripts/apple/validate-visionos.sh diff --git a/scripts/apple/validate-ios.sh b/scripts/apple/validate-ios.sh index 7bda1b9729..50800d84a0 100755 --- a/scripts/apple/validate-ios.sh +++ b/scripts/apple/validate-ios.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # validate-ios.sh - Validate iOS Application with embedded llama.xcframework using SwiftUI # Authentication options (optional) (can be set via environment variables) diff --git a/scripts/apple/validate-macos.sh b/scripts/apple/validate-macos.sh index 6dc28e6949..fa800ee682 100755 --- a/scripts/apple/validate-macos.sh +++ b/scripts/apple/validate-macos.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # validate-macos.sh - Validate macOS Application with embedded llama.xcframework using SwiftUI # Authentication options (optional) (can be set via environment variables) diff --git a/scripts/apple/validate-tvos.sh b/scripts/apple/validate-tvos.sh index 6120189e84..b4da698749 100755 --- a/scripts/apple/validate-tvos.sh +++ b/scripts/apple/validate-tvos.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # validate-tvos.sh - Validate tvOS Application with embedded llama.xcframework using SwiftUI # Authentication options (optional) (can be set via environment variables) diff --git a/scripts/apple/validate-visionos.sh b/scripts/apple/validate-visionos.sh index a18ddcce4a..bbdec66026 100755 --- a/scripts/apple/validate-visionos.sh +++ b/scripts/apple/validate-visionos.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # validate-visionos.sh - Validate visionOS Application with embedded llama.xcframework using SwiftUI # Authentication options (optional) (can be set via environment variables) diff --git a/scripts/check-requirements.sh b/scripts/check-requirements.sh index 4c3b05f68b..da2357d76c 100755 --- a/scripts/check-requirements.sh +++ b/scripts/check-requirements.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -euo pipefail # diff --git a/scripts/ci-run.sh b/scripts/ci-run.sh index 06b5d9c6e5..5877a7edab 100755 --- a/scripts/ci-run.sh +++ b/scripts/ci-run.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -euo pipefail this=$(realpath "$0"); readonly this cd "$(dirname "$this")" diff --git a/scripts/compare-commits.sh b/scripts/compare-commits.sh index 94a8eceb30..051a7a0983 100755 --- a/scripts/compare-commits.sh +++ b/scripts/compare-commits.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash if [ $# -lt 2 ]; then echo "usage: ./scripts/compare-commits.sh [additional llama-bench arguments]" diff --git a/scripts/debug-test.sh b/scripts/debug-test.sh index c6c1e988a0..7e9e8421b0 100755 --- a/scripts/debug-test.sh +++ b/scripts/debug-test.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash PROG=${0##*/} build_dir="build-ci-debug" diff --git a/scripts/gen-authors.sh b/scripts/gen-authors.sh index 3ef8391cc9..73e7b386f9 100755 --- a/scripts/gen-authors.sh +++ b/scripts/gen-authors.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash printf "# date: $(date)\n" > AUTHORS printf "# this file is auto-generated by scripts/gen-authors.sh\n\n" >> AUTHORS diff --git a/scripts/get-hellaswag.sh b/scripts/get-hellaswag.sh index 4e1b1cc15f..484e56fd8f 100755 --- a/scripts/get-hellaswag.sh +++ b/scripts/get-hellaswag.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash wget https://raw.githubusercontent.com/klosax/hellaswag_text_data/main/hellaswag_val_full.txt diff --git a/scripts/get-pg.sh b/scripts/get-pg.sh index b027793e19..f180bf8340 100755 --- a/scripts/get-pg.sh +++ b/scripts/get-pg.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash function usage { echo "usage: $0" diff --git a/scripts/get-wikitext-103.sh b/scripts/get-wikitext-103.sh index 9c65fafbcc..244a371bad 100755 --- a/scripts/get-wikitext-103.sh +++ b/scripts/get-wikitext-103.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip diff --git a/scripts/get-wikitext-2.sh b/scripts/get-wikitext-2.sh index 5f3845ef59..67b0b0118b 100755 --- a/scripts/get-wikitext-2.sh +++ b/scripts/get-wikitext-2.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip unzip wikitext-2-raw-v1.zip diff --git a/scripts/get-winogrande.sh b/scripts/get-winogrande.sh index f1fc0e2d47..2b48b11756 100755 --- a/scripts/get-winogrande.sh +++ b/scripts/get-winogrande.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash wget https://huggingface.co/datasets/ikawrakow/winogrande-eval-for-llama.cpp/raw/main/winogrande-debiased-eval.csv diff --git a/scripts/hf.sh b/scripts/hf.sh index b251925fa4..e41b9053af 100755 --- a/scripts/hf.sh +++ b/scripts/hf.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Shortcut for downloading HF models # diff --git a/scripts/qnt-all.sh b/scripts/qnt-all.sh index bc43738a2f..dc04670dff 100755 --- a/scripts/qnt-all.sh +++ b/scripts/qnt-all.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash qnt=(q8_0 q6_k q5_k q5_1 q5_0 q4_k q4_1 q4_0 q3_k q2_k) args="" diff --git a/scripts/run-all-perf.sh b/scripts/run-all-perf.sh index 6384e364d5..b7de764ff8 100755 --- a/scripts/run-all-perf.sh +++ b/scripts/run-all-perf.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash qnt=(f16 q8_0 q6_k q5_k q5_1 q5_0 q4_k q4_1 q4_0 q3_k q2_k) args="-ngl 999 -n 64 -p 512" diff --git a/scripts/run-all-ppl.sh b/scripts/run-all-ppl.sh index e15f74f1b6..918ecda279 100755 --- a/scripts/run-all-ppl.sh +++ b/scripts/run-all-ppl.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash qnt=(f16 q8_0 q6_k q5_k q5_1 q5_0 q4_k q4_1 q4_0 q3_k q2_k) args="-ngl 999 -t 8" diff --git a/scripts/sync-ggml-am.sh b/scripts/sync-ggml-am.sh index 204354209f..29d30e0a18 100755 --- a/scripts/sync-ggml-am.sh +++ b/scripts/sync-ggml-am.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Synchronize ggml changes to llama.cpp # @@ -83,7 +83,6 @@ while read c; do src/ggml-cpu/* \ src/ggml-cuda/* \ src/ggml-hip/* \ - src/ggml-kompute/* \ src/ggml-metal/* \ src/ggml-musa/* \ src/ggml-opencl/* \ @@ -141,7 +140,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then # src/ggml-cpu/* -> ggml/src/ggml-cpu/* # src/ggml-cuda/* -> ggml/src/ggml-cuda/* # src/ggml-hip/* -> ggml/src/ggml-hip/* - # src/ggml-kompute/* -> ggml/src/ggml-kompute/* # src/ggml-metal/* -> ggml/src/ggml-metal/* # src/ggml-musa/* -> ggml/src/ggml-musa/* # src/ggml-opencl/* -> ggml/src/ggml-opencl/* @@ -174,7 +172,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then -e 's/([[:space:]]| [ab]\/)src\/ggml-cpu\//\1ggml\/src\/ggml-cpu\//g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml-cuda\//\1ggml\/src\/ggml-cuda\//g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml-hip\//\1ggml\/src\/ggml-hip\//g' \ - -e 's/([[:space:]]| [ab]\/)src\/ggml-kompute\//\1ggml\/src\/ggml-kompute\//g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml-metal\//\1ggml\/src\/ggml-metal\//g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml-opencl\//\1ggml\/src\/ggml-opencl\//g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml-rpc\//\1ggml\/src\/ggml-rpc\//g' \ diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index bd1e04ed07..4157e1f53c 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -9e4bee1c5afc2d677a5b32ecb90cbdb483e81fff +0405219965324e11a29b6aadfe22a6d66131978f diff --git a/scripts/sync-ggml.sh b/scripts/sync-ggml.sh index aa1a46b4bf..9b98329e09 100755 --- a/scripts/sync-ggml.sh +++ b/scripts/sync-ggml.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash cp -rpv ../ggml/CMakeLists.txt ./ggml/CMakeLists.txt cp -rpv ../ggml/src/CMakeLists.txt ./ggml/src/CMakeLists.txt @@ -15,7 +15,6 @@ cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/ cp -rpv ../ggml/src/ggml-cpu/* ./ggml/src/ggml-cpu/ cp -rpv ../ggml/src/ggml-cuda/* ./ggml/src/ggml-cuda/ cp -rpv ../ggml/src/ggml-hip/* ./ggml/src/ggml-hip/ -cp -rpv ../ggml/src/ggml-kompute/* ./ggml/src/ggml-kompute/ cp -rpv ../ggml/src/ggml-metal/* ./ggml/src/ggml-metal/ cp -rpv ../ggml/src/ggml-musa/* ./ggml/src/ggml-musa/ cp -rpv ../ggml/src/ggml-opencl/* ./ggml/src/ggml-opencl/ diff --git a/scripts/tool_bench.sh b/scripts/tool_bench.sh index 6c7616a88f..05b41d2f1f 100755 --- a/scripts/tool_bench.sh +++ b/scripts/tool_bench.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -euo pipefail cmake --build build -j diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 435e3b9ba3..9af9c2ad60 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -45,6 +45,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA3N, "gemma3n" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_MAMBA2, "mamba2" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_COHERE2, "cohere2" }, @@ -76,6 +77,9 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_BAILINGMOE, "bailingmoe" }, { LLM_ARCH_DOTS1, "dots1" }, { LLM_ARCH_ARCEE, "arcee" }, + { LLM_ARCH_ERNIE4_5, "ernie4_5" }, + { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, + { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -169,6 +173,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, + { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, @@ -1003,6 +1008,22 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, }, }, + { + LLM_ARCH_MAMBA2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + }, + }, { LLM_ARCH_XVERSE, { @@ -1658,12 +1679,69 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, } }, + { + LLM_ARCH_ERNIE4_5, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_HUNYUAN_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_UNKNOWN, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, }, }, + { + LLM_ARCH_SMOLLM3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, }; static const std::map LLM_TENSOR_INFOS = { @@ -1743,6 +1821,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -1876,6 +1955,7 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) { bool llm_arch_is_recurrent(const llm_arch & arch) { switch (arch) { case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6QWEN2: case LLM_ARCH_RWKV7: diff --git a/src/llama-arch.h b/src/llama-arch.h index 9181ad053f..ba5d03fa24 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -49,6 +49,7 @@ enum llm_arch { LLM_ARCH_GEMMA3N, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, + LLM_ARCH_MAMBA2, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_COHERE2, @@ -80,6 +81,9 @@ enum llm_arch { LLM_ARCH_BAILINGMOE, LLM_ARCH_DOTS1, LLM_ARCH_ARCEE, + LLM_ARCH_ERNIE4_5, + LLM_ARCH_HUNYUAN_MOE, + LLM_ARCH_SMOLLM3, LLM_ARCH_UNKNOWN, }; @@ -173,6 +177,7 @@ enum llm_kv { LLM_KV_SSM_CONV_KERNEL, LLM_KV_SSM_STATE_SIZE, LLM_KV_SSM_TIME_STEP_RANK, + LLM_KV_SSM_GROUP_COUNT, LLM_KV_SSM_DT_B_C_RMS, LLM_KV_WKV_HEAD_SIZE, @@ -292,6 +297,7 @@ enum llm_tensor { LLM_TENSOR_SSM_DT, LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 91b1d6078a..3bc8554e51 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -166,6 +166,8 @@ bool llama_batch_allocr::init( // note: tracking the other way around is not necessary for now //seq_cpl[s0][s1] = true; + + has_cpl = true; } } } @@ -405,6 +407,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const { return n_outputs; } +uint32_t llama_batch_allocr::get_n_used() const { + return n_used; +} + std::vector & llama_batch_allocr::get_out_ids() { return out_ids; } @@ -420,6 +426,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const { void llama_batch_allocr::split_reset() { out_ids.clear(); + n_used = 0; + used.clear(); used.resize(get_n_tokens(), false); @@ -444,6 +452,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { idxs.push_back(cur_idx); used[cur_idx] = true; + ++n_used; ++cur_idx; @@ -459,9 +468,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { return ubatch_add(idxs, idxs.size(), false); } -llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { +llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) { + if (sequential && has_cpl) { + LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__); + + return {}; + } + std::vector cur_seq_set; + llama_seq_id last_seq_id = -1; + // determine the non-overlapping sequence sets participating in this ubatch for (int32_t i = 0; i < batch.n_tokens; ++i) { if (used[i]) { @@ -478,9 +495,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { } } + // accept only increasing sequence ids + if (sequential) { + add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1); + } + if (add) { cur_seq_set.push_back(seq_set[i]); + last_seq_id = batch.seq_id[i][0]; + if (cur_seq_set.size() > n_ubatch) { break; } @@ -529,6 +553,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { idxs_per_seq[s].push_back(idx); used[idx] = true; + ++n_used; ++cur_idx[s]; } @@ -570,6 +595,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) { idxs.push_back(cur_idx); used[cur_idx] = true; + ++n_used; if (idxs.size() >= n_ubatch) { break; diff --git a/src/llama-batch.h b/src/llama-batch.h index d2c5376188..3420803ff9 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -54,6 +54,7 @@ public: uint32_t get_n_tokens() const; uint32_t get_n_outputs() const; + uint32_t get_n_used() const; // the array of output indices in the order they were encountered during the ubatch splitting std::vector & get_out_ids(); @@ -69,7 +70,8 @@ public: llama_ubatch split_simple(uint32_t n_ubatch); // make ubatches of equal-length sequences sets - llama_ubatch split_equal(uint32_t n_ubatch); + // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids + llama_ubatch split_equal(uint32_t n_ubatch, bool sequential); // sequence-set-wise split - each ubatch contains a single sequence-set llama_ubatch split_seq(uint32_t n_ubatch); @@ -112,6 +114,9 @@ private: using pos_set_t = std::set; using seq_cpl_t = std::vector; + // helper flag to quickly determine if there are any coupled sequences in the batch + bool has_cpl; + std::vector seq_pos; // seq_pos[s]: the set of positions in sequence s std::vector seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 @@ -125,6 +130,8 @@ private: // batch indices of the output std::vector out_ids; + uint32_t n_used; + // used[i] indicates if token i has already been used in a previous ubatch std::vector used; diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 5d317f4ee6..cbc19d3c40 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -64,6 +64,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "bailing", LLM_CHAT_TEMPLATE_BAILING }, { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, + { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -185,6 +186,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_LLAMA4; } else if (tmpl_contains("<|endofuserprompt|>")) { return LLM_CHAT_TEMPLATE_DOTS1; + } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) { + return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -665,6 +668,18 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|response|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) { + // tencent/Hunyuan-A13B-Instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|startoftext|>" << message->content << "<|extra_4|>"; + } else if (role == "assistant") { + ss << "<|startoftext|>" << message->content << "<|eos|>"; + } else { + ss << "<|startoftext|>" << message->content << "<|extra_0|>"; + } + } } else { // template not supported return -1; diff --git a/src/llama-chat.h b/src/llama-chat.h index 38800010ae..b621fda281 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -44,6 +44,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_LLAMA4, LLM_CHAT_TEMPLATE_SMOLVLM, LLM_CHAT_TEMPLATE_DOTS1, + LLM_CHAT_TEMPLATE_HUNYUAN_MOE, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 71ee431a97..7f0e8c67f1 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { - if (self_kq_mask) { - mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); - } + mctx->set_input_k_idxs(self_k_idxs, ubatch); + mctx->set_input_v_idxs(self_v_idxs, ubatch); + + mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { - if (self_kq_mask) { - mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); - } + mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); + mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); - if (self_kq_mask_swa) { - mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); - } + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + + mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); + mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); + + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { @@ -333,9 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { - if (self_kq_mask) { - mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); - } + mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch); + mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); const int64_t n_rs = mctx->get_recr()->get_n_rs(); @@ -350,7 +354,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { } } -void llm_graph_input_one::set_input(const llama_ubatch *) { +void llm_graph_input_one::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); GGML_ASSERT(one && ggml_nelements(one) == 1); float f_one = 1.0f; ggml_backend_tensor_set(one, &f_one, 0, sizeof(float)); @@ -560,12 +565,20 @@ ggml_tensor * llm_graph_context::build_ffn( switch (type_op) { case LLM_FFN_SILU: - { + if (gate && type_gate == LLM_FFN_PAR) { + cur = ggml_swiglu_split(ctx0, cur, tmp); + cb(cur, "ffn_swiglu", il); + type_gate = LLM_FFN_SEQ; + } else { cur = ggml_silu(ctx0, cur); cb(cur, "ffn_silu", il); } break; case LLM_FFN_GELU: - { + if (gate && type_gate == LLM_FFN_PAR) { + cur = ggml_geglu_split(ctx0, cur, tmp); + cb(cur, "ffn_geglu", il); + type_gate = LLM_FFN_SEQ; + } else { cur = ggml_gelu(ctx0, cur); cb(cur, "ffn_gelu", il); if (act_scales != NULL) { @@ -574,7 +587,11 @@ ggml_tensor * llm_graph_context::build_ffn( } } break; case LLM_FFN_RELU: - { + if (gate && type_gate == LLM_FFN_PAR) { + cur = ggml_reglu_split(ctx0, cur, tmp); + cb(cur, "ffn_reglu", il); + type_gate = LLM_FFN_SEQ; + } else { cur = ggml_relu(ctx0, cur); cb(cur, "ffn_relu", il); } break; @@ -588,32 +605,19 @@ ggml_tensor * llm_graph_context::build_ffn( } break; case LLM_FFN_SWIGLU: { - // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - int64_t split_point = cur->ne[0] / 2; - // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217 - ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); - ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); - - x0 = ggml_silu(ctx0, x0); - cb(cur, "ffn_silu", il); - - cur = ggml_mul(ctx0, x0, x1); - cb(cur, "ffn_mul", il); + cur = ggml_swiglu(ctx0, cur); + cb(cur, "ffn_swiglu", il); } break; case LLM_FFN_GEGLU: { - // Split into two equal parts - int64_t split_point = cur->ne[0] / 2; - // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217 - ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); - ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); - - x0 = ggml_gelu(ctx0, x0); - cb(x0, "ffn_gelu", il); - - cur = ggml_mul(ctx0, x0, x1); + cur = ggml_geglu(ctx0, cur); cb(cur, "ffn_geglu", il); } break; + case LLM_FFN_REGLU: + { + cur = ggml_reglu(ctx0, cur); + cb(cur, "ffn_reglu", il); + } break; } if (gate && type_gate == LLM_FFN_PAR) { @@ -743,12 +747,18 @@ ggml_tensor * llm_graph_context::build_moe_ffn( switch (type_op) { case LLM_FFN_SILU: - { + if (gate_exps) { + cur = ggml_swiglu_split(ctx0, cur, up); + cb(cur, "ffn_moe_swiglu", il); + } else { cur = ggml_silu(ctx0, cur); cb(cur, "ffn_moe_silu", il); } break; case LLM_FFN_GELU: - { + if (gate_exps) { + cur = ggml_geglu_split(ctx0, cur, up); + cb(cur, "ffn_moe_geglu", il); + } else { cur = ggml_gelu(ctx0, cur); cb(cur, "ffn_moe_gelu", il); } break; @@ -756,11 +766,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn( GGML_ABORT("fatal error"); } - if (gate_exps) { - cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens] - cb(cur, "ffn_moe_gate_par", il); - } - experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); @@ -997,8 +1002,10 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { const auto n_kv = inp->mctx->get_attn()->get_n_kv(); - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask, "KQ_mask", -1); + inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch); + inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch); + + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1135,8 +1142,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con auto inp = std::make_unique(hparams, cparams); // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp_kq_mask, "KQ_mask", -1); + inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->kq_mask); inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; @@ -1198,8 +1204,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const auto n_kv = mctx_cur->get_n_kv(); - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask, "KQ_mask", -1); + inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); + inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); + + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1230,8 +1238,11 @@ ggml_tensor * llm_graph_context::build_attn( // store to KV cache { - ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); + const auto & k_idxs = inp->get_k_idxs(); + const auto & v_idxs = inp->get_v_idxs(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il)); } const auto & kq_mask = inp->get_kq_mask(); @@ -1290,11 +1301,15 @@ ggml_tensor * llm_graph_context::build_attn( // optionally store to KV cache if (k_cur) { - ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); + const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); } if (v_cur) { - ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); + const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il)); } const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); @@ -1326,7 +1341,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; - inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->cross_kq_mask); inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; @@ -1398,8 +1413,11 @@ ggml_tensor * llm_graph_context::build_attn( // store to KV cache { - ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); + const auto & k_idxs = inp->get_k_idxs(); + const auto & v_idxs = inp->get_v_idxs(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il)); } const auto & kq_mask = inp->get_kq_mask(); @@ -1434,8 +1452,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif { const auto n_kv = mctx_cur->get_base()->get_n_kv(); - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask, "KQ_mask", -1); + inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); + inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); + + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1446,8 +1466,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif const auto n_kv = mctx_cur->get_swa()->get_n_kv(); - inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); + inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); + inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); + + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->self_kq_mask_swa); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; @@ -1466,7 +1488,7 @@ ggml_tensor * llm_graph_context::build_rs( uint32_t kv_head, uint32_t kv_size, int32_t rs_zero, - bool avoid_copies) const { + const llm_graph_get_rows_fn & get_state_rows) const { ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size); @@ -1475,19 +1497,11 @@ ggml_tensor * llm_graph_context::build_rs( ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); - ggml_tensor * output_states; - - if (!avoid_copies) { - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // {state_size, kv_size} -> {state_size, n_seqs} - output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); - ggml_build_forward_expand(gf, output_states); - } else { - // FIXME: make the gathering operation happen before the copy below - // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?) - output_states = states; - } + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // {state_size, kv_size} -> {state_size, n_seqs} + ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + ggml_build_forward_expand(gf, output_states); // copy extra states which won't be changed further (between n_seqs and n_kv) ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); @@ -1518,10 +1532,10 @@ ggml_tensor * llm_graph_context::build_rs( ggml_tensor * s, int32_t state_size, int32_t n_seqs, - bool avoid_copies) const { - const auto * mctx_cur = static_cast(mctx); + const llm_graph_get_rows_fn & get_state_rows) const { + const auto * kv_state = static_cast(mctx); - return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies); + return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows); } ggml_tensor * llm_graph_context::build_rs( @@ -1530,10 +1544,10 @@ ggml_tensor * llm_graph_context::build_rs( ggml_tensor * s, int32_t state_size, int32_t n_seqs, - bool avoid_copies) const { - const auto * mctx_cur = static_cast(mctx)->get_recr(); + const llm_graph_get_rows_fn & get_state_rows) const { + const auto * kv_state = static_cast(mctx)->get_recr(); - return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies); + return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows); } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( diff --git a/src/llama-graph.h b/src/llama-graph.h index 4b1ec354df..7bdf656768 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -38,6 +38,7 @@ enum llm_ffn_op_type { LLM_FFN_RELU_SQR, LLM_FFN_SWIGLU, LLM_FFN_GEGLU, + LLM_FFN_REGLU, }; enum llm_ffn_gate_type { @@ -227,8 +228,8 @@ public: ggml_tensor * get_kq_mask() const { return kq_mask_cnv; } - ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch] - ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch] + ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1] + ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1] const llama_hparams & hparams; const llama_cparams & cparams; @@ -248,10 +249,16 @@ public: void set_input(const llama_ubatch * ubatch) override; + ggml_tensor * get_k_idxs() const { return self_k_idxs; } + ggml_tensor * get_v_idxs() const { return self_v_idxs; } + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] + ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] const llama_hparams & hparams; const llama_cparams & cparams; @@ -273,13 +280,23 @@ public: void set_input(const llama_ubatch * ubatch) override; + ggml_tensor * get_k_idxs() const { return self_k_idxs; } + ggml_tensor * get_v_idxs() const { return self_v_idxs; } + ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; } + ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; } + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch] + ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] + ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1] const llama_hparams & hparams; const llama_cparams & cparams; @@ -296,8 +313,8 @@ public: ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; } - ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch] - ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch] + ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] + ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] const llama_cross * cross = nullptr; }; @@ -318,10 +335,16 @@ public: ggml_tensor * s_copy; // I32 [kv_size] + ggml_tensor * get_k_idxs() const { return self_k_idxs; } + ggml_tensor * get_v_idxs() const { return self_v_idxs; } + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] + ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] const llama_hparams & hparams; const llama_cparams & cparams; @@ -335,7 +358,7 @@ public: llm_graph_input_one() {} virtual ~llm_graph_input_one() = default; - void set_input(const llama_ubatch *) override; + void set_input(const llama_ubatch * ubatch) override; ggml_tensor * one = nullptr; // F32 }; @@ -423,6 +446,9 @@ struct llm_graph_params { const llm_graph_cb & cb; }; +// used in build_rs to properly order writes and avoid unnecessary copies +using llm_graph_get_rows_fn = std::function; + struct llm_graph_context { const llm_arch arch; @@ -475,6 +501,7 @@ struct llm_graph_context { std::unique_ptr res; llm_graph_context(const llm_graph_params & params); + virtual ~llm_graph_context() = default; void cb(ggml_tensor * cur, const char * name, int il) const; @@ -661,7 +688,7 @@ struct llm_graph_context { uint32_t kv_head, uint32_t kv_size, int32_t rs_zero, - bool avoid_copies = false) const; + const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; llm_graph_input_rs * build_rs_inp() const; @@ -671,7 +698,7 @@ struct llm_graph_context { ggml_tensor * s, int32_t state_size, int32_t n_seqs, - bool avoid_copies = false) const; + const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; ggml_tensor * build_rs( llm_graph_input_mem_hybrid * inp, @@ -679,7 +706,7 @@ struct llm_graph_context { ggml_tensor * s, int32_t state_size, int32_t n_seqs, - bool avoid_copies = false) const; + const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; ggml_tensor * build_rwkv_token_shift_load( llm_graph_input_rs * inp, diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index bba7a12dc5..86c814d51b 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -73,7 +73,8 @@ uint32_t llama_hparams::n_embd_r() const { // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; + // Corresponds to Mamba's conv_states size + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state); } uint32_t llama_hparams::n_embd_s() const { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index e85afe145a..476d0a5ead 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -114,6 +114,7 @@ struct llama_hparams { uint32_t ssm_d_inner = 0; uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; + uint32_t ssm_n_group = 0; // for hybrid state space models std::array recurrent_layer_arr; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index b9169299c0..fe207ad536 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -113,20 +113,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all ubatches.push_back(std::move(ubatch)); // NOLINT } - auto heads_base = kv_base->prepare(ubatches); - if (heads_base.empty()) { + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split break; } - auto heads_swa = kv_swa->prepare(ubatches); - if (heads_swa.empty()) { + auto sinfos_base = kv_base->prepare(ubatches); + if (sinfos_base.empty()) { break; } - assert(heads_base.size() == heads_swa.size()); + auto sinfos_swa = kv_swa->prepare(ubatches); + if (sinfos_swa.empty()) { + break; + } + + assert(sinfos_base.size() == sinfos_swa.size()); return std::make_unique( - this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); + this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); } while (false); // if it fails, try equal split @@ -135,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all std::vector ubatches; while (true) { - auto ubatch = balloc.split_equal(n_ubatch); + auto ubatch = balloc.split_equal(n_ubatch, false); if (ubatch.n_tokens == 0) { break; @@ -144,20 +149,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all ubatches.push_back(std::move(ubatch)); // NOLINT } - auto heads_base = kv_base->prepare(ubatches); - if (heads_base.empty()) { + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split break; } - auto heads_swa = kv_swa->prepare(ubatches); - if (heads_swa.empty()) { + auto sinfos_base = kv_base->prepare(ubatches); + if (sinfos_base.empty()) { break; } - assert(heads_base.size() == heads_swa.size()); + auto sinfos_swa = kv_swa->prepare(ubatches); + if (sinfos_swa.empty()) { + break; + } + + assert(sinfos_base.size() == sinfos_swa.size()); return std::make_unique( - this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); + this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); } while (false); // TODO: if we fail again, we should attempt different splitting strategies @@ -220,13 +230,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa * kv, - std::vector heads_base, - std::vector heads_swa, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, std::vector ubatches) : ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal - ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)), - ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)), + ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)), + ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)), status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) { } @@ -246,7 +256,7 @@ bool llama_kv_cache_unified_iswa_context::next() { } bool llama_kv_cache_unified_iswa_context::apply() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + assert(!llama_memory_status_is_fail(status)); bool res = true; diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h index 46c1ed614f..23205d826b 100644 --- a/src/llama-kv-cache-unified-iswa.h +++ b/src/llama-kv-cache-unified-iswa.h @@ -74,6 +74,8 @@ private: class llama_kv_cache_unified_iswa_context : public llama_memory_context_i { public: + using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; + // used for errors llama_kv_cache_unified_iswa_context(llama_memory_status status); @@ -90,8 +92,8 @@ public: // used to create a batch processing context from a batch llama_kv_cache_unified_iswa_context( llama_kv_cache_unified_iswa * kv, - std::vector heads_base, - std::vector heads_swa, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, std::vector ubatches); virtual ~llama_kv_cache_unified_iswa_context(); diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 8517b722a9..d3129cc532 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -156,6 +156,13 @@ llama_kv_cache_unified::llama_kv_cache_unified( const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; + + const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); + supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0; + + if (!supports_set_rows) { + LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__); + } } void llama_kv_cache_unified::clear(bool data) { @@ -353,13 +360,18 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( ubatches.push_back(std::move(ubatch)); // NOLINT } - auto heads = prepare(ubatches); - if (heads.empty()) { + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + + auto sinfos = prepare(ubatches); + if (sinfos.empty()) { break; } return std::make_unique( - this, std::move(heads), std::move(ubatches)); + this, std::move(sinfos), std::move(ubatches)); } while (false); return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); @@ -402,12 +414,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct return std::make_unique(this, lctx, do_shift, std::move(dinfo)); } -llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector & ubatches) { - llama_kv_cache_unified::ubatch_heads res; +llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector & ubatches) { + llama_kv_cache_unified::slot_info_vec_t res; struct state { uint32_t head_old; // old position of the head, before placing the ubatch - uint32_t head_new; // new position of the head, after placing the ubatch + + slot_info sinfo; // slot info for the ubatch llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch }; @@ -418,26 +431,29 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std:: bool success = true; for (const auto & ubatch : ubatches) { + // non-continuous slots require support for ggml_set_rows() + const bool cont = supports_set_rows ? false : true; + // only find a suitable slot for the ubatch. don't modify the cells yet - const int32_t head_new = find_slot(ubatch); - if (head_new < 0) { + const auto sinfo_new = find_slot(ubatch, cont); + if (sinfo_new.empty()) { success = false; break; } // remeber the position that we found - res.push_back(head_new); + res.push_back(sinfo_new); // store the old state of the cells in the recovery stack - states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)}); + states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)}); // now emplace the ubatch - apply_ubatch(head_new, ubatch); + apply_ubatch(sinfo_new, ubatch); } // iterate backwards and restore the cells to their original state for (auto it = states.rbegin(); it != states.rend(); ++it) { - cells.set(it->head_new, it->cells); + cells.set(it->sinfo.idxs, it->cells); head = it->head_old; } @@ -539,7 +555,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d return updated; } -int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { +llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const { const uint32_t n_tokens = ubatch.n_tokens; uint32_t head_cur = this->head; @@ -552,7 +568,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { if (n_tokens > cells.size()) { LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); - return -1; + return { }; } if (debug > 0) { @@ -615,15 +631,26 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { uint32_t n_tested = 0; + // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head + // for non-continuous slots, we test the tokens one by one + const uint32_t n_test = cont ? n_tokens : 1; + + slot_info res; + + auto & idxs = res.idxs; + + idxs.reserve(n_tokens); + while (true) { - if (head_cur + n_tokens > cells.size()) { + if (head_cur + n_test > cells.size()) { n_tested += cells.size() - head_cur; head_cur = 0; continue; } - bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { + for (uint32_t i = 0; i < n_test; i++) { + const auto idx = head_cur; + //const llama_pos pos = ubatch.pos[i]; //const llama_seq_id seq_id = ubatch.seq_id[i][0]; @@ -633,19 +660,19 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { // - (disabled) mask causally, if the sequence is the same as the one we are inserting // - mask SWA, using current max pos for that sequence in the cache // always insert in the cell with minimum pos - bool can_use = cells.is_empty(head_cur + i); + bool can_use = cells.is_empty(idx); - if (!can_use && cells.seq_count(head_cur + i) == 1) { - const llama_pos pos_cell = cells.pos_get(head_cur + i); + if (!can_use && cells.seq_count(idx) == 1) { + const llama_pos pos_cell = cells.pos_get(idx); // (disabled) causal mask // note: it's better to purge any "future" tokens beforehand - //if (cells.seq_has(head_cur + i, seq_id)) { + //if (cells.seq_has(idx, seq_id)) { // can_use = pos_cell >= pos; //} if (!can_use) { - const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); + const llama_seq_id seq_id_cell = cells.seq_get(idx); // SWA mask if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { @@ -654,28 +681,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { } } - if (!can_use) { - found = false; - head_cur += i + 1; - n_tested += i + 1; + head_cur++; + n_tested++; + + if (can_use) { + idxs.push_back(idx); + } else { break; } } - if (found) { + if (idxs.size() == n_tokens) { break; } + if (cont) { + idxs.clear(); + } + if (n_tested >= cells.size()) { //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return -1; + return { }; } } - return head_cur; + // we didn't find a suitable slot - return empty result + if (idxs.size() < n_tokens) { + res.clear(); + } + + return res; } -void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { +void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; @@ -683,22 +721,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch seq_pos_max_rm[s] = -1; } - for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { - if (!cells.is_empty(head_cur + i)) { - assert(cells.seq_count(head_cur + i) == 1); + assert(ubatch.n_tokens == sinfo.idxs.size()); - const llama_seq_id seq_id = cells.seq_get(head_cur + i); - const llama_pos pos = cells.pos_get(head_cur + i); + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + const auto idx = sinfo.idxs.at(i); + + if (!cells.is_empty(idx)) { + assert(cells.seq_count(idx) == 1); + + const llama_seq_id seq_id = cells.seq_get(idx); + const llama_pos pos = cells.pos_get(idx); seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); - cells.rm(head_cur + i); + cells.rm(idx); } - cells.pos_set(head_cur + i, ubatch.pos[i]); + cells.pos_set(idx, ubatch.pos[i]); for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { - cells.seq_add(head_cur + i, ubatch.seq_id[i][s]); + cells.seq_add(idx, ubatch.seq_id[i][s]); } } @@ -719,7 +761,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch } // move the head at the end of the slot - head = head_cur + ubatch.n_tokens; + head = sinfo.idxs.back() + 1; } bool llama_kv_cache_unified::get_can_shift() const { @@ -772,47 +814,133 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint 0); } -ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const { +ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * k = layers[ikv].k; + const int64_t n_embd_k_gqa = k->ne[0]; const int64_t n_tokens = k_cur->ne[2]; + k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); + + if (k_idxs && supports_set_rows) { + return ggml_set_rows(ctx, k, k_cur, k_idxs); + } + + // TODO: fallback to old ggml_cpy() method for backwards compatibility + // will be removed when ggml_set_rows() is adopted by all backends + ggml_tensor * k_view = ggml_view_1d(ctx, k, - n_tokens*hparams.n_embd_k_gqa(il), - ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); + n_tokens*n_embd_k_gqa, + ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head()); return ggml_cpy(ctx, k_cur, k_view); } -ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const { +ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * v = layers[ikv].v; + const int64_t n_embd_v_gqa = v->ne[0]; const int64_t n_tokens = v_cur->ne[2]; - v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens); + v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); + + if (v_idxs && supports_set_rows) { + if (!v_trans) { + return ggml_set_rows(ctx, v, v_cur, v_idxs); + } + + // the row becomes a single element + ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]); + + // note: the V cache is transposed when not using flash attention + v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3); + + // note: we can be more explicit here at the cost of extra cont + // however, above we take advantage that a row of single element is always continuous regardless of the row stride + //v_cur = ggml_transpose(ctx, v_cur); + //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]); + + // we broadcast the KV indices n_embd_v_gqa times + // v [1, n_kv, n_embd_v_gqa] + // v_cur [1, n_tokens, n_embd_v_gqa] + // v_idxs [n_tokens, 1, 1] + return ggml_set_rows(ctx, v_view, v_cur, v_idxs); + } + + // TODO: fallback to old ggml_cpy() method for backwards compatibility + // will be removed when ggml_set_rows() is adopted by all backends ggml_tensor * v_view = nullptr; if (!v_trans) { v_view = ggml_view_1d(ctx, v, - n_tokens*hparams.n_embd_v_gqa(il), - ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur); + n_tokens*n_embd_v_gqa, + ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head()); } else { - // note: the V cache is transposed when not using flash attention - v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il), - (v->ne[1])*ggml_element_size(v), - (head_cur)*ggml_element_size(v)); - v_cur = ggml_transpose(ctx, v_cur); + + v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa, + (v->ne[1] )*ggml_element_size(v), + (sinfo.head())*ggml_element_size(v)); } return ggml_cpy(ctx, v_cur, v_view); } +ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { + const uint32_t n_tokens = ubatch.n_tokens; + + ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); + + ggml_set_input(k_idxs); + + return k_idxs; +} + +ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { + const uint32_t n_tokens = ubatch.n_tokens; + + ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); + + ggml_set_input(v_idxs); + + return v_idxs; +} + +void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { + if (!supports_set_rows) { + return; + } + + const uint32_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + int64_t * data = (int64_t *) dst->data; + + for (int64_t i = 0; i < n_tokens; ++i) { + data[i] = sinfo.idxs.at(i); + } +} + +void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { + if (!supports_set_rows) { + return; + } + + const uint32_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + int64_t * data = (int64_t *) dst->data; + + for (int64_t i = 0; i < n_tokens; ++i) { + data[i] = sinfo.idxs.at(i); + } +} + void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { const uint32_t n_tokens = ubatch->n_tokens; @@ -1552,13 +1680,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell ubatch.seq_id[i] = &dest_seq_id; } - const auto head_cur = find_slot(ubatch); - if (head_cur < 0) { + const auto sinfo = find_slot(ubatch, true); + if (sinfo.empty()) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; } - apply_ubatch(head_cur, ubatch); + apply_ubatch(sinfo, ubatch); + + const auto head_cur = sinfo.head(); // keep the head at the old position because we will read the KV data into it in state_read_data() head = head_cur; @@ -1744,7 +1874,11 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { n_kv = kv->get_size(); - head = 0; + + // create a dummy slot info - the actual data is irrelevant. we just need to build the graph + sinfos.resize(1); + sinfos[0].idxs.resize(1); + sinfos[0].idxs[0] = 0; } llama_kv_cache_unified_context::llama_kv_cache_unified_context( @@ -1759,8 +1893,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv, - llama_kv_cache_unified::ubatch_heads heads, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) { + llama_kv_cache_unified::slot_info_vec_t sinfos, + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) { } llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default; @@ -1768,7 +1902,7 @@ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default; bool llama_kv_cache_unified_context::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - if (++i_next >= ubatches.size()) { + if (++i_cur >= ubatches.size()) { return false; } @@ -1776,7 +1910,7 @@ bool llama_kv_cache_unified_context::next() { } bool llama_kv_cache_unified_context::apply() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + assert(!llama_memory_status_is_fail(status)); // no ubatches -> this is a KV cache update if (ubatches.empty()) { @@ -1785,10 +1919,9 @@ bool llama_kv_cache_unified_context::apply() { return true; } - kv->apply_ubatch(heads[i_next], ubatches[i_next]); + kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]); n_kv = kv->get_n_kv(); - head = heads[i_next]; return true; } @@ -1800,7 +1933,7 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const { const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - return ubatches[i_next]; + return ubatches[i_cur]; } uint32_t llama_kv_cache_unified_context::get_n_kv() const { @@ -1815,18 +1948,34 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t return kv->get_v(ctx, il, n_kv); } -ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { - return kv->cpy_k(ctx, k_cur, il, head); +ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const { + return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]); } -ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { - return kv->cpy_v(ctx, v_cur, il, head); +ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const { + return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]); +} + +ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { + return kv->build_input_k_idxs(ctx, ubatch); +} + +ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { + return kv->build_input_v_idxs(ctx, ubatch); } void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const { kv->set_input_k_shift(dst); } +void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const { + kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]); +} + +void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const { + kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]); +} + void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { kv->set_input_kq_mask(dst, ubatch, causal_attn); } diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 4c53f1273a..b8b0356e83 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -24,8 +24,6 @@ public: // this callback is used to filter out layers that should not be included in the cache using layer_filter_cb = std::function; - using ubatch_heads = std::vector; - struct defrag_info { bool empty() const { return ids.empty(); @@ -37,6 +35,32 @@ public: std::vector ids; }; + // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the + // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]] + struct slot_info { + // data for ggml_set_rows + using idx_vec_t = std::vector; + + idx_vec_t idxs; + + uint32_t head() const { + return idxs.at(0); + } + + bool empty() const { + return idxs.empty(); + } + + void clear() { + idxs.clear(); + } + + // TODO: implement + //std::vector seq_idxs; + }; + + using slot_info_vec_t = std::vector; + llama_kv_cache_unified( const llama_model & model, layer_filter_cb && filter, @@ -102,30 +126,37 @@ public: ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; // store k_cur and v_cur in the cache based on the provided head location - ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const; - ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const; + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const; // // preparation API // - // find places for the provided ubatches in the cache, returns the head locations + // find places for the provided ubatches in the cache, returns the slot infos // return empty vector on failure - ubatch_heads prepare(const std::vector & ubatches); + slot_info_vec_t prepare(const std::vector & ubatches); bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); - // return the cell position where we can insert the ubatch - // return -1 on failure to find a contiguous slot of kv cells - int32_t find_slot(const llama_ubatch & ubatch) const; + // find a slot of kv cells that can hold the ubatch + // if cont == true, then the slot must be continuous + // return empty slot_info on failure + slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; - // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens) - void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch); + // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] + void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); // - // set_input API + // input API // + ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; + void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_k_shift (ggml_tensor * dst) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; @@ -157,8 +188,13 @@ private: // SWA const uint32_t n_swa = 0; + // env: LLAMA_KV_CACHE_DEBUG int debug = 0; + // env: LLAMA_SET_ROWS (temporary) + // ref: https://github.com/ggml-org/llama.cpp/pull/14285 + int supports_set_rows = false; + const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; std::vector ctxs; @@ -211,8 +247,8 @@ private: class llama_kv_cache_unified_context : public llama_memory_context_i { public: // some shorthands - using ubatch_heads = llama_kv_cache_unified::ubatch_heads; - using defrag_info = llama_kv_cache_unified::defrag_info; + using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; + using defrag_info = llama_kv_cache_unified::defrag_info; // used for errors llama_kv_cache_unified_context(llama_memory_status status); @@ -231,7 +267,7 @@ public: // used to create a batch procesing context from a batch llama_kv_cache_unified_context( llama_kv_cache_unified * kv, - ubatch_heads heads, + slot_info_vec_t sinfos, std::vector ubatches); virtual ~llama_kv_cache_unified_context(); @@ -257,11 +293,16 @@ public: ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; // store k_cur and v_cur in the cache based on the provided head location - ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; - ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; - void set_input_k_shift(ggml_tensor * dst) const; + ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; + + void set_input_k_shift (ggml_tensor * dst) const; void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; @@ -283,10 +324,10 @@ private: // batch processing context // - // the index of the next ubatch to process - size_t i_next = 0; + // the index of the cur ubatch to process + size_t i_cur = 0; - ubatch_heads heads; + slot_info_vec_t sinfos; std::vector ubatches; @@ -297,7 +338,4 @@ private: // a heuristic, to avoid attending the full cache if it is not yet utilized // as the cache gets filled, the benefit from this heuristic disappears int32_t n_kv; - - // the beginning of the current slot in which the ubatch will be inserted - int32_t head; }; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index c95d635948..0d0dd316fd 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -105,10 +105,30 @@ public: res.resize(n); for (uint32_t j = 0; j < n; ++j) { - res.pos[j] = pos[i + j]; - res.seq[j] = seq[i + j]; + const auto idx = i + j; - assert(shift[i + j] == 0); + res.pos[j] = pos[idx]; + res.seq[j] = seq[idx]; + + assert(shift[idx] == 0); + } + + return res; + } + + // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1]) + llama_kv_cells_unified cp(const std::vector & idxs) const { + llama_kv_cells_unified res; + + res.resize(idxs.size()); + + for (uint32_t j = 0; j < idxs.size(); ++j) { + const auto idx = idxs[j]; + + res.pos[j] = pos[idx]; + res.seq[j] = seq[idx]; + + assert(shift[idx] == 0); } return res; @@ -119,26 +139,58 @@ public: assert(i + other.pos.size() <= pos.size()); for (uint32_t j = 0; j < other.pos.size(); ++j) { - if (pos[i + j] == -1 && other.pos[j] != -1) { + const auto idx = i + j; + + if (pos[idx] == -1 && other.pos[j] != -1) { used.insert(i + j); } - if (pos[i + j] != -1 && other.pos[j] == -1) { + if (pos[idx] != -1 && other.pos[j] == -1) { used.erase(i + j); } - if (pos[i + j] != -1) { + if (pos[idx] != -1) { seq_pos_rm(i + j); } - pos[i + j] = other.pos[j]; - seq[i + j] = other.seq[j]; + pos[idx] = other.pos[j]; + seq[idx] = other.seq[j]; - if (pos[i + j] != -1) { + if (pos[idx] != -1) { seq_pos_add(i + j); } - assert(shift[i + j] == 0); + assert(shift[idx] == 0); + } + } + + // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1]) + void set(const std::vector & idxs, const llama_kv_cells_unified & other) { + assert(idxs.size() == other.pos.size()); + + for (uint32_t j = 0; j < other.pos.size(); ++j) { + const auto idx = idxs[j]; + + if (pos[idx] == -1 && other.pos[j] != -1) { + used.insert(idx); + } + + if (pos[idx] != -1 && other.pos[j] == -1) { + used.erase(idx); + } + + if (pos[idx] != -1) { + seq_pos_rm(idx); + } + + pos[idx] = other.pos[j]; + seq[idx] = other.seq[j]; + + if (pos[idx] != -1) { + seq_pos_add(idx); + } + + assert(shift[idx] == 0); } } diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 15cde98d13..6cd10db06b 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch); + ubatch = balloc.split_equal(n_ubatch, false); } if (ubatch.n_tokens == 0) { @@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + // prepare the recurrent batches first if (!mem_recr->prepare(ubatches)) { // TODO: will the recurrent cache be in an undefined context at this point? @@ -195,11 +200,11 @@ llama_memory_hybrid_context::llama_memory_hybrid_context( llama_memory_hybrid_context::llama_memory_hybrid_context( llama_memory_hybrid * mem, - std::vector heads_attn, + slot_info_vec_t sinfos_attn, std::vector ubatches) : ubatches(std::move(ubatches)), // note: here we copy the ubatches. not sure if this is ideal - ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)), + ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)), ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { } @@ -218,7 +223,7 @@ bool llama_memory_hybrid_context::next() { } bool llama_memory_hybrid_context::apply() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + assert(!llama_memory_status_is_fail(status)); bool res = true; diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h index f0c2420e9a..4ac3181757 100644 --- a/src/llama-memory-hybrid.h +++ b/src/llama-memory-hybrid.h @@ -92,6 +92,8 @@ private: class llama_memory_hybrid_context : public llama_memory_context_i { public: + using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; + // init failure explicit llama_memory_hybrid_context(llama_memory_status status); @@ -107,7 +109,7 @@ public: // init success llama_memory_hybrid_context( llama_memory_hybrid * mem, - std::vector heads_attn, + slot_info_vec_t sinfos_attn, std::vector ubatches); ~llama_memory_hybrid_context() = default; diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 1b1e95d567..a1b5b1a272 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -363,30 +363,40 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { } llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { - std::vector ubatches; + do { + balloc.split_reset(); - while (true) { - llama_ubatch ubatch; + std::vector ubatches; + while (true) { + llama_ubatch ubatch; - if (embd_all) { - // if all tokens are output, split by sequence - ubatch = balloc.split_seq(n_ubatch); - } else { - ubatch = balloc.split_equal(n_ubatch); + if (embd_all) { + // if all tokens are output, split by sequence + ubatch = balloc.split_seq(n_ubatch); + } else { + ubatch = balloc.split_equal(n_ubatch, false); + } + + if (ubatch.n_tokens == 0) { + break; + } + + ubatches.push_back(std::move(ubatch)); // NOLINT } - if (ubatch.n_tokens == 0) { + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split break; } - ubatches.push_back(std::move(ubatch)); // NOLINT - } + if (!prepare(ubatches)) { + break; + } - if (!prepare(ubatches)) { - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); - } + return std::make_unique(this, std::move(ubatches)); + } while (false); - return std::make_unique(this, std::move(ubatches)); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } llama_memory_context_ptr llama_memory_recurrent::init_full() { @@ -1066,7 +1076,15 @@ bool llama_memory_recurrent_context::next() { } bool llama_memory_recurrent_context::apply() { - assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + assert(!llama_memory_status_is_fail(status)); + + // no ubatches -> this is an update + if (ubatches.empty()) { + // recurrent cache never performs updates + assert(status == LLAMA_MEMORY_STATUS_NO_UPDATE); + + return true; + } mem->find_slot(ubatches[i_next]); diff --git a/src/llama-memory.cpp b/src/llama-memory.cpp index f1107672c6..ca6844c32a 100644 --- a/src/llama-memory.cpp +++ b/src/llama-memory.cpp @@ -40,3 +40,20 @@ llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_me // if either status has an update, then the combined status has an update return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE; } + +bool llama_memory_status_is_fail(llama_memory_status status) { + switch (status) { + case LLAMA_MEMORY_STATUS_SUCCESS: + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + return false; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + return true; + } + } + + return false; +} diff --git a/src/llama-memory.h b/src/llama-memory.h index 16b7e5ee24..e8ba336e85 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -31,6 +31,9 @@ enum llama_memory_status { // useful for implementing hybrid memory types (e.g. iSWA) llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1); +// helper function for checking if a memory status indicates a failure +bool llama_memory_status_is_fail(llama_memory_status status); + // the interface for managing the memory context during batch processing // this interface is implemented per memory type. see: // - llama_kv_cache_unified_context diff --git a/src/llama-model.cpp b/src/llama-model.cpp index fc39195ed5..fc4e9a5af0 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -47,6 +47,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_475M: return "475M"; case LLM_TYPE_770M: return "770M"; case LLM_TYPE_780M: return "780M"; + case LLM_TYPE_0_3B: return "0.3B"; case LLM_TYPE_0_5B: return "0.5B"; case LLM_TYPE_0_6B: return "0.6B"; case LLM_TYPE_1B: return "1B"; @@ -101,6 +102,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_57B_A14B: return "57B.A14B"; case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; + case LLM_TYPE_A13B: return "A13B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_E2B: return "E2B"; @@ -207,23 +209,27 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w } break; case GGML_OP_SSM_CONV: { - // FIXME - ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 3; + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs); op_tensor = ggml_ssm_conv(ctx, conv_x, w); } break; case GGML_OP_SSM_SCAN: { - // FIXME - const int64_t d_state = w->ne[0]; - const int64_t d_inner = w->ne[1]; + // w is ssm_a, which is used to distinguish Mamba-1 and Mamba-2 + const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0]; + const int64_t n_head = w->ne[1]; + const int64_t head_dim = hparams.ssm_d_inner / n_head; + const int64_t n_group = hparams.ssm_n_group ? hparams.ssm_n_group : 1; const int64_t n_seq_tokens = 512; - const int64_t n_seqs = 1; - ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); - ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); + const int64_t n_seqs = 3; + ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids); } break; case GGML_OP_RWKV_WKV6: { @@ -1080,6 +1086,38 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MAMBA2: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_SMALL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_MEDIUM; break; + case 1536: type = LLM_TYPE_LARGE; break; + case 2048: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1504,6 +1542,35 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_ERNIE4_5: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 18: type = LLM_TYPE_0_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_HUNYUAN_MOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_A13B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_SMOLLM3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.n_no_rope_layer_step = 4; + + switch (hparams.n_layer) { + case 36: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -3111,6 +3178,54 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } + } break; + case LLM_ARCH_MAMBA2: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + // out_proj layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); } @@ -4344,6 +4459,106 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_ERNIE4_5: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_HUNYUAN_MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } break; + case LLM_ARCH_SMOLLM3: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } @@ -4587,10 +4802,14 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + } + + if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); if (!classifier_labels.empty()) { @@ -5539,12 +5758,10 @@ struct llm_build_falcon : public llm_graph_context { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); // using mode = 2 for neox mode @@ -5821,12 +6038,10 @@ struct llm_build_dbrx : public llm_graph_context { cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); cb(cur, "wqkv_clamped", il); - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( @@ -6337,12 +6552,10 @@ struct llm_build_neo_bert : public llm_graph_context { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); // RoPE @@ -6572,8 +6785,8 @@ struct llm_build_mpt : public llm_graph_context { cb(cur, "wqkv_clamped", il); } - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)); ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); cb(Qcur, "Qcur", il); @@ -6593,6 +6806,12 @@ struct llm_build_mpt : public llm_graph_context { model.layers[il].attn_k_norm_b, LLM_NORM, il); cb(Kcur, "Kcur", il); + } else { + Qcur = ggml_cont(ctx0, Qcur); + cb(Qcur, "Qcur", il); + + Kcur = ggml_cont(ctx0, Kcur); + cb(Kcur, "Kcur", il); } Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); @@ -6847,12 +7066,10 @@ struct llm_build_qwen : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd))); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); // using mode = 2 for neox mode @@ -7617,21 +7834,21 @@ struct llm_build_phi2 : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( @@ -7755,21 +7972,21 @@ struct llm_build_phi3 : public llm_graph_context { cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); cb(cur, "wqkv", il); - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd))); + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd)); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa))); } else { Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( @@ -8125,12 +8342,10 @@ struct llm_build_codeshell : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( @@ -8546,8 +8761,6 @@ struct llm_build_minicpm3 : public llm_graph_context { ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); cb(k_pe, "k_pe", il); - // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont - kv_compressed = ggml_cont(ctx0, kv_compressed); kv_compressed = build_norm(kv_compressed, model.layers[il].attn_kv_a_norm, NULL, LLM_NORM_RMS, il); @@ -8574,12 +8787,6 @@ struct llm_build_minicpm3 : public llm_graph_context { v_states = ggml_cont(ctx0, v_states); cb(v_states, "v_states", il); - v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), - 0); - cb(v_states, "v_states", il); - - q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this q_pe = ggml_rope_ext( ctx0, q_pe, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -8588,7 +8795,6 @@ struct llm_build_minicpm3 : public llm_graph_context { cb(q_pe, "q_pe", il); // shared RoPE key - k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this k_pe = ggml_rope_ext( ctx0, k_pe, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -9622,9 +9828,7 @@ struct llm_build_starcoder2 : public llm_graph_context { }; struct llm_build_mamba : public llm_graph_context { - const llama_model & model; - - llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) { + llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -9642,7 +9846,11 @@ struct llm_build_mamba : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il); + if (model.arch == LLM_ARCH_MAMBA2) { + cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il); + } else { + cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il); + } if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); @@ -9676,11 +9884,11 @@ struct llm_build_mamba : public llm_graph_context { ggml_build_forward_expand(gf, cur); } - // TODO: split ggml_tensor * build_mamba_layer( llm_graph_input_rs * inp, ggml_cgraph * gf, ggml_tensor * cur, + const llama_model & model, const llama_ubatch & ubatch, int il) const { const auto * mctx_cur = static_cast(mctx); @@ -9691,6 +9899,8 @@ struct llm_build_mamba : public llm_graph_context { const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t dt_rank = hparams.ssm_dt_rank; + const int64_t n_head = d_inner; + const int64_t head_dim = 1; const int64_t n_seqs = ubatch.n_seqs; // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; @@ -9706,15 +9916,8 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // (ab)using the KV cache to store the states - ggml_tensor * conv = build_rs( - inp, gf, conv_states_all, - hparams.n_embd_r(), n_seqs); + ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); - ggml_tensor * ssm = build_rs( - inp, gf, ssm_states_all, - hparams.n_embd_s(), n_seqs); - ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9763,8 +9966,8 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x); // split ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); - ggml_tensor * B = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); - ggml_tensor * C = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); + ggml_tensor * B = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); + ggml_tensor * C = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers if (ssm_dt_b_c_rms) { @@ -9777,23 +9980,36 @@ struct llm_build_mamba : public llm_graph_context { dt = build_lora_mm(model.layers[il].ssm_dt, dt); dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); - // Custom operator to optimize the parallel associative scan - // as described in the Annex D of the Mamba paper. - // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C); + cur = x; + x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs); + + ggml_tensor * A = model.layers[il].ssm_a; + + // use the states and the indices provided by build_recurrent_state + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), + ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]), ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); + ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0); // TODO: skip computing output earlier for unused tokens - // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} - y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, model.layers[il].ssm_d)); y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} @@ -9802,7 +10018,136 @@ struct llm_build_mamba : public llm_graph_context { // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); - //cb(cur, "mamba_out", il); + // cb(cur, "mamba_out", il); + + return cur; + } + + ggml_tensor * build_mamba2_layer( + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_model & model, + const llama_ubatch & ubatch, + int il) const { + const auto * mctx_cur = static_cast(mctx); + + const auto kv_head = mctx_cur->get_head(); + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t head_dim = d_inner / n_head; + const int64_t n_group = hparams.ssm_n_group; + const int64_t n_seqs = ubatch.n_seqs; + + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); + conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + + // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} + ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); + + // split the above in three + ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); + ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); + ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); + + // conv + { + // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs} + ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0); + + // copy last (d_conv - 1) columns back into the state cache + ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv, + ggml_view_1d(ctx0, conv_states_all, + (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // For simultaneous sequences, all sequences need to have the same length. + xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + + // bias + xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b); + + xBC = ggml_silu(ctx0, xBC); + } + + // ssm + { + // These correspond to V K Q in SSM/attention duality + ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0); + ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC)); + ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC)); + + // {n_head, n_seq_tokens, n_seqs} + dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); + + ggml_tensor * A = model.layers[il].ssm_a; + + // use the states and the indices provided by build_recurrent_state + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size()); + + // TODO: use semistructured matrices to implement state-space duality + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + + // store last states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), + ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); + + // TODO: skip computing output earlier for unused tokens + + y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + + // grouped RMS norm + y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); + y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); + + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + cur = build_lora_mm(model.layers[il].ssm_out, y); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + // cb(cur, "mamba_out", il); return cur; } @@ -10514,10 +10859,10 @@ struct llm_build_openelm : public llm_graph_context { cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0)); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0); cb(Qcur, "Qcur", il); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head); cb(Kcur, "Kcur", il); ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv))); @@ -10639,12 +10984,10 @@ struct llm_build_gptneox : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); - ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( @@ -11889,6 +12232,8 @@ struct llm_build_chatglm : public llm_graph_context { if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } else { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); @@ -11896,13 +12241,11 @@ struct llm_build_chatglm : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); } - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); @@ -12023,6 +12366,8 @@ struct llm_build_glm4 : public llm_graph_context { if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); } else { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); @@ -12030,13 +12375,11 @@ struct llm_build_glm4 : public llm_graph_context { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); } - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Qcur = ggml_rope_ext( @@ -14125,6 +14468,136 @@ struct llm_build_dots1 : public llm_graph_context { } }; +struct llm_build_ernie4_5 : public llm_graph_context { + llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + { + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + } + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_arcee : public llm_graph_context { llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -14260,6 +14733,304 @@ struct llm_build_arcee : public llm_graph_context { } }; +struct llm_build_hunyuan_moe : public llm_graph_context { + llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, nullptr, + LLM_NORM_RMS, il); + cb(Kcur, "Kcur_norm", il); + + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, nullptr, + LLM_NORM_RMS, il); + cb(Qcur, "Qcur_norm", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network (non-MoE) + ggml_tensor * cur_mlp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur_mlp, "ffn_mlp", il); + + // MoE branch + ggml_tensor * cur_moe = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, + true, // norm_topk_prob + false, + 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur_moe, "ffn_moe_out", il); + + ggml_tensor * ffn_out = ggml_add(ctx0, cur_moe, cur_mlp); + cb(ffn_out, "ffn_out", il); + + cur = ggml_add(ctx0, ffn_out, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_smollm3 : public llm_graph_context { + llm_build_smollm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (use_rope) { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; @@ -14495,6 +15266,7 @@ llm_graph_result_ptr llama_model::build_graph( llm = std::make_unique(*this, params, gf); } break; case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: { llm = std::make_unique(*this, params, gf); } break; @@ -14635,6 +15407,18 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_ERNIE4_5: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_HUNYUAN_MOE: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_SMOLLM3: + { + llm = std::make_unique(*this, params, gf); + } break; default: GGML_ABORT("fatal error"); } @@ -14751,6 +15535,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_REFACT: case LLM_ARCH_BLOOM: case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_T5: case LLM_ARCH_T5ENCODER: @@ -14785,7 +15570,9 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_CHAMELEON: case LLM_ARCH_BAILINGMOE: case LLM_ARCH_NEO_BERT: + case LLM_ARCH_SMOLLM3: case LLM_ARCH_ARCEE: + case LLM_ARCH_ERNIE4_5: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -14821,6 +15608,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_EXAONE: case LLM_ARCH_MINICPM3: case LLM_ARCH_DOTS1: + case LLM_ARCH_HUNYUAN_MOE: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/src/llama-model.h b/src/llama-model.h index 40063b790d..70a6dc89e1 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -39,6 +39,7 @@ enum llm_type { LLM_TYPE_475M, LLM_TYPE_770M, LLM_TYPE_780M, + LLM_TYPE_0_3B, LLM_TYPE_0_5B, LLM_TYPE_0_6B, LLM_TYPE_1B, @@ -93,6 +94,7 @@ enum llm_type { LLM_TYPE_57B_A14B, LLM_TYPE_17B_16E, // llama4 Scout LLM_TYPE_17B_128E, // llama4 Maverick + LLM_TYPE_A13B, LLM_TYPE_30B_A3B, LLM_TYPE_235B_A22B, LLM_TYPE_E2B, @@ -171,6 +173,7 @@ struct llama_layer { struct ggml_tensor * ffn_sub_norm = nullptr; struct ggml_tensor * attn_norm_cross = nullptr; struct ggml_tensor * attn_norm_enc = nullptr; + struct ggml_tensor * ssm_norm = nullptr; // attention struct ggml_tensor * wq = nullptr; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 5c9eb87566..551bba171c 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -351,6 +351,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { break; case LLAMA_VOCAB_PRE_TYPE_STABLELM2: case LLAMA_VOCAB_PRE_TYPE_QWEN2: + case LLAMA_VOCAB_PRE_TYPE_HUNYUAN: regex_exprs = { // original regex from tokenizer.json // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" @@ -1656,6 +1657,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "seed-coder") { pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER; clean_spaces = false; + } else if ( + tokenizer_pre == "hunyuan") { + pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN; + clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d1b2ff10d1..1d837b4322 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -24,10 +24,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include @@ -317,6 +319,538 @@ enum test_mode { MODE_GRAD, }; +// Output format support similar to llama-bench +enum output_formats { CONSOLE, SQL }; + +static const char * output_format_str(output_formats format) { + switch (format) { + case CONSOLE: + return "console"; + case SQL: + return "sql"; + default: + GGML_ABORT("invalid output format"); + } +} + +static bool output_format_from_str(const std::string & s, output_formats & format) { + if (s == "console") { + format = CONSOLE; + } else if (s == "sql") { + format = SQL; + } else { + return false; + } + return true; +} + +// Test result structure for SQL output +struct test_result { + std::string test_time; + std::string build_commit; + std::string backend_name; + std::string op_name; + std::string op_params; + std::string test_mode; + bool supported; + bool passed; + std::string error_message; + double time_us; + double flops; + double bandwidth_gb_s; + size_t memory_kb; + int n_runs; + + test_result() { + // Initialize with default values + time_us = 0.0; + flops = 0.0; + bandwidth_gb_s = 0.0; + memory_kb = 0; + n_runs = 0; + supported = false; + passed = false; + + // Set test time + time_t t = time(NULL); + char buf[32]; + std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t)); + test_time = buf; + + // Set build info + build_commit = ggml_commit(); + } + + test_result(const std::string & backend_name, const std::string & op_name, const std::string & op_params, + const std::string & test_mode, bool supported, bool passed, const std::string & error_message = "", + double time_us = 0.0, double flops = 0.0, double bandwidth_gb_s = 0.0, size_t memory_kb = 0, + int n_runs = 0) : + backend_name(backend_name), + op_name(op_name), + op_params(op_params), + test_mode(test_mode), + supported(supported), + passed(passed), + error_message(error_message), + time_us(time_us), + flops(flops), + bandwidth_gb_s(bandwidth_gb_s), + memory_kb(memory_kb), + n_runs(n_runs) { + // Set test time + time_t t = time(NULL); + char buf[32]; + std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t)); + test_time = buf; + + // Set build info + build_commit = ggml_commit(); + } + + static const std::vector & get_fields() { + static const std::vector fields = { + "test_time", "build_commit", "backend_name", "op_name", "op_params", "test_mode", "supported", + "passed", "error_message", "time_us", "flops", "bandwidth_gb_s", "memory_kb", "n_runs" + }; + return fields; + } + + enum field_type { STRING, BOOL, INT, FLOAT }; + + static field_type get_field_type(const std::string & field) { + if (field == "supported" || field == "passed") { + return BOOL; + } + if (field == "memory_kb" || field == "n_runs") { + return INT; + } + if (field == "time_us" || field == "flops" || field == "bandwidth_gb_s") { + return FLOAT; + } + return STRING; + } + + std::vector get_values() const { + return { test_time, + build_commit, + backend_name, + op_name, + op_params, + test_mode, + std::to_string(supported), + std::to_string(passed), + error_message, + std::to_string(time_us), + std::to_string(flops), + std::to_string(bandwidth_gb_s), + std::to_string(memory_kb), + std::to_string(n_runs) }; + } +}; + +// Printer classes for different output formats +enum class test_status_t { NOT_SUPPORTED, OK, FAIL }; + +struct test_operation_info { + std::string op_name; + std::string op_params; + std::string backend_name; + test_status_t status = test_status_t::OK; + std::string failure_reason; + + // Additional information fields that were previously in separate structs + std::string error_component; + std::string error_details; + + // Gradient info + int64_t gradient_index = -1; + std::string gradient_param_name; + float gradient_value = 0.0f; + + // MAA error info + double maa_error = 0.0; + double maa_threshold = 0.0; + + // Flags for different types of information + bool has_error = false; + bool has_gradient_info = false; + bool has_maa_error = false; + bool is_compare_failure = false; + bool is_large_tensor_skip = false; + + test_operation_info() = default; + + test_operation_info(const std::string & op_name, const std::string & op_params, const std::string & backend_name, + test_status_t status = test_status_t::OK, const std::string & failure_reason = "") : + op_name(op_name), + op_params(op_params), + backend_name(backend_name), + status(status), + failure_reason(failure_reason) {} + + // Set error information + void set_error(const std::string & component, const std::string & details) { + has_error = true; + error_component = component; + error_details = details; + if (status == test_status_t::OK) { + status = test_status_t::FAIL; + } + } + + // Set gradient information + void set_gradient_info(int64_t index, const std::string & param_name, float value) { + has_gradient_info = true; + gradient_index = index; + gradient_param_name = param_name; + gradient_value = value; + if (status == test_status_t::OK) { + status = test_status_t::FAIL; + } + } + + // Set MAA error information + void set_maa_error(double error, double threshold) { + has_maa_error = true; + maa_error = error; + maa_threshold = threshold; + if (status == test_status_t::OK) { + status = test_status_t::FAIL; + } + } + + // Set compare failure + void set_compare_failure() { + is_compare_failure = true; + if (status == test_status_t::OK) { + status = test_status_t::FAIL; + } + } + + // Set large tensor skip + void set_large_tensor_skip() { is_large_tensor_skip = true; } +}; + +struct test_summary_info { + size_t tests_passed; + size_t tests_total; + bool is_backend_summary = false; // true for backend summary, false for test summary + + test_summary_info() = default; + + test_summary_info(size_t tests_passed, size_t tests_total, bool is_backend_summary = false) : + tests_passed(tests_passed), + tests_total(tests_total), + is_backend_summary(is_backend_summary) {} +}; + +struct testing_start_info { + size_t device_count; + + testing_start_info() = default; + + testing_start_info(size_t device_count) : device_count(device_count) {} +}; + +struct backend_init_info { + size_t device_index; + size_t total_devices; + std::string device_name; + bool skipped = false; + std::string skip_reason; + std::string description; + size_t memory_total_mb = 0; + size_t memory_free_mb = 0; + bool has_memory_info = false; + + backend_init_info() = default; + + backend_init_info(size_t device_index, size_t total_devices, const std::string & device_name, bool skipped = false, + const std::string & skip_reason = "", const std::string & description = "", + size_t memory_total_mb = 0, size_t memory_free_mb = 0, bool has_memory_info = false) : + device_index(device_index), + total_devices(total_devices), + device_name(device_name), + skipped(skipped), + skip_reason(skip_reason), + description(description), + memory_total_mb(memory_total_mb), + memory_free_mb(memory_free_mb), + has_memory_info(has_memory_info) {} +}; + +struct backend_status_info { + std::string backend_name; + test_status_t status; + + backend_status_info() = default; + + backend_status_info(const std::string & backend_name, test_status_t status) : + backend_name(backend_name), + status(status) {} +}; + +struct overall_summary_info { + size_t backends_passed; + size_t backends_total; + bool all_passed; + + overall_summary_info() = default; + + overall_summary_info(size_t backends_passed, size_t backends_total, bool all_passed) : + backends_passed(backends_passed), + backends_total(backends_total), + all_passed(all_passed) {} +}; + +struct printer { + virtual ~printer() {} + + FILE * fout = stdout; + + virtual void print_header() {} + + virtual void print_test_result(const test_result & result) = 0; + + virtual void print_footer() {} + + virtual void print_operation(const test_operation_info & info) { (void) info; } + + virtual void print_summary(const test_summary_info & info) { (void) info; } + + virtual void print_testing_start(const testing_start_info & info) { (void) info; } + + virtual void print_backend_init(const backend_init_info & info) { (void) info; } + + virtual void print_backend_status(const backend_status_info & info) { (void) info; } + + virtual void print_overall_summary(const overall_summary_info & info) { (void) info; } +}; + +struct console_printer : public printer { + void print_test_result(const test_result & result) override { + if (result.test_mode == "test") { + print_test_console(result); + } else if (result.test_mode == "perf") { + print_perf_console(result); + } + } + + void print_operation(const test_operation_info & info) override { + printf(" %s(%s): ", info.op_name.c_str(), info.op_params.c_str()); + fflush(stdout); + + // Handle large tensor skip first + if (info.is_large_tensor_skip) { + printf("skipping large tensors for speed \n"); + return; + } + + // Handle not supported status + if (info.status == test_status_t::NOT_SUPPORTED) { + if (!info.failure_reason.empty()) { + printf("not supported [%s]\n", info.failure_reason.c_str()); + } else { + printf("not supported [%s]\n", info.backend_name.c_str()); + } + return; + } + + // Handle errors and additional information + if (info.has_error) { + if (info.error_component == "allocation") { + fprintf(stderr, "failed to allocate tensors [%s] ", info.backend_name.c_str()); + } else if (info.error_component == "backend") { + fprintf(stderr, " Failed to initialize %s backend\n", info.backend_name.c_str()); + } else { + fprintf(stderr, "Error in %s: %s\n", info.error_component.c_str(), info.error_details.c_str()); + } + } + + // Handle gradient info + if (info.has_gradient_info) { + printf("[%s] nonfinite gradient at index %" PRId64 " (%s=%f) ", info.op_name.c_str(), info.gradient_index, + info.gradient_param_name.c_str(), info.gradient_value); + } + + // Handle MAA error + if (info.has_maa_error) { + printf("[%s] MAA = %.9f > %.9f ", info.op_name.c_str(), info.maa_error, info.maa_threshold); + } + + // Handle compare failure + if (info.is_compare_failure) { + printf("compare failed "); + } + + // Print final status + if (info.status == test_status_t::OK) { + printf("\033[1;32mOK\033[0m\n"); + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + } + + void print_summary(const test_summary_info & info) override { + if (info.is_backend_summary) { + printf("%zu/%zu backends passed\n", info.tests_passed, info.tests_total); + } else { + printf(" %zu/%zu tests passed\n", info.tests_passed, info.tests_total); + } + } + + void print_backend_status(const backend_status_info & info) override { + printf(" Backend %s: ", info.backend_name.c_str()); + if (info.status == test_status_t::OK) { + printf("\033[1;32mOK\033[0m\n"); + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + } + + void print_testing_start(const testing_start_info & info) override { + printf("Testing %zu devices\n\n", info.device_count); + } + + void print_backend_init(const backend_init_info & info) override { + printf("Backend %zu/%zu: %s\n", info.device_index + 1, info.total_devices, info.device_name.c_str()); + + if (info.skipped) { + printf(" %s\n", info.skip_reason.c_str()); + return; + } + + if (!info.description.empty()) { + printf(" Device description: %s\n", info.description.c_str()); + } + + if (info.has_memory_info) { + printf(" Device memory: %zu MB (%zu MB free)\n", info.memory_total_mb, info.memory_free_mb); + } + + printf("\n"); + } + + void print_overall_summary(const overall_summary_info & info) override { + printf("%zu/%zu backends passed\n", info.backends_passed, info.backends_total); + if (info.all_passed) { + printf("\033[1;32mOK\033[0m\n"); + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + } + + private: + void print_test_console(const test_result & result) { + printf(" %s(%s): ", result.op_name.c_str(), result.op_params.c_str()); + fflush(stdout); + + if (!result.supported) { + printf("not supported [%s] ", result.backend_name.c_str()); + printf("\n"); + return; + } + + if (result.passed) { + printf("\033[1;32mOK\033[0m\n"); + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + } + + void print_perf_console(const test_result & result) { + int len = printf(" %s(%s): ", result.op_name.c_str(), result.op_params.c_str()); + fflush(stdout); + + if (!result.supported) { + printf("not supported\n"); + return; + } + + // align while also leaving some margin for variations in parameters + int align = 8; + int last = (len + align - 1) / align * align; + if (last - len < 5) { + last += align; + } + printf("%*s", last - len, ""); + + printf(" %8d runs - %8.2f us/run - ", result.n_runs, result.time_us); + + if (result.flops > 0) { + auto format_flops = [](double flops) -> std::string { + char buf[256]; + if (flops >= 1e12) { + snprintf(buf, sizeof(buf), "%6.2f TFLOP", flops / 1e12); + } else if (flops >= 1e9) { + snprintf(buf, sizeof(buf), "%6.2f GFLOP", flops / 1e9); + } else if (flops >= 1e6) { + snprintf(buf, sizeof(buf), "%6.2f MFLOP", flops / 1e6); + } else { + snprintf(buf, sizeof(buf), "%6.2f kFLOP", flops / 1e3); + } + return buf; + }; + uint64_t op_flops_per_run = result.flops * result.time_us / 1e6; + printf("%s/run - \033[1;34m%sS\033[0m", format_flops(op_flops_per_run).c_str(), + format_flops(result.flops).c_str()); + } else { + printf("%8zu kB/run - \033[1;34m%7.2f GB/s\033[0m", result.memory_kb, result.bandwidth_gb_s); + } + printf("\n"); + } +}; + +struct sql_printer : public printer { + static std::string get_sql_field_type(const std::string & field) { + switch (test_result::get_field_type(field)) { + case test_result::STRING: + return "TEXT"; + case test_result::BOOL: + case test_result::INT: + return "INTEGER"; + case test_result::FLOAT: + return "REAL"; + default: + GGML_ABORT("invalid field type"); + } + } + + void print_header() override { + std::vector fields = test_result::get_fields(); + fprintf(fout, "CREATE TABLE IF NOT EXISTS test_backend_ops (\n"); + for (size_t i = 0; i < fields.size(); i++) { + fprintf(fout, " %s %s%s\n", fields[i].c_str(), get_sql_field_type(fields[i]).c_str(), + i < fields.size() - 1 ? "," : ""); + } + fprintf(fout, ");\n\n"); + } + + void print_test_result(const test_result & result) override { + fprintf(fout, "INSERT INTO test_backend_ops ("); + std::vector fields = test_result::get_fields(); + for (size_t i = 0; i < fields.size(); i++) { + fprintf(fout, "%s%s", fields[i].c_str(), i < fields.size() - 1 ? ", " : ""); + } + fprintf(fout, ") VALUES ("); + std::vector values = result.get_values(); + for (size_t i = 0; i < values.size(); i++) { + fprintf(fout, "'%s'%s", values[i].c_str(), i < values.size() - 1 ? ", " : ""); + } + fprintf(fout, ");\n"); + } +}; + +static std::unique_ptr create_printer(output_formats format) { + switch (format) { + case CONSOLE: + return std::make_unique(); + case SQL: + return std::make_unique(); + } + GGML_ABORT("invalid output format"); +} + struct test_case { virtual ~test_case() {} @@ -382,6 +916,8 @@ struct test_case { return 0; } + virtual bool run_whole_graph() { return false; } + ggml_cgraph * gf = nullptr; ggml_cgraph * gb = nullptr; @@ -432,7 +968,7 @@ struct test_case { return t; } - bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name) { + bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, printer * output_printer) { mode = MODE_TEST; ggml_init_params params = { @@ -449,29 +985,33 @@ struct test_case { add_sentinel(ctx); ggml_tensor * out = build_graph(ctx); - - if (op_name != nullptr && op_desc(out) != op_name) { + std::string current_op_name = op_desc(out); + if (op_name != nullptr && current_op_name != op_name) { //printf(" %s: skipping\n", op_desc(out).c_str()); ggml_free(ctx); return true; } - printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str()); - fflush(stdout); - // check if the backends support the ops bool supported = true; for (ggml_backend_t backend : {backend1, backend2}) { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { if (!ggml_backend_supports_op(backend, t)) { - printf("not supported [%s] ", ggml_backend_name(backend)); supported = false; break; } } } + if (!supported) { - printf("\n"); + // Create test result for unsupported operation + test_result result(ggml_backend_name(backend1), current_op_name, vars(), "test", + false, false, "not supported"); + + if (output_printer) { + output_printer->print_test_result(result); + } + ggml_free(ctx); return true; } @@ -574,26 +1114,26 @@ struct test_case { GGML_UNUSED(index); }; - const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud); - - if (!cmp_ok) { - printf("compare failed "); - } + const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud, run_whole_graph() ? out : nullptr); ggml_backend_buffer_free(buf); ggml_free(ctx); - if (ud.ok && cmp_ok) { - printf("\033[1;32mOK\033[0m\n"); - return true; + // Create test result + bool test_passed = ud.ok && cmp_ok; + std::string error_msg = test_passed ? "" : (!cmp_ok ? "compare failed" : "test failed"); + test_result result(ggml_backend_name(backend1), current_op_name, vars(), "test", supported, test_passed, + error_msg); + + if (output_printer) { + output_printer->print_test_result(result); } - printf("\033[1;31mFAIL\033[0m\n"); - return false; + return test_passed; } - bool eval_perf(ggml_backend_t backend, const char * op_name) { + bool eval_perf(ggml_backend_t backend, const char * op_name, printer * output_printer) { mode = MODE_PERF; static const size_t graph_nodes = 8192; @@ -606,30 +1146,26 @@ struct test_case { ggml_context_ptr ctx(ggml_init(params)); // smart ptr GGML_ASSERT(ctx); - ggml_tensor * out = build_graph(ctx.get()); - - if (op_name != nullptr && op_desc(out) != op_name) { + ggml_tensor * out = build_graph(ctx.get()); + std::string current_op_name = op_desc(out); + if (op_name != nullptr && current_op_name != op_name) { //printf(" %s: skipping\n", op_desc(out).c_str()); return true; } - int len = printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str()); - fflush(stdout); - // check if backends support op if (!ggml_backend_supports_op(backend, out)) { - printf("not supported\n"); + // Create test result for unsupported performance test + test_result result(ggml_backend_name(backend), current_op_name, vars(), "perf", false, false, + "not supported"); + + if (output_printer) { + output_printer->print_test_result(result); + } + return true; } - // align while also leaving some margin for variations in parameters - int align = 8; - int last = (len + align - 1) / align * align; - if (last - len < 5) { - last += align; - } - printf("%*s", last - len, ""); - // allocate ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr @@ -713,40 +1249,24 @@ struct test_case { total_runs += n_runs; } while (total_time_us < 1000*1000); // run for at least 1 second - printf(" %8d runs - %8.2f us/run - ", - total_runs, - (double)total_time_us / total_runs); + // Create test result + double avg_time_us = (double) total_time_us / total_runs; + double calculated_flops = (op_flops(out) > 0) ? (op_flops(out) * total_runs) / (total_time_us / 1e6) : 0.0; + double calculated_bandwidth = + (op_flops(out) == 0) ? total_mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0 : 0.0; + size_t calculated_memory_kb = op_size(out) / 1024; - if (op_flops(out) > 0) { - double flops_per_sec = (op_flops(out) * total_runs) / (total_time_us / 1e6); - auto format_flops = [](double flops) -> std::string { - char buf[256]; - if (flops >= 1e12) { - snprintf(buf, sizeof(buf), "%6.2f TFLOP", flops / 1e12); - } else if (flops >= 1e9) { - snprintf(buf, sizeof(buf), "%6.2f GFLOP", flops / 1e9); - } else if (flops >= 1e6) { - snprintf(buf, sizeof(buf), "%6.2f MFLOP", flops / 1e6); - } else { - snprintf(buf, sizeof(buf), "%6.2f KFLOP", flops / 1e3); - } - return buf; - }; - printf("%s/run - \033[1;34m%sS\033[0m", - format_flops(op_flops(out)).c_str(), - format_flops(flops_per_sec).c_str()); + test_result result(ggml_backend_name(backend), current_op_name, vars(), "perf", true, true, "", avg_time_us, + calculated_flops, calculated_bandwidth, calculated_memory_kb, total_runs); - } else { - printf("%8zu kB/run - \033[1;34m%7.2f GB/s\033[0m", - op_size(out) / 1024, - total_mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0); + if (output_printer) { + output_printer->print_test_result(result); } - printf("\n"); return true; } - bool eval_grad(ggml_backend_t backend, const char * op_name) { + bool eval_grad(ggml_backend_t backend, const char * op_name, printer * output_printer) { mode = MODE_GRAD; const std::vector expect = grad_expect(); @@ -764,42 +1284,47 @@ struct test_case { ggml_tensor * out = build_graph(ctx.get()); if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) { - //printf(" %s: skipping\n", op_desc(out).c_str()); return true; } - printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str()); - fflush(stdout); - if (out->type != GGML_TYPE_F32) { - printf("not supported [%s->type != FP32]\n", out->name); + output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend), + test_status_t::NOT_SUPPORTED, + out->name + std::string("->type != FP32"))); return true; } + // Print operation info first + output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend))); + // check if the backend supports the ops - bool supported = true; - bool any_params = false; + bool supported = true; + bool any_params = false; + std::string failure_reason; + for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) { if (!ggml_backend_supports_op(backend, t)) { - printf("not supported [%s] ", ggml_backend_name(backend)); - supported = false; + supported = false; + failure_reason = ggml_backend_name(backend); break; } if ((t->flags & GGML_TENSOR_FLAG_PARAM)) { any_params = true; if (t->type != GGML_TYPE_F32) { - printf("not supported [%s->type != FP32] ", t->name); - supported = false; + supported = false; + failure_reason = std::string(t->name) + "->type != FP32"; break; } } } if (!any_params) { - printf("not supported [%s] \n", op_desc(out).c_str()); - supported = false; + supported = false; + failure_reason = op_desc(out); } + if (!supported) { - printf("\n"); + output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend), + test_status_t::NOT_SUPPORTED, failure_reason)); return true; } @@ -810,7 +1335,9 @@ struct test_case { } } if (ngrads > grad_nmax()) { - printf("skipping large tensors for speed \n"); + test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend)); + info.set_large_tensor_skip(); + output_printer->print_operation(info); return true; } @@ -833,25 +1360,30 @@ struct test_case { for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) { if (!ggml_backend_supports_op(backend, t)) { - printf("not supported [%s] ", ggml_backend_name(backend)); + output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend), + test_status_t::NOT_SUPPORTED, + ggml_backend_name(backend))); supported = false; break; } if ((t->flags & GGML_TENSOR_FLAG_PARAM) && t->type != GGML_TYPE_F32) { - printf("not supported [%s->type != FP32] ", t->name); + output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend), + test_status_t::NOT_SUPPORTED, + std::string(t->name) + "->type != FP32")); supported = false; break; } } if (!supported) { - printf("\n"); return true; } // allocate ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr if (buf == NULL) { - printf("failed to allocate tensors [%s] ", ggml_backend_name(backend)); + test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend)); + info.set_error("allocation", ""); + output_printer->print_operation(info); return false; } @@ -889,7 +1421,9 @@ struct test_case { for (int64_t i = 0; i < ne; ++i) { // gradient algebraic // check for nans if (!std::isfinite(ga[i])) { - printf("[%s] nonfinite gradient at index %" PRId64 " (%s=%f) ", ggml_op_desc(t), i, bn, ga[i]); + test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend)); + info.set_gradient_info(i, bn, ga[i]); + output_printer->print_operation(info); ok = false; break; } @@ -957,7 +1491,9 @@ struct test_case { const double err = mean_abs_asymm(gn.data(), ga.data(), gn.size(), expect); if (err > max_maa_err()) { - printf("[%s] MAA = %.9f > %.9f ", ggml_op_desc(t), err, max_maa_err()); + test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend)); + info.set_maa_error(err, max_maa_err()); + output_printer->print_operation(info); ok = false; break; } @@ -966,16 +1502,18 @@ struct test_case { } } + // Create final test result + test_operation_info final_info(op_desc(out), vars(), ggml_backend_name(backend)); if (!ok) { - printf("compare failed "); + final_info.set_compare_failure(); } + final_info.status = ok ? test_status_t::OK : test_status_t::FAIL; + output_printer->print_operation(final_info); if (ok) { - printf("\033[1;32mOK\033[0m\n"); return true; } - printf("\033[1;31mFAIL\033[0m\n"); return false; } }; @@ -1104,6 +1642,111 @@ struct test_unary : public test_case { }; +// GGML_OP_GLU +struct test_glu : public test_case { + const ggml_glu_op op; + const ggml_type type; + const std::array ne_a; + int v; // view (1 : non-contiguous a) + bool swapped; + + std::string vars() override { + return VARS_TO_STR4(type, ne_a, v, swapped); + } + + test_glu(ggml_glu_op op, + ggml_type type = GGML_TYPE_F32, + std::array ne_a = {128, 2, 2, 2}, + int v = 0, + bool swapped = false) + : op(op), type(type), ne_a(ne_a), v(v), swapped(swapped) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a; + if (v & 1) { + auto ne = ne_a; ne[0] *= 3; + a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view_of_a"); + } else { + a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_name(a, "a"); + } + + ggml_tensor * out = ggml_glu(ctx, a, op, swapped); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + // test extended range of values to check for NaNs in GELU + init_tensor_uniform(t, -150.f, 150.f); + } + } +}; + +struct test_glu_split : public test_case { + const ggml_glu_op op; + const ggml_type type; + const std::array ne_a; + int v; // view (1 : non-contiguous a) + + std::string vars() override { + return VARS_TO_STR3(type, ne_a, v) + ",split"; + } + + test_glu_split(ggml_glu_op op, + ggml_type type = GGML_TYPE_F32, + std::array ne_a = {128, 2, 2, 2}, + int v = 0) + : op(op), type(type), ne_a(ne_a), v(v) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a; + ggml_tensor * b; + if (v & 1) { + auto ne = ne_a; ne[0] *= 3; + a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + + a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view_of_a"); + + b = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(b); + ggml_set_name(b, "b"); + + b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0); + ggml_set_name(a, "view_of_b"); + } else { + a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + + b = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_param(b); + ggml_set_name(b, "b"); + } + + ggml_tensor * out = ggml_glu_split(ctx, a, b, op); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + // test extended range of values to check for NaNs in GELU + init_tensor_uniform(t, -150.f, 150.f); + } + } +}; + // GGML_OP_GET_ROWS struct test_get_rows : public test_case { const ggml_type type; @@ -1213,6 +1856,76 @@ struct test_get_rows_back : public test_case { } }; +// GGML_OP_SET_ROWS +struct test_set_rows : public test_case { + const ggml_type type; + const std::array ne; + const std::array nr23; // broadcast only dims 2 and 3 + const int r; // rows to set + const bool v; // view (non-contiguous src1) + + std::string vars() override { + return VARS_TO_STR5(type, ne, nr23, r, v); + } + + test_set_rows(ggml_type type, + std::array ne, + std::array nr23, + int r, bool v = false) + : type(type), ne(ne), nr23(nr23), r(r), v(v) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]); + ggml_set_name(dst, "dst"); + + ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], r, ne[2]*nr23[0], ne[3]*nr23[1]); + ggml_set_name(src, "src"); + + ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, r, ne[2], ne[3]); + ggml_set_name(row_idxs, "row_idxs"); + + if (v) { + src = ggml_view_4d(ctx, src, ne[0], r/2, ne[2]*nr23[0], ne[3]*nr23[1], src->nb[1], src->nb[2], src->nb[3], 0); + row_idxs = ggml_view_3d(ctx, row_idxs, r/2, ne[2], ne[3], row_idxs->nb[1], row_idxs->nb[2], 0); + ggml_set_name(row_idxs, "view_of_rows"); + } + + ggml_tensor * out = ggml_set_rows(ctx, dst, src, row_idxs); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I64) { + if (ggml_is_view_op(t->op)) { + continue; + } + + for (int i2 = 0; i2 < t->ne[2]; i2++) { + for (int i1 = 0; i1 < t->ne[1]; i1++) { + // generate a shuffled subset of row indices + std::vector data(ne[1]); + for (int i = 0; i < ne[1]; i++) { + data[i] = i; + } + std::shuffle(data.begin(), data.end(), rng); + data.resize(t->ne[0]); + + const size_t offs = i1*t->nb[1] + i2*t->nb[2]; + ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t)); + } + } + } else { + init_tensor_uniform(t); + } + } + } +}; + // GGML_OP_ARGMAX struct test_argmax : public test_case { const ggml_type type; @@ -1828,6 +2541,59 @@ struct test_rms_norm_back : public test_case { } }; +// GGML_OP_RMS_NORM + GGML_OP_MUL +struct test_rms_norm_mul : public test_case { + const ggml_type type; + const std::array ne; + const float eps; + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "RMS_NORM_MUL"; + } + + bool run_whole_graph() override { return true; } + + std::string vars() override { + return VARS_TO_STR3(type, ne, eps); + } + + test_rms_norm_mul(ggml_type type = GGML_TYPE_F32, + std::array ne = {64, 5, 4, 3}, + float eps = 1e-6f) + : type(type), ne(ne), eps(eps) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); + ggml_set_name(a, "a"); + ggml_set_param(b); + ggml_set_name(b, "b"); + + // Use a and b early, so we don't end up with an OP_NONE between rms_norm and mul + a = ggml_add(ctx, a, b); + ggml_tensor * out = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t, -10.f, 10.f); + } + } + + float grad_eps() override { + return 1.0f; + } + + bool grad_precise() override { + return true; + } +}; + // GGML_OP_SSM_CONV struct test_ssm_conv : public test_case { const ggml_type type; @@ -1856,28 +2622,58 @@ struct test_ssm_scan : public test_case { const ggml_type type; const int64_t d_state; - const int64_t d_inner; + const int64_t head_dim; + const int64_t n_head; + const int64_t n_group; const int64_t n_seq_tokens; const int64_t n_seqs; std::string vars() override { - return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs); + return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs); } test_ssm_scan(ggml_type type = GGML_TYPE_F32, - int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) - : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + int64_t d_state = 32, + int64_t head_dim = 1, // non-zero for Mamba-2 + int64_t n_head = 32, + int64_t n_group = 1, + int64_t n_seq_tokens = 32, + int64_t n_seqs = 32) + : type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, n_seqs, 1 }.data()); - ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, 1 , 1 }.data()); - ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C); + ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); + ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head); + ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids); return out; } + + // similar to test_mul_mat_id + void initialize_tensors(ggml_context * ctx) override { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { continue; } + // ids + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t)); + } + } else { + init_tensor_uniform(t); + } + } + } }; // GGML_OP_RWKV_WKV6 @@ -2457,11 +3253,12 @@ struct test_soft_max : public test_case { const std::array ne; const bool mask; const ggml_type m_prec; + const std::array nr23; // broadcast only dims 2 and 3 const float scale; const float max_bias; std::string vars() override { - return VARS_TO_STR6(type, ne, mask, m_prec, scale, max_bias); + return VARS_TO_STR7(type, ne, mask, m_prec, nr23, scale, max_bias); } // the 1024 test with bias occasionally fails: @@ -2474,18 +3271,19 @@ struct test_soft_max : public test_case { std::array ne = {10, 5, 4, 3}, bool mask = false, ggml_type m_prec = GGML_TYPE_F32, + std::array nr23 = {1, 1}, float scale = 1.0f, float max_bias = 0.0f) - : type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {} + : type(type), ne(ne), mask(mask), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]); ggml_set_param(a); ggml_set_name(a, "a"); ggml_tensor * mask = nullptr; if (this->mask) { - mask = ggml_new_tensor_2d(ctx, m_prec, ne[0], ne[1]); + mask = ggml_new_tensor_4d(ctx, m_prec, ne[0], ne[1], ne[2], ne[3]); ggml_set_name(mask, "mask"); } @@ -3068,28 +3866,28 @@ struct test_upscale : public test_case { } }; -// GGML_OP_UPSCALE (ext) -struct test_upscale_ext : public test_case { +// GGML_OP_UPSCALE (via ggml_interpolate) +struct test_interpolate : public test_case { const ggml_type type; const std::array ne; const std::array ne_tgt; - const ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST; + const uint32_t mode = GGML_SCALE_MODE_NEAREST; std::string vars() override { return VARS_TO_STR4(type, ne, ne_tgt, mode); } - test_upscale_ext(ggml_type type = GGML_TYPE_F32, + test_interpolate(ggml_type type = GGML_TYPE_F32, std::array ne = {2, 5, 7, 11}, std::array ne_tgt = {5, 7, 11, 13}, - ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST) + uint32_t mode = GGML_SCALE_MODE_NEAREST) : type(type), ne(ne), ne_tgt(ne_tgt), mode(mode) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_name(a, "a"); - ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3], mode); + ggml_tensor * out = ggml_interpolate(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3], mode); ggml_set_name(out, "out"); return out; @@ -3316,7 +4114,7 @@ struct test_flash_attn_ext : public test_case { const int64_t hsk; // K head size const int64_t hsv; // V head size const int64_t nh; // num heads - const int64_t nr; // repeat in Q, tests for grouped-query attention + const std::array nr23; // repeat in dim 2 and 3, tests for grouped-query attention const int64_t kv; // kv size const int64_t nb; // batch size @@ -3330,7 +4128,7 @@ struct test_flash_attn_ext : public test_case { std::array permute; std::string vars() override { - return VARS_TO_STR12(hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute); + return VARS_TO_STR12(hsk, hsv, nh, nr23, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute); } double max_nmse_err() override { @@ -3341,13 +4139,13 @@ struct test_flash_attn_ext : public test_case { GGML_UNUSED(t); // Just counting matmul costs: // Q*K^T is nb x hsk x kv, P*V is nb x kv x hsv, per head - return 2 * nh*nr * nb * (hsk + hsv) * kv; + return (2 * nh*nr23[0] * nb * (hsk + hsv) * kv)*nr23[1]; } - test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8, + test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, std::array nr23 = {1, 1}, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32, ggml_type type_KV = GGML_TYPE_F16, std::array permute = {0, 1, 2, 3}) - : hsk(hsk), hsv(hsv), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {} + : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {} ggml_tensor * build_graph(ggml_context * ctx) override { const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV)); @@ -3366,18 +4164,18 @@ struct test_flash_attn_ext : public test_case { return t; }; - ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr, 1); + ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1]); ggml_set_name(q, "q"); - ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, 1); + ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1]); ggml_set_name(k, "k"); - ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, 1); + ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1]); ggml_set_name(v, "v"); ggml_tensor * m = nullptr; if (mask) { - m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1); + m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[0], nr23[1]); ggml_set_name(m, "m"); } @@ -3668,6 +4466,7 @@ struct test_llama : public test_llm { static constexpr float attn_factor = 1.0f; static constexpr float beta_fast = 32.0f; static constexpr float beta_slow = 1.0f; + bool fused; std::string op_desc(ggml_tensor * t) override { GGML_UNUSED(t); @@ -3683,7 +4482,9 @@ struct test_llama : public test_llm { return 2e-3; } - test_llama(int n_tokens = 1) + bool run_whole_graph() override { return fused; } + + test_llama(int n_tokens = 1, bool fused = false) : test_llm({ /*n_vocab =*/ 32000, /*n_embd =*/ 3200, @@ -3695,7 +4496,9 @@ struct test_llama : public test_llm { /*f_norm_eps =*/ 0.f, /*f_norm_rms_eps =*/ 1e-5f, /*n_tokens =*/ n_tokens, - }) { + }) + , fused(fused) + { } ggml_tensor * build_graph(ggml_context * ctx) override { @@ -3962,6 +4765,21 @@ static std::vector> make_test_cases_eval() { } } + // glu ops + for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { + for (int v : {0, 1}) { + for (int op = 0; op < GGML_GLU_OP_COUNT; op++) { + for (bool swapped : {false, true}) { + test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped)); + test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped)); + } + + test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v)); + test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v)); + } + } + } + test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false)); for (ggml_type type : all_types) { for (int b : {1, 7}) { @@ -3986,6 +4804,23 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v)); } + test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, { 1, 8, 1, 3 }, { 1, 1 }, 2, false)); + for (ggml_type type : all_types) { + for (int b : {1, 7}) { + for (bool v : {false, true}) { + test_cases.emplace_back(new test_set_rows(type, { 256, 5, b, 3 }, { 1, 1, }, 1, v)); + test_cases.emplace_back(new test_set_rows(type, { 256, 11, 1, b }, { 2, 3, }, 7, v)); + + test_cases.emplace_back(new test_set_rows(type, { 3*ggml_blck_size(type), 3, b, 1 }, { 2, 3, }, 2, v)); + + if (ggml_blck_size(type) == 1) { + test_cases.emplace_back(new test_set_rows(type, { 31, 3, b, 1 }, { 2, 3, }, 2, v)); + test_cases.emplace_back(new test_set_rows(type, { 33, 5, 1, b }, { 2, 3, }, 1, v)); + } + } + } + } + for (ggml_type type_input : {GGML_TYPE_F32}) { for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) { for (int k0 : {1, 3}) { @@ -4222,6 +5057,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps)); } + for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) { + test_cases.emplace_back(new test_rms_norm_mul(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); + } test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f)); @@ -4229,7 +5067,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1})); - test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4)); + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1 + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2 test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1)); test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1)); @@ -4341,8 +5180,10 @@ static std::vector> make_test_cases_eval() { for (auto nr : {1,4}) { for (uint32_t m = 0; m < 2; ++m) { for (uint32_t k = 0; k < 2; ++k) { - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3})); - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true)); + for (ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true)); + } } } } @@ -4354,6 +5195,11 @@ static std::vector> make_test_cases_eval() { // this case is verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend) // test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 512, 262144, 9216, {1, 1}, {1, 1})); + // test large experts*tokens + for (bool b : {false, true}) { + test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16)); + } + for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { for (int n_mats : {4, 8}) { @@ -4440,26 +5286,31 @@ static std::vector> make_test_cases_eval() { for (int64_t ne1 : {16, 1024}) { if (mask) { for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) { - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, scale, max_bias)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias)); + + if (ne0 <= 32 && ne1 <= 32) { + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 3}, mask, m_prec, {3, 1}, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias)); + } } } else { /* The precision of mask here doesn't matter as boolean mask is false */ - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias)); } } } } } } - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 8.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f)); for (float max_bias : {0.0f, 8.0f}) { for (float scale : {1.0f, 0.1f}) { @@ -4475,12 +5326,12 @@ static std::vector> make_test_cases_eval() { for (bool fw : {true, false}) { // fw == forward bool all = true; - for (float v : { 0, 1 }) { - for (float fs : { 1.0f, 1.4245f }) { - for (float ef : { 0.0f, 0.7465f }) { - for (float af : { 1.0f, 1.4245f }) { - for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - for (bool ff : {false, true}) { // freq_factors + for (float fs : { 1.0f, 1.4245f }) { + for (float ef : { 0.0f, 0.7465f }) { + for (float af : { 1.0f, 1.4245f }) { + for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (bool ff : {false, true}) { // freq_factors + for (float v : { 0, 1 }) { test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B if (all) { @@ -4493,13 +5344,21 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) + + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) } if (all) { test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B) test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B) + test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT) } @@ -4530,8 +5389,10 @@ static std::vector> make_test_cases_eval() { for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) { test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode)); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true)); - test_cases.emplace_back(new test_upscale_ext(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode)); + test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode)); + test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {5, 7, 11, 13}, {2, 5, 7, 11}, mode)); } + test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)); test_cases.emplace_back(new test_sum()); test_cases.emplace_back(new test_sum_rows()); @@ -4557,20 +5418,23 @@ static std::vector> make_test_cases_eval() { for (float logit_softcap : {0.0f, 10.0f}) { if (hsk != 128 && logit_softcap != 0.0f) continue; for (int nh : { 4, }) { - for (int nr : { 1, 4, 16 }) { - if (nr == 16 && hsk != 128) continue; - for (int kv : { 512, 1024, }) { - if (nr != 1 && kv != 512) continue; - for (int nb : { 1, 3, 32, 35, }) { - for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { - if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue; - for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { - test_cases.emplace_back(new test_flash_attn_ext( - hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV)); - // run fewer test cases permuted - if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) { + for (int nr3 : { 1, 3, }) { + if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes + for (int nr2 : { 1, 4, 16 }) { + if (nr2 == 16 && hsk != 128) continue; + for (int kv : { 512, 1024, }) { + if (nr2 != 1 && kv != 512) continue; + for (int nb : { 1, 3, 32, 35, }) { + for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { + if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue; + for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { test_cases.emplace_back(new test_flash_attn_ext( - hsk, hsv, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3})); + hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV)); + // run fewer test cases permuted + if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) { + test_cases.emplace_back(new test_flash_attn_ext( + hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3})); + } } } } @@ -4591,8 +5455,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3})); - // these tests are disabled to save execution time, but they can be handy for debugging #if 0 + // these tests are disabled to save execution time, sbut they can be handy for debugging + test_cases.emplace_back(new test_llama(2, true)); test_cases.emplace_back(new test_llama(1)); test_cases.emplace_back(new test_llama(2)); test_cases.emplace_back(new test_falcon(1)); @@ -4613,13 +5478,14 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3})); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1})); @@ -4651,7 +5517,7 @@ static std::vector> make_test_cases_perf() { for (int kv : { 4096, 8192, 16384, }) { for (int hs : { 64, 128, }) { for (int nr : { 1, 4, }) { - test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, nr, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); } } } @@ -4666,7 +5532,8 @@ static std::vector> make_test_cases_perf() { return test_cases; } -static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name, const char * params_filter) { +static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name, const char * params_filter, + printer * output_printer) { auto filter_test_cases = [](std::vector> & test_cases, const char * params_filter) { if (params_filter == nullptr) { return; @@ -4689,17 +5556,19 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op filter_test_cases(test_cases, params_filter); ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL); if (backend_cpu == NULL) { - printf(" Failed to initialize CPU backend\n"); + test_operation_info info("", "", "CPU"); + info.set_error("backend", "Failed to initialize CPU backend"); + output_printer->print_operation(info); return false; } size_t n_ok = 0; for (auto & test : test_cases) { - if (test->eval(backend, backend_cpu, op_name)) { + if (test->eval(backend, backend_cpu, op_name, output_printer)) { n_ok++; } } - printf(" %zu/%zu tests passed\n", n_ok, test_cases.size()); + output_printer->print_summary(test_summary_info(n_ok, test_cases.size(), false)); ggml_backend_free(backend_cpu); @@ -4711,11 +5580,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op filter_test_cases(test_cases, params_filter); size_t n_ok = 0; for (auto & test : test_cases) { - if (test->eval_grad(backend, op_name)) { + if (test->eval_grad(backend, op_name, output_printer)) { n_ok++; } } - printf(" %zu/%zu tests passed\n", n_ok, test_cases.size()); + output_printer->print_summary(test_summary_info(n_ok, test_cases.size(), false)); return n_ok == test_cases.size(); } @@ -4724,7 +5593,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op auto test_cases = make_test_cases_perf(); filter_test_cases(test_cases, params_filter); for (auto & test : test_cases) { - test->eval_perf(backend, op_name); + test->eval_perf(backend, op_name, output_printer); } return true; } @@ -4733,16 +5602,18 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } static void usage(char ** argv) { - printf("Usage: %s [mode] [-o ] [-b ] [-p ]\n", argv[0]); + printf("Usage: %s [mode] [-o ] [-b ] [-p ] [--output ]\n", argv[0]); printf(" valid modes:\n"); printf(" - test (default, compare with CPU backend for correctness)\n"); printf(" - grad (compare gradients from backpropagation with method of finite differences)\n"); printf(" - perf (performance evaluation)\n"); printf(" op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc)\n"); + printf(" --output specifies output format (default: console)\n"); } int main(int argc, char ** argv) { test_mode mode = MODE_TEST; + output_formats output_format = CONSOLE; const char * op_name_filter = nullptr; const char * backend_filter = nullptr; const char * params_filter = nullptr; @@ -4775,6 +5646,16 @@ int main(int argc, char ** argv) { usage(argv); return 1; } + } else if (strcmp(argv[i], "--output") == 0) { + if (i + 1 < argc) { + if (!output_format_from_str(argv[++i], output_format)) { + usage(argv); + return 1; + } + } else { + usage(argv); + return 1; + } } else { usage(argv); return 1; @@ -4784,23 +5665,29 @@ int main(int argc, char ** argv) { // load and enumerate backends ggml_backend_load_all(); - printf("Testing %zu devices\n\n", ggml_backend_dev_count()); + // Create printer for output format + std::unique_ptr output_printer = create_printer(output_format); + if (output_printer) { + output_printer->print_header(); + } + + output_printer->print_testing_start(testing_start_info(ggml_backend_dev_count())); size_t n_ok = 0; for (size_t i = 0; i < ggml_backend_dev_count(); i++) { ggml_backend_dev_t dev = ggml_backend_dev_get(i); - printf("Backend %zu/%zu: %s\n", i + 1, ggml_backend_dev_count(), ggml_backend_dev_name(dev)); - if (backend_filter != NULL && strcmp(backend_filter, ggml_backend_dev_name(dev)) != 0) { - printf(" Skipping\n"); + output_printer->print_backend_init( + backend_init_info(i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), true, "Skipping")); n_ok++; continue; } if (backend_filter == NULL && ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && mode != MODE_GRAD) { - printf(" Skipping CPU backend\n"); + output_printer->print_backend_init(backend_init_info( + i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), true, "Skipping CPU backend")); n_ok++; continue; } @@ -4815,36 +5702,35 @@ int main(int argc, char ** argv) { ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency()); } - printf(" Device description: %s\n", ggml_backend_dev_description(dev)); - size_t free, total; // NOLINT + size_t free, total; // NOLINT ggml_backend_dev_memory(dev, &free, &total); - printf(" Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024); - printf("\n"); + output_printer->print_backend_init(backend_init_info(i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), + false, "", ggml_backend_dev_description(dev), + total / 1024 / 1024, free / 1024 / 1024, true)); - bool ok = test_backend(backend, mode, op_name_filter, params_filter); + bool ok = test_backend(backend, mode, op_name_filter, params_filter, output_printer.get()); - printf(" Backend %s: ", ggml_backend_name(backend)); if (ok) { - printf("\033[1;32mOK\033[0m\n"); n_ok++; - } else { - printf("\033[1;31mFAIL\033[0m\n"); } - - printf("\n"); + output_printer->print_backend_status( + backend_status_info(ggml_backend_name(backend), ok ? test_status_t::OK : test_status_t::FAIL)); ggml_backend_free(backend); } ggml_quantize_free(); - printf("%zu/%zu backends passed\n", n_ok, ggml_backend_dev_count()); + if (output_printer) { + output_printer->print_footer(); + } + + output_printer->print_overall_summary( + overall_summary_info(n_ok, ggml_backend_dev_count(), n_ok == ggml_backend_dev_count())); if (n_ok != ggml_backend_dev_count()) { - printf("\033[1;31mFAIL\033[0m\n"); return 1; } - printf("\033[1;32mOK\033[0m\n"); return 0; } diff --git a/tests/test-c.c b/tests/test-c.c index 95ba73df39..a05071080a 100644 --- a/tests/test-c.c +++ b/tests/test-c.c @@ -1,7 +1,3 @@ #include "llama.h" -#ifdef GGML_USE_KOMPUTE -#include "ggml-kompute.h" -#endif - int main(void) {} diff --git a/tests/test-lora-conversion-inference.sh b/tests/test-lora-conversion-inference.sh index 1d1f4886ca..0255494b82 100755 --- a/tests/test-lora-conversion-inference.sh +++ b/tests/test-lora-conversion-inference.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -e # Array of models to iterate over diff --git a/tests/test-tokenizer-0.sh b/tests/test-tokenizer-0.sh index 4d2b836554..7ef009dc90 100755 --- a/tests/test-tokenizer-0.sh +++ b/tests/test-tokenizer-0.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # Usage: # diff --git a/tests/test-tokenizers-repo.sh b/tests/test-tokenizers-repo.sh index 86e839133c..1158aebae0 100755 --- a/tests/test-tokenizers-repo.sh +++ b/tests/test-tokenizers-repo.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash if [ $# -lt 2 ]; then printf "Usage: $0 []\n" diff --git a/tools/gguf-split/tests.sh b/tools/gguf-split/tests.sh index 05a9322271..c9ad85da0f 100755 --- a/tools/gguf-split/tests.sh +++ b/tools/gguf-split/tests.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -eu diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index a990520ed3..9146c9e9c4 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1405,8 +1405,7 @@ struct clip_graph { ggml_tensor * x = embeddings; embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings); x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x); - embeddings = ggml_silu_inplace(ctx0, embeddings); - embeddings = ggml_mul(ctx0, embeddings,x); + embeddings = ggml_swiglu_split(ctx0, embeddings, x); embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); } // arrangement of BOI/EOI token embeddings @@ -1502,15 +1501,8 @@ struct clip_graph { cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); // swiglu - { - int64_t split_point = cur->ne[0] / 2; - ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); - ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); - - // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half - x1 = ggml_silu(ctx0, x1); - cur = ggml_mul(ctx0, x0, x1); - } + // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half + cur = ggml_swiglu_swapped(ctx0, cur); // mid-norm cur = ggml_rms_norm(ctx0, cur, 1e-6); @@ -1769,35 +1761,42 @@ private: cur = tmp; } + // we only support parallel ffn for now switch (type_op) { case FFN_SILU: - { + if (gate) { + cur = ggml_swiglu_split(ctx0, cur, tmp); + cb(cur, "ffn_swiglu", il); + } else { cur = ggml_silu(ctx0, cur); cb(cur, "ffn_silu", il); } break; case FFN_GELU: - { + if (gate) { + cur = ggml_geglu_split(ctx0, cur, tmp); + cb(cur, "ffn_geglu", il); + } else { cur = ggml_gelu(ctx0, cur); cb(cur, "ffn_gelu", il); } break; case FFN_GELU_ERF: - { + if (gate) { + cur = ggml_geglu_erf_split(ctx0, cur, tmp); + cb(cur, "ffn_geglu_erf", il); + } else { cur = ggml_gelu_erf(ctx0, cur); - cb(cur, "ggml_gelu_erf", il); + cb(cur, "ffn_gelu_erf", il); } break; case FFN_GELU_QUICK: - { + if (gate) { + cur = ggml_geglu_quick_split(ctx0, cur, tmp); + cb(cur, "ffn_geglu_quick", il); + } else { cur = ggml_gelu_quick(ctx0, cur); - cb(cur, "ffn_relu", il); + cb(cur, "ffn_gelu_quick", il); } break; } - // we only support parallel ffn for now - if (gate) { - cur = ggml_mul(ctx0, cur, tmp); - cb(cur, "ffn_gate_par", il); - } - if (down) { cur = ggml_mul_mat(ctx0, down, cur); } diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index aa00198932..b25024c2f1 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # make sure we are in the right directory SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) diff --git a/tools/quantize/tests.sh b/tools/quantize/tests.sh index 70f7610f98..ba96161484 100644 --- a/tools/quantize/tests.sh +++ b/tools/quantize/tests.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash set -eu diff --git a/tools/server/README.md b/tools/server/README.md index 1a624c13be..6f962664f6 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -164,6 +164,7 @@ The project is under active development, and we are [looking for feedback and co | `--api-key-file FNAME` | path to file containing API keys (default: none) | | `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key
(env: LLAMA_ARG_SSL_KEY_FILE) | | `--ssl-cert-file FNAME` | path to file a PEM-encoded SSL certificate
(env: LLAMA_ARG_SSL_CERT_FILE) | +| `--chat-template-kwargs STRING` | JSON object containing additional params for the json template parser. Example: `--chat_template_kwargs "{\"enable_thinking\":false}`"
(env: LLAMA_CHAT_TEMPLATE_KWARGS) | | `-to, --timeout N` | server read/write timeout in seconds (default: 600)
(env: LLAMA_ARG_TIMEOUT) | | `--threads-http N` | number of threads used to process HTTP requests (default: -1)
(env: LLAMA_ARG_THREADS_HTTP) | | `--cache-reuse N` | min chunk size to attempt reusing from the cache via KV shifting (default: 0)
[(card)](https://ggml.ai/f0.png)
(env: LLAMA_ARG_CACHE_REUSE) | @@ -1118,6 +1119,8 @@ See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs The `response_format` parameter supports both plain JSON output (e.g. `{"type": "json_object"}`) and schema-constrained JSON (e.g. `{"type": "json_object", "schema": {"type": "string", "minLength": 10, "maxLength": 100}}` or `{"type": "json_schema", "schema": {"properties": { "name": { "title": "Name", "type": "string" }, "date": { "title": "Date", "type": "string" }, "participants": { "items": {"type: "string" }, "title": "Participants", "type": "string" } } } }`), similar to other OpenAI-inspired API providers. +`chat_template_kwargs`: Allows sending additional parameters to the json templating system. For example: `{"enable_thinking": false}` + *Examples:* You can use either Python `openai` library with appropriate checkpoints: diff --git a/tools/server/chat-llama2.sh b/tools/server/chat-llama2.sh index 1fc79b7e19..450445f17e 100755 --- a/tools/server/chat-llama2.sh +++ b/tools/server/chat-llama2.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash API_URL="${API_URL:-http://127.0.0.1:8080}" diff --git a/tools/server/chat.sh b/tools/server/chat.sh index da0a6ca68c..84cea2d56a 100755 --- a/tools/server/chat.sh +++ b/tools/server/chat.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash API_URL="${API_URL:-http://127.0.0.1:8080}" diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 0fb01665ae..53b71079c1 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 852352383b..57b917f2f9 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2110,6 +2110,7 @@ struct server_context { /* use_jinja */ params_base.use_jinja, /* prefill_assistant */ params_base.prefill_assistant, /* reasoning_format */ params_base.reasoning_format, + /* chat_template_kwargs */ params_base.default_template_kwargs, /* common_chat_templates */ chat_templates.get(), /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, @@ -4805,14 +4806,14 @@ int main(int argc, char ** argv) { // register static assets routes if (!params.public_path.empty()) { // Set the base directory for serving static files - bool is_found = svr->set_mount_point("/", params.public_path); + bool is_found = svr->set_mount_point(params.api_prefix + "/", params.public_path); if (!is_found) { LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); return 1; } } else { // using embedded static index.html - svr->Get("/", [](const httplib::Request & req, httplib::Response & res) { + svr->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) { if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { res.set_content("Error: gzip is not supported by this browser", "text/plain"); } else { @@ -4828,37 +4829,37 @@ int main(int argc, char ** argv) { } // register API routes - svr->Get ("/health", handle_health); // public endpoint (no API key check) - svr->Get ("/metrics", handle_metrics); - svr->Get ("/props", handle_props); - svr->Post("/props", handle_props_change); - svr->Post("/api/show", handle_api_show); - svr->Get ("/models", handle_models); // public endpoint (no API key check) - svr->Get ("/v1/models", handle_models); // public endpoint (no API key check) - svr->Get ("/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check) - svr->Post("/completion", handle_completions); // legacy - svr->Post("/completions", handle_completions); - svr->Post("/v1/completions", handle_completions_oai); - svr->Post("/chat/completions", handle_chat_completions); - svr->Post("/v1/chat/completions", handle_chat_completions); - svr->Post("/api/chat", handle_chat_completions); // ollama specific endpoint - svr->Post("/infill", handle_infill); - svr->Post("/embedding", handle_embeddings); // legacy - svr->Post("/embeddings", handle_embeddings); - svr->Post("/v1/embeddings", handle_embeddings_oai); - svr->Post("/rerank", handle_rerank); - svr->Post("/reranking", handle_rerank); - svr->Post("/v1/rerank", handle_rerank); - svr->Post("/v1/reranking", handle_rerank); - svr->Post("/tokenize", handle_tokenize); - svr->Post("/detokenize", handle_detokenize); - svr->Post("/apply-template", handle_apply_template); + svr->Get (params.api_prefix + "/health", handle_health); // public endpoint (no API key check) + svr->Get (params.api_prefix + "/metrics", handle_metrics); + svr->Get (params.api_prefix + "/props", handle_props); + svr->Post(params.api_prefix + "/props", handle_props_change); + svr->Post(params.api_prefix + "/api/show", handle_api_show); + svr->Get (params.api_prefix + "/models", handle_models); // public endpoint (no API key check) + svr->Get (params.api_prefix + "/v1/models", handle_models); // public endpoint (no API key check) + svr->Get (params.api_prefix + "/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check) + svr->Post(params.api_prefix + "/completion", handle_completions); // legacy + svr->Post(params.api_prefix + "/completions", handle_completions); + svr->Post(params.api_prefix + "/v1/completions", handle_completions_oai); + svr->Post(params.api_prefix + "/chat/completions", handle_chat_completions); + svr->Post(params.api_prefix + "/v1/chat/completions", handle_chat_completions); + svr->Post(params.api_prefix + "/api/chat", handle_chat_completions); // ollama specific endpoint + svr->Post(params.api_prefix + "/infill", handle_infill); + svr->Post(params.api_prefix + "/embedding", handle_embeddings); // legacy + svr->Post(params.api_prefix + "/embeddings", handle_embeddings); + svr->Post(params.api_prefix + "/v1/embeddings", handle_embeddings_oai); + svr->Post(params.api_prefix + "/rerank", handle_rerank); + svr->Post(params.api_prefix + "/reranking", handle_rerank); + svr->Post(params.api_prefix + "/v1/rerank", handle_rerank); + svr->Post(params.api_prefix + "/v1/reranking", handle_rerank); + svr->Post(params.api_prefix + "/tokenize", handle_tokenize); + svr->Post(params.api_prefix + "/detokenize", handle_detokenize); + svr->Post(params.api_prefix + "/apply-template", handle_apply_template); // LoRA adapters hotswap - svr->Get ("/lora-adapters", handle_lora_adapters_list); - svr->Post("/lora-adapters", handle_lora_adapters_apply); + svr->Get (params.api_prefix + "/lora-adapters", handle_lora_adapters_list); + svr->Post(params.api_prefix + "/lora-adapters", handle_lora_adapters_apply); // Save & load slots - svr->Get ("/slots", handle_slots); - svr->Post("/slots/:id_slot", handle_slots_action); + svr->Get (params.api_prefix + "/slots", handle_slots); + svr->Post(params.api_prefix + "/slots/:id_slot", handle_slots_action); // // Start the server diff --git a/tools/server/tests/tests.sh b/tools/server/tests/tests.sh index 33fa8cc646..709b5841aa 100755 --- a/tools/server/tests/tests.sh +++ b/tools/server/tests/tests.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # make sure we are in the right directory SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 1b5205f79d..7ee9a16514 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -132,6 +132,28 @@ def test_chat_template(): assert res.body["__verbose"]["prompt"] == " <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" +@pytest.mark.parametrize("prefill,re_prefill", [ + ("Whill", "Whill"), + ([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"), +]) +def test_chat_template_assistant_prefill(prefill, re_prefill): + global server + server.chat_template = "llama3" + server.debug = True # to get the "__verbose" object in the response + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 8, + "messages": [ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + {"role": "assistant", "content": prefill}, + ] + }) + assert res.status_code == 200 + assert "__verbose" in res.body + assert res.body["__verbose"]["prompt"] == f" <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}" + + def test_apply_chat_template(): global server server.chat_template = "command-r" @@ -228,6 +250,7 @@ def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re [{"role": "system", "content": 123}], # [{"content": "hello"}], # TODO: should not be a valid case [{"role": "system", "content": "test"}, {}], + [{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}], ]) def test_invalid_chat_completion_req(messages): global server diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index f8fab2c866..6c2e91359a 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -579,6 +579,7 @@ struct oaicompat_parser_options { bool use_jinja; bool prefill_assistant; common_reasoning_format reasoning_format; + std::map chat_template_kwargs; common_chat_templates * tmpls; bool allow_image; bool allow_audio; @@ -756,6 +757,13 @@ static json oaicompat_chat_params_parse( llama_params["parse_tool_calls"] = true; } + // merge the template args provided from command line with the args provided in the user request + auto chat_template_kwargs_object = json_value(body, "chat_template_kwargs", json::object()); + inputs.chat_template_kwargs = opt.chat_template_kwargs; + for (const auto & item : chat_template_kwargs_object.items()) { + inputs.chat_template_kwargs[item.key()] = item.value().dump(); + } + // if the assistant message appears at the end of list, we do not add end-of-turn token // for ex. this can be useful to modify the reasoning process in reasoning models bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant; @@ -771,6 +779,11 @@ static json oaicompat_chat_params_parse( /* TODO: test this properly */ inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE; + + if ( (!inputs.enable_thinking) || inputs.chat_template_kwargs.find("enable_thinking") != inputs.chat_template_kwargs.end()) { + throw std::runtime_error("Assistant response prefill is incompatible with enable_thinking."); + } + inputs.add_generation_prompt = true; } @@ -779,7 +792,13 @@ static json oaicompat_chat_params_parse( /* Append assistant prefilled message */ if (prefill_assistant_message) { - chat_params.prompt += last_message.content; + if (!last_message.content_parts.empty()) { + for (auto & p : last_message.content_parts) { + chat_params.prompt += p.text; + } + } else { + chat_params.prompt += last_message.content; + } } llama_params["chat_format"] = static_cast(chat_params.format); diff --git a/tools/server/webui/src/components/Sidebar.tsx b/tools/server/webui/src/components/Sidebar.tsx index a77cb83b45..b52a8df03c 100644 --- a/tools/server/webui/src/components/Sidebar.tsx +++ b/tools/server/webui/src/components/Sidebar.tsx @@ -231,7 +231,7 @@ function ConversationItem({ > {conv.name} -
+