diff --git a/.devops/nix/package.nix b/.devops/nix/package.nix index 651a54db4c..41748e89d5 100644 --- a/.devops/nix/package.nix +++ b/.devops/nix/package.nix @@ -128,10 +128,6 @@ effectiveStdenv.mkDerivation (finalAttrs: { }; postPatch = '' - substituteInPlace ./ggml/src/ggml-metal/ggml-metal.m \ - --replace '[bundle pathForResource:@"ggml-metal" ofType:@"metal"];' "@\"$out/bin/ggml-metal.metal\";" - substituteInPlace ./ggml/src/ggml-metal/ggml-metal.m \ - --replace '[bundle pathForResource:@"default" ofType:@"metallib"];' "@\"$out/bin/default.metallib\";" ''; # With PR#6015 https://github.com/ggml-org/llama.cpp/pull/6015, diff --git a/.github/actions/install-exe/action.yml b/.github/actions/install-exe/action.yml new file mode 100644 index 0000000000..002bec83c7 --- /dev/null +++ b/.github/actions/install-exe/action.yml @@ -0,0 +1,36 @@ +name: "Install exe" +description: "Download and install exe" +inputs: + url: + description: "URL of the exe installer" + required: true + args: + description: "Installer arguments" + required: true + timeout: + description: "Timeout (in ms)" + required: false + default: "600000" + +runs: + using: "composite" + steps: + - name: Install EXE + shell: pwsh + run: | + $ErrorActionPreference = "Stop" + write-host "Downloading Installer EXE" + Invoke-WebRequest -Uri "${{ inputs.url }}" -OutFile "${env:RUNNER_TEMP}\temp-install.exe" + write-host "Installing" + $proc = Start-Process "${env:RUNNER_TEMP}\temp-install.exe" -ArgumentList '${{ inputs.args }}' -NoNewWindow -PassThru + $completed = $proc.WaitForExit(${{ inputs.timeout }}) + if (-not $completed) { + Write-Error "Installer timed out. Killing the process" + $proc.Kill() + exit 1 + } + if ($proc.ExitCode -ne 0) { + Write-Error "Installer failed with exit code $($proc.ExitCode)" + exit 1 + } + write-host "Completed installation" diff --git a/.github/actions/linux-setup-spacemit/action.yml b/.github/actions/linux-setup-spacemit/action.yml new file mode 100644 index 0000000000..e2193e8931 --- /dev/null +++ b/.github/actions/linux-setup-spacemit/action.yml @@ -0,0 +1,20 @@ +name: "Linux - Setup SpacemiT Toolchain" +description: "Setup SpacemiT Toolchain for Linux" +inputs: + path: + description: "Installation path" + required: true + version: + description: "SpacemiT toolchain version" + required: true + +runs: + using: "composite" + steps: + - name: Setup SpacemiT Toolchain + id: setup + uses: ./.github/actions/unarchive-tar + with: + url: https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v${{ inputs.version }}.tar.xz + path: ${{ inputs.path }} + strip: 1 diff --git a/.github/actions/linux-setup-vulkan/action.yml b/.github/actions/linux-setup-vulkan/action.yml new file mode 100644 index 0000000000..4d29837feb --- /dev/null +++ b/.github/actions/linux-setup-vulkan/action.yml @@ -0,0 +1,20 @@ +name: "Linux - Setup Vulkan SDK" +description: "Setup Vulkan SDK for Linux" +inputs: + path: + description: "Installation path" + required: true + version: + description: "Vulkan SDK version" + required: true + +runs: + using: "composite" + steps: + - name: Setup Vulkan SDK + id: setup + uses: ./.github/actions/unarchive-tar + with: + url: https://sdk.lunarg.com/sdk/download/${{ inputs.version }}/linux/vulkan_sdk.tar.xz + path: ${{ inputs.path }} + strip: 1 diff --git a/.github/actions/unarchive-tar/action.yml b/.github/actions/unarchive-tar/action.yml new file mode 100644 index 0000000000..b97e402f46 --- /dev/null +++ b/.github/actions/unarchive-tar/action.yml @@ -0,0 +1,27 @@ +name: "Unarchive tar" +description: "Download and unarchive tar into directory" +inputs: + url: + description: "URL of the tar archive" + required: true + path: + description: "Directory to unarchive into" + required: true + type: + description: "Compression type (tar option)" + required: false + default: "J" + strip: + description: "Strip components" + required: false + default: "0" + +runs: + using: "composite" + steps: + - name: Unarchive into directory + shell: bash + run: | + mkdir -p ${{ inputs.path }} + cd ${{ inputs.path }} + curl --no-progress-meter ${{ inputs.url }} | tar -${{ inputs.type }}x --strip-components=${{ inputs.strip }} diff --git a/.github/actions/windows-setup-rocm/action.yml b/.github/actions/windows-setup-rocm/action.yml new file mode 100644 index 0000000000..b83e6e295b --- /dev/null +++ b/.github/actions/windows-setup-rocm/action.yml @@ -0,0 +1,15 @@ +name: "Windows - Setup ROCm" +description: "Setup ROCm for Windows" +inputs: + version: + description: "ROCm version" + required: true + +runs: + using: "composite" + steps: + - name: Setup ROCm + uses: ./.github/actions/install-exe + with: + url: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ inputs.version }}-WinSvr2022-For-HIP.exe + args: -install diff --git a/.github/workflows/build-cache.yml b/.github/workflows/build-cache.yml new file mode 100644 index 0000000000..6a22e41c3b --- /dev/null +++ b/.github/workflows/build-cache.yml @@ -0,0 +1,89 @@ +name: Build Actions Cache + +on: + workflow_dispatch: # allows manual triggering + schedule: + - cron: '0 * * * *' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + ubuntu-24-vulkan-cache: + runs-on: ubuntu-24.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Get latest Vulkan SDK version + id: vulkan_sdk_version + run: | + echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV" + + - name: Setup Cache + uses: actions/cache@v4 + id: cache-sdk + with: + path: ./vulkan_sdk + key: vulkan-sdk-${{ env.VULKAN_SDK_VERSION }}-${{ runner.os }} + + - name: Setup Vulkan SDK + if: steps.cache-sdk.outputs.cache-hit != 'true' + uses: ./.github/actions/linux-setup-vulkan + with: + path: ./vulkan_sdk + version: ${{ env.VULKAN_SDK_VERSION }} + + ubuntu-24-spacemit-cache: + runs-on: ubuntu-24.04 + + env: + # Make sure this is in sync with build-linux-cross.yml + SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2" + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Setup Cache + uses: actions/cache@v4 + id: cache-toolchain + with: + path: ./spacemit_toolchain + key: spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}-${{ runner.os }} + + - name: Setup SpacemiT Toolchain + if: steps.cache-toolchain.outputs.cache-hit != 'true' + uses: ./.github/actions/linux-setup-spacemit + with: + path: ./spacemit_toolchain + version: ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} + + windows-2022-rocm-cache: + runs-on: windows-2022 + + env: + # Make sure this is in sync with build.yml + HIPSDK_INSTALLER_VERSION: "25.Q3" + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Setup Cache + uses: actions/cache@v4 + id: cache-rocm + with: + path: C:\Program Files\AMD\ROCm + key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }} + + - name: Setup ROCm + if: steps.cache-rocm.outputs.cache-hit != 'true' + uses: ./.github/actions/windows-setup-rocm + with: + version: ${{ env.HIPSDK_INSTALLER_VERSION }} diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml index 2b101876c5..937306f7af 100644 --- a/.github/workflows/build-linux-cross.yml +++ b/.github/workflows/build-linux-cross.yml @@ -258,31 +258,29 @@ jobs: runs-on: ubuntu-24.04 env: + # Make sure this is in sync with build-cache.yml SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2" - SPACEMIT_IME_TOOLCHAIN_PATH: "spacemit-toolchain-linux-glibc-x86_64" steps: - uses: actions/checkout@v4 - - name: Cache Toolchain + - name: Use SpacemiT Toolchain Cache uses: actions/cache@v4 - id: cache-spacemit-ime-cross-toolchain + id: cache-toolchain with: - path: ./${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} - key: ${{ runner.os }}-spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} + path: ./spacemit_toolchain + key: spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}-${{ runner.os }} - - name: Setup Toolchain - if: steps.cache-spacemit-ime-cross-toolchain.outputs.cache-hit != 'true' - run: | - wget --quiet --no-check-certificate https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}.tar.xz -O ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz - rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} - mkdir -p ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} - tar xf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz -C ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} --strip-components=1 - rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz + - name: Setup SpacemiT Toolchain + if: steps.cache-toolchain.outputs.cache-hit != 'true' + uses: ./.github/actions/linux-setup-spacemit + with: + path: ./spacemit_toolchain + version: ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} - name: Build run: | - export RISCV_ROOT_PATH=${PWD}/${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} + export RISCV_ROOT_PATH=${PWD}/spacemit_toolchain cmake -B build -DLLAMA_CURL=OFF \ -DCMAKE_BUILD_TYPE=Release \ -DGGML_OPENMP=OFF \ diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index db885907cd..8d6ba5f9f3 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -413,20 +413,19 @@ jobs: run: | echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV" - - name: Cache Vulkan SDK - id: cache_vulkan_sdk + - name: Use Vulkan SDK Cache uses: actions/cache@v4 + id: cache-sdk with: path: ./vulkan_sdk key: vulkan-sdk-${{ env.VULKAN_SDK_VERSION }}-${{ runner.os }} - - name: Install Vulkan SDK - if: steps.cache_vulkan_sdk.outputs.cache-hit != 'true' - id: vulkan_sdk_install - run: | - mkdir -p vulkan_sdk - cd vulkan_sdk - curl --no-progress-meter https://sdk.lunarg.com/sdk/download/latest/linux/vulkan_sdk.tar.xz | tar -Jx --strip-components=1 + - name: Setup Vulkan SDK + if: steps.cache-sdk.outputs.cache-hit != 'true' + uses: ./.github/actions/linux-setup-vulkan + with: + path: ./vulkan_sdk + version: ${{ env.VULKAN_SDK_VERSION }} - name: Build id: cmake_build @@ -445,8 +444,8 @@ jobs: # This is using llvmpipe and runs slower than other backends ctest -L main --verbose --timeout 4200 - ubuntu-22-cmake-webgpu: - runs-on: ubuntu-22.04 + ubuntu-24-cmake-webgpu: + runs-on: ubuntu-24.04 steps: - name: Clone @@ -456,16 +455,34 @@ jobs: - name: ccache uses: ggml-org/ccache-action@v1.2.16 with: - key: ubuntu-22-cmake-webgpu + key: ubuntu-24-cmake-webgpu evict-old-files: 1d - - name: Vulkan SDK Dependencies - id: vulkan-depends + - name: Dependencies + id: depends run: | - wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add - - sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list + sudo add-apt-repository -y ppa:kisak/kisak-mesa sudo apt-get update -y - sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk libcurl4-openssl-dev + sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libcurl4-openssl-dev + + - name: Get latest Vulkan SDK version + id: vulkan_sdk_version + run: | + echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV" + + - name: Use Vulkan SDK Cache + uses: actions/cache@v4 + id: cache-sdk + with: + path: ./vulkan_sdk + key: vulkan-sdk-${{ env.VULKAN_SDK_VERSION }}-${{ runner.os }} + + - name: Setup Vulkan SDK + if: steps.cache-sdk.outputs.cache-hit != 'true' + uses: ./.github/actions/linux-setup-vulkan + with: + path: ./vulkan_sdk + version: ${{ env.VULKAN_SDK_VERSION }} - name: Dawn Dependency id: dawn-depends @@ -1111,6 +1128,7 @@ jobs: env: # The ROCm version must correspond to the version used in the HIP SDK. ROCM_VERSION: "6.4.2" + # Make sure this is in sync with build-cache.yml HIPSDK_INSTALLER_VERSION: "25.Q3" steps: @@ -1125,33 +1143,18 @@ jobs: 7z x rocwmma.deb 7z x data.tar - - name: Cache ROCm Installation - id: cache-rocm + - name: Use ROCm Installation Cache uses: actions/cache@v4 + id: cache-rocm with: path: C:\Program Files\AMD\ROCm key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }} - - name: Install ROCm + - name: Setup ROCm if: steps.cache-rocm.outputs.cache-hit != 'true' - id: depends - run: | - $ErrorActionPreference = "Stop" - write-host "Downloading AMD HIP SDK Installer" - Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ env.HIPSDK_INSTALLER_VERSION }}-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" - write-host "Installing AMD HIP SDK" - $proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru - $completed = $proc.WaitForExit(600000) - if (-not $completed) { - Write-Error "ROCm installation timed out after 10 minutes. Killing the process" - $proc.Kill() - exit 1 - } - if ($proc.ExitCode -ne 0) { - Write-Error "ROCm installation failed with exit code $($proc.ExitCode)" - exit 1 - } - write-host "Completed AMD HIP SDK installation" + uses: ./.github/actions/windows-setup-rocm + with: + version: ${{ env.HIPSDK_INSTALLER_VERSION }} - name: Verify ROCm id: verify @@ -1512,3 +1515,29 @@ jobs: run: | vulkaninfo --summary GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp + + ggml-ci-arm64-cpu-kleidiai: + runs-on: ubuntu-22.04-arm + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.16 + with: + key: ggml-ci-arm64-cpu-kleidiai + evict-old-files: 1d + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install -y build-essential libcurl4-openssl-dev + + - name: Test + id: ggml-ci + run: | + GG_BUILD_KLEIDIAI=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt + diff --git a/CODEOWNERS b/CODEOWNERS index 15e3559fda..3b696bf94a 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -2,7 +2,7 @@ # multiplie collaborators per item can be specified /.devops/*.Dockerfile @ngxson -/.github/actions/ @slaren +/.github/actions/ @slaren @CISC /.github/workflows/ @CISC /.github/workflows/release.yml @slaren /.github/workflows/winget.yml @slaren @@ -70,6 +70,7 @@ /ggml/src/ggml-rpc/ @rgerganov /ggml/src/ggml-threading.* @ggerganov @slaren /ggml/src/ggml-vulkan/ @0cc4m +/ggml/src/ggml-webgpu/ @reeselevine /ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM /ggml/src/ggml.c @ggerganov @slaren /ggml/src/ggml.cpp @ggerganov @slaren diff --git a/ci/run.sh b/ci/run.sh index c6e31fa0b8..bf0d53f20a 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -22,6 +22,9 @@ # # with MUSA support # GG_BUILD_MUSA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt # +# # with KLEIDIAI support +# GG_BUILD_KLEIDIAI=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt +# if [ -z "$2" ]; then echo "usage: $0 " @@ -115,6 +118,34 @@ if [ ! -z ${GG_BUILD_NO_SVE} ]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm" fi +if [ -n "${GG_BUILD_KLEIDIAI}" ]; then + echo ">>===== Enabling KleidiAI support" + + CANDIDATES=("armv9-a+dotprod+i8mm" "armv8.6-a+dotprod+i8mm" "armv8.2-a+dotprod") + CPU="" + + for cpu in "${CANDIDATES[@]}"; do + if echo 'int main(){}' | ${CXX:-c++} -march="$cpu" -x c++ - -c -o /dev/null >/dev/null 2>&1; then + CPU="$cpu" + break + fi + done + + if [ -z "$CPU" ]; then + echo "ERROR: None of the required ARM baselines (armv9/armv8.6/armv8.2 + dotprod) are supported by this compiler." + exit 1 + fi + + echo ">>===== Using ARM baseline: ${CPU}" + + CMAKE_EXTRA="${CMAKE_EXTRA:+$CMAKE_EXTRA } \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_KLEIDIAI=ON \ + -DGGML_CPU_AARCH64=ON \ + -DGGML_CPU_ARM_ARCH=${CPU} \ + -DBUILD_SHARED_LIBS=OFF" +fi + ## helpers # download a file if it does not exist or if it is outdated @@ -512,12 +543,7 @@ function gg_run_rerank_tiny { gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer_config.json gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/special_tokens_map.json gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/resolve/main/pytorch_model.bin - gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/sentence_bert_config.json - gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.txt - gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/modules.json - gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json - - gg_wget models-mnt/rerank-tiny/1_Pooling https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/1_Pooling/config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.json path_models="../models-mnt/rerank-tiny" diff --git a/common/arg.cpp b/common/arg.cpp index 577048c201..d17645cf2f 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1615,18 +1615,14 @@ static void add_rpc_devices(const std::string & servers) { if (!rpc_reg) { throw std::invalid_argument("failed to find RPC backend"); } - typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint); - ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); - if (!ggml_backend_rpc_add_device_fn) { - throw std::invalid_argument("failed to find RPC device add function"); + typedef ggml_backend_reg_t (*ggml_backend_rpc_add_server_t)(const char * endpoint); + ggml_backend_rpc_add_server_t ggml_backend_rpc_add_server_fn = (ggml_backend_rpc_add_server_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server"); + if (!ggml_backend_rpc_add_server_fn) { + throw std::invalid_argument("failed to find RPC add server function"); } for (const auto & server : rpc_servers) { - ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str()); - if (dev) { - ggml_backend_device_register(dev); - } else { - throw std::invalid_argument("failed to register RPC device"); - } + auto reg = ggml_backend_rpc_add_server_fn(server.c_str()); + ggml_backend_register(reg); } } @@ -1939,6 +1935,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.n_ctx_checkpoints = value; } ).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--cache-ram", "-cram"}, "N", + string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)\n" + "[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)", params.cache_ram_mib), + [](common_params & params, int value) { + params.cache_ram_mib = value; + } + ).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--kv-unified", "-kvu"}, string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n" @@ -2588,6 +2592,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.no_extra_bufts = true; } ).set_env("LLAMA_ARG_NO_REPACK")); + add_opt(common_arg( + {"--no-host"}, + "bypass host buffer allowing extra buffers to be used", + [](common_params & params) { + params.no_host = true; + } + ).set_env("LLAMA_ARG_NO_HOST")); add_opt(common_arg( {"-ctk", "--cache-type-k"}, "TYPE", string_format( @@ -3429,7 +3440,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--reasoning-format"}, "FORMAT", "controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n" "- none: leaves thoughts unparsed in `message.content`\n" - "- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n" + "- deepseek: puts thoughts in `message.reasoning_content`\n" + "- deepseek-legacy: keeps `` tags in `message.content` while also populating `message.reasoning_content`\n" "(default: auto)", [](common_params & params, const std::string & value) { params.reasoning_format = common_reasoning_format_from_name(value); @@ -3856,7 +3868,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF"; params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf"; - params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.embd_normalize = 2; params.n_ctx = 512; params.verbose_prompt = true; @@ -3870,7 +3881,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF"; params.model.hf_file = "e5-small-v2-q8_0.gguf"; - params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.embd_normalize = 2; params.n_ctx = 512; params.verbose_prompt = true; @@ -3884,7 +3894,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF"; params.model.hf_file = "gte-small-q8_0.gguf"; - params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.embd_normalize = 2; params.n_ctx = 512; params.verbose_prompt = true; diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index b3362519a6..7365782e7d 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -3,9 +3,12 @@ #include "log.h" #include "regex-partial.h" +#include +#include #include #include #include +#include #include using json = nlohmann::ordered_json; @@ -166,6 +169,27 @@ void common_chat_msg_parser::consume_literal(const std::string & literal) { } bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) { + std::string pending_reasoning_prefix; + + if (syntax_.reasoning_format == COMMON_REASONING_FORMAT_NONE) { + return false; + } + + auto set_reasoning_prefix = [&](size_t prefix_pos) { + if (!syntax_.thinking_forced_open || syntax_.reasoning_in_content) { + return; + } + if (prefix_pos + start_think.size() > input_.size()) { + pending_reasoning_prefix.clear(); + return; + } + // Capture the exact literal that opened the reasoning section so we can + // surface it back to callers. This ensures formats that force the + // reasoning tag open (e.g. DeepSeek R1) retain their original prefix + // instead of dropping it during parsing. + pending_reasoning_prefix = input_.substr(prefix_pos, start_think.size()); + }; + auto handle_reasoning = [&](const std::string & reasoning, bool closed) { auto stripped_reasoning = string_strip(reasoning); if (stripped_reasoning.empty()) { @@ -178,28 +202,116 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "" : end_think); } } else { + if (!pending_reasoning_prefix.empty()) { + add_reasoning_content(pending_reasoning_prefix); + pending_reasoning_prefix.clear(); + } add_reasoning_content(stripped_reasoning); } }; - if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) { - if (syntax_.thinking_forced_open || try_consume_literal(start_think)) { - if (auto res = try_find_literal(end_think)) { - handle_reasoning(res->prelude, /* closed */ true); - consume_spaces(); - return true; - } - auto rest = consume_rest(); + + const size_t saved_pos = pos_; + const size_t saved_content_size = result_.content.size(); + const size_t saved_reasoning_size = result_.reasoning_content.size(); + + auto restore_state = [&]() { + move_to(saved_pos); + result_.content.resize(saved_content_size); + result_.reasoning_content.resize(saved_reasoning_size); + }; + + // Allow leading whitespace to be preserved as content when reasoning is present at the start + size_t cursor = pos_; + size_t whitespace_end = cursor; + while (whitespace_end < input_.size() && std::isspace(static_cast(input_[whitespace_end]))) { + ++whitespace_end; + } + + if (whitespace_end >= input_.size()) { + restore_state(); + if (syntax_.thinking_forced_open) { + auto rest = input_.substr(saved_pos); if (!rest.empty()) { handle_reasoning(rest, /* closed */ !is_partial()); } - // Allow unclosed thinking tags, for now (https://github.com/ggml-org/llama.cpp/issues/13812, https://github.com/ggml-org/llama.cpp/issues/13877) - // if (!syntax_.thinking_forced_open) { - // throw common_chat_msg_partial_exception(end_think); - // } + move_to(input_.size()); return true; } + return false; + } + + cursor = whitespace_end; + const size_t remaining = input_.size() - cursor; + const size_t start_prefix = std::min(start_think.size(), remaining); + const bool has_start_tag = input_.compare(cursor, start_prefix, start_think, 0, start_prefix) == 0; + + if (has_start_tag && start_prefix < start_think.size()) { + move_to(input_.size()); + return true; + } + + if (has_start_tag) { + if (whitespace_end > pos_) { + add_content(input_.substr(pos_, whitespace_end - pos_)); + } + set_reasoning_prefix(cursor); + cursor += start_think.size(); + } else if (syntax_.thinking_forced_open) { + cursor = whitespace_end; + } else { + restore_state(); + return false; + } + while (true) { + if (cursor >= input_.size()) { + move_to(input_.size()); + return true; + } + + size_t end_pos = input_.find(end_think, cursor); + if (end_pos == std::string::npos) { + std::string_view remaining_view(input_.data() + cursor, input_.size() - cursor); + size_t partial_off = string_find_partial_stop(remaining_view, end_think); + size_t reasoning_end = partial_off == std::string::npos ? input_.size() : cursor + partial_off; + if (reasoning_end > cursor) { + handle_reasoning(input_.substr(cursor, reasoning_end - cursor), /* closed */ partial_off == std::string::npos && !is_partial()); + } + move_to(input_.size()); + return true; + } + + if (end_pos > cursor) { + handle_reasoning(input_.substr(cursor, end_pos - cursor), /* closed */ true); + } else { + handle_reasoning("", /* closed */ true); + } + + cursor = end_pos + end_think.size(); + + while (cursor < input_.size() && std::isspace(static_cast(input_[cursor]))) { + ++cursor; + } + + const size_t next_remaining = input_.size() - cursor; + if (next_remaining == 0) { + move_to(cursor); + return true; + } + + const size_t next_prefix = std::min(start_think.size(), next_remaining); + if (input_.compare(cursor, next_prefix, start_think, 0, next_prefix) == 0) { + if (next_prefix < start_think.size()) { + move_to(input_.size()); + return true; + } + set_reasoning_prefix(cursor); + cursor += start_think.size(); + continue; + } + + move_to(cursor); + return true; } - return false; } std::string common_chat_msg_parser::consume_rest() { diff --git a/common/chat.cpp b/common/chat.cpp index afbb2a2bdd..8587140e1f 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1408,6 +1408,8 @@ static common_chat_params common_chat_params_init_apertus(const common_chat_temp return data; } static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { builder.add_content(builder.consume_rest()); return; @@ -2862,6 +2864,7 @@ common_chat_params common_chat_templates_apply( } static void common_chat_parse_content_only(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); builder.add_content(builder.consume_rest()); } diff --git a/common/chat.h b/common/chat.h index a1afe574bd..f7b36ec711 100644 --- a/common/chat.h +++ b/common/chat.h @@ -33,8 +33,8 @@ struct common_chat_msg_content_part { struct common_chat_msg { std::string role; std::string content; - std::vector content_parts = {}; - std::vector tool_calls = {}; + std::vector content_parts; + std::vector tool_calls; std::string reasoning_content; std::string tool_name; std::string tool_call_id; @@ -44,7 +44,7 @@ struct common_chat_msg { bool empty() const { return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); } - void ensure_tool_call_ids_set(std::vector & ids_cache, const std::function & gen_tool_call_id) { + void set_tool_call_ids(std::vector & ids_cache, const std::function & gen_tool_call_id) { for (auto i = 0u; i < tool_calls.size(); i++) { if (ids_cache.size() <= i) { auto id = tool_calls[i].id; diff --git a/common/common.cpp b/common/common.cpp index c1e736c44c..b0591e84b0 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1133,6 +1133,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; mparams.use_extra_bufts = !params.no_extra_bufts; + mparams.no_host = params.no_host; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; diff --git a/common/common.h b/common/common.h index d33788bd10..040a44ebd8 100644 --- a/common/common.h +++ b/common/common.h @@ -378,7 +378,7 @@ struct common_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool no_perf = false; // disable performance metrics - bool ctx_shift = false; // context shift on infinite text generation + bool ctx_shift = false; // context shift on infinite text generation bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) bool kv_unified = false; // enable unified KV cache @@ -392,6 +392,7 @@ struct common_params { bool check_tensors = false; // validate tensor data bool no_op_offload = false; // globally disable offload host tensor operations to device bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking) + bool no_host = false; // bypass host buffer allowing extra buffers to be used bool single_turn = false; // single turn chat conversation @@ -424,7 +425,8 @@ struct common_params { int32_t timeout_write = timeout_read; // http write timeout in seconds int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting - int32_t n_ctx_checkpoints = 3; // max number of context checkpoints per slot + int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot + int32_t cache_ram_mib = 8192; // 0 = no limit, 1 = 1 MiB, etc. std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT @@ -432,7 +434,7 @@ struct common_params { std::string chat_template = ""; // NOLINT bool use_jinja = false; // NOLINT bool enable_chat_template = true; - common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; int reasoning_budget = -1; bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2fc04173aa..8c5132193e 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -93,13 +93,15 @@ class ModelBase: # Mistral format specifics is_mistral_format: bool = False disable_mistral_community_chat_template: bool = False + sentence_transformers_dense_modules: bool = False def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None, - disable_mistral_community_chat_template: bool = False): + disable_mistral_community_chat_template: bool = False, + sentence_transformers_dense_modules: bool = False): if type(self) is ModelBase or \ type(self) is TextModel or \ type(self) is MmprojModel: @@ -114,6 +116,7 @@ class ModelBase: self.lazy = not eager or (remote_hf_model_id is not None) self.dry_run = dry_run self.remote_hf_model_id = remote_hf_model_id + self.sentence_transformers_dense_modules = sentence_transformers_dense_modules if remote_hf_model_id is not None: self.is_safetensors = True @@ -891,6 +894,9 @@ class TextModel(ModelBase): if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206": # ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base res = "llada-moe" + if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e": + # ref: https://huggingface.co/ibm-granite/granite-docling-258M + res = "granite-docling" if res is None: logger.warning("\n") @@ -1325,6 +1331,7 @@ class MmprojModel(ModelBase): self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) # load preprocessor config + self.preprocessor_config = {} if not self.is_mistral_format: with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: self.preprocessor_config = json.load(f) @@ -1347,7 +1354,8 @@ class MmprojModel(ModelBase): self.gguf_writer.add_vision_projection_dim(self.n_embd_text) # vision config - self.gguf_writer.add_vision_image_size(self.find_vparam(["image_size"])) + self.image_size = self.find_vparam(["image_size"]) + self.gguf_writer.add_vision_image_size(self.image_size) self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"])) self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"])) self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"])) @@ -2378,6 +2386,10 @@ class SmolVLMModel(MmprojModel): self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2)) self.gguf_writer.add_vision_use_gelu(True) + # Add the preprocessor longest edge size + preproc_image_size = self.preprocessor_config.get("size", {}).get("longest_edge", self.image_size) + self.gguf_writer.add_vision_preproc_image_size(preproc_image_size) + def tensor_force_quant(self, name, new_name, bid, n_dims): if ".embeddings." in name: return gguf.GGMLQuantizationType.F32 @@ -5260,6 +5272,53 @@ class Gemma3Model(TextModel): @ModelBase.register("Gemma3TextModel") class EmbeddingGemma(Gemma3Model): model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING + module_paths = [] + dense_features_dims = {} + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.sentence_transformers_dense_modules: + # read modules.json to determine if model has Dense layers + modules_file = self.dir_model / "modules.json" + if modules_file.is_file(): + with open(modules_file, encoding="utf-8") as modules_json_file: + mods = json.load(modules_json_file) + for mod in mods: + if mod["type"] == "sentence_transformers.models.Dense": + mod_path = mod["path"] + # check if model.safetensors file for Dense layer exists + model_tensors_file = self.dir_model / mod_path / "model.safetensors" + if model_tensors_file.is_file(): + self.module_paths.append(mod_path) + # read config.json of the Dense layer to get in/out features + mod_conf_file = self.dir_model / mod_path / "config.json" + if mod_conf_file.is_file(): + with open(mod_conf_file, encoding="utf-8") as mod_conf_json_file: + mod_conf = json.load(mod_conf_json_file) + # hparams dense_2_feat_out and dense_3_feat_in are required when loading model's dense weights + prefix = self._get_dense_prefix(mod_path) + if mod_conf["in_features"] is not None and mod_conf["out_features"] is not None: + self.dense_features_dims[prefix] = (mod_conf["in_features"], mod_conf["out_features"]) + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + from safetensors.torch import load_file + module_paths = list(self.module_paths) + for i, module_path in enumerate(module_paths): + tensors_file = self.dir_model / module_path / "model.safetensors" + local_tensors = load_file(tensors_file) + tensor_name = self._get_dense_prefix(module_path) + for name, local_tensor in local_tensors.items(): + if not name.endswith(".weight"): + continue + orig_name = name.replace("linear", tensor_name) + name = self.map_tensor_name(orig_name) + yield name, local_tensor.clone() + + @staticmethod + def _get_dense_prefix(module_path) -> str: + """Get the tensor name prefix for the Dense layer from module path.""" + tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3" + return tensor_name def set_gguf_parameters(self): super().set_gguf_parameters() @@ -5276,6 +5335,10 @@ class EmbeddingGemma(Gemma3Model): logger.info(f"Using original sliding_window from config: {orig_sliding_window} " f"instead of {self.hparams['sliding_window']}") self.gguf_writer.add_sliding_window(orig_sliding_window) + if self.sentence_transformers_dense_modules: + for dense, dims in self.dense_features_dims.items(): + logger.info(f"Setting dense layer {dense} in/out features to {dims}") + self.gguf_writer.add_dense_features_dims(dense, dims[0], dims[1]) self._try_set_pooling_type() @@ -5903,20 +5966,12 @@ class Mamba2Model(TextModel): class JambaModel(TextModel): model_arch = gguf.MODEL_ARCH.JAMBA - def get_vocab_base_pre(self, tokenizer) -> str: - del tokenizer # unused - - return "gpt-2" - def set_vocab(self): if (self.dir_model / "tokenizer.model").is_file(): - # Using Jamba's tokenizer.json causes errors on model load - # (something about "byte not found in vocab"), - # but there's a working tokenizer.model self._set_vocab_sentencepiece() else: - # Some Jamba models only have a tokenizer.json, which works. - self._set_vocab_gpt2() + self._set_vocab_llama_hf() + self.gguf_writer.add_add_space_prefix(False) def set_gguf_parameters(self): d_model = self.find_hparam(["hidden_size", "mamba_d_model"]) @@ -8827,6 +8882,75 @@ class LFM2Model(TextModel): return [(self.map_tensor_name(name), data_torch)] +@ModelBase.register("Lfm2MoeForCausalLM") +class LFM2MoeModel(TextModel): + model_arch = gguf.MODEL_ARCH.LFM2MOE + + def set_gguf_parameters(self): + # set num_key_value_heads only for attention layers + self.hparams["num_key_value_heads"] = [ + self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0 + for layer_type in self.hparams["layer_types"] + ] + + super().set_gguf_parameters() + + self.gguf_writer.add_expert_count(self.hparams["num_experts"]) + self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"]) + self.gguf_writer.add_leading_dense_block_count(self.hparams["num_dense_layers"]) + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + self.gguf_writer.add_shortconv_l_cache(self.hparams["conv_L_cache"]) + + # cache for experts weights for merging + _experts_cache: dict[int, dict[str, Tensor]] = {} + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # conv op requires 2d tensor + if 'conv.conv' in name: + data_torch = data_torch.squeeze(1) + + if name.endswith(".expert_bias"): + name = name.replace(".expert_bias", ".expert_bias.bias") + + # merge expert weights + if 'experts' in name: + n_experts = self.hparams["num_experts"] + assert bid is not None + + expert_cache = self._experts_cache.setdefault(bid, {}) + expert_cache[name] = data_torch + expert_weights = ["w1", "w2", "w3"] + + # not enough expert weights to merge + if len(expert_cache) < n_experts * len(expert_weights): + return [] + + tensors: list[tuple[str, Tensor]] = [] + for w_name in expert_weights: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{w_name}.weight" + datas.append(expert_cache[ename]) + del expert_cache[ename] + + data_torch = torch.stack(datas, dim=0) + merged_name = f"layers.{bid}.feed_forward.experts.{w_name}.weight" + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + + del self._experts_cache[bid] + return tensors + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + assert not self._experts_cache + + @ModelBase.register("Lfm2VlForConditionalGeneration") class LFM2VLModel(MmprojModel): def __init__(self, *args, **kwargs): @@ -9257,6 +9381,13 @@ def parse_args() -> argparse.Namespace: ) ) + parser.add_argument( + "--sentence-transformers-dense-modules", action="store_true", + help=("Whether to include sentence-transformers dense modules." + "It can be used for sentence-transformers models, like google/embeddinggemma-300m" + "Default these modules are not included.") + ) + args = parser.parse_args() if not args.print_supported_models and args.model is None: parser.error("the following arguments are required: model") @@ -9319,9 +9450,13 @@ def main() -> None: if args.remote: hf_repo_id = args.model from huggingface_hub import snapshot_download + allowed_patterns = ["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"] + if args.sentence_transformers_dense_modules: + # include sentence-transformers dense modules safetensors files + allowed_patterns.append("*.safetensors") local_dir = snapshot_download( repo_id=hf_repo_id, - allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]) + allow_patterns=allowed_patterns) dir_model = Path(local_dir) logger.info(f"Downloaded config and tokenizer to {local_dir}") else: @@ -9389,7 +9524,8 @@ def main() -> None: split_max_tensors=args.split_max_tensors, split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, small_first_shard=args.no_tensor_first_split, - remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template + remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template, + sentence_transformers_dense_modules=args.sentence_transformers_dense_modules ) if args.vocab_only: diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 21bb4a9f3e..28002f766e 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -140,6 +140,7 @@ models = [ {"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", }, {"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", }, {"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", }, + {"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", }, ] # some models are known to be broken upstream, so we will skip them as exceptions diff --git a/examples/model-conversion/Makefile b/examples/model-conversion/Makefile index f0867cfe46..25b0514b29 100644 --- a/examples/model-conversion/Makefile +++ b/examples/model-conversion/Makefile @@ -116,20 +116,39 @@ embedding-convert-model: METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \ ./scripts/embedding/convert-model.sh +embedding-convert-model-st: + $(call validate_embedding_model_path,embedding-convert-model-st) + @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \ + METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \ + ./scripts/embedding/convert-model.sh -st + embedding-run-original-model: $(call validate_embedding_model_path,embedding-run-original-model) @EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \ + USE_SENTENCE_TRANSFORMERS="$(USE_SENTENCE_TRANSFORMERS)" \ ./scripts/embedding/run-original-model.py \ - $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") + $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \ + $(if $(USE_SENTENCE_TRANSFORMERS),--use-sentence-transformers) + +embedding-run-original-model-st: USE_SENTENCE_TRANSFORMERS=1 +embedding-run-original-model-st: embedding-run-original-model embedding-run-converted-model: @./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \ - $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") + $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \ + $(if $(USE_POOLING),--pooling) + +embedding-run-converted-model-st: USE_POOLING=1 +embedding-run-converted-model-st: embedding-run-converted-model embedding-verify-logits: embedding-run-original-model embedding-run-converted-model @./scripts/embedding/compare-embeddings-logits.sh \ $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") +embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model-st + @./scripts/embedding/compare-embeddings-logits.sh \ + $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") + embedding-inspect-original-model: $(call validate_embedding_model_path,embedding-inspect-original-model) @EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH} diff --git a/examples/model-conversion/README.md b/examples/model-conversion/README.md index e95e05cd37..05d95d588b 100644 --- a/examples/model-conversion/README.md +++ b/examples/model-conversion/README.md @@ -189,6 +189,23 @@ This command will save two files to the `data` directory, one is a binary file containing logits which will be used for comparison with the converted model, and the other is a text file which allows for manual visual inspection. +#### Using SentenceTransformer with numbered layers +For models that have numbered SentenceTransformer layers (01_Pooling, 02_Dense, +03_Dense, 04_Normalize), use the `-st` targets to apply all these layers: + +```console +# Run original model with SentenceTransformer (applies all numbered layers) +(venv) $ make embedding-run-original-model-st + +# Run converted model with pooling enabled +(venv) $ make embedding-run-converted-model-st +``` + +This will use the SentenceTransformer library to load and run the model, which +automatically applies all the numbered layers in the correct order. This is +particularly useful when comparing with models that should include these +additional transformation layers beyond just the base model output. + ### Model conversion After updates have been made to [gguf-py](../../gguf-py) to add support for the new model the model can be converted to GGUF format using the following command: @@ -208,6 +225,13 @@ was done manually in the previous steps) and compare the logits: (venv) $ make embedding-verify-logits ``` +For models with SentenceTransformer layers, use the `-st` verification target: +```console +(venv) $ make embedding-verify-logits-st +``` +This convenience target automatically runs both the original model with SentenceTransformer +and the converted model with pooling enabled, then compares the results. + ### llama-server verification To verify that the converted model works with llama-server, the following command can be used: diff --git a/examples/model-conversion/logits.cpp b/examples/model-conversion/logits.cpp index 6dc334189f..bbd095e603 100644 --- a/examples/model-conversion/logits.cpp +++ b/examples/model-conversion/logits.cpp @@ -1,4 +1,7 @@ #include "llama.h" +#include "common.h" + + #include #include #include @@ -8,7 +11,10 @@ static void print_usage(int, char ** argv) { printf("\nexample usage:\n"); - printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [prompt]\n", argv[0]); + printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm ] [prompt]\n", argv[0]); + printf("\n"); + printf(" -embd-norm: normalization type for pooled embeddings (default: 2)\n"); + printf(" -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n"); printf("\n"); } @@ -17,6 +23,8 @@ int main(int argc, char ** argv) { std::string prompt = "Hello, my name is"; int ngl = 0; bool embedding_mode = false; + bool pooling_enabled = false; + int32_t embd_norm = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) { int i = 1; @@ -41,9 +49,13 @@ int main(int argc, char ** argv) { return 1; } } else if (strcmp(argv[i], "-embd-mode") == 0) { + embedding_mode = true; + } else if (strcmp(argv[i], "-pooling") == 0) { + pooling_enabled = true; + } else if (strcmp(argv[i], "-embd-norm") == 0) { if (i + 1 < argc) { try { - embedding_mode = true; + embd_norm = std::stoi(argv[++i]); } catch (...) { print_usage(argc, argv); return 1; @@ -112,7 +124,7 @@ int main(int argc, char ** argv) { ctx_params.no_perf = false; if (embedding_mode) { ctx_params.embeddings = true; - ctx_params.pooling_type = LLAMA_POOLING_TYPE_NONE; + ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE; ctx_params.n_ubatch = ctx_params.n_batch; } @@ -143,17 +155,27 @@ int main(int argc, char ** argv) { return 1; } - float * logits; - int n_logits; + float * data_ptr; + int data_size; const char * type; + std::vector embd_out; if (embedding_mode) { - logits = llama_get_embeddings(ctx); - n_logits = llama_model_n_embd(model) * batch.n_tokens; + const int n_embd = llama_model_n_embd(model); + const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens; + const int n_embeddings = n_embd * n_embd_count; + float * embeddings; type = "-embeddings"; - const int n_embd = llama_model_n_embd(model); - const int n_embd_count = batch.n_tokens; + if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) { + embeddings = llama_get_embeddings_seq(ctx, 0); + embd_out.resize(n_embeddings); + printf("Normalizing embeddings using norm: %d\n", embd_norm); + common_embd_normalize(embeddings, embd_out.data(), n_embeddings, embd_norm); + embeddings = embd_out.data(); + } else { + embeddings = llama_get_embeddings(ctx); + } printf("Embedding dimension: %d\n", n_embd); printf("\n"); @@ -164,7 +186,7 @@ int main(int argc, char ** argv) { // Print first 3 values for (int i = 0; i < 3 && i < n_embd; i++) { - printf("%9.6f ", logits[j * n_embd + i]); + printf("%9.6f ", embeddings[j * n_embd + i]); } printf(" ... "); @@ -172,7 +194,7 @@ int main(int argc, char ** argv) { // Print last 3 values for (int i = n_embd - 3; i < n_embd; i++) { if (i >= 0) { - printf("%9.6f ", logits[j * n_embd + i]); + printf("%9.6f ", embeddings[j * n_embd + i]); } } @@ -180,27 +202,33 @@ int main(int argc, char ** argv) { } printf("\n"); - printf("Embeddings size: %d\n", n_logits); + printf("Embeddings size: %d\n", n_embeddings); + + data_ptr = embeddings; + data_size = n_embeddings; } else { - logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - n_logits = llama_vocab_n_tokens(vocab); + float * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + const int n_logits = llama_vocab_n_tokens(vocab); type = ""; printf("Vocab size: %d\n", n_logits); + + data_ptr = logits; + data_size = n_logits; } std::filesystem::create_directory("data"); - // Save logits to binary file + // Save data to binary file char bin_filename[512]; snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type); - printf("Saving logits to %s\n", bin_filename); + printf("Saving data to %s\n", bin_filename); FILE * f = fopen(bin_filename, "wb"); if (f == NULL) { fprintf(stderr, "%s: error: failed to open binary output file\n", __func__); return 1; } - fwrite(logits, sizeof(float), n_logits, f); + fwrite(data_ptr, sizeof(float), data_size, f); fclose(f); // Also save as text for debugging @@ -211,27 +239,27 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: error: failed to open text output file\n", __func__); return 1; } - for (int i = 0; i < n_logits; i++) { - fprintf(f, "%d: %.6f\n", i, logits[i]); + for (int i = 0; i < data_size; i++) { + fprintf(f, "%d: %.6f\n", i, data_ptr[i]); } fclose(f); if (!embedding_mode) { printf("First 10 logits: "); - for (int i = 0; i < 10 && i < n_logits; i++) { - printf("%.6f ", logits[i]); + for (int i = 0; i < 10 && i < data_size; i++) { + printf("%.6f ", data_ptr[i]); } printf("\n"); printf("Last 10 logits: "); - for (int i = n_logits - 10; i < n_logits; i++) { - if (i >= 0) printf("%.6f ", logits[i]); + for (int i = data_size - 10; i < data_size; i++) { + if (i >= 0) printf("%.6f ", data_ptr[i]); } printf("\n\n"); } - printf("Logits saved to %s\n", bin_filename); - printf("Logits saved to %s\n", txt_filename); + printf("Data saved to %s\n", bin_filename); + printf("Data saved to %s\n", txt_filename); llama_free(ctx); llama_model_free(model); diff --git a/examples/model-conversion/requirements.txt b/examples/model-conversion/requirements.txt index ac9f69e10b..229b2ec75b 100644 --- a/examples/model-conversion/requirements.txt +++ b/examples/model-conversion/requirements.txt @@ -4,3 +4,4 @@ torchvision transformers huggingface-hub accelerate +sentence-transformers diff --git a/examples/model-conversion/scripts/embedding/convert-model.sh b/examples/model-conversion/scripts/embedding/convert-model.sh index 0929e42413..9926350c07 100755 --- a/examples/model-conversion/scripts/embedding/convert-model.sh +++ b/examples/model-conversion/scripts/embedding/convert-model.sh @@ -2,6 +2,21 @@ set -e +# Parse command line arguments +SENTENCE_TRANSFORMERS="" +while [[ $# -gt 0 ]]; do + case $1 in + -st|--sentence-transformers) + SENTENCE_TRANSFORMERS="--sentence-transformers-dense-modules" + shift + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + MODEL_NAME="${MODEL_NAME:-$(basename "$EMBEDDING_MODEL_PATH")}" OUTPUT_DIR="${OUTPUT_DIR:-../../models}" TYPE="${OUTTYPE:-f16}" @@ -15,7 +30,8 @@ echo "Converted model path:: ${CONVERTED_MODEL}" python ../../convert_hf_to_gguf.py --verbose \ ${EMBEDDING_MODEL_PATH} \ --outfile ${CONVERTED_MODEL} \ - --outtype ${TYPE} + --outtype ${TYPE} \ + ${SENTENCE_TRANSFORMERS} echo "" echo "The environment variable CONVERTED_EMBEDDING MODEL can be set to this path using:" diff --git a/examples/model-conversion/scripts/embedding/run-converted-model.sh b/examples/model-conversion/scripts/embedding/run-converted-model.sh index f3e2676632..0f490e6c3b 100755 --- a/examples/model-conversion/scripts/embedding/run-converted-model.sh +++ b/examples/model-conversion/scripts/embedding/run-converted-model.sh @@ -5,6 +5,7 @@ set -e # Parse command line arguments CONVERTED_MODEL="" PROMPTS_FILE="" +USE_POOLING="" while [[ $# -gt 0 ]]; do case $1 in @@ -12,6 +13,10 @@ while [[ $# -gt 0 ]]; do PROMPTS_FILE="$2" shift 2 ;; + --pooling) + USE_POOLING="1" + shift + ;; *) if [ -z "$CONVERTED_MODEL" ]; then CONVERTED_MODEL="$1" @@ -47,4 +52,8 @@ echo $CONVERTED_MODEL cmake --build ../../build --target llama-logits -j8 # TODO: update logits.cpp to accept a --file/-f option for the prompt -../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT" +if [ -n "$USE_POOLING" ]; then + ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT" +else + ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT" +fi diff --git a/examples/model-conversion/scripts/embedding/run-original-model.py b/examples/model-conversion/scripts/embedding/run-original-model.py index 4a3e162413..640e200a97 100755 --- a/examples/model-conversion/scripts/embedding/run-original-model.py +++ b/examples/model-conversion/scripts/embedding/run-original-model.py @@ -14,6 +14,8 @@ unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') parser = argparse.ArgumentParser(description='Process model with specified path') parser.add_argument('--model-path', '-m', help='Path to the model') parser.add_argument('--prompts-file', '-p', help='Path to file containing prompts (one per line)') +parser.add_argument('--use-sentence-transformers', action='store_true', + help='Use SentenceTransformer to apply all numbered layers (01_Pooling, 02_Dense, 03_Dense, 04_Normalize)') args = parser.parse_args() def read_prompt_from_file(file_path): @@ -31,41 +33,52 @@ model_path = os.environ.get('EMBEDDING_MODEL_PATH', args.model_path) if model_path is None: parser.error("Model path must be specified either via --model-path argument or EMBEDDING_MODEL_PATH environment variable") -tokenizer = AutoTokenizer.from_pretrained(model_path) +# Determine if we should use SentenceTransformer +use_sentence_transformers = args.use_sentence_transformers or os.environ.get('USE_SENTENCE_TRANSFORMERS', '').lower() in ('1', 'true', 'yes') -config = AutoConfig.from_pretrained(model_path) - -# This can be used to override the sliding window size for manual testing. This -# can be useful to verify the sliding window attention mask in the original model -# and compare it with the converted .gguf model. -if hasattr(config, 'sliding_window'): - original_sliding_window = config.sliding_window - #original_sliding_window = 6 - print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}") - -print(f"Using unreleased model: {unreleased_model_name}") -if unreleased_model_name: - model_name_lower = unreleased_model_name.lower() - unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" - class_name = f"{unreleased_model_name}Model" - print(f"Importing unreleased model module: {unreleased_module_path}") - - try: - model_class = getattr(importlib.import_module(unreleased_module_path), class_name) - model = model_class.from_pretrained(model_path, config=config) - except (ImportError, AttributeError) as e: - print(f"Failed to import or load model: {e}") - exit(1) +if use_sentence_transformers: + from sentence_transformers import SentenceTransformer + print("Using SentenceTransformer to apply all numbered layers") + model = SentenceTransformer(model_path) + tokenizer = model.tokenizer + config = model[0].auto_model.config # type: ignore else: - model = AutoModel.from_pretrained(model_path, config=config) -print(f"Model class: {type(model)}") -print(f"Model file: {type(model).__module__}") + tokenizer = AutoTokenizer.from_pretrained(model_path) + + config = AutoConfig.from_pretrained(model_path) + + # This can be used to override the sliding window size for manual testing. This + # can be useful to verify the sliding window attention mask in the original model + # and compare it with the converted .gguf model. + if hasattr(config, 'sliding_window'): + original_sliding_window = config.sliding_window + #original_sliding_window = 6 + print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}") + + print(f"Using unreleased model: {unreleased_model_name}") + if unreleased_model_name: + model_name_lower = unreleased_model_name.lower() + unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" + class_name = f"{unreleased_model_name}Model" + print(f"Importing unreleased model module: {unreleased_module_path}") + + try: + model_class = getattr(importlib.import_module(unreleased_module_path), class_name) + model = model_class.from_pretrained(model_path, config=config) + except (ImportError, AttributeError) as e: + print(f"Failed to import or load model: {e}") + exit(1) + else: + model = AutoModel.from_pretrained(model_path, config=config) + print(f"Model class: {type(model)}") + print(f"Model file: {type(model).__module__}") # Verify the model is using the correct sliding window -if hasattr(model.config, 'sliding_window'): - print(f"Model's sliding_window: {model.config.sliding_window}") -else: - print("Model config does not have sliding_window attribute") +if not use_sentence_transformers: + if hasattr(model.config, 'sliding_window'): # type: ignore + print(f"Model's sliding_window: {model.config.sliding_window}") # type: ignore + else: + print("Model config does not have sliding_window attribute") model_name = os.path.basename(model_path) @@ -75,34 +88,56 @@ if args.prompts_file: else: texts = ["Hello world today"] -encoded = tokenizer( - texts, - padding=True, - truncation=True, - return_tensors="pt" -) - -tokens = encoded['input_ids'][0] -token_strings = tokenizer.convert_ids_to_tokens(tokens) -for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): - print(f"{token_id:6d} -> '{token_str}'") - with torch.no_grad(): - outputs = model(**encoded) - hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size] + if use_sentence_transformers: + embeddings = model.encode(texts, convert_to_numpy=True) + all_embeddings = embeddings # Shape: [batch_size, hidden_size] - # Extract embeddings for each token (matching LLAMA_POOLING_TYPE_NONE behavior) - all_embeddings = hidden_states[0].cpu().numpy() # Shape: [seq_len, hidden_size] + encoded = tokenizer( + texts, + padding=True, + truncation=True, + return_tensors="pt" + ) + tokens = encoded['input_ids'][0] + token_strings = tokenizer.convert_ids_to_tokens(tokens) + for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): + print(f"{token_id:6d} -> '{token_str}'") - print(f"Hidden states shape: {hidden_states.shape}") - print(f"All embeddings shape: {all_embeddings.shape}") - print(f"Embedding dimension: {all_embeddings.shape[1]}") + print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}") + print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}") # type: ignore + else: + # Standard approach: use base model output only + encoded = tokenizer( + texts, + padding=True, + truncation=True, + return_tensors="pt" + ) - # Print embeddings exactly like embedding.cpp does for LLAMA_POOLING_TYPE_NONE - n_embd = all_embeddings.shape[1] - n_embd_count = all_embeddings.shape[0] + tokens = encoded['input_ids'][0] + token_strings = tokenizer.convert_ids_to_tokens(tokens) + for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): + print(f"{token_id:6d} -> '{token_str}'") - print() # Empty line to match C++ output + outputs = model(**encoded) + hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size] + + all_embeddings = hidden_states[0].cpu().numpy() # Shape: [seq_len, hidden_size] + + print(f"Hidden states shape: {hidden_states.shape}") + print(f"All embeddings shape: {all_embeddings.shape}") + print(f"Embedding dimension: {all_embeddings.shape[1]}") + + if len(all_embeddings.shape) == 1: + n_embd = all_embeddings.shape[0] # type: ignore + n_embd_count = 1 + all_embeddings = all_embeddings.reshape(1, -1) + else: + n_embd = all_embeddings.shape[1] # type: ignore + n_embd_count = all_embeddings.shape[0] # type: ignore + + print() for j in range(n_embd_count): embedding = all_embeddings[j] @@ -120,29 +155,23 @@ with torch.no_grad(): print() # New line - print() # Final empty line to match C++ output + print() data_dir = Path("data") data_dir.mkdir(exist_ok=True) bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin" txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt" - # Save all embeddings flattened (matching what embedding.cpp would save if it did) flattened_embeddings = all_embeddings.flatten() flattened_embeddings.astype(np.float32).tofile(bin_filename) with open(txt_filename, "w") as f: - f.write(f"# Model class: {model_name}\n") - f.write(f"# Tokens: {token_strings}\n") - f.write(f"# Shape: {all_embeddings.shape}\n") - f.write(f"# n_embd_count: {n_embd_count}, n_embd: {n_embd}\n\n") - + idx = 0 for j in range(n_embd_count): - f.write(f"# Token {j} ({token_strings[j]}):\n") - for i, value in enumerate(all_embeddings[j]): - f.write(f"{j}_{i}: {value:.6f}\n") - f.write("\n") - print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} tokens × {n_embd} dimensions)") + for value in all_embeddings[j]: + f.write(f"{idx}: {value:.6f}\n") + idx += 1 + print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings × {n_embd} dimensions)") print("") print(f"Saved bin embeddings to: {bin_filename}") print(f"Saved txt embeddings to: {txt_filename}") diff --git a/examples/model-conversion/scripts/utils/semantic_check.py b/examples/model-conversion/scripts/utils/semantic_check.py index 7fd417bcea..2ac8b6b7b4 100644 --- a/examples/model-conversion/scripts/utils/semantic_check.py +++ b/examples/model-conversion/scripts/utils/semantic_check.py @@ -35,7 +35,11 @@ def cosine_similarity(a, b=None): def load_embeddings_from_file(filename, n_tokens, n_embd): embeddings = np.fromfile(filename, dtype=np.float32) - return embeddings.reshape(n_tokens, n_embd) + # Check if this is pooled (single embedding) or per-token embeddings + if len(embeddings) == n_embd: + return embeddings.reshape(1, n_embd) + else: + return embeddings.reshape(n_tokens, n_embd) def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt): np.set_printoptions(suppress=True, precision=6) @@ -48,58 +52,83 @@ def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt): print(f"Embeddings shape: Python {python_emb.shape}, llama.cpp {cpp_emb.shape}") n_tokens = len(tokens) + is_pooled = python_emb.shape[0] == 1 - # 1. Direct embedding comparison - print(f"\n1. Raw Embedding Magnitude Comparison:") - # Check if the distance of each token embedding from the origin and compare - # if the vectors are on the same "sphere". This does not tell us about - # direction (meaning of the token embedding), just magnitude. - for i in range(n_tokens): - py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings - cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings + if is_pooled: + print(f"\n[Pooled Embeddings Mode - comparing single sentence embeddings]") + + # 1. Direct embedding comparison for pooled embeddings + print(f"\n1. Raw Embedding Magnitude Comparison:") + py_mag = np.linalg.norm(python_emb[0]) + cpp_mag = np.linalg.norm(cpp_emb[0]) ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf') - print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}") + print(f" Pooled embedding: Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}") - # 2. Cosine similarity between tokens within each model - # Here we check the direction of token embeddings to see if the have the - # same meaning (similarity). This is done by calculating cosine similarity - # of a pair of token embeddings within each model. - print(f"\n2. Within-Model Token Similarities:") - print(" Python model:") - for i in range(n_tokens): - for j in range(i+1, n_tokens): - sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0] - print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}") + # 2. Cross-model similarity for pooled embeddings + print(f"\n2. Cross-Model Pooled Embedding Similarity:") + sim = cosine_similarity([python_emb[0]], [cpp_emb[0]])[0][0] + print(f" Cosine similarity: {sim:.6f}") - print(" llama.cpp model:") - for i in range(n_tokens): - for j in range(i+1, n_tokens): - sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0] - print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}") + return { + 'cross_model_similarities': [sim], + 'similarity_matrix_diff': np.array([[0.0]]), + 'max_diff': 0.0, + 'mean_diff': 0.0, + 'rms_diff': 0.0 + } + else: + # Original per-token comparison logic + # 1. Direct embedding comparison + print(f"\n1. Raw Embedding Magnitude Comparison:") + # Check if the distance of each token embedding from the origin and compare + # if the vectors are on the same "sphere". This does not tell us about + # direction (meaning of the token embedding), just magnitude. + for i in range(n_tokens): + py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings + cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings + ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf') + print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}") - # 3. Cross-model similarity (same token position) - print(f"\n3. Cross-Model Same-Token Similarities:") - for i in range(n_tokens): - sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] - print(f" Token {i} ({tokens[i]}): {sim:.4f}") + # 2. Cosine similarity between tokens within each model + # Here we check the direction of token embeddings to see if the have the + # same meaning (similarity). This is done by calculating cosine similarity + # of a pair of token embeddings within each model. + print(f"\n2. Within-Model Token Similarities:") + print(" Python model:") + for i in range(n_tokens): + for j in range(i+1, n_tokens): + sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0] + print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}") - # 4. Similarity matrix comparison - print(f"\n4. Similarity Matrix Differences:") - py_sim_matrix = cosine_similarity(python_emb) - cpp_sim_matrix = cosine_similarity(cpp_emb) - diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix) + print(" llama.cpp model:") + for i in range(n_tokens): + for j in range(i+1, n_tokens): + sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0] + print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}") - print(f" Max difference: {np.max(diff_matrix):.4f}") - print(f" Mean difference: {np.mean(diff_matrix):.4f}") - print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}") + # 3. Cross-model similarity (same token position) + print(f"\n3. Cross-Model Same-Token Similarities:") + for i in range(n_tokens): + sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] + print(f" Token {i} ({tokens[i]}): {sim:.4f}") - return { - 'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)], - 'similarity_matrix_diff': diff_matrix, - 'max_diff': np.max(diff_matrix), - 'mean_diff': np.mean(diff_matrix), - 'rms_diff': np.sqrt(np.mean(diff_matrix**2)) - } + # 4. Similarity matrix comparison + print(f"\n4. Similarity Matrix Differences:") + py_sim_matrix = cosine_similarity(python_emb) + cpp_sim_matrix = cosine_similarity(cpp_emb) + diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix) + + print(f" Max difference: {np.max(diff_matrix):.4f}") + print(f" Mean difference: {np.mean(diff_matrix):.4f}") + print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}") + + return { + 'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)], + 'similarity_matrix_diff': diff_matrix, + 'max_diff': np.max(diff_matrix), + 'mean_diff': np.mean(diff_matrix), + 'rms_diff': np.sqrt(np.mean(diff_matrix**2)) + } def read_prompt_from_file(file_path): try: diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 6ce52ffc66..73032be68e 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -222,6 +222,9 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF) option(GGML_WEBGPU "ggml: use WebGPU" OFF) option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF) +option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)" OFF) +option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF) + option(GGML_ZDNN "ggml: use zDNN" OFF) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 62b6d65e51..f1b7407859 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -215,6 +215,8 @@ extern "C" { // Backend registry // + GGML_API void ggml_backend_register(ggml_backend_reg_t reg); + GGML_API void ggml_backend_device_register(ggml_backend_dev_t device); // Backend (reg) enumeration diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 1e67411276..72eff00273 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -7,26 +7,25 @@ extern "C" { #endif -#define RPC_PROTO_MAJOR_VERSION 2 +#define RPC_PROTO_MAJOR_VERSION 3 #define RPC_PROTO_MINOR_VERSION 0 #define RPC_PROTO_PATCH_VERSION 0 #define GGML_RPC_MAX_SERVERS 16 // backend API -GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint); +GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device); GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend); -GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device); -GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total); +GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total); -GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, - const char * cache_dir, - size_t free_mem, size_t total_mem); +GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, + size_t n_threads, size_t n_devices, + ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); - -GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint); +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint); #ifdef __cplusplus } diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index c8f3d85964..892c23318a 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -145,6 +145,9 @@ endif() # which was introduced in POSIX.1-2008, forcing us to go higher if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") add_compile_definitions(_XOPEN_SOURCE=700) +elseif (CMAKE_SYSTEM_NAME MATCHES "AIX") + # Don't define _XOPEN_SOURCE. We need _ALL_SOURCE, which is the default, + # in order to define _SC_PHYS_PAGES. else() add_compile_definitions(_XOPEN_SOURCE=600) endif() diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index 07784d6f66..6792ba986e 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -209,9 +209,6 @@ extern "C" { void * context; }; - // Internal backend registry API - GGML_API void ggml_backend_register(ggml_backend_reg_t reg); - // Add backend dynamic loading support to the backend // Initialize the backend diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index b707b84359..debbcadc1e 100755 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -341,11 +341,18 @@ private: #ifdef USE_ACL_GRAPH struct ggml_graph_node_properties { + // dst tensor void * node_address; - ggml_op node_op; int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; + + // src tensor void * src_address[GGML_MAX_SRC]; + int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS]; + size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS]; + + // op + ggml_op node_op; int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; }; diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index b51b554e75..ad1adba6b3 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2186,7 +2186,15 @@ static void add_lru_matched_graph_node_properties( std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb); for (int src = 0; src < GGML_MAX_SRC; ++src) { - prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr; + if (node->src[src]) { + prop.src_address[src] = node->src[src]->data; + std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]); + std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]); + } else { + prop.src_address[src] = nullptr; + std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0); + std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0); + } } memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS); @@ -2206,14 +2214,18 @@ static void add_lru_matched_graph_node_properties( * @param graph_node_properties The stored properties of a CANN graph node. * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise. */ -static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { +static bool ggml_graph_node_has_matching_properties( + ggml_tensor * node, + ggml_graph_node_properties * graph_node_properties) { if (node->data != graph_node_properties->node_address && - node->op != GGML_OP_VIEW) { + node->op != GGML_OP_VIEW) { return false; } + if (node->op != graph_node_properties->node_op) { return false; } + for (int i = 0; i < GGML_MAX_DIMS; i++) { if (node->ne[i] != graph_node_properties->ne[i]) { return false; @@ -2222,17 +2234,31 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra return false; } } + for (int i = 0; i < GGML_MAX_SRC; i++) { - if (node->src[i] && - node->src[i]->data != graph_node_properties->src_address[i] && - node->op != GGML_OP_VIEW - ) { - return false; + if (node->src[i]) { + if (node->src[i]->data != graph_node_properties->src_address[i] && + node->op != GGML_OP_VIEW) { + return false; + } + + for (int d = 0; d < GGML_MAX_DIMS; d++) { + if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) { + return false; + } + if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) { + return false; + } + } + } else { + if (graph_node_properties->src_address[i] != nullptr) { + return false; + } } } - if (node->op == GGML_OP_SCALE && - memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { - return false; + + if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) { + return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0; } return true; } diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp index 867e158dca..895a571375 100644 --- a/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ggml/src/ggml-cpu/amx/amx.cpp @@ -149,6 +149,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous is_contiguous_2d(op->src[1]) && // src1 must be contiguous op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() && + op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315) op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) { // src1 must be host buffer diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index 7ba659124c..3eaa5e3f41 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -29,6 +29,108 @@ #define NELEMS(x) sizeof(x) / sizeof(*x) +template +static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) { + return Fn(a, b, c); +} + +template +static inline size_t kernel_offs_fn2(size_t a, size_t b, size_t) { + return Fn(a, b); +} + +template +static inline void kernel_run_fn11(size_t m, size_t n, size_t k, size_t bl, + const void* lhs, const void* rhs, void* dst, + size_t dst_stride_row, size_t dst_stride_col, + float clamp_min, float clamp_max) { + Fn(m, n, k, bl, lhs, rhs, static_cast(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max); +} + +template +static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/, + const void* lhs, const void* rhs, void* dst, + size_t dst_stride_row, size_t dst_stride_col, + float clamp_min, float clamp_max) { + Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max); +} + +template +static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { + return Fn(m, k, bl, mr, kr, sr); +} + +template +static inline size_t lhs_ps_fn5(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) { + return Fn(m, k, mr, kr, sr); +} + +template +static inline size_t lhs_offs_fn6(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { + return Fn(m_idx, k, bl, mr, kr, sr); +} + +template +static inline size_t lhs_offs_fn5(size_t m_idx, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) { + return Fn(m_idx, k, mr, kr, sr); +} + +template +static inline void lhs_pack_float_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, + size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { + Fn(m, k, bl, mr, kr, sr, m_idx_start, static_cast(lhs), lhs_stride, lhs_packed); +} + +template +static inline void lhs_pack_void_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, + size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { + Fn(m, k, bl, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed); +} + +template +static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr, + size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { + Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed); +} + +template +static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) { + return Fn(n, k, nr, kr, bl); +} + +template +static inline size_t rhs_ps_fn2(size_t n, size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) { + return Fn(n, k); +} + +template +static inline size_t rhs_stride_fn4(size_t k, size_t nr, size_t kr, size_t bl) { + return Fn(k, nr, kr, bl); +} + +template +static inline size_t rhs_stride_fn1(size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) { + return Fn(k); +} + +template +static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, + size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* /*scale*/, + void* rhs_packed, size_t extra_bytes, const void* params) { + Fn(num_groups, n, k, nr, kr, sr, bl, + static_cast(rhs), + static_cast(bias), + rhs_packed, extra_bytes, + static_cast(params)); +} + +template +static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/, + size_t rhs_stride, const void* rhs, const void* bias, const void* scale, + void* rhs_packed, size_t extra_bytes, const void* params) { + Fn(num_groups, n, k, nr, kr, sr, rhs_stride, rhs, bias, scale, rhs_packed, extra_bytes, params); +} + static const size_t INT4_PER_BYTE = 2; static const size_t INT4_BITS = 4; static const int Q4_0_ZERO_POINT = 8; @@ -122,17 +224,18 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, + /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* SME GEMV */ /* .kern_info = */ { @@ -142,23 +245,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, - /* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, + /* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_SME, /* .lhs_type = */ GGML_TYPE_F32, @@ -174,17 +278,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_lhs_offset_ex = */ &kernel_offs_fn2, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2, + /* .run_kernel_ex = */ &kernel_run_fn10, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, - /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, + /* .get_packed_offset_ex = */ &lhs_offs_fn5, + /* .packed_size_ex = */ &lhs_ps_fn5, + /* .pack_func_ex = */ &lhs_pack_void_fn9, }, /* SME GEMV */ /* .kern_info = */ { @@ -194,23 +298,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, - /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_lhs_offset_ex = */ nullptr, + /* .get_rhs_packed_offset_ex = */ nullptr, + /* .run_kernel_ex = */ nullptr, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, - /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, + /* .get_packed_offset_ex = */ &lhs_offs_fn5, + /* .packed_size_ex = */ &lhs_ps_fn5, + /* .pack_func_ex = */ &lhs_pack_void_fn9, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - /* .packed_stride = */ NULL, - /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - /* .to_float = */ NULL, + /* .packed_stride = */ nullptr, + /* .to_float = */ nullptr, + /* .packed_size_ex = */ &rhs_ps_fn2, + /* .packed_stride_ex = */ &rhs_stride_fn1, + /* .pack_func_ex = */ &rhs_pack_fn13, }, /* .required_cpu = */ CPU_FEATURE_SME, /* .lhs_type = */ GGML_TYPE_F32, @@ -229,17 +334,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* DOTPROD GEMV */ /* .kern_info = */ { @@ -249,23 +354,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, /* .lhs_type = */ GGML_TYPE_F32, @@ -283,17 +389,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* i8mm GEMV */ /* .kern_info = */ { @@ -303,23 +409,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, @@ -338,17 +445,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* i8mm GEMV */ /* .kern_info = */ { @@ -358,23 +465,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, @@ -392,17 +500,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* DOTPROD GEMV */ /* .kern_info = */ { @@ -412,23 +520,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, - /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, - /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, - /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, - /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, /* .lhs_type = */ GGML_TYPE_F32, @@ -443,6 +552,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c ggml_kleidiai_kernels * kernel = nullptr; if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { +#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type && @@ -452,6 +562,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c break; } } +#endif } return kernel; @@ -460,12 +571,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) { ggml_kleidiai_kernels * kernels = nullptr; +#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) { kernels = &gemm_gemv_kernels[i]; break; } } +#endif return kernels; } diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h index 2ad6ad6fd0..a84795a6b2 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.h +++ b/ggml/src/ggml-cpu/kleidiai/kernels.h @@ -4,8 +4,6 @@ #pragma once -#include -#include #include "ggml.h" enum cpu_feature { @@ -15,6 +13,7 @@ enum cpu_feature { CPU_FEATURE_SVE = 4, CPU_FEATURE_SME = 8 }; + inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) { lhs = static_cast(lhs | rhs); return lhs; @@ -30,63 +29,52 @@ struct kernel_info { size_t (*get_nr)(void); size_t (*get_kr)(void); size_t (*get_sr)(void); - std::variant< - std::function, - std::function - > get_lhs_offset; - std::variant< - std::function, - std::function - > get_rhs_packed_offset; + size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride); size_t (*get_dst_size)(size_t m, size_t n); - std::variant< - std::function, - std::function - > run_kernel; + + size_t (*get_lhs_offset_ex)(size_t m_idx, size_t k, size_t bl); + + size_t (*get_rhs_packed_offset_ex)(size_t n_idx, size_t k, size_t bl); + + void (*run_kernel_ex)( + size_t m, size_t n, size_t k, size_t bl, + const void* lhs_packed, const void* rhs_packed, + void* dst, size_t dst_stride_row, size_t dst_stride_col, + float clamp_min, float clamp_max); }; struct lhs_packing_info { size_t (*get_offset)(size_t m_idx, size_t lhs_stride); - std::variant< - std::function, - std::function - > get_packed_offset; - std::variant< - std::function, - std::function - > packed_size; - std::variant< - std::function, - std::function - > pack_func; + + size_t (*get_packed_offset_ex)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + + size_t (*packed_size_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); + + void (*pack_func_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, + size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed); }; struct rhs_packing_info { - std::variant< - std::function, - std::function - > packed_size; size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl); - std::variant< - std::function, - std::function - > pack_func; - void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride, - size_t kr, size_t bl, size_t num_bytes_multiplier); + + void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, + size_t nr_pack, size_t packed_row_stride, size_t kr, size_t bl, + size_t num_bytes_multiplier); + + size_t (*packed_size_ex)(size_t n, size_t k, size_t nr, size_t kr, size_t bl); + + size_t (*packed_stride_ex)(size_t k, size_t nr, size_t kr, size_t bl); + + void (*pack_func_ex)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, + size_t rhs_stride, const void * rhs, const void * bias, const void * scale, void * rhs_packed, size_t extra_bytes, const void * params); }; struct ggml_kleidiai_kernels { - kernel_info gemm; + kernel_info gemm; lhs_packing_info gemm_lhs_info; - kernel_info gemv; + kernel_info gemv; lhs_packing_info gemv_lhs_info; rhs_packing_info rhs_info; diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 44691e5dfd..8b3df7d780 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #if defined(__linux__) #include #include @@ -87,40 +88,6 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { return tensor->ne[dim]; } -template -constexpr bool variant_any_invocable_impl(std::index_sequence) { - using V = std::remove_reference_t; - return (std::is_invocable_r_v< - Ret, - std::variant_alternative_t, - Args...> || ...); -} - -template -constexpr bool variant_any_invocable_v = - variant_any_invocable_impl( - std::make_index_sequence< - std::variant_size_v>>{}); - -template -static inline Ret variant_call(Variant && var, Args&&... args) { - static_assert(variant_any_invocable_v, Ret, Args...>, - "No alternative in Variant is invocable with the provided arguments and return type."); - - return std::visit( - [&](auto && f) -> Ret { - using F = std::decay_t; - if constexpr (std::is_invocable_r_v) { - return std::invoke(std::forward(f), std::forward(args)...); - } else { - GGML_ABORT("Invalid function type in variant_call"); - GGML_UNREACHABLE(); - } - }, - std::forward(var) - ); -} - namespace ggml::cpu::kleidiai { static size_t round_down(size_t x, size_t y) { @@ -145,7 +112,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { return false; } ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op); - GGML_ASSERT(kernels); + if (!kernels) { + return false; + } bool is_gemv = op->src[1]->ne[1] == 1; kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; @@ -159,16 +128,18 @@ class tensor_traits : public ggml::cpu::tensor_traits { size_t sr = kernel->get_sr(); if (kernels->rhs_type == GGML_TYPE_Q4_0) { - size = variant_call(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr); + if (!lhs_info->packed_size_ex) return false; + size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr); } else if (kernels->rhs_type == GGML_TYPE_F16) { + if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false; const int64_t lhs_batch_size0 = op->src[1]->ne[2]; const int64_t rhs_batch_size0 = op->src[0]->ne[2]; const int64_t r = lhs_batch_size0 / rhs_batch_size0; - size = variant_call(lhs_info->packed_size, m * r, k, mr, kr, sr) + - variant_call(kernels->rhs_info.packed_size, n, k) + + size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) + + kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) + k * n * sizeof(float) + n * sizeof(float); } else { - GGML_ASSERT(false); + return false; } return true; @@ -196,12 +167,18 @@ class tensor_traits : public ggml::cpu::tensor_traits { GGML_TENSOR_BINARY_OP_LOCALS ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); - GGML_ASSERT(kernels); + if (!kernels) { + return false; + } const bool is_gemv = src1->ne[1] == 1; kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; GGML_ASSERT(kernel); + if (!kernels->rhs_info.pack_func_ex || + !kernel->get_lhs_offset_ex || !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex) { + return false; + } const int nth = params->nth; const int ith = params->ith; @@ -228,10 +205,10 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t kr = (int64_t) kernel->get_kr(); const int64_t sr = (int64_t) kernel->get_sr(); - const size_t lhs_packed_size = variant_call(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); - const size_t rhs_packed_size = variant_call(kernels->rhs_info.packed_size, (size_t)n, (size_t)k); - const size_t kxn_size = (size_t)k * (size_t)n * sizeof(float); - const size_t bias_size = (size_t)n * sizeof(float); + const size_t lhs_packed_size = lhs_info->packed_size_ex(m, k, 0, mr, kr, sr); + const size_t rhs_packed_size = kernels->rhs_info.packed_size_ex(n, k, nr, kr, 0); + const size_t kxn_size = k * n * sizeof(float); + const size_t bias_size = n * sizeof(float); const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size; GGML_ASSERT(wsize_required <= params->wsize); @@ -259,10 +236,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; // Base packed offset (aligned) and per-row stride in bytes - const size_t base_packed_off = variant_call( - lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); - const size_t next_block_off = variant_call( - lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); + const size_t base_packed_off = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr); + const size_t next_block_off = lhs_info->get_packed_offset_ex(m_start + mr, k, 0, mr, kr, sr); const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr; int64_t remaining = m_count; @@ -278,9 +253,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes; void * dst_ptr = lhs_packed + dst_off; - variant_call(lhs_info->pack_func, - (size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr, - /*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr); + lhs_info->pack_func_ex(take, k, 0, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr); cur += take; remaining -= take; @@ -296,10 +269,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { reinterpret_cast(rhs_batch_base), rhs_stride); - variant_call(kernels->rhs_info.pack_func, - /*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr, - /*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)), - rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr); + kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, n * sizeof(float), + rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr); } ggml_barrier(params->threadpool); @@ -320,20 +291,15 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0; // LHS packed base at row 0 (consistent with packing above) - const size_t lhs_packed_offset0 = variant_call( - lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); - const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k); - const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride); + const size_t lhs_packed_offset0 = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr); + const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0); + const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride); const void * lhs_ptr = lhs_packed + lhs_packed_offset0; const void * rhs_ptr = rhs_packed + rhs_packed_offset; float * dst_ptr = reinterpret_cast(dst_batch_base + dst_offset); - variant_call(kernel->run_kernel, - (size_t)m, (size_t)n_to_process, (size_t)k, - lhs_ptr, rhs_ptr, - dst_ptr, dst_stride, sizeof(float), - -FLT_MAX, FLT_MAX); + kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); } } @@ -354,13 +320,19 @@ class tensor_traits : public ggml::cpu::tensor_traits { GGML_TENSOR_BINARY_OP_LOCALS ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); - GGML_ASSERT(kernels); + if (!kernels) { + return false; + } bool is_gemv = src1->ne[1] == 1; kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; GGML_ASSERT(kernel); + if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex || + !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) { + return false; + } const int ith = params->ith; const int nth_raw = params->nth; @@ -402,25 +374,26 @@ class tensor_traits : public ggml::cpu::tensor_traits { // Transform LHS const size_t src_stride = src1->nb[1]; const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); - const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr); + const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr); void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); - variant_call(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); + // Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer + lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); } ggml_barrier(params->threadpool); // Perform the operation const size_t dst_stride = dst->nb[1]; - const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr); - const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, n_start, k, QK4_0); + const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr); + const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0); const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); if (n_to_process > 0) { - variant_call(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, + kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); } @@ -429,7 +402,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0); - GGML_ASSERT(ctx.kernels); + if (!ctx.kernels) { + return false; + } const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -438,6 +413,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { rhs_packing_info * rhs_info = &ctx.kernels->rhs_info; kernel_info * kernel = &ctx.kernels->gemm; + if (!rhs_info->to_float || !kernel->get_nr) { + return false; + } const int64_t nc = ne00; const int64_t nr = ggml_nelements(src1); @@ -480,7 +458,7 @@ public: struct kai_rhs_pack_qs4cxs1s0_param params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; - variant_call(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms); + ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, ¶ms); return 0; GGML_UNUSED(data_size); @@ -548,7 +526,7 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_ const size_t nr = ctx.kernels->gemm.get_nr(); const size_t kr = ctx.kernels->gemm.get_kr(); - return variant_call(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0); + return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0); GGML_UNUSED(buft); } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6275c8305a..1c43865ff6 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3467,31 +3467,27 @@ static void ggml_compute_forward_norm_f32( GGML_ASSERT(eps >= 0.0f); - // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ith; i01 < ne01; i01 += nth) { const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - ggml_float sum = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (ggml_float)x[i00]; - } - + float sum = 0.0; + ggml_vec_sum_f32(ne00, &sum, x); float mean = sum/ne00; float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + float variance = 0; - ggml_float sum2 = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - float v = x[i00] - mean; - y[i00] = v; - sum2 += (ggml_float)(v*v); - } +#ifdef GGML_USE_ACCELERATE + mean = -mean; + vDSP_vsadd(x, 1, &mean, y, 1, ne00); + vDSP_measqv(y, 1, &variance, ne00); +#else + variance = ggml_vec_cvar_f32(ne00, y, x, mean); +#endif //GGML_USE_ACCELERATE - float variance = sum2/ne00; const float scale = 1.0f/sqrtf(variance + eps); - ggml_vec_scale_f32(ne00, y, scale); } } @@ -8135,7 +8131,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( } // V /= S - const float S_inv = 1.0f/S; + const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; ggml_vec_scale_f32(DV, VKQ32, S_inv); // dst indices diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 437192d525..b8e37052d3 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -404,6 +404,72 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * } } +ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) { + int i = 0; + ggml_float sum = 0; +// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE +// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344 +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + __m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i), + _mm512_set1_ps(mean)); + _mm512_storeu_ps(y + i, val); + sum += (ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val)); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + __m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i), + _mm256_set1_ps(mean)); + _mm256_storeu_ps(y + i, val); + val = _mm256_mul_ps(val,val); + __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1), + _mm256_castps256_ps128(val)); + val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2)); + val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2)); + sum += (ggml_float)_mm_cvtss_f32(val2); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + __m128 val = _mm_sub_ps(_mm_loadu_ps(x + i), + _mm_set1_ps(mean)); + _mm_storeu_ps(y + i, val); + val = _mm_mul_ps(val, val); +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) + val = _mm_add_ps(val, _mm_movehl_ps(val, val)); + val = _mm_add_ss(val, _mm_movehdup_ps(val)); +#else + __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1)); + val = _mm_add_ps(val, tmp); + tmp = _mm_movehl_ps(tmp, val); + val = _mm_add_ss(val, tmp); +#endif // __AVX__ || __AVX2__ || __AVX512F__ + sum += (ggml_float)_mm_cvtss_f32(val); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + float32x4_t val = vsubq_f32(vld1q_f32(x + i), + vdupq_n_f32(mean)); + vst1q_f32(y + i, val); + val = vmulq_f32(val, val); + sum += (ggml_float)vaddvq_f32(val); + } +#elif defined(__VXE__) || defined(__VXE2__) + for (; i + 3 < n; i += 4) { + float32x4_t val = vec_sub(vec_xl(0, x + i), vec_splats(mean)); + vec_xst(val, 0, y + i); + val = vec_mul(val, val); + sum += (ggml_float)vec_hsum_f32x4(val); + } +#endif + for (; i < n; ++i) { + float val = x[i] - mean; + val *= val; + sum += (ggml_float)val; + y[i] = val; + } + return sum/n; +} + ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { int i = 0; ggml_float sum = 0; diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 341e64e64f..2751359ce4 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -44,6 +44,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc); void ggml_vec_silu_f32(const int n, float * y, const float * x); +ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean); //it will also center y ( y = y - mean ) ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max); ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max); @@ -654,11 +655,11 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { } // leftovers // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only - if (np < n) { - svbool_t pg = svwhilelt_b32(np, n); - ay1 = svld1_f32(pg, y + np); + for (int i = np; i < n; i += ggml_f32_epr) { + svbool_t pg = svwhilelt_b32(i, n); + ay1 = svld1_f32(pg, y + i); ay1 = svmul_f32_m(pg, ay1, vx); - svst1_f32(pg, y + np, ay1); + svst1_f32(pg, y + i, ay1); } #elif defined(__riscv_v_intrinsic) for (int i = 0, avl; i < n; i += avl) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index d7736d3610..0c8e7b3e41 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -208,6 +208,12 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const const int cc = ggml_cuda_info().devices[device].cc; + // TODO: temporary until support is extended + // https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206 + if (K->ne[1] % FATTN_KQ_STRIDE != 0) { + return BEST_FATTN_KERNEL_NONE; + } + switch (K->ne[0]) { case 64: case 128: diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 26e72bbc2b..fb691528b7 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -231,7 +231,7 @@ static ggml_cuda_device_info ggml_cuda_init() { info.default_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; - info.devices[id].integrated = prop.integrated; + info.devices[id].integrated = false; // Temporarily disabled due to issues with corrupted output (e.g. #15034) info.devices[id].nsm = prop.multiProcessorCount; info.devices[id].smpb = prop.sharedMemPerBlock; info.devices[id].warp_size = prop.warpSize; diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 819f31c8a3..e23abdda97 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -338,7 +338,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar char base[256]; char name[256]; - snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); + const char * suffix = ""; + + if (op->src[1]->ne[0] % 4 == 0) { + suffix = "_4"; + } + + snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix); snprintf(name, 256, "%s", base); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); @@ -352,15 +358,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar } ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + char base[256]; char name[256]; - if (op->src[3]->ne[0] == 1) { - snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type)); - } else { - snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); - } - snprintf(name, 256, "%s", base); + const int nsg = (ne00 + 31)/32; + + snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s_nsg=%d", base, nsg); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); if (res) { @@ -369,7 +375,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg); return res; } @@ -918,6 +924,96 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + bool has_mask, + int32_t ncpsg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + GGML_UNUSED(op); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_%s", + "flash_attn_ext_pad"); + + snprintf(name, 256, "%s_mask=%d_ncpsg=%d", + base, + has_mask, + ncpsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23); + //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + int32_t nqptg, + int32_t ncpsg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + GGML_UNUSED(op); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_%s", + "flash_attn_ext_blk"); + + snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d", + base, + nqptg, + ncpsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23); + ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const ggml_tensor * op, @@ -925,6 +1021,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); @@ -937,18 +1034,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; + // do bounds checks for the mask? + const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0); + snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", "flash_attn_ext", ggml_type_name(op->src[1]->type), dk, dv); - snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d", + snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d", base, has_mask, has_sinks, has_bias, has_scap, + has_kvpad, + bc_mask, ns10, ns20, nsg); @@ -964,6 +1066,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1); ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2); ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4); + + ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10); ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20); ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21); @@ -983,6 +1088,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg, int32_t nwg) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); @@ -1002,12 +1108,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( dk, dv); - snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d", + snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d", base, has_mask, has_sinks, has_bias, has_scap, + has_kvpad, ns10, ns20, nsg, nwg); @@ -1023,6 +1130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1); ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2); ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4); ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20); ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index f6ebf90a00..1034e4bbf6 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -135,6 +135,18 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + bool has_mask, + int32_t ncpsg); + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + int32_t nqptg, + int32_t ncpsg); + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const struct ggml_tensor * op, @@ -142,6 +154,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( @@ -151,6 +164,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg, int32_t nwg); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 523f9d71ba..9527973015 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -776,9 +776,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te }; } case GGML_OP_GET_ROWS: - { - return op->ne[3] == 1; - } + return true; case GGML_OP_SET_ROWS: { if (op->src[0]->type != GGML_TYPE_F32) { diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 88c98423eb..c9dff87305 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -69,11 +69,20 @@ #define N_SG_IQ4_XS 2 // function constants offsets -#define FC_FLASH_ATTN_EXT 100 -#define FC_FLASH_ATTN_EXT_VEC 200 -#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300 -#define FC_MUL_MV 400 -#define FC_MUL_MM 500 +#define FC_FLASH_ATTN_EXT_PAD 100 +#define FC_FLASH_ATTN_EXT_BLK 200 +#define FC_FLASH_ATTN_EXT 300 +#define FC_FLASH_ATTN_EXT_VEC 400 +#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500 +#define FC_MUL_MV 600 +#define FC_MUL_MM 700 + +// op-specific constants +#define OP_FLASH_ATTN_EXT_NQPTG 8 +#define OP_FLASH_ATTN_EXT_NCPSG 64 + +#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1 +#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32 // kernel argument structs // @@ -178,6 +187,7 @@ typedef struct { } ggml_metal_kargs_clamp; typedef struct { + int64_t nk0; int64_t ne00; int64_t ne01; int64_t ne02; @@ -243,6 +253,35 @@ typedef struct { int32_t sect_3; } ggml_metal_kargs_rope; +typedef struct { + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; +} ggml_metal_kargs_flash_attn_ext_pad; + +typedef struct { + int32_t ne01; + int32_t ne30; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; +} ggml_metal_kargs_flash_attn_ext_blk; + typedef struct { int32_t ne01; int32_t ne02; @@ -261,6 +300,7 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne31; int32_t ne32; int32_t ne33; uint64_t nb31; @@ -295,6 +335,7 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne31; int32_t ne32; int32_t ne33; uint64_t nb31; @@ -572,32 +613,45 @@ typedef struct { int64_t n_seq_tokens; int64_t n_seqs; uint64_t s_off; + uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; + uint64_t nb10; uint64_t nb11; uint64_t nb12; + uint64_t ns12; uint64_t nb13; + uint64_t nb20; uint64_t nb21; + uint64_t ns21; uint64_t nb22; + int64_t ne30; uint64_t nb31; uint64_t nb41; uint64_t nb42; + uint64_t ns42; uint64_t nb43; uint64_t nb51; uint64_t nb52; + uint64_t ns52; uint64_t nb53; + uint64_t nb0; } ggml_metal_kargs_ssm_scan; typedef struct { - int64_t ne00; + int32_t ne00t; + int32_t ne00; uint64_t nb01; uint64_t nb02; - int64_t ne10; + uint64_t nb03; + int32_t ne10; uint64_t nb10; uint64_t nb11; + uint64_t nb12; uint64_t nb1; uint64_t nb2; + uint64_t nb3; } ggml_metal_kargs_get_rows; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e85a223c01..1137e21077 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -226,6 +226,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb); GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne); GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb); + GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb); + GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb); GGML_TENSOR_LOCALS( int64_t, ne, node, ne); GGML_TENSOR_LOCALS(uint64_t, nb, node, nb); @@ -237,6 +241,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, ggml_is_contiguous(node->src[1]), node->src[1]->name); } + if (node->src[2]) { + GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23, + ggml_is_contiguous(node->src[2]), node->src[2]->name); + } + if (node->src[3]) { + GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33, + ggml_is_contiguous(node->src[3]), node->src[3]->name); + } if (node) { GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, node->name); @@ -577,6 +589,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); ggml_metal_kargs_cpy args = { + /*.nk0 =*/ ne00, /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, @@ -906,23 +919,31 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); ggml_metal_kargs_get_rows args = { - /*.ne00 =*/ ne00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.ne10 =*/ ne10, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, + /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00, + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, }; + const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const int nw0 = (args.ne00t + nth - 1)/nth; + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1); return 1; } @@ -1117,7 +1138,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); @@ -1172,25 +1193,36 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { /*.n_seq_tokens =*/ n_seq_tokens, /*.n_seqs =*/ n_seqs, /*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float), + /*.nb00 =*/ nb00, /*.nb01 =*/ nb01, /*.nb02 =*/ nb02, /*.nb03 =*/ nb03, + /*.nb10 =*/ nb10, /*.nb11 =*/ nb11, /*.nb12 =*/ nb12, + /*.ns12 =*/ nb12/nb10, /*.nb13 =*/ nb13, + /*.nb20 =*/ nb20, /*.nb21 =*/ nb21, + /*.ns21 =*/ nb21/nb20, /*.nb22 =*/ nb22, + /*.ne30 =*/ ne30, /*.nb31 =*/ nb31, /*.nb41 =*/ nb41, /*.nb42 =*/ nb42, + /*.ns42 =*/ nb42/nb40, /*.nb43 =*/ nb43, /*.nb51 =*/ nb51, /*.nb52 =*/ nb52, + /*.ns52 =*/ nb52/nb50, /*.nb53 =*/ nb53, + /*.nb0 =*/ nb0, }; ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); + GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + const size_t sms = ggml_metal_pipeline_get_smem(pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline); @@ -1206,13 +1238,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0); - if (ne30 == 1) { - // Mamba-2 - ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); - } else { - GGML_ASSERT(d_inner == 1); - ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1); - } + ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); return 1; } @@ -1273,26 +1299,23 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0); - // TODO: support - //const int32_t nk00 = ne00/ggml_blck_size(op->type); - const int32_t nk00 = ne00; - - int nth = 32; // SIMD width - - while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { - nth *= 2; + int64_t nk0 = ne00; + if (ggml_is_quantized(op->src[0]->type)) { + nk0 = ne00/16; + } else if (ggml_is_quantized(op->type)) { + nk0 = ne00/ggml_blck_size(op->type); } - nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); // when rows are small, we can batch them together in a single threadgroup int nrptg = 1; // TODO: relax this constraint in the future if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) { - if (nth > nk00) { - nrptg = (nth + nk00 - 1)/nk00; - nth = nk00; + if (nth > nk0) { + nrptg = (nth + nk0 - 1)/nk0; + nth = nk0; if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nrptg--; @@ -1300,10 +1323,11 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { } } - nth = std::min(nth, nk00); + nth = std::min(nth, nk0); ggml_metal_kargs_cpy args = { - /*.ne00 =*/ nk00, + /*.nk0 =*/ nk0, + /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, /*.ne03 =*/ ne03, @@ -1321,12 +1345,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { /*.nb3 =*/ nb3, }; + const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1; + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1); return 1; } @@ -1875,20 +1901,107 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) { return (ne01 < 20) && (ne00 % 32 == 0); } +size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + + size_t res = 0; + + const bool has_mask = op->src[3] != nullptr; + + if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0; + + if (has_kvpad) { + res += OP_FLASH_ATTN_EXT_VEC_NCPSG*( + nb11*ne12*ne13 + + nb21*ne22*ne23 + + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); + } + } else { + const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0; + + if (has_kvpad) { + res += OP_FLASH_ATTN_EXT_NCPSG*( + nb11*ne12*ne13 + + nb21*ne22*ne23 + + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); + } + } + + return res; +} + +size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + //GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + + size_t res = 0; + + const bool has_mask = op->src[3] != nullptr; + + if (!has_mask) { + return res; + } + + const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op); + + // this optimization is not useful for the vector kernels + if (is_vec) { + return res; + } + + const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG; + const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG; + + const int64_t ne1 = (ne01 + nqptg - 1)/nqptg; + const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg; + + res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32); + + return res; +} + size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); - const int64_t nwg = 32; + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); - const int64_t ne01 = op->src[0]->ne[1]; - const int64_t ne02 = op->src[0]->ne[2]; - const int64_t ne03 = op->src[0]->ne[3]; - const int64_t ne20 = op->src[2]->ne[0]; + size_t res = 0; - // temp buffer for writing the results from each workgroup - // - ne20: the size of the Value head - // - + 2: the S and M values for each intermediate result - return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2)); + if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + const int64_t nwg = 32; + + // temp buffer for writing the results from each workgroup + // - ne20: the size of the Value head + // - + 2: the S and M values for each intermediate result + res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2)); + } + + return res; } int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { @@ -1910,8 +2023,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS( int32_t, nb, op, nb); - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ne11 % 32 == 0); + GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(op->src[1]->type == op->src[2]->type); @@ -1921,8 +2033,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ne12 == ne22); GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16); - GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) && - "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] && + "the Flash-Attention Metal kernel requires the mask to be at least n_queries big"); float scale; float max_bias; @@ -1949,15 +2061,111 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ne01 < 65536); + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]); + ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0; + ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0; + + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_buffer_id bid_pad = bid_dst; + bid_pad.offs += ggml_nbytes(op); + + ggml_metal_buffer_id bid_blk = bid_pad; + bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op); + + ggml_metal_buffer_id bid_tmp = bid_blk; + bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op); + if (!ggml_metal_op_flash_attn_ext_use_vec(op)) { // half8x8 kernel - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !! + const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup + const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); + bool need_sync = false; + + const bool has_kvpad = ne11 % ncpsg != 0; + + if (has_kvpad) { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0); + + ggml_metal_kargs_flash_attn_ext_pad args0 = { + /*.ne11 =*/ne11, + /*.ne_12_2 =*/ne12, + /*.ne_12_3 =*/ne13, + /*.nb11 =*/nb11, + /*.nb12 =*/nb12, + /*.nb13 =*/nb13, + /*.nb21 =*/nb21, + /*.nb22 =*/nb22, + /*.nb23 =*/nb23, + /*.ne31 =*/ne31, + /*.ne32 =*/ne32, + /*.ne33 =*/ne33, + /*.nb31 =*/nb31, + /*.nb32 =*/nb32, + /*.nb33 =*/nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_src2, 2); + ggml_metal_encoder_set_buffer (enc, bid_src3, 3); + ggml_metal_encoder_set_buffer (enc, bid_pad, 4); + + assert(ne12 == ne22); + assert(ne13 == ne23); + + ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); + + need_sync = true; + } else { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0); + } + + if (has_mask) { + assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0); + + ggml_metal_kargs_flash_attn_ext_blk args0 = { + /*.ne01 =*/ ne01, + /*.ne30 =*/ ne30, + /*.ne31 =*/ ne31, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src3, 1); + ggml_metal_encoder_set_buffer (enc, bid_blk, 2); + + const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg); + const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg); + + ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1); + + need_sync = true; + } else { + assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0); + } + + if (need_sync) { + ggml_metal_op_concurrency_reset(ctx); + } + const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0; // 2*(2*ncpsg) @@ -2007,6 +2215,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb21 =*/ nb21, /*.nb22 =*/ nb22, /*.nb23 =*/ nb23, + /*.ne31 =*/ ne31, /*.ne32 =*/ ne32, /*.ne33 =*/ ne33, /*.nb31 =*/ nb31, @@ -2023,24 +2232,18 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg); + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); - if (op->src[3]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4); - } - if (op->src[4]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5); - } - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_src2, 3); + ggml_metal_encoder_set_buffer (enc, bid_src3, 4); + ggml_metal_encoder_set_buffer (enc, bid_src4, 5); + ggml_metal_encoder_set_buffer (enc, bid_pad, 6); + ggml_metal_encoder_set_buffer (enc, bid_blk, 7); + ggml_metal_encoder_set_buffer (enc, bid_dst, 8); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); @@ -2048,14 +2251,62 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { #undef FATTN_SMEM } else { // half4x4 kernel - const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - const int64_t nkpsg = 1*ncpsg; + const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup + const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !! + const int nkpsg = 1*ncpsg; GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 1 == 0); GGML_ASSERT(ncpsg % 32 == 0); + bool need_sync = false; + + const bool has_kvpad = ne11 % ncpsg != 0; + + if (has_kvpad) { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0); + + ggml_metal_kargs_flash_attn_ext_pad args0 = { + /*.ne11 =*/ne11, + /*.ne_12_2 =*/ne12, + /*.ne_12_3 =*/ne13, + /*.nb11 =*/nb11, + /*.nb12 =*/nb12, + /*.nb13 =*/nb13, + /*.nb21 =*/nb21, + /*.nb22 =*/nb22, + /*.nb23 =*/nb23, + /*.ne31 =*/ne31, + /*.ne32 =*/ne32, + /*.ne33 =*/ne33, + /*.nb31 =*/nb31, + /*.nb32 =*/nb32, + /*.nb33 =*/nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_src2, 2); + ggml_metal_encoder_set_buffer (enc, bid_src3, 3); + ggml_metal_encoder_set_buffer (enc, bid_pad, 4); + + assert(ne12 == ne22); + assert(ne13 == ne23); + + ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); + + need_sync = true; + } else { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0); + } + + if (need_sync) { + ggml_metal_op_concurrency_reset(ctx); + } + // ne00 + 2*ncpsg*(nsg) // for each query, we load it as f16 in shared memory (ne00) // and store the soft_max values and the mask @@ -2120,6 +2371,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb21 =*/ nb21, /*.nb22 =*/ nb22, /*.nb23 =*/ nb23, + /*.ne31 =*/ ne31, /*.ne32 =*/ ne32, /*.ne33 =*/ ne33, /*.nb31 =*/ nb31, @@ -2136,25 +2388,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg); + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); - if (op->src[3]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4); - } - if (op->src[4]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5); - } + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_src2, 3); + ggml_metal_encoder_set_buffer (enc, bid_src3, 4); + ggml_metal_encoder_set_buffer (enc, bid_src4, 5); const size_t smem = FATTN_SMEM(nsg); @@ -2162,23 +2406,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size); if (nwg == 1) { + assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0); + // using 1 workgroup -> write the result directly into dst - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6); + ggml_metal_encoder_set_buffer(enc, bid_pad, 6); + ggml_metal_encoder_set_buffer(enc, bid_dst, 7); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); } else { // sanity checks + assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0); + GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3); GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31)); - ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); - // write the results from each workgroup into a temp buffer - ggml_metal_buffer_id bid_tmp = bid_dst; - bid_tmp.offs += ggml_nbytes(op); - ggml_metal_encoder_set_buffer(enc, bid_tmp, 6); + ggml_metal_encoder_set_buffer(enc, bid_pad, 6); + ggml_metal_encoder_set_buffer(enc, bid_tmp, 7); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 8df4c72e7c..d4cb944621 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -39,6 +39,8 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op); // return true if we should use the FA vector kernel for this op bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op); +size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op); +size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op); size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op); int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index e11555a78f..7afc881fa7 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -193,9 +193,9 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ } break; case GGML_OP_FLASH_ATTN_EXT: { - if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) { - res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor); - } + res += ggml_metal_op_flash_attn_ext_extra_pad(tensor); + res += ggml_metal_op_flash_attn_ext_extra_blk(tensor); + res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor); } break; default: break; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 96df6f0ce6..45d91def88 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2032,7 +2032,38 @@ kernel void kernel_ssm_conv_f32_f32( x[0] = sumf; } -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part +kernel void kernel_ssm_conv_f32_f32_4( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i2 = tgpig.y; + const int64_t i3 = tgpig.z; + + const int64_t nc = args.ne10; + //const int64_t ncs = args.ne00; + //const int64_t nr = args.ne01; + //const int64_t n_t = args.ne1; + //const int64_t n_s = args.ne2; + + device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11); + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc/4; ++i0) { + sumf += dot(s[i0], c[i0]); + } + + x[0] = sumf; +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part kernel void kernel_ssm_scan_f32( constant ggml_metal_kargs_ssm_scan & args, device const void * src0, @@ -2044,219 +2075,88 @@ kernel void kernel_ssm_scan_f32( device const void * src6, device float * dst, threadgroup float * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgptg[[simdgroups_per_threadgroup]], - uint3 tgpg[[threadgroups_per_grid]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + constexpr short NW = N_SIMDWIDTH; - const int64_t i0 = tpitg.x; - const int64_t i1 = 0; - const int64_t ir = tgpig.x; // current head - const int64_t i3 = tgpig.y; // current seq + shared[tpitg.x] = 0.0f; - const uint64_t nb00 = sizeof(float); - const uint64_t nb10 = sizeof(float); - const uint64_t nb20 = sizeof(float); + const int32_t i0 = tpitg.x; + const int32_t i1 = tgpig.x; + const int32_t ir = tgpig.y; // current head + const int32_t i3 = tgpig.z; // current seq - const int64_t nc = args.d_state; - const int64_t nr = args.d_inner; - const int64_t nh = args.n_head; - const int64_t ng = args.n_group; - const int64_t n_t = args.n_seq_tokens; + const int32_t nc = args.d_state; + const int32_t nr = args.d_inner; + const int32_t nh = args.n_head; + const int32_t ng = args.n_group; + const int32_t n_t = args.n_seq_tokens; - const int64_t s_off = args.s_off; + const int32_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); - const int64_t i = i0 + i1*nc; - const int64_t g = ir / (nh / ng); // repeat_interleave + + const int32_t i = i0 + i1*nc; + const int32_t g = ir / (nh / ng); // repeat_interleave + float s0 = s0_buff[i]; - float s = s_buff[i]; + float s = 0.0f; - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); - device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); - device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); - device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh} - for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} - device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} + const float A0 = A[i0%args.ne30]; - const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - const float x_dt = x[0] * dt_soft_plus; + device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns} + device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns} - const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); - s = state; + device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns} - // Parallel sum: This relies on the fact that this kernel will be - // dispatched with each threadgroup having (d_state, 1, 1) threads which - // are subdivided into SIMD groups of size `sgptg`. The goal is to - // compute y = sum({state * C[i] for i in range(d_state)}). - // To parallelize this effectively, we first use simd_sum over each SIMD - // group to compute the sum of each SIMD group, then place the result in - // the SIMD group's indexed bucket in the shared memory. We then sum - // over the individual group sums to compute the final sum. - - // Computed for each thread - float sumf = state * C[i0]; - - // Sum the threads in the simd group => simd sum - sumf = simd_sum(sumf); - - if (sgptg > 1) { - - // Once per simd group, place the group sum into the shared buffer - if (tiisg == 0) { - shared[sgitg] = sumf; - } - - // Wait for all threads in the threadgroup to reach this point. This - // ensures that all elements of the shared buffer are populated with the - // sum of the individual simd groups. - threadgroup_barrier(mem_flags::mem_threadgroup); - - // For simd group 0 at indices < num simd groups, extract the shared - // simd sum - sumf = 0.0f; - if (sgitg == 0) { - if (tiisg < sgptg) { - sumf = shared[tiisg]; - } - sumf = simd_sum(sumf); - if (tiisg == 0) { - y[0] = sumf; - } - } - } else if (tiisg == 0) { - y[0] = sumf; - } - - // recurse - s0 = s; - } - - // Assign the final state to the output buffer - s_buff[i] = s; -} - -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part -kernel void kernel_ssm_scan_group_f32( - constant ggml_metal_kargs_ssm_scan & args, - device const void * src0, - device const void * src1, - device const void * src2, - device const void * src3, - device const void * src4, - device const void * src5, - device const void * src6, - device float * dst, - threadgroup float * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgptg[[simdgroups_per_threadgroup]], - uint3 tgpg[[threadgroups_per_grid]]) { - - const int64_t i0 = tpitg.x; - const int64_t i1 = tgpig.x; - const int64_t ir = tgpig.y; // current head - const int64_t i3 = tgpig.z; // current seq - - const uint64_t nb00 = sizeof(float); - const uint64_t nb10 = sizeof(float); - const uint64_t nb20 = sizeof(float); - - const int64_t nc = args.d_state; - const int64_t nr = args.d_inner; - const int64_t nh = args.n_head; - const int64_t ng = args.n_group; - const int64_t n_t = args.n_seq_tokens; - - const int64_t s_off = args.s_off; - - device const int32_t * ids = (device const int32_t *) src6; - - device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); - device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); - const int64_t i = i0 + i1*nc; - const int64_t g = ir / (nh / ng); // repeat_interleave - float s0 = s0_buff[i]; - float s = s_buff[i]; - - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} - device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); - device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); - device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); - - for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} - device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} - - const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - const float x_dt = x[0] * dt_soft_plus; - const float dA = exp(dt_soft_plus * A[0]); - - const float state = (s0 * dA) + (B[i0] * x_dt); - s = state; - - // Parallel sum: This relies on the fact that this kernel will be - // dispatched with each threadgroup having (d_state, 1, 1) threads which - // are subdivided into SIMD groups of size `sgptg`. The goal is to - // compute y = sum({state * C[i] for i in range(d_state)}). - // To parallelize this effectively, we first use simd_sum over each SIMD - // group to compute the sum of each SIMD group, then place the result in - // the SIMD group's indexed bucket in the shared memory. We then sum - // over the individual group sums to compute the final sum. - - // Computed for each thread - float sumf = state * C[i0]; - - // Sum the threads in the simd group => simd sum - sumf = simd_sum(sumf); - - // Once per simd group, place the group sum into the shared buffer - if (tiisg == 0) { - shared[sgitg] = sumf; - } - - // Wait for all threads in the threadgroup to reach this point. This - // ensures that all elements of the shared buffer are populated with the - // sum of the individual simd groups. + for (int i2 = 0; i2 < n_t; i2 += sgptg) { threadgroup_barrier(mem_flags::mem_threadgroup); - // For simd group 0 at indices < num simd groups, extract the shared - // simd sum - sumf = 0.0f; - if (sgitg == 0) { - if (tiisg < sgptg) { - sumf = shared[tiisg]; - } - sumf = simd_sum(sumf); + for (int t = 0; t < sgptg && i2 + t < n_t; t++) { + const float dt0 = dt[0]; + const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0; + const float x_dt = x[0] * dtsp; + const float dA = exp(dtsp * A0); + + s = (s0 * dA) + (B[i0] * x_dt); + + const float sumf = simd_sum(s * C[i0]); + if (tiisg == 0) { - y[0] = sumf; + shared[t*NW + sgitg] = sumf; } + + // recurse + s0 = s; + + x += args.ns12; + dt += args.ns21; + B += args.ns42; + C += args.ns52; } - // recurse - s0 = s; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float sumf = simd_sum(shared[sgitg*NW + tiisg]); + + if (tiisg == 0 && i2 + sgitg < n_t) { + y[sgitg*nh*nr] = sumf; + } + + y += sgptg*nh*nr; } - // Assign the final state to the output buffer s_buff[i] = s; } @@ -4449,10 +4349,142 @@ kernel void kernel_leaky_relu_f32_4( dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope); } +constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]]; + +constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]]; + +// pad the last chunk of C elements of k and v into a an extra pad buffer +kernel void kernel_flash_attn_ext_pad( + constant ggml_metal_kargs_flash_attn_ext_pad & args, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int32_t C = FC_flash_attn_ext_pad_ncpsg; + + device char * k_pad = dst; + device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3; + device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const int32_t icp = args.ne11 % C; + const int32_t ic0 = args.ne11 - icp; + + const int32_t i1 = tgpig[0]; + const int32_t i2 = tgpig[1]; + const int32_t i3 = tgpig[2]; + + if (i2 < args.ne_12_2 && i3 < args.ne_12_3) { + device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3; + device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3; + + device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3; + device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3; + + if (i1 >= icp) { + // here it is not important the exact value that will be used as we rely on masking out the scores in the attention + for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { + k_dst[i] = 0; + } + for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { + v_dst[i] = 0; + } + } else { + for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { + k_dst[i] = k_src[i]; + } + for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { + v_dst[i] = v_src[i]; + } + } + } + + if (FC_flash_attn_ext_pad_has_mask) { + if (i2 < args.ne32 && i3 < args.ne33) { + for (int ib = i1; ib < args.ne31; ib += C) { + device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0; + device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3; + + for (int i = tiitg; i < C; i += ntg.x) { + if (i >= icp) { + mask_dst[i] = -MAXHALF; + } else { + mask_dst[i] = mask_src[i]; + } + } + } + } + } +} + +constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]]; +constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]]; + +// scan the blocks of the mask that are not masked +// 0 - masked (i.e. full of -INF, skip) +// 1 - not masked (i.e. at least one element of the mask is not -INF) +kernel void kernel_flash_attn_ext_blk( + constant ggml_metal_kargs_flash_attn_ext_blk & args, + device const char * mask, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + // block size C x Q + const int32_t Q = FC_flash_attn_ext_blk_nqptg; + const int32_t C = FC_flash_attn_ext_blk_ncpsg; + + constexpr short NW = N_SIMDWIDTH; + + const int32_t i3 = tgpig[2]/args.ne32; + const int32_t i2 = tgpig[2]%args.ne32; + const int32_t i1 = tgpig[1]; + const int32_t i0 = tgpig[0]; + + char res = i0*C + C > args.ne30 ? 1 : 0; + + device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg; + + // fast route + if (res == 0) { + if (simd_max(*mask_src) > -MAXHALF/2) { + res = 1; + } + } + + // detailed check of the elements of the block + if ((C > NW || Q > 1) && res == 0) { + half m = -MAXHALF; + + FOR_UNROLL (short j = 0; j < Q; ++j) { + FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) { + m = max(m, mask_src[ii*NW]); + } + + mask_src += args.nb31/2; + } + + if (simd_max(m) > -MAXHALF/2) { + res = 1; + } + } + + const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); + const int32_t nblk0 = ((args.ne30 + C - 1)/C); + + if (tiisg == 0) { + dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res; + } +} + constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]]; constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]]; constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]]; constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]]; +constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]]; + +constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; //constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; //constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]]; @@ -4499,6 +4531,8 @@ void kernel_flash_attn_ext_impl( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, + device const char * blk, device char * dst, threadgroup half * shmem_f16, uint3 tgpig, @@ -4564,6 +4598,13 @@ void kernel_flash_attn_ext_impl( pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); } + { + const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); + const int32_t nblk0 = ((args.ne11 + C - 1)/C); + + blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0; + } + { q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; @@ -4623,16 +4664,75 @@ void kernel_flash_attn_ext_impl( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic = 0; ic < args.ne11; ic += C) { + for (int ic0 = 0; ; ++ic0) { + int ic = ic0*C; + if (ic >= args.ne11) { + break; + } + + // the last partial chunk uses the pad buffer as source + if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) { + k = pad; + v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; + mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; + v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; + + if (!FC_flash_attn_ext_has_mask) { + threadgroup half * sm = (threadgroup half *) (sm2); + + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + for (short i = tiisg; i < C; i += NW) { + if (ic + i >= args.ne11) { + sm[2*j*SH + i] = -MAXHALF; + } + } + } + } else { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + pm2[jj] = (device const half2 *) ((device const half *) mask + + (iq1 + j)*C + + (iq2%args.ne32)*(C*args.ne31) + + (iq3%args.ne33)*(C*args.ne31*args.ne32)); + } + } + + ic = 0; + } + // read the mask into shared mem if (FC_flash_attn_ext_has_mask) { + if (blk[ic0] == 0) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + pm2[jj] += NW; + } + + continue; + } + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { const short j = jj*NSG + sgitg; - sm2[j*SH + tiisg] = pm2[jj][tiisg]; + if (FC_flash_attn_ext_bc_mask) { + sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF); + } else { + sm2[j*SH + tiisg] = pm2[jj][tiisg]; + } + pm2[jj] += NW; } +#if 0 + // note: old -INF block optimization - obsoleted by pre-computing non-masked blocks + threadgroup_barrier(mem_flags::mem_threadgroup); // used to detect blocks full of -INF @@ -4651,13 +4751,14 @@ void kernel_flash_attn_ext_impl( continue; } +#endif } // Q*K^T // this is compile-time check, so it does not have runtime overhead if (is_same::value) { // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11); + device const k_t * pk = (device const k_t *) (k + ic*args.nb11); threadgroup const q_t * pq = sq; threadgroup s_t * ps = ss; @@ -4668,26 +4769,24 @@ void kernel_flash_attn_ext_impl( constexpr short NC = (C/8)/NSG; - // TODO: not good to unroll for large contexts - not sure why? + // note: do not unroll for large heads + #pragma unroll (DK <= 64 ? NC : 1) for (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); - if (DK8 % 16 != 0) { + if (DK % 16 != 0) { k8x8_t mk; q8x8_t mq; FOR_UNROLL (short i = 0; i < DK8; ++i) { simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk, pk, NS10, 0, true); - simdgroup_load(mq, pq, DK); + simdgroup_load(mk, pk + 8*i, NS10, 0, true); + simdgroup_load(mq, pq + 8*i, DK); simdgroup_barrier(mem_flags::mem_none); simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - - pk += 8; - pq += 8; } } else { k8x8_t mk[2]; @@ -4696,26 +4795,22 @@ void kernel_flash_attn_ext_impl( FOR_UNROLL (short i = 0; i < DK8/2; ++i) { simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk[0], pk + 0*8, NS10, 0, true); - simdgroup_load(mk[1], pk + 1*8, NS10, 0, true); + simdgroup_load(mq[0], pq + 0*8 + 16*i, DK); + simdgroup_load(mq[1], pq + 1*8 + 16*i, DK); - simdgroup_load(mq[0], pq + 0*8, DK); - simdgroup_load(mq[1], pq + 1*8, DK); + simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true); + simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true); simdgroup_barrier(mem_flags::mem_none); simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk); simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk); - - pk += 16; - pq += 16; } } simdgroup_store(mqk, ps, SH, 0, false); - pk += 8*(NSG*NS10 - DK8); - pq += 8*(NSG*0 - DK8); + pk += 8*(NSG*NS10); ps += 8*(NSG); } } else { @@ -4729,7 +4824,7 @@ void kernel_flash_attn_ext_impl( qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); for (short ii = 0; ii < DK16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11)); + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11)); if (DK16%4 == 0) { // the head is evenly divisible by 4*16 = 64, so no need for bound checks @@ -4849,27 +4944,50 @@ void kernel_flash_attn_ext_impl( } { - auto sst = ss; - - device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21); + device const v_t * pv = (device const v_t *) (v + ic*args.nb21); pv += 8*sgitg; - FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { - s8x8_t vs; - simdgroup_load(vs, sst, SH, 0, false); + if (DV <= 64) { + FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, ss + 8*cc, SH, 0, false); - FOR_UNROLL (short ii = 0; ii < NO; ++ii) { - v8x8_t mv; + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[2]; - simdgroup_load(mv, pv, NS20, 0, false); - simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]); + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false); - pv += 8*NSG; + simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]); + } + + pv += 8*NS20; } + } else { + FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) { + s8x8_t vs[2]; - pv += 8*(NS20 - NO*NSG); - sst += 8; + simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false); + simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false); + + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[4]; + + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]); + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]); + } + + pv += 2*8*NS20; + } } } @@ -4893,7 +5011,7 @@ void kernel_flash_attn_ext_impl( simdgroup_load(vs, ss + 8*cc, SH, 0, false); for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21)); + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21)); if (DV16%4 == 0) { // no need for bound checks @@ -4983,7 +5101,7 @@ void kernel_flash_attn_ext_impl( device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; - const float scale = 1.0f/S[jj]; + const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj]; if (DV4 % NW == 0) { FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { @@ -5028,8 +5146,8 @@ template< void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), short DK, // K head size short DV, // V head size - short Q = 8, // queries per threadgroup - short C = 64> // cache items per threadgroup + short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup kernel void kernel_flash_attn_ext( constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, @@ -5037,13 +5155,15 @@ kernel void kernel_flash_attn_ext( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, + device const char * blk, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { #define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg +#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg switch (FC_flash_attn_ext_nsg) { // note: disabled cases to reduce library load time //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; @@ -5163,6 +5283,7 @@ constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_ constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]]; constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]]; constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]]; +constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]]; //constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]]; //constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]]; @@ -5189,9 +5310,9 @@ template< void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), short DK, // K head size short DV, // V head size - short NE = 4, // head elements per thread - short Q = 1, // queries per threadgroup - short C = 32, // cache items per threadgroup + short NE, // head elements per thread + short Q, // queries per threadgroup + short C, // cache items per threadgroup short NSG> // number of simd groups void kernel_flash_attn_ext_vec_impl( constant ggml_metal_kargs_flash_attn_ext_vec & args, @@ -5200,6 +5321,7 @@ void kernel_flash_attn_ext_vec_impl( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -5305,12 +5427,38 @@ void kernel_flash_attn_ext_vec_impl( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) { - const int ic = ic0 + C*sgitg; + for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) { + int ic = ic0*C; if (ic >= args.ne11) { break; } + // the last partial chunk uses the pad buffer as source + if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) { + k = pad; + v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; + mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; + v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; + + if (!FC_flash_attn_ext_vec_has_mask) { + if (ic + tiisg >= args.ne11) { + sm[tiisg] = -MAXHALF; + } + } else { + pm = (device const half *) (mask) + + iq1*C + + (iq2%args.ne32)*(C*args.ne31) + + (iq3%args.ne33)*(C*args.ne31*args.ne32); + } + + ic = 0; + } + if (FC_flash_attn_ext_vec_has_mask) { sm[tiisg] = pm[ic + tiisg]; } @@ -5322,7 +5470,7 @@ void kernel_flash_attn_ext_vec_impl( // Q*K^T { - device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11); + device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11); threadgroup const q4_t * pq4 = sq4; pk4 += ty*NS10/4 + tx; @@ -5337,7 +5485,7 @@ void kernel_flash_attn_ext_vec_impl( mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]); } } else { - device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11)); + device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11)); k4_t mk; @@ -5435,7 +5583,7 @@ void kernel_flash_attn_ext_vec_impl( } if (is_same::value) { - device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21); + device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21); pv4 += ty*NS20/4 + tx; @@ -5448,7 +5596,7 @@ void kernel_flash_attn_ext_vec_impl( } } else { FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { - device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21)); + device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21)); FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { const short i = ii*NL + tx; @@ -5573,7 +5721,7 @@ void kernel_flash_attn_ext_vec_impl( device float4 * dst4 = (device float4 *) dst; device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results - const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f; + const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f; // interleave the workgroup data for (short i = tiisg; i < DV4; i += NW) { @@ -5611,8 +5759,8 @@ template< short DK, // K head size short DV, // V head size short NE = 4, // head elements per thread - short Q = 1, // queries per threadgroup - short C = 32> // cache items per threadgroup + short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( constant ggml_metal_kargs_flash_attn_ext_vec & args, device const char * q, @@ -5620,13 +5768,14 @@ kernel void kernel_flash_attn_ext_vec( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { #define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg +#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg switch (FC_flash_attn_ext_vec_nsg) { // note: disabled cases to reduce library load time case 1: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; @@ -5750,7 +5899,8 @@ kernel void kernel_flash_attn_ext_vec_reduce( const float m = simd_max(M); const float ms = exp(M - m); - S = 1.0f/simd_sum(S*ms); + S = simd_sum(S*ms); + S = S == 0.0f ? 0.0f : 1.0f/S; const short DV4 = DV/4; @@ -5770,21 +5920,17 @@ kernel void kernel_flash_attn_ext_vec_reduce( } template -kernel void kernel_cpy( +kernel void kernel_cpy_t_t( constant ggml_metal_kargs_cpy & args, device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 tptg[[threads_per_threadgroup]]) { + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x; - - if (i01 >= args.ne01) { - return; - } + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; @@ -5795,190 +5941,70 @@ kernel void kernel_cpy( device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) { + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) { device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; + break; } } -typedef decltype(kernel_cpy) kernel_cpy_t; +typedef decltype(kernel_cpy_t_t) kernel_cpy_t; -template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t; #endif -template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t; #endif -// TODO: templetify these kernels -kernel void kernel_cpy_f32_q8_0( +template +kernel void kernel_cpy_f32_q( constant ggml_metal_kargs_cpy & args, device const char * src0, - device char * dst, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], + ushort tiitg[[thread_index_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]; + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; - device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00); - quantize_q8_0(src, dst_data[i00/QK8_0]); + quantize_func(src, dst_data[i00]); + + break; } } -kernel void kernel_cpy_f32_q4_0( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; +typedef decltype(kernel_cpy_f32_q) cpy_f_q_t; - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0; - - device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q4_0(src, dst_data[i00/QK4_0]); - } -} - -kernel void kernel_cpy_f32_q4_1( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1; - - device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q4_1(src, dst_data[i00/QK4_1]); - } -} - -kernel void kernel_cpy_f32_q5_0( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0; - - device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q5_0(src, dst_data[i00/QK5_0]); - } -} - -kernel void kernel_cpy_f32_q5_1( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1; - - device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q5_1(src, dst_data[i00/QK5_1]); - } -} - -kernel void kernel_cpy_f32_iq4_nl( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL; - - device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_iq4_nl(src, dst_data[i00/QK4_NL]); - } -} +template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q; template kernel void kernel_cpy_q_f32( @@ -5986,11 +6012,12 @@ kernel void kernel_cpy_q_f32( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], + ushort tiitg[[thread_index_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]; + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; @@ -6002,10 +6029,12 @@ kernel void kernel_cpy_q_f32( device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) { + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { T4x4 temp; dequantize_func(src_data + i00/nl, i00%nl, temp); dst_data[i00] = temp; + + break; } } @@ -7765,66 +7794,60 @@ kernel void kernel_mul_mv_mxfp4_f32( template kernel void kernel_get_rows_q( constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + device const void * src0, + device const void * src1, + device void * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg [[threads_per_threadgroup]]) { + const int32_t iw0 = tgpig.x/args.ne10; + const int32_t i10 = tgpig.x%args.ne10; + const int32_t i11 = tgpig.y; + const int32_t i12 = tgpig.z; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; - const int64_t i02 = i11; + const int32_t i02 = i11; + const int32_t i03 = i12; - for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) { + auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); + auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); + + for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { float4x4 temp; - dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp; + dequantize_func(psrc + ind/nl, ind%nl, temp); + pdst[ind] = temp; + + break; } } -template +template kernel void kernel_get_rows_f( constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + device const void * src0, + device const void * src1, + device void * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg [[threads_per_threadgroup]]) { + const int32_t iw0 = tgpig.x/args.ne10; + const int32_t i10 = tgpig.x%args.ne10; + const int32_t i11 = tgpig.y; + const int32_t i12 = tgpig.z; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; - const int64_t i02 = i11; + const int32_t i02 = i11; + const int32_t i03 = i12; - for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { - (( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = - ((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; - } -} + auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); + auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); -kernel void kernel_get_rows_i32( - constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device int32_t * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { + pdst[ind] = psrc[ind]; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { - (( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = - ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; + break; } } @@ -8310,12 +8333,13 @@ kernel void kernel_mul_mm_id( // get rows // -typedef decltype(kernel_get_rows_f) get_rows_f_t; +typedef decltype(kernel_get_rows_f) get_rows_f_t; -template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; -template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; #endif typedef decltype(kernel_get_rows_q) get_rows_q_t; diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index f99681c84c..aad48d62a8 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -105,9 +105,12 @@ enum rpc_cmd { RPC_CMD_INIT_TENSOR, RPC_CMD_GET_ALLOC_SIZE, RPC_CMD_HELLO, + RPC_CMD_DEVICE_COUNT, RPC_CMD_COUNT, }; +static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14"); + // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold const size_t HASH_THRESHOLD = 10 * 1024 * 1024; @@ -117,7 +120,12 @@ struct rpc_msg_hello_rsp { uint8_t patch; }; +struct rpc_msg_device_count_rsp { + uint32_t device_count; +}; + struct rpc_msg_get_alloc_size_req { + uint32_t device; rpc_tensor tensor; }; @@ -130,6 +138,7 @@ struct rpc_msg_init_tensor_req { }; struct rpc_msg_alloc_buffer_req { + uint32_t device; uint64_t size; }; @@ -138,10 +147,18 @@ struct rpc_msg_alloc_buffer_rsp { uint64_t remote_size; }; +struct rpc_msg_get_alignment_req { + uint32_t device; +}; + struct rpc_msg_get_alignment_rsp { uint64_t alignment; }; +struct rpc_msg_get_max_size_req { + uint32_t device; +}; + struct rpc_msg_get_max_size_rsp { uint64_t max_size; }; @@ -192,6 +209,10 @@ struct rpc_msg_graph_compute_rsp { uint8_t result; }; +struct rpc_msg_get_device_memory_req { + uint32_t device; +}; + struct rpc_msg_get_device_memory_rsp { uint64_t free_mem; uint64_t total_mem; @@ -207,13 +228,15 @@ static ggml_guid_t ggml_backend_rpc_guid() { struct ggml_backend_rpc_buffer_type_context { std::string endpoint; + uint32_t device; std::string name; - size_t alignment; - size_t max_size; + size_t alignment; + size_t max_size; }; struct ggml_backend_rpc_context { std::string endpoint; + uint32_t device; std::string name; }; @@ -608,23 +631,30 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con RPC_STATUS_ASSERT(status); } +static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer; +} + static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { - // check if src and dst are on the same server - ggml_backend_buffer_t src_buffer = src->buffer; - ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context; - ggml_backend_buffer_t dst_buffer = dst->buffer; - ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context; - if (src_ctx->sock != dst_ctx->sock) { - return false; + if (ggml_backend_buffer_is_rpc(src->buffer)) { + // check if src and dst are on the same server + ggml_backend_buffer_t src_buffer = src->buffer; + ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context; + ggml_backend_buffer_t dst_buffer = dst->buffer; + ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context; + if (src_ctx->sock != dst_ctx->sock) { + return false; + } + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_msg_copy_tensor_req request; + request.src = serialize_tensor(src); + request.dst = serialize_tensor(dst); + rpc_msg_copy_tensor_rsp response; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.result; } - ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; - rpc_msg_copy_tensor_req request; - request.src = serialize_tensor(src); - request.dst = serialize_tensor(dst); - rpc_msg_copy_tensor_rsp response; - bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response)); - RPC_STATUS_ASSERT(status); - return response.result; + return false; } static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { @@ -653,7 +683,7 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; - rpc_msg_alloc_buffer_req request = {size}; + rpc_msg_alloc_buffer_req request = {buft_ctx->device, size}; rpc_msg_alloc_buffer_rsp response; auto sock = get_socket(buft_ctx->endpoint); bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response)); @@ -669,9 +699,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back } } -static size_t get_alignment(const std::shared_ptr & sock) { +static size_t get_alignment(const std::shared_ptr & sock, uint32_t device) { + rpc_msg_get_alignment_req request = {device}; rpc_msg_get_alignment_rsp response; - bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response)); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response)); RPC_STATUS_ASSERT(status); return response.alignment; } @@ -681,9 +712,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ return buft_ctx->alignment; } -static size_t get_max_size(const std::shared_ptr & sock) { +static size_t get_max_size(const std::shared_ptr & sock, uint32_t device) { + rpc_msg_get_max_size_req request = {device}; rpc_msg_get_max_size_rsp response; - bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response)); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response)); RPC_STATUS_ASSERT(status); return response.max_size; } @@ -700,7 +732,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty auto sock = get_socket(buft_ctx->endpoint); rpc_msg_get_alloc_size_req request; - + request.device = buft_ctx->device; request.tensor = serialize_tensor(tensor); rpc_msg_get_alloc_size_rsp response; @@ -754,7 +786,7 @@ static void add_tensor(ggml_tensor * tensor, std::vector & tensors, tensors.push_back(serialize_tensor(tensor)); } -static void serialize_graph(const ggml_cgraph * cgraph, std::vector & output) { +static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector & output) { uint32_t n_nodes = cgraph->n_nodes; std::vector tensors; std::unordered_set visited; @@ -762,24 +794,29 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector & o add_tensor(cgraph->nodes[i], tensors, visited); } // serialization format: - // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | uint32_t n_tensors = tensors.size(); - int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor); + int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor); output.resize(output_size, 0); - memcpy(output.data(), &n_nodes, sizeof(n_nodes)); + uint8_t * dest = output.data(); + memcpy(dest, &device, sizeof(device)); + dest += sizeof(device); + memcpy(dest, &n_nodes, sizeof(n_nodes)); + dest += sizeof(n_nodes); for (uint32_t i = 0; i < n_nodes; i++) { - memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t)); + memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t)); } - uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t)); - *out_ntensors = n_tensors; - rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t)); + dest += n_nodes * sizeof(uint64_t); + memcpy(dest, &n_tensors, sizeof(n_tensors)); + dest += sizeof(n_tensors); + rpc_tensor * out_tensors = (rpc_tensor *)dest; memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor)); } static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; std::vector input; - serialize_graph(cgraph, input); + serialize_graph(rpc_ctx->device, cgraph, input); rpc_msg_graph_compute_rsp response; auto sock = get_socket(rpc_ctx->endpoint); bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response)); @@ -804,12 +841,13 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .graph_optimize = */ NULL, }; -ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { +ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) { static std::mutex mutex; std::lock_guard lock(mutex); + std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]"; // NOTE: buffer types are allocated and never freed; this is by design static std::unordered_map buft_map; - auto it = buft_map.find(endpoint); + auto it = buft_map.find(buft_name); if (it != buft_map.end()) { return it->second; } @@ -818,34 +856,37 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { GGML_LOG_ERROR("Failed to connect to %s\n", endpoint); return nullptr; } - size_t alignment = get_alignment(sock); - size_t max_size = get_max_size(sock); + size_t alignment = get_alignment(sock, device); + size_t max_size = get_max_size(sock, device); ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { /* .endpoint = */ endpoint, - /* .name = */ "RPC[" + std::string(endpoint) + "]", + /* .device = */ device, + /* .name = */ buft_name, /* .alignment = */ alignment, /* .max_size = */ max_size }; - + auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { /* .iface = */ ggml_backend_rpc_buffer_type_interface, - /* .device = */ ggml_backend_rpc_add_device(endpoint), + /* .device = */ ggml_backend_reg_dev_get(reg, device), /* .context = */ buft_ctx }; - buft_map[endpoint] = buft; + buft_map[buft_name] = buft; return buft; } -ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { +ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { + std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]"; ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { - /* .endpoint = */ endpoint, - /* .name = */ "RPC[" + std::string(endpoint) + "]", + /* .endpoint = */ endpoint, + /* .device = */ device, + /* .name = */ dev_name }; - + auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_t backend = new ggml_backend { /* .guid = */ ggml_backend_rpc_guid(), /* .iface = */ ggml_backend_rpc_interface, - /* .device = */ ggml_backend_rpc_add_device(endpoint), + /* .device = */ ggml_backend_reg_dev_get(reg, device), /* .context = */ ctx }; return backend; @@ -855,37 +896,39 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) { return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid()); } -static void get_device_memory(const std::shared_ptr & sock, size_t * free, size_t * total) { +static void get_device_memory(const std::shared_ptr & sock, uint32_t device, size_t * free, size_t * total) { + rpc_msg_get_device_memory_req request; + request.device = device; rpc_msg_get_device_memory_rsp response; - bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response)); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response)); RPC_STATUS_ASSERT(status); *free = response.free_mem; *total = response.total_mem; } -void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { +void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) { auto sock = get_socket(endpoint); if (sock == nullptr) { *free = 0; *total = 0; return; } - get_device_memory(sock, free, total); + get_device_memory(sock, device, free, total); } // RPC server-side implementation class rpc_server { public: - rpc_server(ggml_backend_t backend, const char * cache_dir) - : backend(backend), cache_dir(cache_dir) { + rpc_server(std::vector backends, const char * cache_dir) + : backends(std::move(backends)), cache_dir(cache_dir) { } ~rpc_server(); void hello(rpc_msg_hello_rsp & response); - void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); - void get_alignment(rpc_msg_get_alignment_rsp & response); - void get_max_size(rpc_msg_get_max_size_rsp & response); + bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); + bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response); + bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response); bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response); bool free_buffer(const rpc_msg_free_buffer_req & request); bool buffer_clear(const rpc_msg_buffer_clear_req & request); @@ -906,7 +949,7 @@ private: std::unordered_map & tensor_map); - ggml_backend_t backend; + std::vector backends; const char * cache_dir; std::unordered_set buffers; }; @@ -919,6 +962,10 @@ void rpc_server::hello(rpc_msg_hello_rsp & response) { } bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } ggml_backend_buffer_type_t buft; struct ggml_init_params params { /*.mem_size =*/ ggml_tensor_overhead(), @@ -935,10 +982,10 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_ GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n"); return false; } - LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data); + LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data); if (tensor->buffer == nullptr) { //No buffer allocated. - buft = ggml_backend_get_default_buffer_type(backend); + buft = ggml_backend_get_default_buffer_type(backends[dev_id]); } else { buft = tensor->buffer->buft; } @@ -948,33 +995,49 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_ return true; } -void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) { - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); +bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size); response.remote_ptr = 0; response.remote_size = 0; if (buffer != nullptr) { response.remote_ptr = reinterpret_cast(buffer); response.remote_size = buffer->size; - LOG_DBG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size); + LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", + __func__, dev_id, request.size, response.remote_ptr, response.remote_size); buffers.insert(buffer); } else { - LOG_DBG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size); + LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size); } + return true; } -void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) { - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); +bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); size_t alignment = ggml_backend_buft_get_alignment(buft); - LOG_DBG("[%s] alignment: %lu\n", __func__, alignment); + LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment); response.alignment = alignment; + return true; } -void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) { - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); +bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); size_t max_size = ggml_backend_buft_get_max_size(buft); - LOG_DBG("[%s] max_size: %lu\n", __func__, max_size); + LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size); response.max_size = max_size; + return true; } bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) { @@ -1332,23 +1395,33 @@ ggml_tensor * rpc_server::create_node(uint64_t id, bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response) { // serialization format: - // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | - if (input.size() < sizeof(uint32_t)) { + // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + if (input.size() < 2*sizeof(uint32_t)) { + return false; + } + const uint8_t * src = input.data(); + uint32_t device; + memcpy(&device, src, sizeof(device)); + src += sizeof(device); + if (device >= backends.size()) { return false; } uint32_t n_nodes; - memcpy(&n_nodes, input.data(), sizeof(n_nodes)); - if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) { + memcpy(&n_nodes, src, sizeof(n_nodes)); + src += sizeof(n_nodes); + if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) { return false; } - const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes)); + const uint64_t * nodes = (const uint64_t *)src; + src += n_nodes*sizeof(uint64_t); uint32_t n_tensors; - memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors)); - if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) { + memcpy(&n_tensors, src, sizeof(n_tensors)); + src += sizeof(n_tensors); + if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) { return false; } - const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors)); - LOG_DBG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors); + const rpc_tensor * tensors = (const rpc_tensor *)src; + LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors); size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); @@ -1380,7 +1453,7 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph return false; } } - ggml_status status = ggml_backend_graph_compute(backend, graph); + ggml_status status = ggml_backend_graph_compute(backends[device], graph); response.result = status; return true; } @@ -1391,9 +1464,9 @@ rpc_server::~rpc_server() { } } -static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, - sockfd_t sockfd, size_t free_mem, size_t total_mem) { - rpc_server server(backend, cache_dir); +static void rpc_serve_client(const std::vector & backends, const char * cache_dir, + sockfd_t sockfd, const std::vector & free_mem, const std::vector & total_mem) { + rpc_server server(backends, cache_dir); uint8_t cmd; if (!recv_data(sockfd, &cmd, 1)) { return; @@ -1425,13 +1498,26 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, // HELLO command is handled above return; } + case RPC_CMD_DEVICE_COUNT: { + if (!recv_msg(sockfd, nullptr, 0)) { + return; + } + rpc_msg_device_count_rsp response; + response.device_count = backends.size(); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } case RPC_CMD_ALLOC_BUFFER: { rpc_msg_alloc_buffer_req request; if (!recv_msg(sockfd, &request, sizeof(request))) { return; } rpc_msg_alloc_buffer_rsp response; - server.alloc_buffer(request, response); + if (!server.alloc_buffer(request, response)) { + return; + } if (!send_msg(sockfd, &response, sizeof(response))) { return; } @@ -1452,22 +1538,28 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, break; } case RPC_CMD_GET_ALIGNMENT: { - if (!recv_msg(sockfd, nullptr, 0)) { + rpc_msg_get_alignment_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { return; } rpc_msg_get_alignment_rsp response; - server.get_alignment(response); + if (!server.get_alignment(request, response)) { + return; + } if (!send_msg(sockfd, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_MAX_SIZE: { - if (!recv_msg(sockfd, nullptr, 0)) { + rpc_msg_get_max_size_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { return; } rpc_msg_get_max_size_rsp response; - server.get_max_size(response); + if (!server.get_max_size(request, response)) { + return; + } if (!send_msg(sockfd, &response, sizeof(response))) { return; } @@ -1593,12 +1685,19 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, break; } case RPC_CMD_GET_DEVICE_MEMORY: { - if (!recv_msg(sockfd, nullptr, 0)) { + rpc_msg_get_device_memory_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + auto dev_id = request.device; + if (dev_id >= backends.size()) { return; } rpc_msg_get_device_memory_rsp response; - response.free_mem = free_mem; - response.total_mem = total_mem; + response.free_mem = free_mem[dev_id]; + response.total_mem = total_mem[dev_id]; + LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id, + response.free_mem, response.total_mem); if (!send_msg(sockfd, &response, sizeof(response))) { return; } @@ -1612,16 +1711,41 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, } } -void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, - const char * cache_dir, - size_t free_mem, size_t total_mem) { +void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, + size_t n_threads, size_t n_devices, + ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) { + if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) { + fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n"); + return; + } + std::vector backends; + std::vector free_mem_vec(free_mem, free_mem + n_devices); + std::vector total_mem_vec(total_mem, total_mem + n_devices); printf("Starting RPC server v%d.%d.%d\n", RPC_PROTO_MAJOR_VERSION, RPC_PROTO_MINOR_VERSION, RPC_PROTO_PATCH_VERSION); printf(" endpoint : %s\n", endpoint); printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a"); - printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024)); + printf("Devices:\n"); + for (size_t i = 0; i < n_devices; i++) { + auto dev = devices[i]; + printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024); + auto backend = ggml_backend_dev_init(dev, nullptr); + if (!backend) { + fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev)); + return; + } + backends.push_back(backend); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + if (reg) { + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + ggml_backend_set_n_threads_fn(backend, n_threads); + } + } + } std::string host; int port; @@ -1649,22 +1773,27 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint fprintf(stderr, "Failed to accept client connection\n"); return; } - printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); + printf("Accepted client connection\n"); fflush(stdout); - rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem); + rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec); printf("Client connection closed\n"); fflush(stdout); } #ifdef _WIN32 WSACleanup(); #endif + for (auto backend : backends) { + ggml_backend_free(backend); + } } // device interface struct ggml_backend_rpc_device_context { std::string endpoint; + uint32_t device; std::string name; + std::string description; }; static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) { @@ -1676,15 +1805,13 @@ static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) { static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; - return ctx->name.c_str(); + return ctx->description.c_str(); } static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; - ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total); - - GGML_UNUSED(dev); + ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total); } static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) { @@ -1710,7 +1837,7 @@ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggm static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; - return ggml_backend_rpc_init(ctx->endpoint.c_str()); + return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device); GGML_UNUSED(params); } @@ -1718,7 +1845,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; - return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str()); + return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device); GGML_UNUSED(dev); } @@ -1736,7 +1863,7 @@ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_b } ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context; - return buft_ctx->endpoint == dev_ctx->endpoint; + return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device; } static const struct ggml_backend_device_i ggml_backend_rpc_device_i = { @@ -1759,28 +1886,34 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = { // backend reg interface -static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) { - return "RPC"; +struct ggml_backend_rpc_reg_context { + std::string name; + std::vector devices; +}; - GGML_UNUSED(reg); +static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) { + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + return ctx ? ctx->name.c_str() : "RPC"; } static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) { - return 0; - - GGML_UNUSED(reg); + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + return ctx ? ctx->devices.size() : 0; } static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) { - GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead"); - - GGML_UNUSED(reg); - GGML_UNUSED(index); + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + if (ctx == nullptr) { + GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead"); + } else { + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; + } } static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) { - if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) { - return (void *)ggml_backend_rpc_add_device; + if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) { + return (void *)ggml_backend_rpc_add_server; } if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) { return (void *)ggml_backend_rpc_start_server; @@ -1807,30 +1940,61 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) { return &ggml_backend_rpc_reg; } -ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) { - static std::unordered_map dev_map; - - static std::mutex mutex; - std::lock_guard lock(mutex); - - if (dev_map.find(endpoint) != dev_map.end()) { - return dev_map[endpoint]; - } - - ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context { - /* .endpoint = */ endpoint, - /* .name = */ "RPC[" + std::string(endpoint) + "]", - }; - - ggml_backend_dev_t dev = new ggml_backend_device { - /* .iface = */ ggml_backend_rpc_device_i, - /* .reg = */ ggml_backend_rpc_reg(), - /* .context = */ ctx, - }; - - dev_map[endpoint] = dev; - - return dev; +static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) { + auto sock = get_socket(endpoint); + rpc_msg_device_count_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.device_count; } +static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = { + /* .get_name = */ ggml_backend_rpc_reg_get_name, + /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count, + /* .get_device = */ ggml_backend_rpc_reg_get_device, + /* .get_proc_address = */ ggml_backend_rpc_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) { + static std::unordered_map reg_map; + static std::mutex mutex; + static uint32_t dev_id = 0; + std::lock_guard lock(mutex); + if (reg_map.find(endpoint) != reg_map.end()) { + return reg_map[endpoint]; + } + uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint); + if (dev_count == 0) { + return nullptr; + } + ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context; + ctx->name = "RPC[" + std::string(endpoint) + "]"; + for (uint32_t ind = 0; ind < dev_count; ind++) { + std::string dev_name = "RPC" + std::to_string(dev_id); + std::string dev_desc = std::string(endpoint); + ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context { + /* .endpoint = */ endpoint, + /* .device = */ ind, + /* .name = */ dev_name, + /* .description = */ dev_desc + }; + + ggml_backend_dev_t dev = new ggml_backend_device { + /* .iface = */ ggml_backend_rpc_device_i, + /* .reg = */ ggml_backend_rpc_reg(), + /* .context = */ dev_ctx, + }; + ctx->devices.push_back(dev); + dev_id++; + } + ggml_backend_reg_t reg = new ggml_backend_reg { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_rpc_reg_interface, + /* .context = */ ctx + }; + reg_map[endpoint] = reg; + return reg; +} + + GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 4e7449d06e..d66d7ade90 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -197,6 +197,7 @@ struct sycl_device_info { int cc; // compute capability // int nsm; // number of streaming multiprocessors // size_t smpb; // max. shared memory per block + size_t smpbo; // max. shared memory per block (with opt-in) bool vmm; // virtual memory support size_t total_vram; //sycl_hw_info hw_info; \\ device id and aarch, currently not used @@ -416,13 +417,6 @@ static __dpct_inline__ float warp_reduce_sum(float x, const sycl::nd_item<3>& item_ct1) { #pragma unroll for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { - /* - DPCT1096:98: The right-most dimension of the work-group used in the SYCL - kernel that calls this function may be less than "32". The function - "dpct::permute_sub_group_by_xor" may return an unexpected result on the - CPU device. Modify the size of the work-group to ensure that the value - of the right-most dimension is a multiple of "32". - */ x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask); } return x; @@ -440,17 +434,67 @@ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) { return a; } +template +static __dpct_inline__ int warp_reduce_sum(int x) { + return sycl::reduce_over_group( + sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>()); +} + +template +static __dpct_inline__ float warp_reduce_sum(float x) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + x += dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), x, offset, width); + } + return x; +} + +template +static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + a.x() += dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), a.x(), offset, + width); + a.y() += dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), a.y(), offset, + width); + } + return a; +} + +template +static __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + a = a + dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), a, offset, + width); + } + return a; +} + +static constexpr int ggml_sycl_get_physical_warp_size() { + // todo: for old iGPU + dGPU case, need to be changed. + return WARP_SIZE; +} + +template +static __dpct_inline__ float warp_reduce_max(float x) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + x = sycl::fmax(x, dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), x, + offset, width)); + } + return x; +} + static __dpct_inline__ float warp_reduce_max(float x, const sycl::nd_item<3>& item_ct1) { #pragma unroll for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { - /* - DPCT1096:97: The right-most dimension of the work-group used in the SYCL - kernel that calls this function may be less than "32". The function - "dpct::permute_sub_group_by_xor" may return an unexpected result on the - CPU device. Modify the size of the work-group to ensure that the value - of the right-most dimension is a multiple of "32". - */ x = sycl::fmax(x, dpct::permute_sub_group_by_xor( item_ct1.get_sub_group(), x, mask)); } @@ -558,4 +602,18 @@ struct scope_op_debug_print { std::string_view func_suffix; }; +static __dpct_inline__ float get_alibi_slope(const float max_bias, + const uint32_t h, + const uint32_t n_head_log2, + const float m0, + const float m1) { + if (max_bias <= 0.0f) { + return 1.0f; + } + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + return dpct::pow(base, exph); +} + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index d538965b09..f93cfa701f 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -277,6 +277,26 @@ namespace dpct } // namespace detail + // COPY from DPCT head files + /// dim3 is used to store 3 component dimensions. + class dim3 { + public: + unsigned x, y, z; + + constexpr dim3(unsigned x = 1, unsigned y = 1, unsigned z = 1) + : x(x), y(y), z(z) {} + + dim3(const sycl::id<3> &r) : dim3(r[2], r[1], r[0]) {} + + operator sycl::range<3>() const { return sycl::range<3>(z, y, x); } + }; // namespace dim3 + + inline dim3 operator*(const dim3 &a, const dim3 &b) { + return dim3{a.x * b.x, a.y * b.y, a.z * b.z}; + } + // COPY from DPCT head files + + /// Pitched 2D/3D memory data. class pitched_data { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 4ac919ea2d..e4cc3c8ed8 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -87,6 +87,7 @@ static ggml_sycl_device_info ggml_sycl_init() { 100 * prop.get_major_version() + 10 * prop.get_minor_version(); info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu); info.max_work_group_sizes[i] = prop.get_max_work_group_size(); + info.devices[i].smpbo = prop.get_local_mem_size(); } for (int id = 0; id < info.device_count; ++id) { @@ -3741,6 +3742,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_SOFT_MAX: ggml_sycl_op_soft_max(ctx, dst); break; + case GGML_OP_SOFT_MAX_BACK: + ggml_sycl_op_soft_max_back(ctx, dst); + break; case GGML_OP_ROPE: ggml_sycl_rope(ctx, dst); break; @@ -3778,6 +3782,7 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg return true; } catch (sycl::exception & e) { std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::cerr << "Error OP "<op)<< std::endl; std::exit(1); } @@ -4386,19 +4391,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return true; case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; - case GGML_OP_SOFT_MAX: - // TODO: support batching - if (op->src[0]->ne[3] != 1) { - return false; - } - // TODO: support attention sinks [TAG_ATTN_SINKS] - if (op->src[2]) { - return false; - } - // TODO: support broadcast - // ref: https://github.com/ggml-org/llama.cpp/pull/14435 - return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); case GGML_OP_DIAG_MASK_INF: + return true; + case GGML_OP_SOFT_MAX: + return true; + case GGML_OP_SOFT_MAX_BACK: { + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); + return max_bias == 0.0f; + } case GGML_OP_ROPE: case GGML_OP_IM2COL: return true; diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index 52fcf4b3db..83b7c71b66 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -1,37 +1,94 @@ #include "softmax.hpp" +#include +#include +#include -template -static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, - const int nrows_y, const float scale, const float max_bias, const float m0, - const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { - const int ncols = ncols_template == 0 ? ncols_par : ncols_template; - const int tid = item_ct1.get_local_id(2); - const int rowx = item_ct1.get_group(2); - const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension +template static __dpct_inline__ float t2f32(T val) { + return (float) val; +} - const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template; +template <> float __dpct_inline__ t2f32(sycl::half val) { + return sycl::vec(val) + .convert()[0]; +} - const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; - const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; +struct soft_max_params { + + int64_t nheads; + uint32_t n_head_log2; + int64_t ncols; + int64_t nrows_x; + int64_t nrows_y; + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + int64_t nb11; + int64_t nb12; + int64_t nb13; + + int64_t ne12; + int64_t ne13; + float scale; + float max_bias; + float m0; + float m1; +}; + +// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. +// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template +static void soft_max_f32(const float * x, + const T * mask, + const float * sinks, + float * dst, + const soft_max_params p, + uint8_t * dpct_local) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int ncols = ncols_template == 0 ? p.ncols : ncols_template; + const int block_size = block_size_template == 0 + ? item_ct1.get_local_range(2) + : block_size_template; const int nthreads = block_size; const int nwarps = nthreads / WARP_SIZE; size_t nreduce = nwarps / WARP_SIZE; - float slope = 1.0f; - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = rowx/nrows_y; // head index + const int tid = item_ct1.get_local_id(2); - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const int64_t i03 = item_ct1.get_group(0); + const int64_t i02 = item_ct1.get_group(1); + const int64_t i01 = item_ct1.get_group(2); - slope = sycl::pow(base, float(exp)); - } + //TODO: noncontigous inputs/outputs + const int rowx = item_ct1.get_group(2) + + item_ct1.get_group(1) * item_ct1.get_group_range(2) + + item_ct1.get_group(0) * item_ct1.get_group_range(2) * + item_ct1.get_group_range(1); - float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols; - float max_val = -INFINITY; + const int64_t i11 = i01; + const int64_t i12 = i02 % p.ne12; + const int64_t i13 = i03 % p.ne13; + x += int64_t(rowx)*ncols; + mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr); + dst += int64_t(rowx)*ncols; + + const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; + const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; + + const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1); + + float * buf_iw = (float *) dpct_local; + + // shared memory buffer to cache values between iterations: + float *vals = use_shared ? buf_iw + sycl::max(nwarps, WARP_SIZE) : dst; + float max_val = sinks ? sinks[i02] : -INFINITY; +#pragma unroll for (int col0 = 0; col0 < ncols; col0 += block_size) { const int col = col0 + tid; @@ -39,42 +96,35 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int break; } - const int ix = rowx*ncols + col; - const int iy = rowy*ncols + col; - - const float val = x[ix]*scale + (mask ? slope*static_cast(mask[iy]) : 0.0f); + const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f); vals[col] = val; - max_val = sycl::max(max_val, val); + max_val = sycl::max(max_val, val); } - // find the max value in the block - max_val = warp_reduce_max(max_val, item_ct1); + max_val = warp_reduce_max(max_val); + if (block_size > WARP_SIZE) { if (warp_id == 0) { - buf[lane_id] = -INFINITY; - for (size_t i = 1; i < nreduce; i += 1) { - buf[lane_id + i * WARP_SIZE] = -INFINITY; - } + buf_iw[lane_id] = -INFINITY; } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(); if (lane_id == 0) { - buf[warp_id] = max_val; + buf_iw[warp_id] = max_val; } - item_ct1.barrier(sycl::access::fence_space::local_space); - max_val = buf[lane_id]; - for (size_t i = 1; i < nreduce; i += 1) { - max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]); - } - max_val = warp_reduce_max(max_val, item_ct1); - } + item_ct1.barrier(); + + max_val = buf_iw[lane_id]; + max_val = warp_reduce_max(max_val); + } + float tmp = 0.0f; // partial sum - float tmp = 0.f; #pragma unroll for (int col0 = 0; col0 < ncols; col0 += block_size) { const int col = col0 + tid; - if (ncols_template == 0 && col >= ncols) { + + if (ncols_template == 0 && col >= ncols) { break; } @@ -82,32 +132,33 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int tmp += val; vals[col] = val; } - // find the sum of exps in the block - tmp = warp_reduce_sum(tmp, item_ct1); + tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(); if (warp_id == 0) { - buf[lane_id] = 0.f; + buf_iw[lane_id] = 0.0f; for (size_t i = 1; i < nreduce; i += 1) { - buf[lane_id + i * WARP_SIZE] = 0.f; + buf_iw[lane_id + i * WARP_SIZE] = 0.f; } } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(); if (lane_id == 0) { - buf[warp_id] = tmp; + buf_iw[warp_id] = tmp; } - item_ct1.barrier(sycl::access::fence_space::local_space); + item_ct1.barrier(); - tmp = buf[lane_id]; + tmp = buf_iw[lane_id]; for (size_t i = 1; i < nreduce; i += 1) { - tmp += buf[lane_id + i * WARP_SIZE]; + tmp += buf_iw[lane_id + i * WARP_SIZE]; } - tmp = warp_reduce_sum(tmp, item_ct1); + tmp = warp_reduce_sum(tmp); } - - const float inv_sum = 1.f / tmp; + if (sinks) { + tmp += sycl::native::exp(sinks[i02] - max_val); + } + const float inv_sum = 1.0f / tmp; #pragma unroll for (int col0 = 0; col0 < ncols; col0 += block_size) { @@ -117,145 +168,259 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int return; } - const int idst = rowx*ncols + col; - dst[idst] = vals[col] * inv_sum; + dst[col] = vals[col] * inv_sum; + } +} +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ + +static void soft_max_back_f32(const float *grad, const float *dstf, float *dst, + const int ncols, const float scale) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int tid = item_ct1.get_local_id(2); + const int rowx = item_ct1.get_group(2); + + grad += int64_t(rowx)*ncols; + dstf += int64_t(rowx)*ncols; + dst += int64_t(rowx)*ncols; + + float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dgf_dot += dstf[col]*grad[col]; + } + + dgf_dot = warp_reduce_sum(dgf_dot); + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dst[col] = scale * (grad[col] - dgf_dot) * dstf[col]; } } -template -static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par, - const int nrows_y, const float scale, const float max_bias, const float m0, - const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, - const size_t n_local_scratch, queue_ptr stream) { +template +static void launch_soft_max_kernels(const float * x, + const T * mask, + const float * sinks, + float * dst, + const soft_max_params & p, + dpct::queue_ptr stream, + dpct::dim3 block_dims, + dpct::dim3 block_nums, + size_t nbytes_shared) +{ + auto launch_kernel = [=](auto I) -> bool { + constexpr int ncols = decltype(I)::value; + constexpr int block = (ncols > 1024 ? 1024 : ncols); + if (p.ncols == ncols) { + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor dpct_local_acc_ct1( + sycl::range<1>(nbytes_shared), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size( + WARP_SIZE)]] { + soft_max_f32( + x, mask, sinks, dst, p, + dpct_local_acc_ct1 + .get_multi_ptr() + .get()); + GGML_UNUSED(item_ct1); + }); + }); + return true; + } + return false; + }; + + // unary fold over launch_kernel + if ((launch_kernel(std::integral_constant{}) || ...)) { + return; + } + stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor local_buf_acc(n_local_scratch, cgh); + sycl::local_accessor dpct_local_acc_ct1( + sycl::range<1>(nbytes_shared), cgh); cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - soft_max_f32(x, mask, dst, ncols_par, - nrows_y, scale, max_bias, m0, - m1, n_head_log2, item_ct1, - get_pointer(local_buf_acc)); - }); + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + soft_max_f32( + x, mask, sinks, dst, p, + dpct_local_acc_ct1 + .get_multi_ptr() + .get()); + GGML_UNUSED(item_ct1); + }); }); } -template -static void soft_max_f32_sycl(const float * x, const T * mask, - float * dst, const int ncols_x, const int nrows_x, - const int nrows_y, const float scale, const float max_bias, - queue_ptr stream, int device) { +template +static void soft_max_f32_sycl(const float *x, const T *mask, + const float *sinks, float *dst, + const soft_max_params ¶ms, + dpct::queue_ptr stream, int device) { int nth = WARP_SIZE; int max_block_size = ggml_sycl_info().max_work_group_sizes[device]; + const int64_t ncols_x = params.ncols; + while (nth < ncols_x && nth < max_block_size) nth *= 2; if (nth>max_block_size) nth = max_block_size; - const sycl::range<3> block_dims(1, 1, nth); - const sycl::range<3> block_nums(1, 1, nrows_x); - const size_t n_val_tmp = nth / WARP_SIZE; - const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + n_val_tmp); + const dpct::dim3 block_dims(nth, 1, 1); + const dpct::dim3 block_nums(params.ne01, params.ne02, params.ne03); + const size_t nbytes_shared = + (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE) * sizeof(float); - const uint32_t n_head_kv = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const int id = get_current_device_id(); + const size_t smpbo = ggml_sycl_info().devices[id].smpbo; - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - const size_t local_mem_size = stream->get_device().get_info(); - if (n_local_scratch*sizeof(float) < local_mem_size) { - if (ncols_x > max_block_size) { - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - return; - } - switch (ncols_x) { - case 32: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 64: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 128: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 256: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 512: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 1024: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 2048: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - case 4096: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - default: - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, n_local_scratch, stream); - break; - } + if (nbytes_shared <= smpbo) { + launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>( + x, mask, sinks, dst, params, stream, block_dims, block_nums, + nbytes_shared); } else { - soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, - max_bias, m0, m1, n_head_log2, block_nums, - block_dims, WARP_SIZE, stream); + const size_t nbytes_shared_low = WARP_SIZE * sizeof(float); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor dpct_local_acc_ct1( + sycl::range<1>(nbytes_shared_low), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + soft_max_f32( + x, mask, sinks, dst, params, + dpct_local_acc_ct1 + .get_multi_ptr() + .get()); + GGML_UNUSED(item_ct1); + }); + }); } } +static void soft_max_back_f32_sycl(const float * grad, + const float * dstf, + float * dst, + const int ncols, + const int nrows, + const float scale, + dpct::queue_ptr stream) { + const dpct::dim3 block_dims(WARP_SIZE, 1, 1); + const dpct::dim3 block_nums(nrows, 1, 1); + + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + soft_max_back_f32(grad, dstf, dst, ncols, scale); + GGML_UNUSED(item_ct1); + }); +} + void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + const float * src0_d = (const float *) src0->data; + const void * src1_d = src1 ? (const void *) src1->data : nullptr; + const void * src2_d = src2 ? (const void *) src2->data : nullptr; + float * dst_d = (float *) dst->data; + + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional + // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - const int64_t ne00 = dst->src[0]->ne[0]; - const int64_t nrows_x = ggml_nrows(dst->src[0]); - const int64_t nrows_y = dst->src[0]->ne[1]; + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; - float scale = 1.0f; + const int64_t ne00 = src0->ne[0]; + + float scale = 1.0f; float max_bias = 0.0f; - memcpy(&scale, dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, dst->op_params + 1, sizeof(float)); + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - ggml_sycl_set_device(ctx.device); - dpct::queue_ptr main_stream = ctx.stream(); + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; - if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) { - const sycl::half * src1_dd = static_cast(dst->src[1]->data); - soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, - main_stream, ctx.device); - } else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) { - const float * src1_dd = static_cast(dst->src[1]->data); - soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + + soft_max_params params = {}; + params.nheads = src0->ne[2]; + params.n_head_log2 = n_head_log2; + params.ncols = ne00; + params.nrows_x = nrows_x; + params.nrows_y = nrows_y; + params.ne00 = src0->ne[0]; + params.ne01 = src0->ne[1]; + params.ne02 = src0->ne[2]; + params.ne03 = src0->ne[3]; + params.nb11 = nb11; + params.nb12 = nb12; + params.nb13 = nb13; + params.ne12 = ne12; + params.ne13 = ne13; + params.scale = scale; + params.max_bias = max_bias; + params.m0 = m0; + params.m1 = m1; + + if (use_f16) { + soft_max_f32_sycl(src0_d, (const sycl::half *)src1_d, + (const float *)src2_d, dst_d, params, stream, + ctx.device); } else { - /* mask unavailable */ - soft_max_f32_sycl(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + soft_max_f32_sycl(src0_d, (const float *)src1_d, (const float *)src2_d, + dst_d, params, stream, ctx.device); } } + +void ggml_sycl_op_soft_max_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + const ggml_tensor * src0 = dst->src[0]; // grad + const ggml_tensor * src1 = dst->src[1]; // forward pass output + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + + GGML_ASSERT(max_bias == 0.0f); + + soft_max_back_f32_sycl(src0_d, src1_d, dst_d, ncols, nrows, scale, stream); +} diff --git a/ggml/src/ggml-sycl/softmax.hpp b/ggml/src/ggml-sycl/softmax.hpp index 2cf8582ec9..23f1e5a9d6 100644 --- a/ggml/src/ggml-sycl/softmax.hpp +++ b/ggml/src/ggml-sycl/softmax.hpp @@ -15,6 +15,10 @@ #include "common.hpp" +#define SYCL_SOFT_MAX_BLOCK_SIZE 1024 + void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst); +void ggml_sycl_op_soft_max_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + #endif // GGML_SYCL_SOFTMAX_HPP diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index b97e7bf995..83a83887b5 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -1,5 +1,6 @@ cmake_minimum_required(VERSION 3.19) cmake_policy(SET CMP0114 NEW) +cmake_policy(SET CMP0116 NEW) find_package(Vulkan COMPONENTS glslc REQUIRED) @@ -54,25 +55,25 @@ if (Vulkan_FOUND) # Test all shader extensions test_shader_extension_support( "GL_KHR_cooperative_matrix" - "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat.comp" "GGML_VULKAN_COOPMAT_GLSLC_SUPPORT" ) test_shader_extension_support( "GL_NV_cooperative_matrix2" - "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2.comp" "GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT" ) test_shader_extension_support( "GL_EXT_integer_dot_product" - "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp" "GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT" ) test_shader_extension_support( "GL_EXT_bfloat16" - "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/bfloat16.comp" "GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT" ) @@ -160,7 +161,6 @@ if (Vulkan_FOUND) set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$") set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}") set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp") - set (_ggml_vk_source "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp") set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders") set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv") @@ -176,24 +176,35 @@ if (Vulkan_FOUND) add_custom_command( OUTPUT ${_ggml_vk_header} - ${_ggml_vk_source} - COMMAND ${_ggml_vk_genshaders_cmd} - --glslc ${Vulkan_GLSLC_EXECUTABLE} - --input-dir ${_ggml_vk_input_dir} --output-dir ${_ggml_vk_output_dir} --target-hpp ${_ggml_vk_header} - --target-cpp ${_ggml_vk_source} - --no-clean - - DEPENDS ${_ggml_vk_shader_files} - ${_ggml_vk_shaders_gen_sources} + DEPENDS ${_ggml_vk_shaders_gen_sources} vulkan-shaders-gen - - COMMENT "Generate vulkan shaders" + COMMENT "Generate vulkan shaders header" ) + target_sources(ggml-vulkan PRIVATE ${_ggml_vk_header}) - target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header}) + foreach (file_full ${_ggml_vk_shader_files}) + get_filename_component(file ${file_full} NAME) + set (_ggml_vk_target_cpp "${CMAKE_CURRENT_BINARY_DIR}/${file}.cpp") + + add_custom_command( + OUTPUT ${_ggml_vk_target_cpp} + DEPFILE ${_ggml_vk_target_cpp}.d + COMMAND ${_ggml_vk_genshaders_cmd} + --glslc ${Vulkan_GLSLC_EXECUTABLE} + --source ${file_full} + --output-dir ${_ggml_vk_output_dir} + --target-hpp ${_ggml_vk_header} + --target-cpp ${_ggml_vk_target_cpp} + DEPENDS ${file_full} + ${_ggml_vk_shaders_gen_sources} + vulkan-shaders-gen + COMMENT "Generate vulkan shaders for ${file}" + ) + target_sources(ggml-vulkan PRIVATE ${_ggml_vk_target_cpp}) + endforeach() else() message(WARNING "Vulkan not found") diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp index d896f1ef0b..5084a70ed4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp index 00cf2dd62f..3bcfe6908e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp @@ -6,8 +6,8 @@ #extension GL_KHR_shader_subgroup_basic : enable #endif -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" const uint num_threads = 256; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp index 3ae8f0116c..495249d5f6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp @@ -2,7 +2,7 @@ #extension GL_EXT_control_flow_attributes : require -#include "types.comp" +#include "types.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp index a1d4c240dd..7c12877671 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp index dc53a401e0..c81b84452e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_control_flow_attributes : enable -#include "types.comp" +#include "types.glsl" layout(constant_id = 0) const int BLOCK_SIZE = 1024; layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp index 1e5cb8dae4..653431895e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp b/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp index 9ee2f1fae2..e404698382 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp index 6567a8c54c..ca1a3ac25b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" #extension GL_EXT_control_flow_attributes : require diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp index 938c74da50..70a301488e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp @@ -1,6 +1,6 @@ #version 450 -#include "types.comp" +#include "types.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index 44a64ddc80..0367e80bbf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -11,7 +11,7 @@ # extension GL_KHR_shader_subgroup_shuffle : enable #endif -#include "types.comp" +#include "types.glsl" // shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j layout(binding = 0) readonly buffer A { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp index b17b4e83ee..5217e18bdd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp @@ -1,6 +1,6 @@ #version 450 -#include "types.comp" +#include "types.glsl" layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin] layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin] diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp index f476a2e3dd..9f8bfd3c18 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp index 978d430030..06df509525 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -1,8 +1,8 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" -#include "dequant_funcs.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" +#include "dequant_funcs.glsl" #if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) // 16 invocations needed for init_iq_shmem diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index bc2e1f2df3..b8c40eec10 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -1,7 +1,7 @@ #version 450 -#include "rte.comp" -#include "types.comp" +#include "rte.glsl" +#include "types.glsl" #if defined(SET_ROWS) && QUANT_K == 1 layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; @@ -14,7 +14,7 @@ const uint BLOCK_SIZE = 32; layout (binding = 0) readonly buffer S {float data_s[];}; #if defined(SET_ROWS) -#include "generic_binary_head.comp" +#include "generic_binary_head.glsl" layout (binding = 1) readonly buffer C {B_TYPE data_i[];}; layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];}; @@ -25,7 +25,7 @@ layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];}; #endif #else -#include "generic_unary_head.comp" +#include "generic_unary_head.glsl" layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];}; #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp index 0b8d02f58f..db6865db98 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp b/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp index d9345497c7..e75df66756 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp @@ -2,8 +2,8 @@ #extension GL_EXT_control_flow_attributes : enable -#include "types.comp" -#include "generic_head.comp" +#include "types.glsl" +#include "generic_head.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp index a4d3fca556..765afffa80 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl similarity index 99% rename from ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 73fef4fa65..0d98f5a9d6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require #endif -#include "types.comp" +#include "types.glsl" #if defined(A_TYPE_PACKED16) layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl similarity index 99% rename from ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 706540fd85..6a5bb4574d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -1,5 +1,5 @@ -#include "types.comp" +#include "types.glsl" layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { block_q4_0_packed16 block; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl similarity index 91% rename from ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl index 8d806435b7..addceafade 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl @@ -10,4 +10,4 @@ layout (push_constant) uniform parameter uint nel; } p; -#include "types.comp" +#include "types.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp index b604c1881a..637c95fa35 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp index fd1e4e30d2..d1cbc5e9d0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp index 127c7b6424..78490162cd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp index a08331c40d..9b8ce0a7f8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp index 0ae9acd02a..aacf07d0f8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp index e4f42be94c..f2c20b1d2c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp index 19c7fdeefc..671c1f4a0d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp index 46d9ad15eb..8f7833eab2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp index f930852a48..a313699775 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp index ee496e9d56..ffba5a77dd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp index d4e4e6bae6..58dc2e5dfd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp index 3661f771c7..0c90be8b4e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp index 4081853272..b92b292135 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp index 2f27eee686..6b63cbe583 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp index 1370db3654..8b7be557e9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp index b20b805292..f1b0bac872 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp index dc59fe3b77..c495b31f17 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp index 3f3b839e11..6bc04670fc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp index 9cf34256e8..c8d6fcb49f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp index bd1344a88d..10844ddf78 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp index 26d8bc22ad..9cef8a8ec3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp @@ -10,7 +10,7 @@ layout (push_constant) uniform parameter uint n_past; } p; -#include "types.comp" +#include "types.glsl" layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/div.comp b/ggml/src/ggml-vulkan/vulkan-shaders/div.comp index 9fb69c6c15..572472f8a9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/div.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" const uint num_threads = 256; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp index a3941372a7..b69d4ddb09 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp @@ -1,8 +1,8 @@ #version 450 -#include "rte.comp" -#include "generic_head.comp" -#include "types.comp" +#include "rte.glsl" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/bfloat16.comp similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/bfloat16.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat.comp similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2.comp similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/integer_dot.comp similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/integer_dot.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index e42475026a..62acbf107a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -8,8 +8,8 @@ #extension GL_KHR_shader_subgroup_shuffle : enable -#include "types.comp" -#include "flash_attn_base.comp" +#include "types.glsl" +#include "flash_attn_base.glsl" const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 0507df2d89..2066a05b34 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -10,8 +10,8 @@ #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable -#include "types.comp" -#include "flash_attn_base.comp" +#include "types.glsl" +#include "flash_attn_base.glsl" const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index a65553a481..910da1ab0c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -16,9 +16,9 @@ #extension GL_KHR_shader_subgroup_vote : enable #extension GL_EXT_null_initializer : enable -#include "types.comp" -#include "dequant_funcs_cm2.comp" -#include "flash_attn_base.comp" +#include "types.glsl" +#include "dequant_funcs_cm2.glsl" +#include "flash_attn_base.glsl" layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp index f4268ed24f..e017b50368 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp @@ -1,6 +1,6 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" const float GELU_COEF_A = 0.044715f; const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; @@ -10,4 +10,4 @@ float op(float a, float b) { return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp index cbd4cb36bf..759a1848fa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp @@ -1,6 +1,6 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation // ref: https://www.johndcook.com/blog/python_erf/ @@ -24,4 +24,4 @@ float op(float a, float b) { return 0.5f * a * (1.0f + erf_approx) * b; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp index 3a2a6897bf..c4032ab21d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp @@ -1,6 +1,6 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" const float GELU_QUICK_COEF = -1.702f; @@ -8,4 +8,4 @@ float op(float a, float b) { return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp index 4cc7a68ca1..a95c2525c8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp index 5fd5a5e703..58375aba09 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp index e6e6fcfd20..bfdfe2182d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl similarity index 97% rename from ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl index 750e785753..99595fc688 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl @@ -1,8 +1,8 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_control_flow_attributes : require -#include "rte.comp" -#include "utils.comp" +#include "rte.glsl" +#include "utils.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp index 7ef75cd7a4..76d83041ce 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp index 339f905fc7..9dba437edb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -2,9 +2,9 @@ #extension GL_EXT_control_flow_attributes : enable -#include "types.comp" -#include "generic_binary_head.comp" -#include "dequant_funcs.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" +#include "dequant_funcs.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl similarity index 95% rename from ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl index 51d70869d9..2168989340 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl @@ -1,6 +1,6 @@ #extension GL_EXT_shader_16bit_storage : require -#include "rte.comp" +#include "rte.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp index b6a0d56454..bdf97dbb5d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp b/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp index 1da252cc66..b4dbdf3141 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp b/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp index 3afc588274..1ec315915e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index f0f19a019c..1827d647a2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -3,9 +3,8 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_control_flow_attributes : require -#include "rte.comp" - -#include "types.comp" +#include "rte.glsl" +#include "types.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp index 9faa636ac2..4bf8b4ca04 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp @@ -4,9 +4,8 @@ #extension GL_EXT_control_flow_attributes : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "rte.comp" - -#include "types.comp" +#include "rte.glsl" +#include "types.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp index deba8c3985..83ef2f8795 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp index d90a99aea5..b281e855cb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp index 43de19df8e..02ef1eace1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" const uint num_threads = 256; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index bb429dd594..9a03925cfd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl similarity index 99% rename from ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index f761391eae..450dee0408 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -11,7 +11,7 @@ #define EXPERT_COUNT 8 #endif -#include "types.comp" +#include "types.glsl" #ifndef MMQ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; @@ -32,7 +32,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; layout (binding = 3) readonly buffer IDS {int data_ids[];}; #endif -#include "dequant_funcs.comp" +#include "dequant_funcs.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp index e4acbd4f96..4cb292380c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp index 309da0991a..0b74b33212 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp index 8d01536fa6..e424af12c5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp index c496043241..0cd906dbbf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp index 94d4b92e1e..71bd72d17e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp index f021e40476..a4b9ab1f94 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp index 3fe9dc3a41..40849c691f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 423ceb8a3d..03ed25d3bf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index e91724a28d..528f224d86 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index f9cde06488..21d07d2e50 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 6c84ef3cde..9e46c89a11 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index d53d9ee0a2..d7a7f6426e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp index 8fb314fa0a..64293f6eca 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -6,13 +6,13 @@ #define MMQ #define B_TYPE block_q8_1_x4 -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; #define K_PER_ITER 8 -#include "mul_mmq_funcs.comp" +#include "mul_mmq_funcs.glsl" uint a_offset, b_offset, d_offset; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 3cb24412d5..85400ac5fc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -28,7 +28,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #endif -#include "types.comp" +#include "types.glsl" #ifndef LOAD_VEC_A #define LOAD_VEC_A 1 @@ -195,7 +195,7 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif -#include "mul_mm_funcs.comp" +#include "mul_mm_funcs.glsl" void main() { #ifdef NEEDS_INIT_IQ_SHMEM diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 0e3065e014..2e04baa44e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -18,8 +18,8 @@ #extension GL_EXT_bfloat16 : enable #endif -#include "types.comp" -#include "utils.comp" +#include "types.glsl" +#include "utils.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; @@ -71,7 +71,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #if QUANT_K > 1 #define DECODEFUNCA , dequantFuncA -#include "dequant_funcs_cm2.comp" +#include "dequant_funcs_cm2.glsl" #else #define DECODEFUNCA diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index f36add62a9..b5d761c0ba 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -20,7 +20,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #endif -#include "types.comp" +#include "types.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; @@ -110,7 +110,7 @@ shared u16vec2 row_ids[4096]; shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif -#include "mul_mmq_funcs.comp" +#include "mul_mmq_funcs.glsl" void main() { #ifdef NEEDS_INIT_IQ_SHMEM diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl similarity index 99% rename from ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index cdfb230f4e..fe71eb131c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require -#include "types.comp" +#include "types.glsl" // Each iqs value maps to a 32-bit integer diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp index 854a2ad818..1e8f694a72 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp @@ -8,9 +8,9 @@ #extension GL_KHR_shader_subgroup_basic : enable #endif -#include "rte.comp" -#include "types.comp" -#include "utils.comp" +#include "rte.glsl" +#include "types.glsl" +#include "utils.glsl" layout (push_constant) uniform parameter2 { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp index 6627a50bd9..cc3ea0b760 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp index e0214fe764..1f05f922cc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp index 6426dedee5..1251f9cc64 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp @@ -1,6 +1,6 @@ #version 450 -#include "generic_head.comp" +#include "generic_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp index 0d81220c71..f3c8176872 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp @@ -1,6 +1,6 @@ #version 450 -#include "types.comp" +#include "types.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp index b6124411a0..d9d7166e36 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp @@ -1,6 +1,6 @@ #version 450 -#include "types.comp" +#include "types.glsl" #extension GL_EXT_shader_16bit_storage : require diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp index 145c9fbdc9..0f3c6ca871 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp @@ -17,7 +17,7 @@ layout (push_constant) uniform parameter uint ne; } p; -#include "types.comp" +#include "types.glsl" layout(constant_id = 0) const uint GROUP_SIZE = 32; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp index 0073d8f766..86be2669a1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp @@ -1,9 +1,9 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" float op(float a, float b) { return max(a, 0.0f) * b; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp index 4f806270c7..5725cef236 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp b/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp index 1568b141de..8f4b9a8684 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp index d86279934f..87df782944 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index 41197e9301..d5b211ffaa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_binary_head.comp" -#include "types.comp" +#include "generic_binary_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp index 76009f3df6..87707fc149 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp index ba4677c293..4618b2c7e8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_binary_head.comp" -#include "types.comp" +#include "generic_binary_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #extension GL_KHR_shader_subgroup_arithmetic : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp b/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp index b9abe8dedc..68fbd0c7be 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl similarity index 97% rename from ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl index 00e203e73b..50fc1f1e2d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl @@ -1,8 +1,8 @@ -#include "types.comp" +#include "types.glsl" #extension GL_EXT_shader_16bit_storage : require -#include "rte.comp" +#include "rte.glsl" layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index 5808710ccf..111286b498 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -1,6 +1,6 @@ #version 450 -#include "rope_head.comp" +#include "rope_head.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index 366a7b1c47..06e095bef9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -1,6 +1,6 @@ #version 450 -#include "rope_head.comp" +#include "rope_head.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index 9643bca96a..6ba9575409 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -1,6 +1,6 @@ #version 450 -#include "rope_head.comp" +#include "rope_head.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp index cedacc4d14..d37d1c1043 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -1,6 +1,6 @@ #version 450 -#include "rope_head.comp" +#include "rope_head.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/rte.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp index f10b0a02b5..35ec726a01 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" const uint num_threads = 128; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp index 5c9e5c3503..32298d43c6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp index 4d36f88e08..7d1cc6f45a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp index f9afa9b13c..e5d949ff18 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp index d7c15a1695..61f17b2f00 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp index 5f20a1ee7d..dca0d896bc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp @@ -23,7 +23,7 @@ layout (push_constant) uniform parameter uint has_sinks; } p; -#include "types.comp" +#include "types.glsl" layout(constant_id = 0) const uint BLOCK_SIZE = 32; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp index 144ea58e6f..d873332eeb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp @@ -2,8 +2,8 @@ #extension GL_EXT_control_flow_attributes : enable -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" layout(constant_id = 0) const uint BLOCK_SIZE = 32; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp index 4bc697b9b9..70daad6c5d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/square.comp b/ggml/src/ggml-vulkan/vulkan-shaders/square.comp index ef43598baf..4eb56afcb1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/square.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp index 72353cc329..bc924b520a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp @@ -2,8 +2,8 @@ #extension GL_EXT_shader_16bit_storage : require -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" const uint num_threads = 256; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp index 759204afaf..bc22aa7bd7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp @@ -1,6 +1,6 @@ #version 450 -#include "types.comp" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp index a28e7c6cc8..4fee433a12 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp @@ -1,9 +1,9 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" float op(float a, float b) { return a / (1.0f + exp(-a)) * b; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp index 970750eec0..bda9dea21c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp @@ -1,6 +1,6 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" float op(float a, float b) { float xi = min(a, p.limit); @@ -11,4 +11,4 @@ float op(float a, float b) { return out_glu; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp b/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp index 8a6f868f58..7b5eb413bf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp b/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp index ce8e09442d..1605565457 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp @@ -9,7 +9,7 @@ layout (push_constant) uniform parameter uint max_period; } p; -#include "types.comp" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 256 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/types.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/types.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp index 74771def0f..154a2172d8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp @@ -9,7 +9,7 @@ layout (push_constant) uniform parameter float sf0; float sf1; float sf2; float sf3; } p; -#include "types.comp" +#include "types.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp b/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/utils.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 84bb9df9a0..f0cc24ff31 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -1,5 +1,3 @@ - - #include #include #include @@ -22,6 +20,7 @@ #include #ifdef _WIN32 + #define NOMINMAX #include #include // For _mkdir on Windows #else @@ -34,13 +33,13 @@ std::mutex lock; std::vector> shader_fnames; +std::locale c_locale("C"); std::string GLSLC = "glslc"; -std::string input_dir = "vulkan-shaders"; +std::string input_filepath = ""; std::string output_dir = "/tmp"; -std::string target_hpp = "ggml-vulkan-shaders.hpp"; -std::string target_cpp = "ggml-vulkan-shaders.cpp"; -bool no_clean = false; +std::string target_hpp = ""; +std::string target_cpp = ""; const std::vector type_names = { "f32", @@ -75,6 +74,7 @@ enum MatMulIdType { }; namespace { + void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { #ifdef _WIN32 HANDLE stdout_read, stdout_write; @@ -232,16 +232,87 @@ std::string basename(const std::string &path) { return path.substr(path.find_last_of("/\\") + 1); } +std::stringstream make_generic_stringstream() { + std::stringstream ss; + ss.imbue(c_locale); + return ss; +} + +std::string read_binary_file(const std::string& path, bool may_not_exist = false) { + FILE* f = fopen(path.c_str(), "rb"); + if (!f) { + if (!may_not_exist) { + std::cerr << "Error opening file: " << path << " (" << strerror(errno) << ")\n"; + } + return {}; + } + + fseek(f, 0, SEEK_END); + size_t size = ftell(f); + fseek(f, 0, SEEK_SET); + + std::string data(size, '\0'); + size_t read_size = fread(data.data(), 1, size, f); + fclose(f); + if (read_size != size) { + std::cerr << "Error reading file: " << path << " (" << strerror(errno) << ")\n"; + return {}; + } + + return data; +} + +void write_binary_file(const std::string& path, const std::string& content) { + FILE* f = fopen(path.c_str(), "wb"); + if (!f) { + std::cerr << "Error opening file for writing: " << path << " (" << strerror(errno) << ")\n"; + return; + } + + size_t write_size = fwrite(content.data(), 1, content.size(), f); + fclose(f); + if (write_size != content.size()) { + std::cerr << "Error writing file: " << path << " (" << strerror(errno) << ")\n"; + return; + } +} + +void write_file_if_changed(const std::string& path, const std::string& content) { + std::string existing = read_binary_file(path, true); + if (existing != content) { + write_binary_file(path, content); + } +} + + // variables to track number of compiles in progress static uint32_t compile_count = 0; static std::mutex compile_count_mutex; static std::condition_variable compile_count_cond; +static bool generate_dep_file = true; -void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { - std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); - std::string out_fname = join_paths(output_dir, name + ".spv"); - std::string in_path = join_paths(input_dir, in_fname); +void decrement_compile_count(uint32_t * count) { + if (count) { + std::lock_guard guard(compile_count_mutex); + assert(compile_count > 0); + compile_count--; + compile_count_cond.notify_all(); + } +} +using compile_count_guard = std::unique_ptr; + +compile_count_guard acquire_compile_slot() { + // wait until fewer than N compiles are in progress. + // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. + uint32_t N = std::max(1u, std::min(16u, std::thread::hardware_concurrency())); + std::unique_lock guard(compile_count_mutex); + compile_count_cond.wait(guard, [N] { return compile_count < N; }); + compile_count++; + return compile_count_guard(&compile_count, &decrement_compile_count); +} + +void string_to_spv_func(std::string name, std::string in_path, std::string out_path, std::map defines, bool coopmat, bool dep_file, compile_count_guard slot) { std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2"; // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734 @@ -249,11 +320,17 @@ void string_to_spv_func(const std::string& _name, const std::string& in_fname, c std::string opt_level = (coopmat || name.find("bf16") != std::string::npos) ? "" : "-O"; #ifdef _WIN32 - std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""}; + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_path + "\""}; #else - std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_fname}; + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_path}; #endif + if (dep_file) { + cmd.push_back("-MD"); + cmd.push_back("-MF"); + cmd.push_back("\"" + target_cpp + ".d\""); + } + #ifdef GGML_VULKAN_SHADER_DEBUG_INFO cmd.push_back("-g"); #endif @@ -281,17 +358,23 @@ void string_to_spv_func(const std::string& _name, const std::string& in_fname, c return; } + if (dep_file) { + // replace .spv output path with the embed .cpp path which is used as output in CMakeLists.txt + std::string dep = read_binary_file(target_cpp + ".d", true); + if (!dep.empty()) { + size_t pos = dep.find(out_path); + if (pos != std::string::npos) { + dep.replace(pos, out_path.length(), target_cpp); + } + write_binary_file(target_cpp + ".d", dep); + } + } + std::lock_guard guard(lock); - shader_fnames.push_back(std::make_pair(name, out_fname)); + shader_fnames.push_back(std::make_pair(name, out_path)); } catch (const std::exception& e) { std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl; } - { - std::lock_guard guard(compile_count_mutex); - assert(compile_count > 0); - compile_count--; - } - compile_count_cond.notify_all(); } std::map merge_maps(const std::map& a, const std::map& b) { @@ -301,18 +384,24 @@ std::map merge_maps(const std::map> compiles; -void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { - { - // wait until fewer than N compiles are in progress. - // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. - uint32_t N = 16; - std::unique_lock guard(compile_count_mutex); - while (compile_count >= N) { - compile_count_cond.wait(guard); - } - compile_count++; +void string_to_spv(std::string name, const std::string& source, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { + name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); + std::string out_path = join_paths(output_dir, name + ".spv"); + + if (input_filepath == "") { + // No input source to compile, only generate header for all shaders + shader_fnames.push_back(std::pair(name, out_path)); + return; + } else if (basename(input_filepath) != source) { + // Only compile shader variants matching the input filename + return; } - compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc)); + + compile_count_guard slot = acquire_compile_slot(); + compiles.push_back(std::async( + string_to_spv_func, name, input_filepath, out_path, defines, coopmat, generate_dep_file, std::move(slot))); + // Don't write the same dep file from multiple processes + generate_dep_file = false; } void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) { @@ -485,7 +574,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } void process_shaders() { - std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl; std::map base_dict = {{"FLOAT_TYPE", "float"}}; // matmul @@ -837,11 +925,11 @@ void process_shaders() { } void write_output_files() { - FILE* hdr = fopen(target_hpp.c_str(), "w"); - FILE* src = fopen(target_cpp.c_str(), "w"); + std::stringstream hdr = make_generic_stringstream(); + std::stringstream src = make_generic_stringstream(); - fprintf(hdr, "#include \n\n"); - fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str()); + hdr << "#include \n\n"; + src << "#include \"" << basename(target_hpp) << "\"\n\n"; std::sort(shader_fnames.begin(), shader_fnames.end()); for (const auto& pair : shader_fnames) { @@ -853,91 +941,85 @@ void write_output_files() { const std::string& path = pair.second; #endif - FILE* spv = fopen(path.c_str(), "rb"); - if (!spv) { - std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; - continue; - } + hdr << "extern const uint64_t " << name << "_len;\n"; + hdr << "extern const unsigned char " << name << "_data[];\n\n"; - fseek(spv, 0, SEEK_END); - size_t size = ftell(spv); - fseek(spv, 0, SEEK_SET); + if (input_filepath != "") { + std::string data = read_binary_file(path); + if (data.empty()) { + continue; + } - std::vector data(size); - size_t read_size = fread(data.data(), 1, size, spv); - fclose(spv); - if (read_size != size) { - std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; - continue; - } - - fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size); - fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size); - - fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size); - for (size_t i = 0; i < size; ++i) { - fprintf(src, "0x%02x,", data[i]); - if ((i + 1) % 12 == 0) fprintf(src, "\n"); - } - fprintf(src, "\n};\n\n"); - - if (!no_clean) { - std::remove(path.c_str()); + src << "const uint64_t " << name << "_len = " << data.size() << ";\n"; + src << "const unsigned char " << name << "_data[" << data.size() << "] = {\n" << std::hex; + auto bytes = reinterpret_cast(data.data()); + for (size_t i = 0; i < data.size(); ++i) { + src << "0x" << static_cast(bytes[i]) << ","; + if ((i + 1) % 12 == 0) src << "\n"; + } + src << std::dec << "\n};\n\n"; } } std::string suffixes[2] = {"_f32", "_f16"}; - for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) { - fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op); - fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op); - std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = "; - std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = "; + for (auto op : {"add", "sub", "mul", "div", "add_rms"}) { + hdr << "extern const void * " << op << "_data[2][2][2][2];\n"; + hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n"; + + std::string op_file = op == "add_rms" ? "add.comp" : std::string(op) + ".comp"; + if (basename(input_filepath) != op_file) { + continue; + } + std::stringstream data = make_generic_stringstream(); + std::stringstream len = make_generic_stringstream(); + data << "const void * " << op << "_data[2][2][2][2] = "; + len << "const uint64_t " << op << "_len[2][2][2][2] = "; for (uint32_t t0 = 0; t0 < 2; ++t0) { if (t0 == 0) { - data += "{"; - len += "{"; + data << "{"; + len << "{"; } for (uint32_t t1 = 0; t1 < 2; ++t1) { if (t1 == 0) { - data += "{"; - len += "{"; + data << "{"; + len << "{"; } for (uint32_t t2 = 0; t2 < 2; ++t2) { if (t2 == 0) { - data += "{"; - len += "{"; + data << "{"; + len << "{"; } for (uint32_t rte = 0; rte < 2; ++rte) { if (rte == 0) { - data += "{"; - len += "{"; + data << "{"; + len << "{"; } - data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : ""); - len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : ""); - data += "_data,"; - len += "_len,"; + data << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); + len << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); + data << "_data,"; + len << "_len,"; if (rte == 1) { - data += "}, "; - len += "}, "; + data << "}, "; + len << "}, "; } } if (t2 == 1) { - data += "}, "; - len += "}, "; + data << "}, "; + len << "}, "; } } if (t1 == 1) { - data += "}, "; - len += "}, "; + data << "}, "; + len << "}, "; } } if (t0 == 1) { - data += "};\n"; - len += "};\n"; + data << "};\n"; + len << "};\n"; } } - fputs(data.c_str(), src); - fputs(len.c_str(), src); + src << data.str(); + src << len.str(); } std::vector btypes = {"f16", "f32"}; @@ -951,20 +1033,25 @@ void write_output_files() { if (btype == "q8_1" && !is_legacy_quant(tname)) { continue; } - fprintf(hdr, "extern unsigned char *arr_dmmv_%s_%s_f32_data[3];\n", tname.c_str(), btype.c_str()); - fprintf(hdr, "extern uint64_t arr_dmmv_%s_%s_f32_len[3];\n", tname.c_str(), btype.c_str()); - std::string data = "unsigned char *arr_dmmv_" + tname + "_" + btype + "_f32_data[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_data};\n"; - std::string len = "uint64_t arr_dmmv_" + tname + "_" + btype + "_f32_len[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_len};\n"; - fputs(data.c_str(), src); - fputs(len.c_str(), src); + hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n"; + hdr << "extern const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3];\n"; + if (basename(input_filepath) == "mul_mat_vec.comp") { + src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n"; + src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n"; + } } } - fclose(hdr); - fclose(src); -} + if (input_filepath == "") { + write_file_if_changed(target_hpp, hdr.str()); + } + if (target_cpp != "") { + write_binary_file(target_cpp, src.str()); + } } +} // namespace + int main(int argc, char** argv) { std::map args; for (int i = 1; i < argc; ++i) { @@ -982,8 +1069,8 @@ int main(int argc, char** argv) { if (args.find("--glslc") != args.end()) { GLSLC = args["--glslc"]; // Path to glslc } - if (args.find("--input-dir") != args.end()) { - input_dir = args["--input-dir"]; // Directory containing shader sources + if (args.find("--source") != args.end()) { + input_filepath = args["--source"]; // The shader source file to compile } if (args.find("--output-dir") != args.end()) { output_dir = args["--output-dir"]; // Directory for containing SPIR-V output @@ -994,14 +1081,6 @@ int main(int argc, char** argv) { if (args.find("--target-cpp") != args.end()) { target_cpp = args["--target-cpp"]; // Path to generated cpp file } - if (args.find("--no-clean") != args.end()) { - no_clean = true; // Keep temporary SPIR-V files in output-dir after build - } - - if (!directory_exists(input_dir)) { - std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl; - return EXIT_FAILURE; - } if (!directory_exists(output_dir)) { if (!create_directory(output_dir)) { diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt index 78a985a4d1..c6a95d5151 100644 --- a/ggml/src/ggml-webgpu/CMakeLists.txt +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -50,5 +50,13 @@ if (GGML_WEBGPU_DEBUG) target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1) endif() +if (GGML_WEBGPU_CPU_PROFILE) + target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_CPU_PROFILE=1) +endif() + +if (GGML_WEBGPU_GPU_PROFILE) + target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_GPU_PROFILE=1) +endif() + target_include_directories(ggml-webgpu PRIVATE ${SHADER_OUTPUT_DIR}) target_link_libraries(ggml-webgpu PRIVATE ${DawnWebGPU_TARGET}) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index de68c5689b..05e16cd432 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -11,10 +11,12 @@ #include +#include #include #include #include #include +#include #include #include @@ -25,12 +27,44 @@ # define WEBGPU_LOG_DEBUG(msg) ((void) 0) #endif // GGML_WEBGPU_DEBUG +#ifdef GGML_WEBGPU_CPU_PROFILE +// total timing (aggregated) +# define WEBGPU_CPU_PROFILE_TOTAL_START(id) auto cpu_total_start_##id = std::chrono::high_resolution_clock::now(); + +# define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx) \ + auto cpu_total_end_##id = std::chrono::high_resolution_clock::now(); \ + double cpu_total_time_##id = \ + std::chrono::duration(cpu_total_end_##id - cpu_total_start_##id).count(); \ + (ctx)->cpu_time_ms[#id] += cpu_total_time_##id; + +// fine-grained timing (not included in totals) +# define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now(); + +# define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx) \ + auto cpu_detail_end_##id = std::chrono::high_resolution_clock::now(); \ + double cpu_detail_time_##id = \ + std::chrono::duration(cpu_detail_end_##id - cpu_detail_start_##id).count(); \ + (ctx)->cpu_detail_ms[#id] += cpu_detail_time_##id; +#else +# define WEBGPU_CPU_PROFILE_TOTAL_START(id) +# define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx) +# define WEBGPU_CPU_PROFILE_DETAIL_START(id) +# define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx) +#endif // GGML_WEBGPU_CPU_PROFILE + +#ifdef GGML_WEBGPU_GPU_PROFILE +# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24 +# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps +#endif + /* Constants */ -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16 -#define WEBGPU_WAIT_ANY_BATCH_SIZE 64 -#define WEBGPU_MUL_MAT_WG_SIZE 64 -#define WEBGPU_NUM_PARAM_BUFS 100 +#define WEBGPU_MUL_MAT_WG_SIZE 256 +#define WEBGPU_NUM_PARAM_BUFS 32u +#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u +#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 +// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool +#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 @@ -66,6 +100,11 @@ struct webgpu_pool_bufs { wgpu::Buffer dev_buf; }; +// The futures to wait on for a single queue submission +struct webgpu_submission_futures { + std::vector futures; +}; + // Holds a pool of parameter buffers for WebGPU operations struct webgpu_buf_pool { std::vector free; @@ -112,6 +151,83 @@ struct webgpu_buf_pool { } }; +#ifdef GGML_WEBGPU_GPU_PROFILE +struct webgpu_gpu_profile_bufs { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; + wgpu::QuerySet query_set; +}; + +// Holds a pool of parameter buffers for WebGPU operations +struct webgpu_gpu_profile_buf_pool { + std::vector free; + + std::mutex mutex; + + std::condition_variable cv; + + void init(wgpu::Device device, + int num_bufs, + size_t buf_size, + wgpu::BufferUsage dev_buf_usage, + wgpu::BufferUsage host_buf_usage) { + for (int i = 0; i < num_bufs; i++) { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; + ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf"); + ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf"); + // Create a query set for 2 timestamps + wgpu::QuerySetDescriptor ts_query_set_desc = {}; + + ts_query_set_desc.type = wgpu::QueryType::Timestamp; + ts_query_set_desc.count = 2; + wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc); + + free.push_back({ host_buf, dev_buf, ts_query_set }); + } + } + + webgpu_gpu_profile_bufs alloc_bufs() { + std::unique_lock lock(mutex); + cv.wait(lock, [this] { return !free.empty(); }); + webgpu_gpu_profile_bufs bufs = free.back(); + free.pop_back(); + return bufs; + } + + void free_bufs(std::vector bufs) { + std::lock_guard lock(mutex); + free.insert(free.end(), bufs.begin(), bufs.end()); + cv.notify_all(); + } + + void cleanup() { + std::lock_guard lock(mutex); + for (auto & bufs : free) { + bufs.host_buf.Destroy(); + bufs.dev_buf.Destroy(); + bufs.query_set.Destroy(); + } + free.clear(); + } +}; +#endif + +struct webgpu_pipeline { + wgpu::ComputePipeline pipeline; + std::string name; +}; + +struct webgpu_command { + wgpu::CommandBuffer commands; + webgpu_pool_bufs params_bufs; + std::optional set_rows_error_bufs; +#ifdef GGML_WEBGPU_GPU_PROFILE + webgpu_gpu_profile_bufs timestamp_query_bufs; + std::string pipeline_name; +#endif +}; + // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { wgpu::Instance instance; @@ -125,45 +241,50 @@ struct webgpu_context_struct { uint32_t max_wg_size_x; std::recursive_mutex mutex; + std::atomic_uint inflight_threads = 0; webgpu_buf_pool param_buf_pool; webgpu_buf_pool set_rows_error_buf_pool; - wgpu::ComputePipeline memset_pipeline; - wgpu::ComputePipeline mul_mat_pipeline[30][2]; - wgpu::ComputePipeline set_rows_pipeline; - wgpu::ComputePipeline get_rows_pipeline[30]; - wgpu::ComputePipeline get_rows_f32_no_vec_pipeline; - wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type - wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace - wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace - wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split - wgpu::ComputePipeline scale_pipeline[2]; // inplace - wgpu::ComputePipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace + webgpu_pipeline memset_pipeline; + webgpu_pipeline mul_mat_pipeline[30][2]; + webgpu_pipeline set_rows_pipeline; + webgpu_pipeline get_rows_pipeline[30]; + webgpu_pipeline get_rows_f32_no_vec_pipeline; + webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type + webgpu_pipeline add_pipeline[2][2]; // type, inplace + webgpu_pipeline sub_pipeline[2][2]; // type, inplace + webgpu_pipeline mul_pipeline[2][2]; // type, inplace + webgpu_pipeline div_pipeline[2][2]; // type, inplace + webgpu_pipeline rms_norm_pipeline[2]; // inplace + webgpu_pipeline rope_pipeline[2][2][2]; // type, ff, inplace + webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split + webgpu_pipeline scale_pipeline[2]; // inplace + webgpu_pipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace size_t memset_bytes_per_thread; // Staging buffer for reading data from the GPU wgpu::Buffer get_tensor_staging_buf; - // Command buffers which need to be submitted - std::vector staged_command_bufs; - - // Parameter buffers associated with the staged command buffers - std::vector staged_param_bufs; - // Buffers associated with set_rows operations, used to store potential errors - std::vector staged_set_row_error_bufs; - - std::vector callback_futures; - #ifdef GGML_WEBGPU_DEBUG wgpu::Buffer debug_host_buf; wgpu::Buffer debug_dev_buf; #endif + +#ifdef GGML_WEBGPU_CPU_PROFILE + // Profiling: labeled CPU time in ms (total) + std::unordered_map cpu_time_ms; + // Profiling: detailed CPU time in ms + std::unordered_map cpu_detail_ms; +#endif + +#ifdef GGML_WEBGPU_GPU_PROFILE + // Profiling: per-shader GPU time in ms + std::unordered_map shader_gpu_time_ms; + // Profiling: pool of timestamp query buffers (one per operation) + webgpu_gpu_profile_buf_pool timestamp_query_buf_pool; +#endif }; typedef std::shared_ptr webgpu_context; @@ -199,12 +320,10 @@ struct ggml_backend_webgpu_buffer_context { /* WebGPU object initializations */ static void ggml_webgpu_create_pipeline(wgpu::Device & device, - wgpu::ComputePipeline & pipeline, + webgpu_pipeline & pipeline, const char * shader_code, const char * label, const std::vector & constants = {}) { - WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()"); - wgpu::ShaderSourceWGSL shader_source; shader_source.code = shader_code; @@ -222,7 +341,7 @@ static void ggml_webgpu_create_pipeline(wgpu::Device & pipeline_desc.compute.constants = constants.data(); pipeline_desc.compute.constantCount = constants.size(); } - pipeline = device.CreateComputePipeline(&pipeline_desc); + pipeline = { device.CreateComputePipeline(&pipeline_desc), label }; } static void ggml_webgpu_create_buffer(wgpu::Device & device, @@ -230,8 +349,6 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, size_t size, wgpu::BufferUsage usage, const char * label) { - WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()"); - wgpu::BufferDescriptor buffer_desc; buffer_desc.size = size; buffer_desc.usage = usage; @@ -247,83 +364,35 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** WebGPU Actions */ // Wait for the queue to finish processing all submitted work -static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) { - std::lock_guard lock(ctx->mutex); - if (ctx->callback_futures.empty()) { - // no existing callbacks, wait on queue submission - ctx->instance.WaitAny( - ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, - [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", - std::string(message).c_str()); - } - }), - UINT64_MAX); - } else { - // WebGPU implementations may limit the number of futures that can be waited on at once, - // so wait in batches (64 is what Dawn supports). - for (size_t i = 0; i < ctx->callback_futures.size(); i += WEBGPU_WAIT_ANY_BATCH_SIZE) { - size_t end = std::min(i + WEBGPU_WAIT_ANY_BATCH_SIZE, ctx->callback_futures.size()); - ctx->instance.WaitAny(end - i, ctx->callback_futures.data() + i, UINT64_MAX); +static void ggml_backend_webgpu_wait(webgpu_context & ctx, + std::vector & futures, + bool block = true) { + // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads, + // inflight_max may be 0, meaning that we must wait on all futures. + uint64_t timeout_ms = block ? UINT64_MAX : 0; + uint inflight_threads = ctx->inflight_threads; + uint inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); + while (futures.size() >= inflight_max && futures.size() > 0) { + ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); + futures.erase(futures.begin()); + } + size_t i = 0; + while (i < futures.size()) { + auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms); + switch (waitStatus) { + case wgpu::WaitStatus::Success: + futures.erase(futures.begin() + i); + break; + case wgpu::WaitStatus::TimedOut: + i++; + break; + case wgpu::WaitStatus::Error: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); + break; + default: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); + break; } - ctx->callback_futures.clear(); - } -} - -static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) { - std::lock_guard lock(ctx->mutex); - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_submit_queue()"); - if (ctx->staged_command_bufs.empty()) { - // Nothing to submit - return; - } - ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data()); - - // If there are SET_ROWS operations in this submission, copy their error buffers to the host. - if (ctx->staged_set_row_error_bufs.size() > 0) { - wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); - for (auto & error_bufs : ctx->staged_set_row_error_bufs) { - // Copy the error buffer to the host buffer - encoder.CopyBufferToBuffer(error_bufs.dev_buf, 0, error_bufs.host_buf, 0, error_bufs.host_buf.GetSize()); - } - wgpu::CommandBuffer commands = encoder.Finish(); - ctx->queue.Submit(1, &commands); - } - - ctx->staged_command_bufs.clear(); - std::vector staged_param_bufs = std::move(ctx->staged_param_bufs); - std::vector staged_set_row_error_bufs = std::move(ctx->staged_set_row_error_bufs); - - // Free the staged parameter buffers once the submission completes - wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, - [ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); - } - // Free the staged buffers - ctx->param_buf_pool.free_bufs(staged_param_bufs); - }); - ctx->callback_futures.push_back({ p_f }); - - // Check for errrors in SET_ROWS operations - for (auto & error_bufs : staged_set_row_error_bufs) { - wgpu::Future f = error_bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, error_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, - [ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str()); - } else { - const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf.GetConstMappedRange(); - if (*error_data) { - GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); - } - // We can't unmap in here due to WebGPU reentrancy limitations. - ctx->set_rows_error_buf_pool.free_bufs({ error_bufs }); - } - }); - ctx->callback_futures.push_back({ f }); } } @@ -347,7 +416,6 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx, // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and // debug statements in the shader, and then call this function after encoding the commands and submitting them. static void ggml_backend_webgpu_debug(webgpu_context & ctx) { - ggml_backend_webgpu_submit_queue(ctx); wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); wgpu::CommandBuffer commands = encoder.Finish(); @@ -364,13 +432,85 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) { } #endif -static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & ctx, - wgpu::ComputePipeline & pipeline, - std::vector params, - std::vector bind_group_entries, - uint32_t wg_x, - const char * bind_group_label = nullptr, - bool submit_and_wait = false) { +static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, std::vector commands) { + std::vector command_buffers; + std::vector params_bufs; + std::vector set_rows_error_bufs; +#ifdef GGML_WEBGPU_GPU_PROFILE + std::vector> pipeline_name_and_ts_bufs; +#endif + + for (const auto & command : commands) { + command_buffers.push_back(command.commands); + params_bufs.push_back(command.params_bufs); + if (command.set_rows_error_bufs) { + set_rows_error_bufs.push_back(command.set_rows_error_bufs.value()); + } + } + ctx->queue.Submit(command_buffers.size(), command_buffers.data()); + + std::vector futures; + + wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, + [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + if (status != wgpu::QueueWorkDoneStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); + } + // Free the staged buffers + ctx->param_buf_pool.free_bufs({ params_bufs }); + }); + futures.push_back({ p_f }); + + for (const auto & bufs : set_rows_error_bufs) { + wgpu::Future f = bufs.host_buf.MapAsync( + wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, + [ctx, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { + if (status != wgpu::MapAsyncStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str()); + } else { + const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange(); + if (*error_data) { + GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); + } + // We can't unmap in here due to WebGPU reentrancy limitations. + ctx->set_rows_error_buf_pool.free_bufs({ bufs }); + } + }); + futures.push_back({ f }); + } + +#ifdef GGML_WEBGPU_GPU_PROFILE + for (const auto & command : commands) { + auto label = command.pipeline_name; + auto ts_bufs = command.timestamp_query_bufs; + + wgpu::Future f = ts_bufs.host_buf.MapAsync( + wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, + [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) { + if (status != wgpu::MapAsyncStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str()); + } else { + const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange(); + // WebGPU timestamps are in ns; convert to ms + double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6; + ctx->shader_gpu_time_ms[label] += elapsed_ms; + // We can't unmap in here due to WebGPU reentrancy limitations. + ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); + } + }); + futures.push_back({ f }); + } +#endif + return { futures }; +} + +static webgpu_command ggml_backend_webgpu_build(webgpu_context & ctx, + webgpu_pipeline & pipeline, + std::vector params, + std::vector bind_group_entries, + uint32_t wg_x, + std::optional set_rows_error_bufs = std::nullopt) { webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize()); @@ -388,44 +528,58 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & .size = params_bufs.dev_buf.GetSize() }); wgpu::BindGroupDescriptor bind_group_desc; - bind_group_desc.layout = pipeline.GetBindGroupLayout(0); + bind_group_desc.layout = pipeline.pipeline.GetBindGroupLayout(0); bind_group_desc.entryCount = bind_group_entries.size(); bind_group_desc.entries = bind_group_entries.data(); - if (bind_group_label) { - bind_group_desc.label = bind_group_label; - } + bind_group_desc.label = pipeline.name.c_str(); wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); + +#ifdef GGML_WEBGPU_GPU_PROFILE + // --- Profiling: GPU timestamp queries --- + // Allocate a timestamp query buffer (2 timestamps: start/end) + webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); + if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { + ts_bufs.host_buf.Unmap(); + } + + wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set, + .beginningOfPassWriteIndex = 0, + .endOfPassWriteIndex = 1 }; + wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc); +#else wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); - pass.SetPipeline(pipeline); +#endif + pass.SetPipeline(pipeline.pipeline); pass.SetBindGroup(0, bind_group); pass.DispatchWorkgroups(wg_x, 1, 1); pass.End(); - wgpu::CommandBuffer commands = encoder.Finish(); - if (submit_and_wait) { - // Submit and wait immediately - ctx->queue.Submit(1, &commands); - ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, - [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data); - } - ctx->param_buf_pool.free_bufs({ params_bufs }); - }), - UINT64_MAX); - } else { - // Lock the context mutex when pushing to the staging vectors. - std::lock_guard lock(ctx->mutex); - // Enqueue commands and only submit if we have enough staged commands - ctx->staged_command_bufs.push_back(commands); - ctx->staged_param_bufs.push_back(params_bufs); - if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { - ggml_backend_webgpu_submit_queue(ctx); - } + +#ifdef GGML_WEBGPU_GPU_PROFILE + // Resolve the query set into the device buffer + encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); + encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); +#endif + + // If there are SET_ROWS operations in this submission, copy their error buffers to the host. + if (set_rows_error_bufs) { + encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0, + set_rows_error_bufs->host_buf.GetSize()); } + + wgpu::CommandBuffer commands = encoder.Finish(); + webgpu_command result = {}; + result.commands = commands; + result.params_bufs = params_bufs; + result.set_rows_error_bufs = set_rows_error_bufs; +#ifdef GGML_WEBGPU_GPU_PROFILE + result.timestamp_query_bufs = ts_bufs; + result.pipeline_name = pipeline.name; +#endif + return result; } static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx, @@ -439,7 +593,10 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx, }; size_t bytes_per_wg = ctx->max_wg_size_x * ctx->memset_bytes_per_thread; uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, "MEMSET", true); + + webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipeline, params, entries, wg_x); + std::vector futures = { ggml_backend_webgpu_submit(ctx, { command }) }; + ggml_backend_webgpu_wait(ctx, futures); } /** End WebGPU Actions */ @@ -455,8 +612,48 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")"); - // TODO: cleanup +#ifdef GGML_WEBGPU_CPU_PROFILE + std::cout << "\n[ggml_webgpu cpu profiling summary]\n"; + double total_cpu = 0.0; + for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) { + total_cpu += kv.second; + } + std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n"; + std::cout << "ggml_webgpu: cpu breakdown:\n"; + for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) { + double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0; + std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; + } + if (ctx->webgpu_ctx->cpu_detail_ms.size() > 0) { + std::cout << "ggml_webgpu: cpu detailed breakdown:\n"; + } + for (const auto & kv : ctx->webgpu_ctx->cpu_detail_ms) { + double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0; + std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; + } +#endif + +#ifdef GGML_WEBGPU_GPU_PROFILE + std::cout << "\n[ggml_webgpu gpu profiling summary]\n"; + double total_gpu = 0.0; + for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { + total_gpu += kv.second; + } + std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n"; + std::cout << "\nggml_webgpu: gpu breakdown:\n"; + for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { + double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0; + std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; + } +#endif + +#if defined(GGML_WEBGPU_CPU_PROFILE) && defined(GGML_WEBGPU_GPU_PROFILE) + std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n"; +#endif + +#if !defined(GGML_WEBGPU_CPU_PROFILE) && !defined(GGML_WEBGPU_GPU_PROFILE) GGML_UNUSED(ctx); +#endif } static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { @@ -489,7 +686,7 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); } -static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); std::vector params = { @@ -518,14 +715,16 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x, - ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x); } -static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { +static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { // For set rows specifically, we need to check if src and idx are empty tensors. if (ggml_is_empty(src) || ggml_is_empty(idx)) { - return; + return std::nullopt; } webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs(); @@ -568,13 +767,13 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; - std::lock_guard lock(ctx->mutex); - ctx->staged_set_row_error_bufs.push_back(error_bufs); - - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, ctx->set_rows_pipeline, params, entries, wg_x, error_bufs); } -static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), @@ -609,14 +808,17 @@ static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (dst->ne[1] * dst->ne[2] * dst->ne[3] + max_wg_size - 1) / max_wg_size; - wgpu::ComputePipeline pipeline = ctx->get_rows_pipeline[src->type]; + webgpu_pipeline pipeline = ctx->get_rows_pipeline[src->type]; if (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 != 0) { pipeline = ctx->get_rows_f32_no_vec_pipeline; } - ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), @@ -653,16 +855,15 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t uint32_t wg_x = (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x, - ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x); } -static void ggml_webgpu_binary_op(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst, - wgpu::ComputePipeline & pipeline, - bool inplace) { +static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst, + webgpu_pipeline & pipeline, + bool inplace) { std::vector params = { (uint32_t) ggml_nelements(dst), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -700,10 +901,10 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx, size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { int inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { @@ -735,15 +936,14 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, ggml_nrows(src), - ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, ctx->rms_norm_pipeline[inplace], params, entries, ggml_nrows(src)); } -static void ggml_webgpu_rope(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { const int inplace = ggml_webgpu_tensor_equal(src0, dst); const int has_freq_factor = (src2 != nullptr); @@ -821,13 +1021,13 @@ static void ggml_webgpu_rope(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - wgpu::ComputePipeline pipeline = ctx->rope_pipeline[dst->type][has_freq_factor][inplace]; - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + webgpu_pipeline pipeline = ctx->rope_pipeline[dst->type][has_freq_factor][inplace]; + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size; + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { const int split = (src1 != nullptr); std::vector params = { @@ -874,13 +1074,13 @@ static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tenso .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - wgpu::ComputePipeline pipeline = ctx->glu_pipeline[ggml_get_glu_op(dst)][dst->type][split]; - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + webgpu_pipeline pipeline = ctx->glu_pipeline[ggml_get_glu_op(dst)][dst->type][split]; + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { int inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { @@ -915,15 +1115,14 @@ static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tens size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->scale_pipeline[inplace], params, entries, wg_x, - ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, ctx->scale_pipeline[inplace], params, entries, wg_x); } -static void ggml_webgpu_soft_max(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { const int inplace = ggml_webgpu_tensor_equal(src0, dst); const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here const int has_sink = (src2 != nullptr); @@ -988,14 +1187,14 @@ static void ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->soft_max_pipeline[mask_type][has_sink][inplace], params, entries, - ggml_nrows(dst), ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, ctx->soft_max_pipeline[mask_type][has_sink][inplace], params, entries, + ggml_nrows(dst)); } -// Returns true if node has enqueued work into the queue, false otherwise -static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { +// Returns the encoded command, or std::nullopt if the operation is a no-op +static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { if (ggml_is_empty(node)) { - return false; + return std::nullopt; } WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")"); @@ -1010,60 +1209,49 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: case GGML_OP_RESHAPE: - return false; + return std::nullopt; case GGML_OP_CPY: case GGML_OP_CONT: - ggml_webgpu_cpy(ctx, src0, node); - break; + return ggml_webgpu_cpy(ctx, src0, node); case GGML_OP_SET_ROWS: - ggml_webgpu_set_rows(ctx, src0, src1, node); - break; + return ggml_webgpu_set_rows(ctx, src0, src1, node); case GGML_OP_GET_ROWS: - ggml_webgpu_get_rows(ctx, src0, src1, node); - break; + return ggml_webgpu_get_rows(ctx, src0, src1, node); case GGML_OP_MUL_MAT: - ggml_webgpu_mul_mat(ctx, src0, src1, node); - break; + return ggml_webgpu_mul_mat(ctx, src0, src1, node); case GGML_OP_ADD: { int inplace = ggml_webgpu_tensor_equal(src0, node); - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace); - break; + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace); } case GGML_OP_SUB: { int inplace = ggml_webgpu_tensor_equal(src0, node); - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace); - break; + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace); } case GGML_OP_MUL: { int inplace = ggml_webgpu_tensor_equal(src0, node); - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace); - break; + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace); } case GGML_OP_DIV: { int inplace = ggml_webgpu_tensor_equal(src0, node); - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace); - break; + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace); } case GGML_OP_RMS_NORM: - ggml_webgpu_rms_norm(ctx, src0, node); - break; + return ggml_webgpu_rms_norm(ctx, src0, node); case GGML_OP_ROPE: - ggml_webgpu_rope(ctx, src0, src1, src2, node); - break; + return ggml_webgpu_rope(ctx, src0, src1, src2, node); case GGML_OP_GLU: - ggml_webgpu_glu(ctx, src0, src1, node); - break; + return ggml_webgpu_glu(ctx, src0, src1, node); case GGML_OP_SCALE: - ggml_webgpu_scale(ctx, src0, node); - break; + return ggml_webgpu_scale(ctx, src0, node); + case GGML_OP_SOFT_MAX: + return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); default: - return false; + return std::nullopt; } - return true; } static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { @@ -1072,13 +1260,35 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ggml_backend_webgpu_context * backend_ctx = static_cast(backend->context); webgpu_context ctx = backend_ctx->webgpu_ctx; + WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); + + ctx->inflight_threads++; + + std::vector commands; + std::vector futures; for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_webgpu_encode_node(ctx, cgraph->nodes[i]); + if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { + commands.push_back(*cmd); + } + // compute the batch size based on the number of inflight threads + uint inflight_threads = ctx->inflight_threads; + uint batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), + WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); + if (commands.size() >= batch_size) { + futures.push_back(ggml_backend_webgpu_submit(ctx, commands)); + // Process events and check for completed submissions + ctx->instance.ProcessEvents(); + ggml_backend_webgpu_wait(ctx, futures, false); + commands.clear(); + } } - - ggml_backend_webgpu_submit_queue(ctx); - ggml_backend_webgpu_wait_on_submission(ctx); - + if (!commands.empty()) { + webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands); + futures.push_back(new_futures); + } + ggml_backend_webgpu_wait(ctx, futures); + ctx->inflight_threads--; + WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx); return GGML_STATUS_SUCCESS; } @@ -1104,7 +1314,6 @@ static ggml_backend_i ggml_backend_webgpu_i = { /* GGML Backend Buffer Interface */ static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_free_buffer()"); ggml_backend_webgpu_buffer_context * ctx = static_cast(buffer->context); ctx->buffer.Destroy(); } @@ -1125,6 +1334,8 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe return; } + WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor); + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); @@ -1135,6 +1346,7 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe // This is a trick to set all bytes of a u32 to the same 1 byte value. uint32_t val32 = (uint32_t) value * 0x01010101; ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size); + WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->webgpu_ctx); } static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, @@ -1144,6 +1356,7 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, size_t size) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); + WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor); ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; @@ -1166,8 +1379,17 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, remaining_size); } else { // wait for WriteBuffer to complete - ggml_backend_webgpu_wait_on_submission(webgpu_ctx); + webgpu_ctx->instance.WaitAny( + webgpu_ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, + [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + if (status != wgpu::QueueWorkDoneStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", + std::string(message).c_str()); + } + }), + UINT64_MAX); } + WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, webgpu_ctx); } static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, @@ -1177,7 +1399,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, size_t size) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); - + WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor); ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; wgpu::Device device = webgpu_ctx->device; @@ -1217,12 +1439,15 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, // Copy the data from the mapped range to the output buffer std::memcpy(data, mapped_range, size); webgpu_ctx->get_tensor_staging_buf.Unmap(); + WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, webgpu_ctx); } static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")"); + WEBGPU_CPU_PROFILE_TOTAL_START(clear); ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size); + WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->webgpu_ctx); } static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { @@ -1806,6 +2031,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SCALE: supports_op = op->type == GGML_TYPE_F32; break; + case GGML_OP_SOFT_MAX: + supports_op = op->type == GGML_TYPE_F32; + break; default: break; } @@ -1869,6 +2097,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t GGML_ASSERT(index == 0); WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()"); + WEBGPU_CPU_PROFILE_TOTAL_START(reg_get_device); + ggml_backend_webgpu_reg_context * reg_ctx = static_cast(reg->context); webgpu_context ctx = reg_ctx->webgpu_ctx; @@ -1895,7 +2125,11 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16, wgpu::FeatureName::ImplicitDeviceSynchronization }; - wgpu::DeviceDescriptor dev_desc; +#ifdef GGML_WEBGPU_GPU_PROFILE + required_features.push_back(wgpu::FeatureName::TimestampQuery); +#endif + + wgpu::DeviceDescriptor dev_desc; dev_desc.requiredLimits = &ctx->limits; dev_desc.requiredFeatures = required_features.data(); dev_desc.requiredFeatureCount = required_features.size(); @@ -1909,8 +2143,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t dev_desc.SetUncapturedErrorCallback( [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { GGML_UNUSED(device); - GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), - std::string(message).c_str()); + GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), + std::string(message).c_str()); }); ctx->instance.WaitAny(ctx->adapter.RequestDevice( &dev_desc, wgpu::CallbackMode::AllowSpontaneous, @@ -1932,6 +2166,15 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + +#ifdef GGML_WEBGPU_GPU_PROFILE + // Initialize buffer pool for timestamp queries (profiling) + ctx->timestamp_query_buf_pool.init(ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, + WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, + wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, + wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst); +#endif + ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); @@ -1949,6 +2192,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_webgpu_init_rope_pipeline(ctx); ggml_webgpu_init_glu_pipeline(ctx); ggml_webgpu_init_scale_pipeline(ctx); + ggml_webgpu_init_soft_max_pipeline(ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers @@ -1975,6 +2219,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t /* .reg = */ reg, /* .context = */ &device_ctx, }; + + WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, ctx); return &device; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl index 25e2185de8..141db9b39d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl @@ -870,7 +870,7 @@ struct MulMatParams { @group(0) @binding(3) var params: MulMatParams; -@compute @workgroup_size(64) +@compute @workgroup_size(256) fn main(@builtin(global_invocation_id) global_id: vec3) { let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; if (global_id.x >= total) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl index 4f72bb1c85..712b921f1a 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl @@ -84,7 +84,7 @@ fn main(@builtin(workgroup_id) wid: vec3, let i2 = i / params.ne1; let i1 = i % params.ne1; let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; - let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; let elems = (params.ne0 + wg_size - 1) / wg_size; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl index 64ab576c08..c74dc4cc92 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl @@ -300,6 +300,7 @@ fn main(@builtin(workgroup_id) wid: vec3, workgroupBarrier(); } let row_max = scratch[0]; + workgroupBarrier(); var sum = 0.0f; col = lid.x; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1600405ea8..f5e5fba800 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -128,6 +128,8 @@ class Keys: ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx" ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs" EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input" + DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in" + DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out" class Attention: HEAD_COUNT = "{arch}.attention.head_count" @@ -261,6 +263,7 @@ class Keys: class ClipVision: IMAGE_SIZE = "clip.vision.image_size" + PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size" PATCH_SIZE = "clip.vision.patch_size" EMBEDDING_LENGTH = "clip.vision.embedding_length" FEED_FORWARD_LENGTH = "clip.vision.feed_forward_length" @@ -406,6 +409,7 @@ class MODEL_ARCH(IntEnum): SMOLLM3 = auto() GPT_OSS = auto() LFM2 = auto() + LFM2MOE = auto() DREAM = auto() SMALLTHINKER = auto() LLADA = auto() @@ -431,6 +435,8 @@ class MODEL_TENSOR(IntEnum): TOKEN_TYPES = auto() POS_EMBD = auto() OUTPUT = auto() + DENSE_2_OUT = auto() # embeddinggemma 2_Dense + DENSE_3_OUT = auto() # embeddinggemma 3_Dense OUTPUT_NORM = auto() ROPE_FREQS = auto() ROPE_FACTORS_LONG = auto() @@ -748,6 +754,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.SMOLLM3: "smollm3", MODEL_ARCH.GPT_OSS: "gpt-oss", MODEL_ARCH.LFM2: "lfm2", + MODEL_ARCH.LFM2MOE: "lfm2moe", MODEL_ARCH.DREAM: "dream", MODEL_ARCH.SMALLTHINKER: "smallthinker", MODEL_ARCH.LLADA: "llada", @@ -774,6 +781,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.POS_EMBD: "position_embd", MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.OUTPUT: "output", + MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embeddinggemma 2_Dense + MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embeddinggemma 2_Dense MODEL_TENSOR.ROPE_FREQS: "rope_freqs", MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long", MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short", @@ -1756,6 +1765,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_ARCH.GEMMA_EMBEDDING: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.DENSE_2_OUT, + MODEL_TENSOR.DENSE_3_OUT, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q_NORM, @@ -2697,6 +2708,29 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.OUTPUT, ], + MODEL_ARCH.LFM2MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.SHORTCONV_CONV, + MODEL_TENSOR.SHORTCONV_INPROJ, + MODEL_TENSOR.SHORTCONV_OUTPROJ, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.ATTN_NORM, # operator_norm + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + ], MODEL_ARCH.SMALLTHINKER: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 30fc1a05ec..306679e218 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -730,6 +730,10 @@ class GGUFWriter: def add_sliding_window_pattern(self, value: Sequence[bool]) -> None: self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value) + def add_dense_features_dims(self, dense:str, in_f:int, out_f:int) -> None: + self.add_uint32(Keys.LLM.DENSE_FEAT_IN_SIZE.format(arch=self.arch, dense=dense), in_f) + self.add_uint32(Keys.LLM.DENSE_FEAT_OUT_SIZE.format(arch=self.arch, dense=dense), out_f) + def add_logit_scale(self, value: float) -> None: self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value) @@ -1037,6 +1041,9 @@ class GGUFWriter: def add_vision_image_size(self, value: int) -> None: self.add_uint32(Keys.ClipVision.IMAGE_SIZE, value) + def add_vision_preproc_image_size(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.PREPROC_IMAGE_SIZE, value) + def add_vision_image_mean(self, values: Sequence[float]) -> None: self.add_array(Keys.ClipVision.IMAGE_MEAN, values) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 67b2741340..c05aa6cc48 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -76,7 +76,12 @@ class TensorNameMap: "lm_head", # llama4 "model.transformer.ff_out", # llada ), - + MODEL_TENSOR.DENSE_2_OUT: ( + "dense_2_out", # embeddinggemma + ), + MODEL_TENSOR.DENSE_3_OUT: ( + "dense_3_out", # embeddinggemma + ), # Output norm MODEL_TENSOR.OUTPUT_NORM: ( "gpt_neox.final_layer_norm", # gptneox @@ -358,6 +363,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.router", # openai-moe "model.layers.{bid}.mlp.gate.wg", # hunyuan "model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker + "model.layers.{bid}.feed_forward.gate", # lfm2moe ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -367,6 +373,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_EXP_PROBS_B: ( "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1 "model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe + "model.layers.{bid}.feed_forward.expert_bias", # lfm2moe ), # Feed-forward up diff --git a/include/llama.h b/include/llama.h index 1df8f96920..14e12d7c51 100644 --- a/include/llama.h +++ b/include/llama.h @@ -296,6 +296,7 @@ extern "C" { bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data bool use_extra_bufts; // use extra buffer types (used for weight repacking) + bool no_host; // bypass host buffer allowing extra buffers to be used }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations diff --git a/requirements/requirements-all.txt b/requirements/requirements-all.txt index 56b6752ac0..6c6bea9490 100644 --- a/requirements/requirements-all.txt +++ b/requirements/requirements-all.txt @@ -14,3 +14,5 @@ -r ./requirements-tool_bench.txt -r ./requirements-gguf_editor_gui.txt + +-r ../examples/model-conversion/requirements.txt diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 4fd083aa04..869e4dccf0 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -93,6 +93,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_OPENAI_MOE, "gpt-oss" }, { LLM_ARCH_LFM2, "lfm2" }, + { LLM_ARCH_LFM2MOE, "lfm2moe" }, { LLM_ARCH_DREAM, "dream" }, { LLM_ARCH_SMALLTHINKER, "smallthinker" }, { LLM_ARCH_LLADA, "llada" }, @@ -218,6 +219,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, + // sentence-transformers dense modules feature dims + { LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" }, + { LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" }, + { LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" }, + { LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, @@ -1070,6 +1076,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_DENSE_2_OUT, "dense_2" }, + { LLM_TENSOR_DENSE_3_OUT, "dense_3" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, @@ -2104,6 +2112,32 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_OUTPUT, "output" }, } }, + { + LLM_ARCH_LFM2MOE, + { + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" }, + { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, + { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + } + }, { LLM_ARCH_SMALLTHINKER, { @@ -2254,6 +2288,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, @@ -2493,6 +2529,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_PLAMO2: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: case LLM_ARCH_NEMOTRON_H: return true; default: diff --git a/src/llama-arch.h b/src/llama-arch.h index bc4b04bb4e..c3ae71655b 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -97,6 +97,7 @@ enum llm_arch { LLM_ARCH_SMOLLM3, LLM_ARCH_OPENAI_MOE, LLM_ARCH_LFM2, + LLM_ARCH_LFM2MOE, LLM_ARCH_DREAM, LLM_ARCH_SMALLTHINKER, LLM_ARCH_LLADA, @@ -270,6 +271,12 @@ enum llm_kv { LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_MIDDLE_ID, + + // sentence-transformers dense layers in and out features + LLM_KV_DENSE_2_FEAT_IN, + LLM_KV_DENSE_2_FEAT_OUT, + LLM_KV_DENSE_3_FEAT_IN, + LLM_KV_DENSE_3_FEAT_OUT, }; enum llm_tensor { @@ -277,6 +284,8 @@ enum llm_tensor { LLM_TENSOR_TOKEN_EMBD_NORM, LLM_TENSOR_TOKEN_TYPES, LLM_TENSOR_POS_EMBD, + LLM_TENSOR_DENSE_2_OUT, + LLM_TENSOR_DENSE_3_OUT, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_ROPE_FREQS, diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 66e6c6a38f..956c4e085e 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -590,7 +590,7 @@ int32_t llm_chat_apply_template( ss << message->content << "<|end_of_text|>\n"; } if (add_ass) { - ss << "<|start_of_role|>assistant<|end_of_role|>\n"; + ss << "<|start_of_role|>assistant<|end_of_role|>"; } } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { // GigaChat template diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d8a8b5e647..e7526e7d0a 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2346,6 +2346,12 @@ llama_context * llama_init_from_model( return nullptr; } + if (params.pooling_type != model->hparams.pooling_type) { + //user-specified pooling-type is different from the model default + LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__, + model->hparams.pooling_type, params.pooling_type); + } + try { auto * ctx = new llama_context(*model, params); return ctx; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 90cd885a60..a24853c63a 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1853,6 +1853,23 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +void llm_graph_context::build_dense_out( + ggml_tensor * dense_2, + ggml_tensor * dense_3) const { + if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) { + return; + } + ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd; + GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd"); + + cur = ggml_mul_mat(ctx0, dense_2, cur); + cur = ggml_mul_mat(ctx0, dense_3, cur); + cb(cur, "result_embd_pooled", -1); + res->t_embd_pooled = cur; + ggml_build_forward_expand(gf, cur); +} + + void llm_graph_context::build_pooling( ggml_tensor * cls, ggml_tensor * cls_b, diff --git a/src/llama-graph.h b/src/llama-graph.h index 34b984afeb..dc84b79428 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -814,6 +814,14 @@ struct llm_graph_context { ggml_tensor * cls_b, ggml_tensor * cls_out, ggml_tensor * cls_out_b) const; + + // + // dense (out) + // + + void build_dense_out( + ggml_tensor * dense_2, + ggml_tensor * dense_3) const; }; // TODO: better name diff --git a/src/llama-hparams.h b/src/llama-hparams.h index f29b23eeff..4e7f73ec23 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -169,6 +169,12 @@ struct llama_hparams { uint32_t laurel_rank = 64; uint32_t n_embd_altup = 256; + // needed for sentence-transformers dense layers + uint32_t dense_2_feat_in = 0; // in_features of the 2_Dense + uint32_t dense_2_feat_out = 0; // out_features of the 2_Dense + uint32_t dense_3_feat_in = 0; // in_features of the 3_Dense + uint32_t dense_3_feat_out = 0; // out_features of the 3_Dense + // xIELU std::array xielu_alpha_n; std::array xielu_alpha_p; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 816f2d5de5..736693e174 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -123,11 +123,8 @@ llama_kv_cache::llama_kv_cache( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k; - ggml_tensor * v; - - k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); - v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); + ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); + ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(v, "cache_v_l%d", il); diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index cb8832a353..dfb8439e01 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -73,7 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch, false); + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); } if (ubatch.n_tokens == 0) { diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index e23e74982b..d67f5a5f47 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -382,7 +382,9 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch, false); + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); } if (ubatch.n_tokens == 0) { @@ -859,9 +861,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { if (dest_seq_id != -1) { // single sequence - seq_rm(dest_seq_id, -1, -1); + if (cell_count == 0) { + return true; + } + llama_batch_allocr balloc(hparams.n_pos_per_embd()); llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4c2d481a41..a5fe5b749c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -114,6 +114,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; case LLM_TYPE_A13B: return "A13B"; + case LLM_TYPE_8B_A1B: return "8B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; @@ -310,7 +311,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara } // CPU: ACCEL -> GPU host -> CPU extra -> CPU -static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts) { +static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts, bool no_host) { buft_list_t buft_list; // add ACCEL buffer types @@ -331,11 +332,13 @@ static buft_list_t make_cpu_buft_list(const std::vector & de // generally, this will be done using the first device in the list // a better approach would be to handle this on a weight-by-weight basis using the offload_op // function of the device to determine if it would benefit from being stored in a host buffer - for (auto * dev : devices) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); - if (buft) { - buft_list.emplace_back(dev, buft); - break; + if (!no_host) { + for (auto * dev : devices) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + break; + } } } @@ -1215,12 +1218,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.set_swa_pattern(6); hparams.causal_attn = false; // embeddings do not use causal attention - hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_scale_train_swa = 1.0f; - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + + //applied only if model converted with --sentence-transformers-dense-modules + ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); + ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false); + + GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd"); + GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd"); switch (hparams.n_layer) { case 24: type = LLM_TYPE_0_3B; break; @@ -1993,14 +2005,29 @@ void llama_model::load_hparams(llama_model_loader & ml) { for (uint32_t il = 0; il < hparams.n_layer; ++il) { hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; } + hparams.n_layer_dense_lead = hparams.n_layer; switch (hparams.n_ff()) { case 4608: type = LLM_TYPE_350M; break; case 6912: type = LLM_TYPE_700M; break; case 8192: type = LLM_TYPE_1_2B; break; case 10752: type = LLM_TYPE_2_6B; break; - default: type = LLM_TYPE_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_LFM2MOE: + { + ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + } + + type = LLM_TYPE_8B_A1B; + } break; case LLM_ARCH_SMALLTHINKER: { const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); @@ -2083,7 +2110,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false"); // build a list of buffer types for the CPU and GPU devices - pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts); + pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); for (auto * dev : devices) { buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); // add CPU buffer types as a fallback @@ -3668,6 +3695,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } + // Dense linear weights + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED); + dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -5812,6 +5844,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); @@ -5823,11 +5856,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - // ffn is same for transformer and conv layers + + const bool is_moe_layer = i >= static_cast(hparams.n_layer_dense_lead); + + // ffn/moe is same for transformer and conv layers layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + if (is_moe_layer) { + GGML_ASSERT(n_expert && n_expert_used); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } else { // dense + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } // for operator_norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -6308,7 +6353,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); } - if (arch == LLM_ARCH_SMALLTHINKER) { + if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } @@ -18600,6 +18645,8 @@ struct llm_build_lfm2 : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + const bool is_moe_layer = il >= static_cast(hparams.n_layer_dense_lead); + auto * prev_cur = cur; cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "model.layers.{}.operator_norm", il); @@ -18614,7 +18661,16 @@ struct llm_build_lfm2 : public llm_graph_context { } cur = ggml_add(ctx0, prev_cur, cur); - cur = ggml_add(ctx0, cur, build_feed_forward(cur, il)); + + auto * ffn_norm_out = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(ffn_norm_out, "model.layers.{}.ffn_norm", il); + + ggml_tensor * ffn_out = is_moe_layer ? + build_moe_feed_forward(ffn_norm_out, il) : + build_dense_feed_forward(ffn_norm_out, il); + cb(ffn_norm_out, "model.layers.{}.ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_out); } cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1); @@ -18629,23 +18685,32 @@ struct llm_build_lfm2 : public llm_graph_context { ggml_build_forward_expand(gf, cur); } - ggml_tensor * build_feed_forward(ggml_tensor * cur, - int il) const { - cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "model.layers.{}.ffn_norm", il); + ggml_tensor * build_moe_feed_forward(ggml_tensor * cur, + int il) const { + return build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + static_cast(hparams.expert_gating_func), + il); + } + ggml_tensor * build_dense_feed_forward(ggml_tensor * cur, + int il) const { GGML_ASSERT(!model.layers[il].ffn_up_b); GGML_ASSERT(!model.layers[il].ffn_gate_b); GGML_ASSERT(!model.layers[il].ffn_down_b); - cur = build_ffn(cur, + return build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "model.layers.{}.feed_forward.w2", il); - - return cur; } ggml_tensor * build_attn_block(ggml_tensor * cur, @@ -19815,6 +19880,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: { llm = std::make_unique(*this, params); } break; @@ -19841,6 +19907,12 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + // if the gguf model was converted with --sentence-transformers-dense-modules + // there will be two additional dense projection layers + // dense linear projections are applied after pooling + // TODO: move reranking logic here and generalize + llm->build_dense_out(dense_2_out_layers, dense_3_out_layers); + return llm->res->get_gf(); } @@ -19865,6 +19937,7 @@ llama_model_params llama_model_default_params() { /*.use_mlock =*/ false, /*.check_tensors =*/ false, /*.use_extra_bufts =*/ true, + /*.no_host =*/ false, }; return result; @@ -20036,6 +20109,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_OPENAI_MOE: case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: case LLM_ARCH_SMALLTHINKER: case LLM_ARCH_GLM4_MOE: case LLM_ARCH_SEED_OSS: diff --git a/src/llama-model.h b/src/llama-model.h index eec564e70b..7f48662f28 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -107,6 +107,7 @@ enum llm_type { LLM_TYPE_17B_16E, // llama4 Scout LLM_TYPE_17B_128E, // llama4 Maverick LLM_TYPE_A13B, + LLM_TYPE_8B_A1B, // lfm2moe LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, LLM_TYPE_106B_A12B, // GLM-4.5-Air @@ -437,6 +438,12 @@ struct llama_model { std::vector layers; + //Dense linear projections for SentenceTransformers models like embeddinggemma + // For Sentence Transformers models structure see + // https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models + struct ggml_tensor * dense_2_out_layers = nullptr; + struct ggml_tensor * dense_3_out_layers = nullptr; + llama_model_params params; // gguf metadata diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 2186f827bf..55d2e355fd 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -2541,8 +2541,13 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_ if (n_non_eog == 0) { cur_p->size = 1; cur_p->data[0].id = ctx->vocab->token_eot(); + if (cur_p->data[0].id == LLAMA_TOKEN_NULL) { + cur_p->data[0].id = ctx->vocab->token_eos(); + } cur_p->data[0].logit = 1.0f; + GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL); + return; } diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index da938af03b..7fffd17149 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -347,6 +347,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_OLMO: case LLAMA_VOCAB_PRE_TYPE_JAIS: case LLAMA_VOCAB_PRE_TYPE_TRILLION: + case LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING: regex_exprs = { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }; @@ -1961,6 +1962,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "trillion") { pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION; clean_spaces = false; + } else if ( + tokenizer_pre == "granite-docling") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING; + clean_spaces = false; } else if ( tokenizer_pre == "bailingmoe" || tokenizer_pre == "llada-moe") { @@ -2166,6 +2171,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|end|>" || t.first == "" || t.first == "<|endoftext|>" + || t.first == "<|end_of_text|>" // granite || t.first == "" || t.first == "_" || t.first == "<|end▁of▁sentence|>" // DeepSeek diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 0d2f28c36c..5e468675e4 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -8,46 +8,47 @@ // pre-tokenization types enum llama_vocab_pre_type { - LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, - LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, - LLAMA_VOCAB_PRE_TYPE_FALCON = 4, - LLAMA_VOCAB_PRE_TYPE_MPT = 5, - LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, - LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, - LLAMA_VOCAB_PRE_TYPE_REFACT = 8, - LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, - LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, - LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, - LLAMA_VOCAB_PRE_TYPE_OLMO = 12, - LLAMA_VOCAB_PRE_TYPE_DBRX = 13, - LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, - LLAMA_VOCAB_PRE_TYPE_PORO = 15, - LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, - LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, - LLAMA_VOCAB_PRE_TYPE_VIKING = 18, - LLAMA_VOCAB_PRE_TYPE_JAIS = 19, - LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, - LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, - LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, - LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, - LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, - LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, - LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, - LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, - LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, - LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, - LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, - LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, - LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, - LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, - LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, - LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, - LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, + LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, + LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, + LLAMA_VOCAB_PRE_TYPE_FALCON = 4, + LLAMA_VOCAB_PRE_TYPE_MPT = 5, + LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, + LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, + LLAMA_VOCAB_PRE_TYPE_REFACT = 8, + LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, + LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, + LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, + LLAMA_VOCAB_PRE_TYPE_OLMO = 12, + LLAMA_VOCAB_PRE_TYPE_DBRX = 13, + LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, + LLAMA_VOCAB_PRE_TYPE_PORO = 15, + LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, + LLAMA_VOCAB_PRE_TYPE_VIKING = 18, + LLAMA_VOCAB_PRE_TYPE_JAIS = 19, + LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, + LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, + LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, + LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, + LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, + LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, + LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, + LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, + LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, + LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, + LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, + LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, + LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, + LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, }; struct LLM_KV; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c1e45972e5..2fa16b497a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -131,6 +131,50 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } } +// generate an F16 mask where certain blocks are randomly masked with -INF value +static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { + GGML_ASSERT(tensor->type == GGML_TYPE_F16); + + GGML_TENSOR_LOCALS( int32_t, ne, tensor, ne); + + std::vector data_f32(ne0*ne1*ne2*ne3); + std::vector data_f16(ne0*ne1*ne2*ne3); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dis(min, max); + + for (size_t i = 0; i < data_f32.size(); i++) { + data_f32[i] = dis(gen); + } + + // block size + const int blck0 = 128; + const int blck1 = 64; + + // number of INF blocks + const int n_inf_blocks = 0.1*(ne0*ne1*ne2*ne3)/(blck0*blck1); + + for (int b = 0; b < n_inf_blocks; b++) { + const int p3 = (rd() % ne3); + const int p2 = (rd() % ne2); + const int p1 = (rd() % ne1); + const int p0 = (rd() % ne0); + + for (int i1 = 0; i1 < blck1 && p1 + i1 < ne1; i1++) { + const int idx = p3*ne2*ne1*ne0 + p2*ne1*ne0 + (p1 + i1)*ne0 + p0; + + for (int i0 = 0; i0 < blck0 && p0 + i0 < ne0; i0++) { + data_f32[idx + i0] = -INFINITY; + } + } + } + + ggml_fp32_to_fp16_row(data_f32.data(), data_f16.data(), ne0*ne1*ne2*ne3); + + ggml_backend_tensor_set(tensor, data_f16.data(), 0, data_f16.size()*sizeof(ggml_fp16_t)); +} + static std::vector tensor_to_float(const ggml_tensor * t) { std::vector tv; tv.reserve(ggml_nelements(t)); @@ -5111,6 +5155,8 @@ struct test_flash_attn_ext : public test_case { if (strcmp(t->name, "s") == 0) { // make the sink values more noticable in order to trigger a test failure when the implementation is wrong init_tensor_uniform(t, -10.0f, 10.0f); + } else if (strcmp(t->name, "m") == 0) { + init_tensor_kq_mask(t); } else { init_tensor_uniform(t); } @@ -6727,7 +6773,8 @@ static std::vector> make_test_cases_eval() { if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes for (int nr2 : { 1, 4, 16 }) { if (nr2 == 16 && hsk != 128) continue; - for (int kv : { 512, 1024, }) { + //for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) { + for (int kv : { 113, 512, 1024, }) { if (nr2 != 1 && kv != 512) continue; for (int nb : { 1, 3, 32, 35, }) { for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { diff --git a/tests/test-chat-parser.cpp b/tests/test-chat-parser.cpp index 547ebb4871..0b275befb8 100644 --- a/tests/test-chat-parser.cpp +++ b/tests/test-chat-parser.cpp @@ -106,6 +106,34 @@ static void test_reasoning() { assert_equals("Cogito", builder.result().content); assert_equals("Ergo sum", builder.consume_rest()); } + { + const std::string variant("content_only_inline_think"); + common_chat_syntax syntax = { + /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ false, + /* .parse_tool_calls = */ false, + }; + const std::string input = "PenseBonjour"; + auto msg = common_chat_parse(input, false, syntax); + assert_equals(variant, std::string("Pense"), msg.reasoning_content); + assert_equals(variant, std::string("Bonjour"), msg.content); + } + { + const std::string variant("llama_3_inline_think"); + common_chat_syntax syntax = { + /* .format = */ COMMON_CHAT_FORMAT_LLAMA_3_X, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ false, + /* .parse_tool_calls = */ false, + }; + const std::string input = "PlanRéponse"; + auto msg = common_chat_parse(input, false, syntax); + assert_equals(variant, std::string("Plan"), msg.reasoning_content); + assert_equals(variant, std::string("Réponse"), msg.content); + } // Test DeepSeek V3.1 parsing - reasoning content followed by "" and then regular content { common_chat_syntax syntax = { diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index b863367db6..a5382ae3a3 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -214,7 +214,7 @@ int main(void) { { /* .name= */ "ibm-granite/granite-3.0-8b-instruct", /* .template_str= */ "{%- if tools %}\n {{- '<|start_of_role|>available_tools<|end_of_role|>\n' }}\n {%- for tool in tools %}\n {{- tool | tojson(indent=4) }}\n {%- if not loop.last %}\n {{- '\n\n' }}\n {%- endif %}\n {%- endfor %}\n {{- '<|end_of_text|>\n' }}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'user' %}\n {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'tool_response' %}\n {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}", - /* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n", + /* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>", /* .expected_output_jinja= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>", }, { diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 275ba367c0..0de07b9811 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -168,7 +168,7 @@ static std::vector parse_devices_arg(const std::string & val return devices; } -static std::vector register_rpc_device_list(const std::string & servers) { +static void register_rpc_server_list(const std::string & servers) { auto rpc_servers = string_split(servers, ','); if (rpc_servers.empty()) { throw std::invalid_argument("no RPC servers specified"); @@ -179,36 +179,15 @@ static std::vector register_rpc_device_list(const std::strin throw std::invalid_argument("failed to find RPC backend"); } - using add_rpc_device_fn = ggml_backend_dev_t (*)(const char * endpoint); - auto * ggml_backend_rpc_add_device_fn = (add_rpc_device_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); - if (!ggml_backend_rpc_add_device_fn) { - throw std::invalid_argument("failed to find RPC device add function"); + using add_rpc_server_fn = ggml_backend_reg_t (*)(const char * endpoint); + auto * ggml_backend_rpc_add_server_fn = (add_rpc_server_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server"); + if (!ggml_backend_rpc_add_server_fn) { + throw std::invalid_argument("failed to find RPC add server function"); } - - static std::unordered_set registered; - std::vector devices; for (const auto & server : rpc_servers) { - ggml_backend_dev_t dev = nullptr; - - std::string name = string_format("RPC[%s]", server.c_str()); - - if (registered.find(server) != registered.end()) { - dev = ggml_backend_dev_by_name(name.c_str()); - } - - if (!dev) { - dev = ggml_backend_rpc_add_device_fn(server.c_str()); - if (!dev) { - throw std::invalid_argument(string_format("failed to add RPC device for server '%s'", server.c_str())); - } - ggml_backend_device_register(dev); - registered.insert(server); - } - - devices.push_back(dev); + auto reg = ggml_backend_rpc_add_server_fn(server.c_str()); + ggml_backend_register(reg); } - - return devices; } static std::string devices_to_string(const std::vector & devices) { @@ -357,6 +336,7 @@ struct cmd_params { std::vector use_mmap; std::vector embeddings; std::vector no_op_offload; + std::vector no_host; ggml_numa_strategy numa; int reps; ggml_sched_priority prio; @@ -394,6 +374,7 @@ static const cmd_params cmd_params_defaults = { /* use_mmap */ { true }, /* embeddings */ { false }, /* no_op_offload */ { false }, + /* no_host */ { false }, /* numa */ GGML_NUMA_STRATEGY_DISABLED, /* reps */ 5, /* prio */ GGML_SCHED_PRIO_NORMAL, @@ -474,6 +455,8 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -ot --override-tensor =;...\n"); printf(" (default: disabled)\n"); printf(" -nopo, --no-op-offload <0|1> (default: 0)\n"); + printf(" --no-host <0|1> (default: %s)\n", + join(cmd_params_defaults.no_host, ",").c_str()); printf("\n"); printf( "Multiple values can be given for each parameter by separating them with ','\n" @@ -714,7 +697,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } try { - register_rpc_device_list(argv[i]); + register_rpc_server_list(argv[i]); } catch (const std::exception & e) { fprintf(stderr, "error: %s\n", e.what()); invalid_param = true; @@ -803,6 +786,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split(argv[i], split_delim); params.no_op_offload.insert(params.no_op_offload.end(), p.begin(), p.end()); + } else if (arg == "--no-host") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.no_host.insert(params.no_host.end(), p.begin(), p.end()); } else if (arg == "-ts" || arg == "--tensor-split") { if (++i >= argc) { invalid_param = true; @@ -1024,6 +1014,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.no_op_offload.empty()) { params.no_op_offload = cmd_params_defaults.no_op_offload; } + if (params.no_host.empty()) { + params.no_host = cmd_params_defaults.no_host; + } if (params.n_threads.empty()) { params.n_threads = cmd_params_defaults.n_threads; } @@ -1065,6 +1058,7 @@ struct cmd_params_instance { bool use_mmap; bool embeddings; bool no_op_offload; + bool no_host; llama_model_params to_llama_mparams() const { llama_model_params mparams = llama_model_default_params(); @@ -1077,6 +1071,7 @@ struct cmd_params_instance { mparams.main_gpu = main_gpu; mparams.tensor_split = tensor_split.data(); mparams.use_mmap = use_mmap; + mparams.no_host = no_host; if (n_cpu_moe <= 0) { if (tensor_buft_overrides.empty()) { @@ -1122,6 +1117,7 @@ struct cmd_params_instance { split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap && tensor_split == other.tensor_split && devices == other.devices && + no_host == other.no_host && vec_tensor_buft_override_equal(tensor_buft_overrides, other.tensor_buft_overrides); } @@ -1157,6 +1153,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & ts : params.tensor_split) for (const auto & ot : params.tensor_buft_overrides) for (const auto & mmp : params.use_mmap) + for (const auto & noh : params.no_host) for (const auto & embd : params.embeddings) for (const auto & nopo : params.no_op_offload) for (const auto & nb : params.n_batch) @@ -1199,6 +1196,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .no_op_offload= */ nopo, + /* .no_host = */ noh, }; instances.push_back(instance); } @@ -1232,6 +1230,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .no_op_offload= */ nopo, + /* .no_host = */ noh, }; instances.push_back(instance); } @@ -1265,6 +1264,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .no_op_offload= */ nopo, + /* .no_host = */ noh, }; instances.push_back(instance); } @@ -1303,6 +1303,7 @@ struct test { bool use_mmap; bool embeddings; bool no_op_offload; + bool no_host; int n_prompt; int n_gen; int n_depth; @@ -1339,6 +1340,7 @@ struct test { use_mmap = inst.use_mmap; embeddings = inst.embeddings; no_op_offload = inst.no_op_offload; + no_host = inst.no_host; n_prompt = inst.n_prompt; n_gen = inst.n_gen; n_depth = inst.n_depth; @@ -1368,13 +1370,23 @@ struct test { static std::string get_backend() { std::vector backends; + bool rpc_used = false; for (size_t i = 0; i < ggml_backend_reg_count(); i++) { auto * reg = ggml_backend_reg_get(i); std::string name = ggml_backend_reg_name(reg); - if (name != "CPU") { - backends.push_back(ggml_backend_reg_name(reg)); + if (string_starts_with(name, "RPC")) { + if (ggml_backend_reg_dev_count(reg) > 0) { + rpc_used = true; + } + } else { + if (name != "CPU") { + backends.push_back(ggml_backend_reg_name(reg)); + } } } + if (rpc_used) { + backends.push_back("RPC"); + } return backends.empty() ? "CPU" : join(backends, ","); } @@ -1386,8 +1398,8 @@ struct test { "type_k", "type_v", "n_gpu_layers", "n_cpu_moe", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "devices", "tensor_split", "tensor_buft_overrides", "use_mmap", "embeddings", "no_op_offload", - "n_prompt", "n_gen", "n_depth", "test_time", "avg_ns", - "stddev_ns", "avg_ts", "stddev_ts" + "no_host", "n_prompt", "n_gen", "n_depth", "test_time", + "avg_ns", "stddev_ns", "avg_ts", "stddev_ts" }; return fields; } @@ -1402,7 +1414,7 @@ struct test { return INT; } if (field == "f16_kv" || field == "no_kv_offload" || field == "cpu_strict" || field == "flash_attn" || - field == "use_mmap" || field == "embeddings") { + field == "use_mmap" || field == "embeddings" || field == "no_host") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -1477,6 +1489,7 @@ struct test { std::to_string(use_mmap), std::to_string(embeddings), std::to_string(no_op_offload), + std::to_string(no_host), std::to_string(n_prompt), std::to_string(n_gen), std::to_string(n_depth), @@ -1665,6 +1678,9 @@ struct markdown_printer : public printer { if (field == "no_op_offload") { return 4; } + if (field == "no_host") { + return 4; + } int width = std::max((int) field.length(), 10); @@ -1699,6 +1715,9 @@ struct markdown_printer : public printer { if (field == "no_op_offload") { return "nopo"; } + if (field == "no_host") { + return "noh"; + } if (field == "devices") { return "dev"; } @@ -1779,6 +1798,9 @@ struct markdown_printer : public printer { if (params.no_op_offload.size() > 1 || params.no_op_offload != cmd_params_defaults.no_op_offload) { fields.emplace_back("no_op_offload"); } + if (params.no_host.size() > 1 || params.no_host != cmd_params_defaults.no_host) { + fields.emplace_back("no_host"); + } fields.emplace_back("test"); fields.emplace_back("t/s"); diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 664b0c9ac6..7a7523851c 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -31,6 +31,7 @@ // vision-specific #define KEY_IMAGE_SIZE "clip.vision.image_size" +#define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size" #define KEY_PATCH_SIZE "clip.vision.patch_size" #define KEY_IMAGE_MEAN "clip.vision.image_mean" #define KEY_IMAGE_STD "clip.vision.image_std" diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 210ecc883f..98e68af27a 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -170,7 +170,9 @@ struct clip_hparams { int32_t projection_dim; int32_t n_head; int32_t n_layer; - int32_t proj_scale_factor = 0; // idefics3 + // idefics3 + int32_t preproc_image_size = 0; + int32_t proj_scale_factor = 0; float image_mean[3]; float image_std[3]; @@ -2250,6 +2252,7 @@ struct clip_model_loader { if (is_vision) { get_u32(KEY_IMAGE_SIZE, hparams.image_size); + get_u32(KEY_PREPROC_IMAGE_SIZE, hparams.preproc_image_size, false); get_u32(KEY_PATCH_SIZE, hparams.patch_size); get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy @@ -3551,10 +3554,51 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // res_imgs->data[0] = *res; res_imgs->entries.push_back(std::move(img_f32)); return true; - } - else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE + } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) { + // The refined size has two steps: + // 1. Resize w/ aspect-ratio preserving such that the longer side is + // the preprocessor longest size + // 2. Resize w/out preserving aspect ratio such that both sides are + // multiples of image_size (always rounding up) + // + // CITE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics3/image_processing_idefics3.py#L737 + const clip_image_size refined_size = image_manipulation::calc_size_preserved_ratio( + original_size, params.image_size, params.preproc_image_size); + + llava_uhd::slice_instructions instructions; + instructions.overview_size = clip_image_size{params.image_size, params.image_size}; + instructions.refined_size = refined_size; + instructions.grid_size = clip_image_size{ + static_cast(std::ceil(static_cast(refined_size.width) / params.image_size)), + static_cast(std::ceil(static_cast(refined_size.height) / params.image_size)), + }; + for (int y = 0; y < refined_size.height; y += params.image_size) { + for (int x = 0; x < refined_size.width; x += params.image_size) { + instructions.slices.push_back(llava_uhd::slice_coordinates{ + /* x */x, + /* y */y, + /* size */clip_image_size{ + std::min(params.image_size, refined_size.width - x), + std::min(params.image_size, refined_size.height - y) + } + }); + } + } + auto imgs = llava_uhd::slice_image(img, instructions); + + // cast and normalize to f32 + for (size_t i = 0; i < imgs.size(); ++i) { + // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp"); + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + } + + res_imgs->grid_x = instructions.grid_size.width; + res_imgs->grid_y = instructions.grid_size.height; + return true; + } else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE || ctx->proj_type() == PROJECTOR_TYPE_GEMMA3 - || ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type() == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution ) { clip_image_u8 resized_image; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index cd022c5e24..4d487581ae 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -76,7 +76,7 @@ enum mtmd_slice_tmpl { MTMD_SLICE_TMPL_MINICPMV_2_5, MTMD_SLICE_TMPL_MINICPMV_2_6, MTMD_SLICE_TMPL_LLAMA4, - // TODO @ngxson : add support for idefics (SmolVLM) + MTMD_SLICE_TMPL_IDEFICS3, }; const char * mtmd_default_marker() { @@ -114,19 +114,22 @@ struct mtmd_context { // for llava-uhd style models, we need special tokens in-between slices // minicpmv calls them "slices", llama 4 calls them "tiles" mtmd_slice_tmpl slice_tmpl = MTMD_SLICE_TMPL_NONE; - llama_token tok_ov_img_start = LLAMA_TOKEN_NULL; // overview image - llama_token tok_ov_img_end = LLAMA_TOKEN_NULL; // overview image - llama_token tok_slices_start = LLAMA_TOKEN_NULL; // start of all slices - llama_token tok_slices_end = LLAMA_TOKEN_NULL; // end of all slices - llama_token tok_sli_img_start = LLAMA_TOKEN_NULL; // single slice start - llama_token tok_sli_img_end = LLAMA_TOKEN_NULL; // single slice end - llama_token tok_sli_img_mid = LLAMA_TOKEN_NULL; // between 2 slices - llama_token tok_row_end = LLAMA_TOKEN_NULL; // end of row + std::vector tok_ov_img_start; // overview image + std::vector tok_ov_img_end; // overview image + std::vector tok_slices_start; // start of all slices + std::vector tok_slices_end; // end of all slices + std::vector tok_sli_img_start; // single slice start + std::vector tok_sli_img_end; // single slice end + std::vector tok_sli_img_mid; // between 2 slices + std::vector tok_row_end; // end of row bool tok_row_end_trail = false; bool ov_img_first = false; bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE + // string template for slice image delimiters with row/col (idefics3) + std::string sli_img_start_tmpl; + // for whisper, we pre-calculate the mel filter bank whisper_preprocessor::whisper_filters w_filters; @@ -197,13 +200,13 @@ struct mtmd_context { // minicpmv 2.5 format: // (overview) (slice) (slice) \n ... slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_5; - tok_ov_img_start = lookup_token(""); - tok_ov_img_end = lookup_token(""); - tok_slices_start = lookup_token(""); - tok_slices_end = lookup_token(""); + tok_ov_img_start = {lookup_token("")}; + tok_ov_img_end = {lookup_token("")}; + tok_slices_start = {lookup_token("")}; + tok_slices_end = {lookup_token("")}; tok_sli_img_start = tok_ov_img_start; tok_sli_img_end = tok_ov_img_end; - tok_row_end = lookup_token("\n"); + tok_row_end = {lookup_token("\n")}; tok_row_end_trail = false; // no trailing end-of-row token ov_img_first = true; @@ -211,11 +214,11 @@ struct mtmd_context { // minicpmv 2.6 format: // (overview) (slice) (slice) \n ... slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_6; - tok_ov_img_start = lookup_token(""); - tok_ov_img_end = lookup_token(""); - tok_sli_img_start = lookup_token(""); - tok_sli_img_end = lookup_token(""); - tok_row_end = lookup_token("\n"); + tok_ov_img_start = {lookup_token("")}; + tok_ov_img_end = {lookup_token("")}; + tok_sli_img_start = {lookup_token("")}; + tok_sli_img_end = {lookup_token("")}; + tok_row_end = {lookup_token("\n")}; tok_row_end_trail = false; // no trailing end-of-row token ov_img_first = true; @@ -230,9 +233,9 @@ struct mtmd_context { // <|image|> (overview) <-- overview image is last // <|image_end|> slice_tmpl = MTMD_SLICE_TMPL_LLAMA4; - tok_ov_img_start = lookup_token("<|image|>"); - tok_sli_img_mid = lookup_token("<|tile_x_separator|>"); - tok_row_end = lookup_token("<|tile_y_separator|>"); + tok_ov_img_start = {lookup_token("<|image|>")}; + tok_sli_img_mid = {lookup_token("<|tile_x_separator|>")}; + tok_row_end = {lookup_token("<|tile_y_separator|>")}; tok_row_end_trail = true; // add trailing end-of-row token ov_img_first = false; // overview image is last } @@ -245,8 +248,11 @@ struct mtmd_context { } else if (proj == PROJECTOR_TYPE_IDEFICS3) { // https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215 - img_beg = ""; - img_end = ""; + slice_tmpl = MTMD_SLICE_TMPL_IDEFICS3; + tok_ov_img_start = {lookup_token("\n\n"), lookup_token(""), lookup_token("")}; + tok_ov_img_end = {lookup_token("")}; + tok_row_end = {lookup_token("\n")}; + sli_img_start_tmpl = ""; } else if (proj == PROJECTOR_TYPE_PIXTRAL) { // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md @@ -504,6 +510,7 @@ struct mtmd_tokenizer { ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6 || ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4 + || ctx->slice_tmpl == MTMD_SLICE_TMPL_IDEFICS3 ) { const int n_col = batch_f32.grid_x; const int n_row = batch_f32.grid_y; @@ -517,53 +524,45 @@ struct mtmd_tokenizer { // add overview image (first) if (ctx->ov_img_first) { - if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_ov_img_start}); - } + add_text(ctx->tok_ov_img_start); cur.entries.emplace_back(std::move(ov_chunk)); - if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_ov_img_end}); - } + add_text(ctx->tok_ov_img_end); } // add slices (or tiles) if (!chunks.empty()) { GGML_ASSERT((int)chunks.size() == n_row * n_col); - if (ctx->tok_slices_start != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_slices_start}); - } + add_text(ctx->tok_slices_start); for (int y = 0; y < n_row; y++) { for (int x = 0; x < n_col; x++) { const bool is_last_in_row = (x == n_col - 1); - if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_sli_img_start}); + if (!ctx->tok_sli_img_start.empty()) { + add_text(ctx->tok_sli_img_start); + } else if (!ctx->sli_img_start_tmpl.empty()) { + // If using a template to preceed a slice image + const size_t sz = std::snprintf(nullptr, 0, ctx->sli_img_start_tmpl.c_str(), y+1, x+1) + 1; + std::unique_ptr buf(new char[sz]); + std::snprintf(buf.get(), sz, ctx->sli_img_start_tmpl.c_str(), y+1, x+1); + add_text(std::string(buf.get(), buf.get() + sz - 1), true); } cur.entries.emplace_back(std::move(chunks[y * n_col + x])); - if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_sli_img_end}); - } - if (!is_last_in_row && ctx->tok_sli_img_mid != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_sli_img_mid}); + add_text(ctx->tok_sli_img_end); + if (!is_last_in_row) { + add_text(ctx->tok_sli_img_mid); } } - if ((y != n_row - 1 || ctx->tok_row_end_trail) && ctx->tok_row_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_row_end}); + if ((y != n_row - 1 || ctx->tok_row_end_trail)) { + add_text(ctx->tok_row_end); } } - if (ctx->tok_slices_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_slices_end}); - } + add_text(ctx->tok_slices_end); } // add overview image (last) if (!ctx->ov_img_first) { - if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_ov_img_start}); - } + add_text(ctx->tok_ov_img_start); cur.entries.emplace_back(std::move(ov_chunk)); - if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_ov_img_end}); - } + add_text(ctx->tok_ov_img_end); } } else { @@ -780,7 +779,9 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); bool ok = false; - if (clip_is_llava(ctx_clip) || clip_is_minicpmv(ctx_clip) || clip_is_glm(ctx_clip)) { + if (clip_is_llava(ctx_clip) + || clip_is_minicpmv(ctx_clip) + || clip_is_glm(ctx_clip)) { // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode() const auto & entries = image_tokens->batch_f32.entries; for (size_t i = 0; i < entries.size(); i++) { diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index c64be03630..dbdf7656a6 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -69,6 +69,7 @@ add_test_vision "ggml-org/InternVL2_5-1B-GGUF:Q8_0" add_test_vision "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0" add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" add_test_vision "ggml-org/LFM2-VL-450M-GGUF:Q8_0" +add_test_vision "ggml-org/granite-docling-258M-GGUF:Q8_0" add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0" add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" diff --git a/tools/rpc/README.md b/tools/rpc/README.md index 561f19fda6..afbb302f4b 100644 --- a/tools/rpc/README.md +++ b/tools/rpc/README.md @@ -4,7 +4,7 @@ > This example and the RPC backend are currently in a proof-of-concept development stage. As such, the functionality is fragile and > insecure. **Never run the RPC server on an open network or in a sensitive environment!** -The `rpc-server` allows running `ggml` backend on a remote host. +The `rpc-server` allows exposing `ggml` devices on a remote host. The RPC backend communicates with one or several instances of `rpc-server` and offloads computations to them. This can be used for distributed LLM inference with `llama.cpp` in the following way: @@ -14,28 +14,34 @@ flowchart TD rpcb<-->|TCP|srvb rpcb<-.->|TCP|srvn subgraph hostn[Host N] - srvn[rpc-server]<-.->backend3["Backend (CUDA,Metal,etc.)"] + srvn[rpc-server]<-.->dev4["CUDA0"] + srvn[rpc-server]<-.->dev5["CPU"] end subgraph hostb[Host B] - srvb[rpc-server]<-->backend2["Backend (CUDA,Metal,etc.)"] + srvb[rpc-server]<-->dev3["Metal"] end subgraph hosta[Host A] - srva[rpc-server]<-->backend["Backend (CUDA,Metal,etc.)"] + srva[rpc-server]<-->dev["CUDA0"] + srva[rpc-server]<-->dev2["CUDA1"] end subgraph host[Main Host] - local["Backend (CUDA,Metal,etc.)"]<-->ggml[llama-cli] + local["Local devices"]<-->ggml[llama-cli] ggml[llama-cli]<-->rpcb[RPC backend] end style hostn stroke:#66,stroke-width:2px,stroke-dasharray: 5 5 + classDef devcls fill:#5B9BD5 + class local,dev,dev2,dev3,dev4,dev5 devcls ``` -Each host can run a different backend, e.g. one with CUDA and another with Metal. -You can also run multiple `rpc-server` instances on the same host, each with a different backend. +By default, `rpc-server` exposes all available accelerator devices on the host. +If there are no accelerators, it exposes a single `CPU` device. ## Usage -On each host, build the corresponding backend with `cmake` and add `-DGGML_RPC=ON` to the build options. -For example, to build the CUDA backend with RPC support: +### Remote hosts + +On each remote host, build the backends for each accelerator by adding `-DGGML_RPC=ON` to the build options. +For example, to build the `rpc-server` with support for CUDA accelerators: ```bash mkdir build-rpc-cuda @@ -44,33 +50,38 @@ cmake .. -DGGML_CUDA=ON -DGGML_RPC=ON cmake --build . --config Release ``` -Then, start the `rpc-server` with the backend: +When started, the `rpc-server` will detect and expose all available `CUDA` devices: ```bash -$ bin/rpc-server -p 50052 -create_backend: using CUDA backend -ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no -ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes +$ bin/rpc-server +ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no +ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no ggml_cuda_init: found 1 CUDA devices: - Device 0: NVIDIA T1200 Laptop GPU, compute capability 7.5, VMM: yes -Starting RPC server on 0.0.0.0:50052 + Device 0: NVIDIA GeForce RTX 5090, compute capability 12.0, VMM: yes +Starting RPC server v3.0.0 + endpoint : 127.0.0.1:50052 + local cache : n/a +Devices: + CUDA0: NVIDIA GeForce RTX 5090 (32109 MiB, 31588 MiB free) ``` -When using the CUDA backend, you can specify the device with the `CUDA_VISIBLE_DEVICES` environment variable, e.g.: +You can control the set of exposed CUDA devices with the `CUDA_VISIBLE_DEVICES` environment variable or the `--device` command line option. The following two commands have the same effect: ```bash $ CUDA_VISIBLE_DEVICES=0 bin/rpc-server -p 50052 +$ bin/rpc-server --device CUDA0 -p 50052 ``` -This way you can run multiple `rpc-server` instances on the same host, each with a different CUDA device. +### Main host -On the main host build `llama.cpp` for the local backend and add `-DGGML_RPC=ON` to the build options. -Finally, when running `llama-cli`, use the `--rpc` option to specify the host and port of each `rpc-server`: +On the main host build `llama.cpp` with the backends for the local devices and add `-DGGML_RPC=ON` to the build options. +Finally, when running `llama-cli` or `llama-server`, use the `--rpc` option to specify the host and port of each `rpc-server`: ```bash -$ bin/llama-cli -m ../models/tinyllama-1b/ggml-model-f16.gguf -p "Hello, my name is" --repeat-penalty 1.0 -n 64 --rpc 192.168.88.10:50052,192.168.88.11:50052 -ngl 99 +$ llama-cli -hf ggml-org/gemma-3-1b-it-GGUF -ngl 99 --rpc 192.168.88.10:50052,192.168.88.11:50052 ``` -This way you can offload model layers to both local and remote devices. +By default, llama.cpp distributes model weights and the KV cache across all available devices -- both local and remote -- in proportion to each device's available memory. +You can override this behavior with the `--tensor-split` option and set custom proportions when splitting tensor data across devices. ### Local cache @@ -83,3 +94,11 @@ $ bin/rpc-server -c ``` By default, the cache is stored in the `$HOME/.cache/llama.cpp/rpc` directory and can be controlled via the `LLAMA_CACHE` environment variable. + +### Troubleshooting + +Use the `GGML_RPC_DEBUG` environment variable to enable debug messages from `rpc-server`: +```bash +$ GGML_RPC_DEBUG=1 bin/rpc-server +``` + diff --git a/tools/rpc/rpc-server.cpp b/tools/rpc/rpc-server.cpp index dc8e077f34..0885156127 100644 --- a/tools/rpc/rpc-server.cpp +++ b/tools/rpc/rpc-server.cpp @@ -22,6 +22,7 @@ #include #include #include +#include namespace fs = std::filesystem; @@ -131,24 +132,24 @@ static std::string fs_get_cache_directory() { } struct rpc_server_params { - std::string host = "127.0.0.1"; - int port = 50052; - size_t backend_mem = 0; - bool use_cache = false; - int n_threads = std::max(1U, std::thread::hardware_concurrency()/2); - std::string device; + std::string host = "127.0.0.1"; + int port = 50052; + bool use_cache = false; + int n_threads = std::max(1U, std::thread::hardware_concurrency()/2); + std::vector devices; + std::vector dev_mem; }; static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) { fprintf(stderr, "Usage: %s [options]\n\n", argv[0]); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help show this help message and exit\n"); - fprintf(stderr, " -t, --threads number of threads for the CPU backend (default: %d)\n", params.n_threads); - fprintf(stderr, " -d DEV, --device device to use\n"); - fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str()); - fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port); - fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n"); - fprintf(stderr, " -c, --cache enable local file cache\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -t, --threads N number of threads for the CPU device (default: %d)\n", params.n_threads); + fprintf(stderr, " -d, --device comma-separated list of devices\n"); + fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str()); + fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port); + fprintf(stderr, " -m, --mem memory size for each device (in MB)\n"); + fprintf(stderr, " -c, --cache enable local file cache\n"); fprintf(stderr, "\n"); } @@ -174,17 +175,17 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & if (++i >= argc) { return false; } - params.device = argv[i]; - if (ggml_backend_dev_by_name(params.device.c_str()) == nullptr) { - fprintf(stderr, "error: unknown device: %s\n", params.device.c_str()); - fprintf(stderr, "available devices:\n"); - for (size_t i = 0; i < ggml_backend_dev_count(); i++) { - auto * dev = ggml_backend_dev_get(i); - size_t free, total; - ggml_backend_dev_memory(dev, &free, &total); - printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024); + const std::regex regex{ R"([,/]+)" }; + std::string dev_str = argv[i]; + std::sregex_token_iterator iter(dev_str.begin(), dev_str.end(), regex, -1); + std::sregex_token_iterator end; + for ( ; iter != end; ++iter) { + try { + params.devices.push_back(*iter); + } catch (const std::exception & ) { + fprintf(stderr, "error: invalid device: %s\n", iter->str().c_str()); + return false; } - return false; } } else if (arg == "-p" || arg == "--port") { if (++i >= argc) { @@ -200,7 +201,19 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & if (++i >= argc) { return false; } - params.backend_mem = std::stoul(argv[i]) * 1024 * 1024; + const std::regex regex{ R"([,/]+)" }; + std::string mem_str = argv[i]; + std::sregex_token_iterator iter(mem_str.begin(), mem_str.end(), regex, -1); + std::sregex_token_iterator end; + for ( ; iter != end; ++iter) { + try { + size_t mem = std::stoul(*iter) * 1024 * 1024; + params.dev_mem.push_back(mem); + } catch (const std::exception & ) { + fprintf(stderr, "error: invalid memory size: %s\n", iter->str().c_str()); + return false; + } + } } else if (arg == "-h" || arg == "--help") { print_usage(argc, argv, params); exit(0); @@ -213,45 +226,46 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & return true; } -static ggml_backend_t create_backend(const rpc_server_params & params) { - ggml_backend_t backend = nullptr; +static std::vector get_devices(const rpc_server_params & params) { + std::vector devices; + if (!params.devices.empty()) { + for (auto device : params.devices) { + ggml_backend_dev_t dev = ggml_backend_dev_by_name(device.c_str()); + if (dev) { + devices.push_back(dev); + } else { + fprintf(stderr, "error: unknown device: %s\n", device.c_str()); + fprintf(stderr, "available devices:\n"); + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + auto * dev = ggml_backend_dev_get(i); + size_t free, total; + ggml_backend_dev_memory(dev, &free, &total); + printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024); + } + return {}; + } + } + } - if (!params.device.empty()) { - ggml_backend_dev_t dev = ggml_backend_dev_by_name(params.device.c_str()); + // Try non-CPU devices first + if (devices.empty()) { + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) { + devices.push_back(dev); + } + } + } + + // If there are no accelerators, fallback to CPU device + if (devices.empty()) { + ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); if (dev) { - backend = ggml_backend_dev_init(dev, nullptr); - if (!backend) { - fprintf(stderr, "Failed to create backend for device %s\n", params.device.c_str()); - return nullptr; - } + devices.push_back(dev); } } - if (!backend) { - backend = ggml_backend_init_best(); - } - - if (backend) { - fprintf(stderr, "%s: using %s backend\n", __func__, ggml_backend_name(backend)); - - // set the number of threads - ggml_backend_dev_t dev = ggml_backend_get_device(backend); - ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; - if (reg) { - auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); - if (ggml_backend_set_n_threads_fn) { - ggml_backend_set_n_threads_fn(backend, params.n_threads); - } - } - } - - return backend; -} - -static void get_backend_memory(ggml_backend_t backend, size_t * free_mem, size_t * total_mem) { - ggml_backend_dev_t dev = ggml_backend_get_device(backend); - GGML_ASSERT(dev != nullptr); - ggml_backend_dev_memory(dev, free_mem, total_mem); + return devices; } int main(int argc, char * argv[]) { @@ -273,18 +287,23 @@ int main(int argc, char * argv[]) { fprintf(stderr, "\n"); } - ggml_backend_t backend = create_backend(params); - if (!backend) { - fprintf(stderr, "Failed to create backend\n"); + auto devices = get_devices(params); + if (devices.empty()) { + fprintf(stderr, "No devices found\n"); return 1; } std::string endpoint = params.host + ":" + std::to_string(params.port); - size_t free_mem, total_mem; - if (params.backend_mem > 0) { - free_mem = params.backend_mem; - total_mem = params.backend_mem; - } else { - get_backend_memory(backend, &free_mem, &total_mem); + std::vector free_mem, total_mem; + for (size_t i = 0; i < devices.size(); i++) { + if (i < params.dev_mem.size()) { + free_mem.push_back(params.dev_mem[i]); + total_mem.push_back(params.dev_mem[i]); + } else { + size_t free, total; + ggml_backend_dev_memory(devices[i], &free, &total); + free_mem.push_back(free); + total_mem.push_back(total); + } } const char * cache_dir = nullptr; std::string cache_dir_str; @@ -309,8 +328,7 @@ int main(int argc, char * argv[]) { return 1; } - start_server_fn(backend, endpoint.c_str(), cache_dir, free_mem, total_mem); - - ggml_backend_free(backend); + start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(), + devices.data(), free_mem.data(), total_mem.data()); return 0; } diff --git a/tools/server/README.md b/tools/server/README.md index 9f7ab229f7..f5ab9236d5 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -190,7 +190,7 @@ The project is under active development, and we are [looking for feedback and co | `--no-slots` | disables slots monitoring endpoint
(env: LLAMA_ARG_NO_ENDPOINT_SLOTS) | | `--slot-save-path PATH` | path to save slot kv cache (default: disabled) | | `--jinja` | use jinja template for chat (default: disabled)
(env: LLAMA_ARG_JINJA) | -| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:
- none: leaves thoughts unparsed in `message.content`
- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)
(default: auto)
(env: LLAMA_ARG_THINK) | +| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:
- none: leaves thoughts unparsed in `message.content`
- deepseek: puts thoughts in `message.reasoning_content`
- deepseek-legacy: keeps `` tags in `message.content` while also populating `message.reasoning_content`
(default: deepseek)
(env: LLAMA_ARG_THINK) | | `--reasoning-budget N` | controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)
(env: LLAMA_ARG_THINK_BUDGET) | | `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | | `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | @@ -393,7 +393,7 @@ node index.js ### GET `/health`: Returns health check result -This endpoint is public (no API key check). +This endpoint is public (no API key check). `/v1/health` also works. **Response format** @@ -1045,6 +1045,7 @@ Available metrics: - `llamacpp:kv_cache_tokens`: KV-cache tokens. - `llamacpp:requests_processing`: Number of requests processing. - `llamacpp:requests_deferred`: Number of requests deferred. +- `llamacpp:n_past_max`: High watermark of the context size observed. ### POST `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file. diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 2801319c98..5026edcebe 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a21147613d..60326e8e50 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -9,7 +9,6 @@ #include "sampling.h" #include "speculative.h" #include "mtmd.h" -#include "mtmd-helper.h" // mime type for sending response #define MIMETYPE_JSON "application/json; charset=utf-8" @@ -158,7 +157,6 @@ struct slot_params { if (only_metrics) { return json { - {"n_predict", n_predict}, // Server configured n_predict {"seed", sampling.seed}, {"temperature", sampling.temp}, {"dynatemp_range", sampling.dynatemp_range}, @@ -181,7 +179,8 @@ struct slot_params { {"mirostat", sampling.mirostat}, {"mirostat_tau", sampling.mirostat_tau}, {"mirostat_eta", sampling.mirostat_eta}, - {"max_tokens", n_predict}, // User configured n_predict + {"max_tokens", n_predict}, + {"n_predict", n_predict}, // TODO: deduplicate? {"n_keep", n_keep}, {"n_discard", n_discard}, {"ignore_eos", sampling.ignore_eos}, @@ -209,7 +208,6 @@ struct slot_params { } return json { - {"n_predict", n_predict}, // Server configured n_predict {"seed", sampling.seed}, {"temperature", sampling.temp}, {"dynatemp_range", sampling.dynatemp_range}, @@ -234,7 +232,8 @@ struct slot_params { {"mirostat_tau", sampling.mirostat_tau}, {"mirostat_eta", sampling.mirostat_eta}, {"stop", antiprompt}, - {"max_tokens", n_predict}, // User configured n_predict + {"max_tokens", n_predict}, + {"n_predict", n_predict}, // TODO: deduplicate? {"n_keep", n_keep}, {"n_discard", n_discard}, {"ignore_eos", sampling.ignore_eos}, @@ -265,15 +264,15 @@ struct server_task { int id = -1; // to be filled by server_queue int index = -1; // used when there are multiple prompts (batch request) - server_task_type type; - // used by SERVER_TASK_TYPE_CANCEL int id_target = -1; + int id_slot = -1; // used by SERVER_TASK_TYPE_INFERENCE slot_params params; - server_tokens prompt_tokens; - int id_selected_slot = -1; + server_tokens tokens; + + server_task_type type; // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE struct slot_action { @@ -289,6 +288,8 @@ struct server_task { // used by SERVER_TASK_TYPE_SET_LORA std::vector set_lora; + server_task() = default; + server_task(server_task_type type) : type(type) {} static slot_params params_from_json_cmpl( @@ -305,6 +306,7 @@ struct server_task { defaults.sampling = params_base.sampling; defaults.speculative = params_base.speculative; defaults.n_keep = params_base.n_keep; + defaults.n_predict = params_base.n_predict; defaults.antiprompt = params_base.antiprompt; // enabling this will output extra debug information in the HTTP responses from the server @@ -323,32 +325,32 @@ struct server_task { params.n_discard = json_value(data, "n_discard", defaults.n_discard); //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); - params.response_fields = json_value(data, "response_fields", std::vector()); + params.response_fields = json_value(data, "response_fields", std::vector()); - params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); - params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); - params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); - params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); - params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); - params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); - params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); - params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); - params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); - params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); - params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); - params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); - params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); - params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); - params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); - params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); - params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); - params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); - params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); - params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); - params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); - params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); - params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); - params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); + params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); + params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); + params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); + params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); + params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); + params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); + params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); + params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); + params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); + params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); + params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); + params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); + params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); + params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); + params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); + params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); + params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); + params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); + params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); @@ -690,7 +692,7 @@ struct server_task_result { // using shared_ptr for polymorphism of server_task_result using server_task_result_ptr = std::unique_ptr; -inline std::string stop_type_to_str(stop_type type) { +static inline std::string stop_type_to_str(stop_type type) { switch (type) { case STOP_TYPE_EOS: return "eos"; case STOP_TYPE_WORD: return "word"; @@ -764,13 +766,6 @@ struct completion_token_output { } }; -struct ctx_checkpoint { - llama_pos pos_min; - llama_pos pos_max; - - std::vector data; -}; - struct server_task_result_cmpl_final : server_task_result { int index = 0; @@ -797,11 +792,12 @@ struct server_task_result_cmpl_final : server_task_result { slot_params generation_params; // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_msg oaicompat_msg; + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_msg oaicompat_msg; + std::vector oaicompat_msg_diffs; virtual int get_index() override { @@ -1373,17 +1369,17 @@ struct server_task_result_slot_save_load : server_task_result { { "save_ms", t_ms } }}, }; - } else { - return json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_restored", n_tokens }, - { "n_read", n_bytes }, - { "timings", { - { "restore_ms", t_ms } - }}, - }; } + + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", n_tokens }, + { "n_read", n_bytes }, + { "timings", { + { "restore_ms", t_ms } + }}, + }; } }; @@ -1404,15 +1400,218 @@ struct server_task_result_apply_lora : server_task_result { } }; +struct server_prompt_checkpoint { + llama_pos pos_min; + llama_pos pos_max; + + std::vector data; + + size_t size() const { + return data.size(); + } +}; + +struct server_prompt { + server_tokens tokens; + + std::vector data; + + std::list checkpoints; + + size_t size() const { + size_t res = data.size(); + + for (const auto & checkpoint : checkpoints) { + res += checkpoint.size(); + } + + return res; + } + + int n_tokens() const { + return tokens.size(); + } +}; + +struct server_prompt_cache { + server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) { + this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib); + this->limit_tokens = limit_tokens; + } + + std::list states; + + // in bytes, 0 = no limit + size_t limit_size = 0; + + // in tokens, 0 = no limit + size_t limit_tokens = 0; + + size_t size() const { + size_t res = 0; + + for (const auto & state : states) { + res += state.size(); + } + + return res; + } + + size_t n_tokens() const { + size_t res = 0; + + for (const auto & state : states) { + res += state.n_tokens(); + } + + return res; + } + + server_prompt * alloc(const server_prompt & prompt, size_t state_size) { + // first check if the current state is contained fully in the cache + for (auto it = states.begin(); it != states.end(); ++it) { + const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); + + if (cur_lcp_len == (int) prompt.tokens.size()) { + SRV_WRN("%s", " - prompt is already in the cache, skipping\n"); + return nullptr; + } + } + + // next, remove any cached prompts that are fully contained in the current prompt + for (auto it = states.begin(); it != states.end();) { + const int len = it->tokens.get_common_prefix(prompt.tokens); + + if (len == (int) it->tokens.size()) { + SRV_WRN(" - removing obsolete cached prompt with length %d\n", len); + + it = states.erase(it); + } else { + ++it; + } + } + + std::vector state_data; + + // check if we can allocate enough memory for the new state + try { + state_data.resize(state_size); + } catch (const std::bad_alloc & e) { + SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what()); + + limit_size = std::max(1, 0.4*size()); + + SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0)); + + update(); + + return nullptr; + } + + // TODO: for some reason we can't copy server_tokens, so we have to do this workaround + auto & cur = states.emplace_back(); + cur = { + /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), + /*.data =*/ std::move(state_data), + /*.checkpoints =*/ prompt.checkpoints, + }; + + return &cur; + } + + bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { + const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); + + float f_keep_best = float(lcp_best) / prompt.tokens.size(); + float sim_best = float(lcp_best) / tokens_new.size(); + + SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); + + auto it_best = states.end(); + + // find the most similar cached prompt, that would also preserve the most context + for (auto it = states.begin(); it != states.end(); ++it) { + const int lcp_cur = it->tokens.get_common_prefix(tokens_new); + + const float f_keep_cur = float(lcp_cur) / it->tokens.size(); + const float sim_cur = float(lcp_cur) / tokens_new.size(); + + // don't trash large prompts + if (f_keep_cur < 0.25f) { + continue; + } + + if (f_keep_best < f_keep_cur && sim_best < sim_cur) { + f_keep_best = f_keep_cur; + sim_best = sim_cur; + + it_best = it; + } + } + + if (it_best != states.end()) { + SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); + + const size_t size = it_best->data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0); + if (n != size) { + SRV_WRN("failed to restore state with size %zu\n", size); + + return false; + } + + it_best->data.clear(); + it_best->data.shrink_to_fit(); + + prompt = std::move(*it_best); + + states.erase(it_best); + } + + return true; + } + + void update() { + if (limit_size > 0) { + // always keep at least one state, regardless of the limits + while (states.size() > 1 && size() > limit_size) { + if (states.empty()) { + break; + } + + SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + + states.pop_front(); + } + } + + if (limit_tokens > 0) { + while (states.size() > 1 && n_tokens() > limit_tokens) { + if (states.empty()) { + break; + } + + SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); + + states.pop_front(); + } + } + + SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens)\n", + states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens); + + for (const auto & state : states) { + SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); + } + } +}; + struct server_slot { int id; - int id_task = -1; - - // only used for completion/embedding/infill/rerank - server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; llama_batch batch_spec = {}; + // TODO: change to unique_ptrs for consistency: llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; @@ -1421,15 +1620,8 @@ struct server_slot { common_speculative * spec = nullptr; - std::vector lora; - int32_t alora_invocation_start = -1; - - // the index relative to completion multi-task request - size_t index = 0; - - struct slot_params params; - - slot_state state = SLOT_STATE_IDLE; + std::unique_ptr task; + std::unique_ptr task_prev; // used for debugging // used to determine the slot that has been used the longest int64_t t_last_used = -1; @@ -1437,38 +1629,66 @@ struct server_slot { // generation props int32_t n_ctx = 0; // context size per slot int32_t n_past = 0; + int32_t n_keep = 0; int32_t n_decoded = 0; int32_t n_remaining = -1; int32_t i_batch = -1; - int32_t n_predict = -1; // TODO: disambiguate from params.n_predict - // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated - int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens_cache = 0; int32_t n_prompt_tokens_processed = 0; - // input prompt tokens - server_tokens prompt_tokens; + int32_t n_prompt_tokens() const { + return task->tokens.size(); + } size_t last_nl_pos = 0; std::string generated_text; llama_tokens generated_tokens; + common_chat_msg chat_msg; - server_tokens cache_tokens; - std::vector generated_token_probs; - std::vector ctx_checkpoints; - bool has_next_token = true; bool has_new_line = false; bool truncated = false; + stop_type stop; std::string stopping_word; + // state + slot_state state = SLOT_STATE_IDLE; + + server_prompt prompt; + + void prompt_save(server_prompt_cache & prompt_cache) const { + assert(prompt.data.size() == 0); + + const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); + + SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", + (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); + + auto * cur = prompt_cache.alloc(prompt, cur_size); + if (cur == nullptr) { + return; + } + + llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); + } + + void prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { + bool res = prompt_cache.load(prompt, tokens, ctx, id); + if (!res) { + SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); + } + } + + std::vector lora; + int32_t alora_invocation_start = -1; + // sampling json json_schema; @@ -1480,7 +1700,7 @@ struct server_slot { std::vector generated_tool_call_ids; // stats - size_t n_sent_text = 0; // number of sent text character + size_t n_sent_text = 0; // number of sent text character int64_t t_start_process_prompt; int64_t t_start_generation; @@ -1497,19 +1717,17 @@ struct server_slot { void reset() { SLT_DBG(*this, "%s", "\n"); - n_prompt_tokens = 0; n_prompt_tokens_cache = 0; - last_nl_pos = 0; - generated_text = ""; - has_new_line = false; - truncated = false; - stop = STOP_TYPE_NONE; - stopping_word = ""; - n_past = 0; - n_sent_text = 0; - task_type = SERVER_TASK_TYPE_COMPLETION; - chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + last_nl_pos = 0; + generated_text = ""; + has_new_line = false; + truncated = false; + stop = STOP_TYPE_NONE; + stopping_word = ""; + n_past = 0; + n_sent_text = 0; + chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; generated_tokens.clear(); generated_token_probs.clear(); @@ -1521,16 +1739,23 @@ struct server_slot { n_draft_total = 0; n_draft_accepted = 0; + task.reset(); + task_prev.reset(); + // clear alora start alora_invocation_start = -1; } bool need_embd() const { - return server_task_type_need_embd(task_type); + GGML_ASSERT(task); + + return server_task_type_need_embd(task->type); } bool need_logits() const { - return server_task_type_need_logits(task_type); + GGML_ASSERT(task); + + return server_task_type_need_logits(task->type); } // if the context does not have a memory module then all embeddings have to be computed within a single ubatch @@ -1542,18 +1767,22 @@ struct server_slot { } bool can_batch_with(server_slot & other_slot) const { - return task_type == other_slot.task_type && are_lora_equal(lora, other_slot.lora); + GGML_ASSERT(task); + + return task->type == other_slot.task->type && are_lora_equal(lora, other_slot.lora); } bool has_budget(const common_params & global_params) { - if (params.n_predict == -1 && global_params.n_predict == -1) { + GGML_ASSERT(task); + + if (task->params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless } n_remaining = -1; - if (params.n_predict != -1) { - n_remaining = params.n_predict - n_decoded; + if (task->params.n_predict != -1) { + n_remaining = task->params.n_predict - n_decoded; } else if (global_params.n_predict != -1) { n_remaining = global_params.n_predict - n_decoded; } @@ -1566,7 +1795,7 @@ struct server_slot { } bool can_speculate() const { - return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; + return ctx_dft; } void add_token(const completion_token_output & token) { @@ -1579,11 +1808,17 @@ struct server_slot { void release() { if (is_processing()) { + GGML_ASSERT(task); + SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated); t_last_used = ggml_time_us(); t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; state = SLOT_STATE_IDLE; + + task_prev = std::move(task); + task.reset(); + callback_on_release(id); } } @@ -1592,19 +1827,19 @@ struct server_slot { result_timings timings; timings.cache_n = n_prompt_tokens_cache; - timings.prompt_n = n_prompt_tokens_processed; - timings.prompt_ms = t_prompt_processing; + timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_ms = t_prompt_processing; timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; - timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - timings.predicted_n = n_decoded; - timings.predicted_ms = t_token_generation; + timings.predicted_n = n_decoded; + timings.predicted_ms = t_token_generation; timings.predicted_per_token_ms = t_token_generation / n_decoded; - timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; // Add speculative metrics if (n_draft_total > 0) { - timings.draft_n = n_draft_total; + timings.draft_n = n_draft_total; timings.draft_n_accepted = n_draft_accepted; } @@ -1612,14 +1847,16 @@ struct server_slot { } const common_chat_msg & update_chat_msg(std::vector & diffs) { + GGML_ASSERT(task); + auto previous_msg = chat_msg; SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); auto new_msg = common_chat_parse( generated_text, /* is_partial= */ stop != STOP_TYPE_EOS, - params.oaicompat_chat_syntax); + task->params.oaicompat_chat_syntax); if (!new_msg.empty()) { - new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id); + new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); chat_msg = new_msg; diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg); } @@ -1627,9 +1864,11 @@ struct server_slot { } size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { + GGML_ASSERT(task); + size_t stop_pos = std::string::npos; - for (const std::string & word : params.antiprompt) { + for (const std::string & word : task->params.antiprompt) { size_t pos; if (is_full_stop) { @@ -1682,43 +1921,36 @@ struct server_slot { } json to_json(bool only_metrics = false) const { - if (only_metrics) { - return json { - {"id", id}, - {"id_task", id_task}, - {"n_ctx", n_ctx}, - {"speculative", can_speculate()}, - {"is_processing", is_processing()}, - {"params", params.to_json(true)}, - {"next_token", - { - {"has_next_token", has_next_token}, - {"has_new_line", has_new_line}, - {"n_remain", n_remaining}, - {"n_decoded", n_decoded}, - } - }, - }; - } + json res; - return json { + res = { {"id", id}, - {"id_task", id_task}, {"n_ctx", n_ctx}, {"speculative", can_speculate()}, {"is_processing", is_processing()}, - {"params", params.to_json()}, - {"prompt", prompt_tokens.detokenize(ctx, true)}, - {"next_token", + }; + + const auto & ptask = task ? task : task_prev; + + if (ptask) { + res["id_task"] = ptask->id; + res["params"] = ptask->params.to_json(only_metrics); + res["next_token"] = { { {"has_next_token", has_next_token}, {"has_new_line", has_new_line}, {"n_remain", n_remaining}, {"n_decoded", n_decoded}, - {"stopping_word", stopping_word}, } - }, - }; + }; + + if (!only_metrics) { + res["prompt"] = ptask->tokens.detokenize(ctx, true); + res["generated"] = generated_text; + } + } + + return res; } }; @@ -1937,7 +2169,7 @@ private: void cleanup_pending_task(int id_target) { // no need lock because this is called exclusively by post() auto rm_func = [id_target](const server_task & task) { - return task.id_target == id_target; + return task.id == id_target; }; queue_tasks.erase( std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), @@ -2109,11 +2341,14 @@ struct server_context { // slots / clients std::vector slots; - json default_generation_settings_for_props; + + int slots_debug = 0; server_queue queue_tasks; server_response queue_results; + std::unique_ptr prompt_cache; + server_metrics metrics; // Necessary similarity of prompt for slot selection @@ -2268,9 +2503,8 @@ struct server_context { slot.id = i; slot.ctx = ctx; slot.n_ctx = n_ctx_slot; - slot.n_predict = params_base.n_predict; slot.mctx = mctx; - slot.cache_tokens.has_mtmd = mctx != nullptr; + slot.prompt.tokens.has_mtmd = mctx != nullptr; if (model_dft) { slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); @@ -2286,16 +2520,13 @@ struct server_context { SRV_ERR("%s", "failed to create speculator\n"); return; } - for (auto &pair : params_base.speculative.replacements) { + for (auto & pair : params_base.speculative.replacements) { common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); } } SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); - slot.params.sampling = params_base.sampling; - slot.params.n_keep = params_base.n_keep; - slot.callback_on_release = [this](int) { queue_tasks.pop_deferred_task(); }; @@ -2305,7 +2536,14 @@ struct server_context { slots.push_back(std::move(slot)); } - default_generation_settings_for_props = slots[0].to_json(); + { + const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG"); + slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0; + + if (slots_debug) { + SRV_WRN("slots debug = %d\n", slots_debug); + } + } // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) @@ -2316,11 +2554,25 @@ struct server_context { metrics.init(); + if (params_base.cache_ram_mib != 0) { + if (params_base.cache_ram_mib < 0) { + SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit"); + } else { + SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib); + } + SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n"); + + prompt_cache = std::make_unique(params_base.cache_ram_mib, n_ctx); + } else { + SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n"); + } + SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n"); + // thinking is enabled if: // 1. It's not explicitly disabled (reasoning_budget == 0) // 2. The chat template supports it const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); - SRV_INF("Enable thinking? %d\n", enable_thinking); + SRV_INF("thinking = %d\n", enable_thinking); oai_parser_opt = { /* use_jinja */ params_base.use_jinja, @@ -2347,10 +2599,11 @@ struct server_context { server_slot * get_available_slot(const server_task & task) { server_slot * ret = nullptr; + bool update_cache = false; + // find the slot that has at least n% prompt similarity if (ret == nullptr && slot_prompt_similarity != 0.0f) { - int lcs_len = 0; - float similarity = 0; + float sim_best = 0; for (server_slot & slot : slots) { // skip the slot if it is not available @@ -2358,27 +2611,34 @@ struct server_context { continue; } + const auto & tokens = slot.prompt.tokens; + // skip the slot if it does not contains cached tokens - if (slot.cache_tokens.empty()) { + if (tokens.empty()) { continue; } - // length of the Longest Common Subsequence between the current slot's prompt and the input prompt - int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); - - // fraction of the common subsequence length compared to the current slot's prompt length - float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); + // fraction of the Longest Common Prefix length with respect to the input prompt length + const float sim_cur = float(tokens.get_common_prefix(task.tokens)) / task.tokens.size(); // select the current slot if the criteria match - if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { - lcs_len = cur_lcs_len; - similarity = cur_similarity; + if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { + sim_best = sim_cur; + ret = &slot; } } if (ret != nullptr) { - SLT_INF(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %.3f (> %.3f thold)\n", lcs_len, similarity, slot_prompt_similarity); + const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size(); + + SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n", + sim_best, slot_prompt_similarity, f_keep); + + // if we are about to lose a large portion of the existing context - save it in the prompt cache + if (f_keep < 0.5f) { + update_cache = true; + } } } @@ -2401,6 +2661,36 @@ struct server_context { if (ret != nullptr) { SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last); + + update_cache = true; + } + } + + if (ret) { + const auto & tokens = ret->prompt.tokens; + + update_cache = update_cache && prompt_cache; + + // cache prompts only for completion tasks + update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; + + // don't update the cache if the slot's context is empty + update_cache = update_cache && tokens.size() > 0; + + // TODO: mtmd does not support prompt cache + update_cache = update_cache && (ret->mctx == nullptr); + + if (update_cache) { + SRV_WRN("%s", "updating prompt cache\n"); + + const int64_t t_start = ggml_time_us(); + + ret->prompt_save(*prompt_cache); + ret->prompt_load(*prompt_cache, task.tokens); + + prompt_cache->update(); + + SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); } } @@ -2409,27 +2699,21 @@ struct server_context { bool launch_slot_with_task(server_slot & slot, server_task && task) { slot.reset(); - slot.id_task = task.id; - slot.index = task.index; - slot.task_type = task.type; - slot.params = std::move(task.params); - slot.prompt_tokens = std::move(task.prompt_tokens); - if (!are_lora_equal(slot.params.lora, slot.lora)) { + if (!are_lora_equal(task.params.lora, slot.lora)) { // if lora has changed, check to see if the cache should be cleared - if (lora_should_clear_cache(slot.lora, slot.params.lora)) { - SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), slot.params.lora.size()); - slot.cache_tokens.clear(); + if (lora_should_clear_cache(slot.lora, task.params.lora)) { + SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size()); + slot.prompt.tokens.clear(); } else { - SLT_INF(slot, "keeping cache for alora. %zu target loras\n", slot.params.lora.size()); + SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size()); } - slot.lora = slot.params.lora; + slot.lora = task.params.lora; } // if using alora, make sure it's only a single one requested and active - size_t alora_invocation_start = slot.prompt_tokens.size(); + size_t alora_invocation_start = task.tokens.size(); if (lora_all_alora(slot.lora)) { - const auto & enabled_ids = lora_get_enabled_ids(slot.lora); // TODO: This will error out if a user requests two aloras, but only // provides the activation string for one. We could, instead search @@ -2448,10 +2732,10 @@ struct server_context { // scan backwards through the prompt tokens to find the last // occurrence of the invocation sequence int match_idx = static_cast(n_invocation_tokens) - 1; - for (int i = slot.prompt_tokens.size() - 1; i >= 0; --i) { + for (int i = task.tokens.size() - 1; i >= 0; --i) { // the token in this position matches the next token to find in // the invocation sequence - if (slot.prompt_tokens[i] == invocation_tokens[match_idx]) { + if (task.tokens[i] == invocation_tokens[match_idx]) { // if it's a full match, we've found the start if (match_idx == 0) { alora_invocation_start = i; @@ -2466,7 +2750,7 @@ struct server_context { } // if the activation string is not found, disable the alora - if (alora_invocation_start == slot.prompt_tokens.size()) { + if (alora_invocation_start == task.tokens.size()) { SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]); slot.lora[enabled_ids[0]].scale = 0.0f; } else { @@ -2475,24 +2759,20 @@ struct server_context { } } - if (!slot.prompt_tokens.validate(ctx)) { + if (!task.tokens.validate(ctx)) { send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); return false; } + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); - if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { - // Might be better to reject the request with a 400 ? - SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict); - slot.params.n_predict = slot.n_predict; - } - + // initialize samplers { if (slot.smpl != nullptr) { common_sampler_free(slot.smpl); } - slot.smpl = common_sampler_init(model, slot.params.sampling); + slot.smpl = common_sampler_init(model, task.params.sampling); if (slot.smpl == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); @@ -2500,12 +2780,15 @@ struct server_context { } } + // initialize draft batch if (slot.ctx_dft) { llama_batch_free(slot.batch_spec); - slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1); } + slot.task = std::make_unique(std::move(task)); + slot.state = SLOT_STATE_STARTED; SLT_INF(slot, "%s", "processing task\n"); @@ -2527,7 +2810,7 @@ struct server_context { slot.sampled = result.tok; slot.generated_text += token_str; - if (slot.params.return_tokens) { + if (slot.task->params.return_tokens) { slot.generated_tokens.push_back(result.tok); } slot.has_next_token = true; @@ -2564,7 +2847,7 @@ struct server_context { } slot.add_token(result); - if (slot.params.stream) { + if (slot.task->params.stream) { send_partial_response(slot, result, false); } } @@ -2586,12 +2869,12 @@ struct server_context { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); + SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict); } if (slot.has_new_line) { // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent - if (slot.params.n_indent > 0) { + if (slot.task->params.n_indent > 0) { // check the current indentation // TODO: improve by not doing it more than once for each new line if (slot.last_nl_pos > 0) { @@ -2603,7 +2886,7 @@ struct server_context { pos++; } - if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) { + if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; @@ -2630,11 +2913,11 @@ struct server_context { slot.has_new_line = true; // if we have seen a new line, we stop after a certain time limit, but only upon another new line - if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) { + if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) { slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms); } } @@ -2645,7 +2928,7 @@ struct server_context { slot.has_next_token = false; SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", - slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); + slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx); } if (llama_vocab_is_eog(vocab, result.tok)) { @@ -2657,7 +2940,7 @@ struct server_context { const auto n_ctx_train = llama_model_n_ctx_train(model); - if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { + if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) { slot.truncated = true; slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; // stop prediction @@ -2665,7 +2948,7 @@ struct server_context { SLT_WRN(slot, "n_predict (%d) is set for infinite generation. " "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n", - slot.params.n_predict, n_ctx_train); + slot.task->params.n_predict, n_ctx_train); } SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); @@ -2674,7 +2957,7 @@ struct server_context { } void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { - size_t n_probs = slot.params.sampling.n_probs; + size_t n_probs = slot.task->params.sampling.n_probs; size_t n_vocab = llama_vocab_n_tokens(vocab); if (post_sampling) { @@ -2728,7 +3011,7 @@ struct server_context { } void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(slot.id_task, error, type, slot.n_prompt_tokens, slot.n_ctx); + send_error(slot.task->id, error, type, slot.n_prompt_tokens(), slot.n_ctx); } void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { @@ -2749,7 +3032,7 @@ struct server_context { } // if multimodal is enabled, send an error and return false - bool ensure_no_mtmd(const int id_task) { + bool check_no_mtmd(const int id_task) { if (mctx) { send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); return false; @@ -2760,14 +3043,14 @@ struct server_context { void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { auto res = std::make_unique(); - res->id = slot.id_task; - res->index = slot.index; + res->id = slot.task->id; + res->index = slot.task->index; if (is_progress) { res->is_progress = true; - res->progress.total = slot.n_prompt_tokens; + res->progress.total = slot.n_prompt_tokens(); res->progress.cache = slot.n_prompt_tokens_cache; - res->progress.processed = slot.cache_tokens.size(); + res->progress.processed = slot.prompt.tokens.size(); res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt / 1000); } else { res->content = tkn.text_to_send; @@ -2777,21 +3060,21 @@ struct server_context { } res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; - res->post_sampling_probs = slot.params.post_sampling_probs; + res->n_prompt_tokens = slot.n_prompt_tokens(); + res->post_sampling_probs = slot.task->params.post_sampling_probs; - res->verbose = slot.params.verbose; - res->oaicompat = slot.params.oaicompat; - res->oaicompat_model = slot.params.oaicompat_model; - res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res->verbose = slot.task->params.verbose; + res->oaicompat = slot.task->params.oaicompat; + res->oaicompat_model = slot.task->params.oaicompat_model; + res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; // populate res.probs_output - if (slot.params.sampling.n_probs > 0) { + if (slot.task->params.sampling.n_probs > 0) { res->prob_output = tkn; // copy the token probs } // populate timings if this is final response or timings_per_token is enabled - if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) { + if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) { res->timings = slot.get_timings(); } @@ -2800,36 +3083,37 @@ struct server_context { void send_final_response(server_slot & slot) { auto res = std::make_unique(); - res->id = slot.id_task; - res->id_slot = slot.id; - res->index = slot.index; + res->id = slot.task->id; + res->id_slot = slot.id; + + res->index = slot.task->index; res->content = slot.generated_text; res->tokens = std::move(slot.generated_tokens); res->timings = slot.get_timings(); - res->prompt = slot.prompt_tokens.detokenize(ctx, true); - res->response_fields = std::move(slot.params.response_fields); + res->prompt = slot.task->tokens.detokenize(ctx, true); + res->response_fields = std::move(slot.task->params.response_fields); res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_prompt_tokens = slot.n_prompt_tokens(); res->n_tokens_cached = slot.n_past; res->has_new_line = slot.has_new_line; res->stopping_word = slot.stopping_word; res->stop = slot.stop; - res->post_sampling_probs = slot.params.post_sampling_probs; + res->post_sampling_probs = slot.task->params.post_sampling_probs; - res->verbose = slot.params.verbose; - res->stream = slot.params.stream; - res->include_usage = slot.params.include_usage; - res->oaicompat = slot.params.oaicompat; - res->oaicompat_model = slot.params.oaicompat_model; - res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; - res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); + res->verbose = slot.task->params.verbose; + res->stream = slot.task->params.stream; + res->include_usage = slot.task->params.include_usage; + res->oaicompat = slot.task->params.oaicompat; + res->oaicompat_model = slot.task->params.oaicompat_model; + res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; + res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); // populate res.probs_output - if (slot.params.sampling.n_probs > 0) { - if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { + if (slot.task->params.sampling.n_probs > 0) { + if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) { const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); @@ -2843,17 +3127,17 @@ struct server_context { } } - res->generation_params = slot.params; // copy the parameters + res->generation_params = slot.task->params; // copy the parameters queue_results.send(std::move(res)); } void send_embedding(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); - res->id = slot.id_task; - res->index = slot.index; - res->n_tokens = slot.n_prompt_tokens; - res->oaicompat = slot.params.oaicompat; + res->id = slot.task->id; + res->index = slot.task->index; + res->n_tokens = slot.n_prompt_tokens(); + res->oaicompat = slot.task->params.oaicompat; const int n_embd = llama_model_n_embd(model); @@ -2880,12 +3164,12 @@ struct server_context { // normalize only when there is pooling if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, slot.params.embd_normalize); + common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize); res->embedding.push_back(embd_res); break; - } else { - res->embedding.emplace_back(embd, embd + n_embd); } + + res->embedding.emplace_back(embd, embd + n_embd); } SLT_DBG(slot, "%s", "sending embeddings\n"); @@ -2895,9 +3179,9 @@ struct server_context { void send_rerank(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); - res->id = slot.id_task; - res->index = slot.index; - res->n_tokens = slot.n_prompt_tokens; + res->id = slot.task->id; + res->index = slot.task->index; + res->n_tokens = slot.n_prompt_tokens(); for (int i = 0; i < batch.n_tokens; ++i) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { @@ -3034,7 +3318,7 @@ struct server_context { case SERVER_TASK_TYPE_EMBEDDING: case SERVER_TASK_TYPE_RERANK: { - const int id_slot = task.id_selected_slot; + const int id_slot = task.id_slot; server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); @@ -3061,7 +3345,7 @@ struct server_context { { // release slot linked with the task id for (auto & slot : slots) { - if (slot.id_task == task.id_target) { + if (slot.task && slot.task->id == task.id_target) { slot.release(); break; } @@ -3079,7 +3363,7 @@ struct server_context { int n_processing_slots = 0; for (server_slot & slot : slots) { - json slot_data = slot.to_json(true); + json slot_data = slot.to_json(slots_debug == 0); if (slot.is_processing()) { n_processing_slots++; @@ -3121,7 +3405,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_SAVE: { - if (!ensure_no_mtmd(task.id)) { + if (!check_no_mtmd(task.id)) { break; } @@ -3138,13 +3422,13 @@ struct server_context { break; } - const size_t token_count = slot->cache_tokens.size(); + const size_t token_count = slot->prompt.tokens.size(); const int64_t t_start = ggml_time_us(); std::string filename = task.slot_action.filename; std::string filepath = task.slot_action.filepath; - const llama_tokens & tokens = slot->cache_tokens.get_text_tokens(); + const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens(); const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); const int64_t t_end = ggml_time_us(); @@ -3162,7 +3446,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_RESTORE: { - if (!ensure_no_mtmd(task.id)) break; + if (!check_no_mtmd(task.id)) break; int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -3186,13 +3470,13 @@ struct server_context { size_t token_count = 0; size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); if (nread == 0) { - slot->cache_tokens.clear(); // KV may already been invalidated? + slot->prompt.tokens.clear(); // KV may already been invalidated? send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); break; } tokens.resize(token_count); - slot->cache_tokens.clear(); - slot->cache_tokens.insert(tokens); + slot->prompt.tokens.clear(); + slot->prompt.tokens.insert(tokens); const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; @@ -3209,7 +3493,9 @@ struct server_context { } break; case SERVER_TASK_TYPE_SLOT_ERASE: { - if (!ensure_no_mtmd(task.id)) break; + if (!check_no_mtmd(task.id)) { + break; + } int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { @@ -3224,9 +3510,9 @@ struct server_context { } // Erase token cache - const size_t n_erased = slot->cache_tokens.size(); + const size_t n_erased = slot->prompt.tokens.size(); llama_memory_seq_rm(llama_get_memory(ctx), slot->id, -1, -1); - slot->cache_tokens.clear(); + slot->prompt.tokens.clear(); auto res = std::make_unique(); res->id = task.id; @@ -3282,8 +3568,8 @@ struct server_context { if (!params_base.ctx_shift) { // this check is redundant (for good) // we should never get here, because generation should already stopped in process_token() - slot.release(); send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); + slot.release(); continue; } @@ -3294,9 +3580,16 @@ struct server_context { } // Shift context - const int n_keep = slot.params.n_keep + add_bos_token; + int n_keep = slot.task->params.n_keep < 0 ? slot.n_prompt_tokens() : slot.task->params.n_keep; + + if (add_bos_token) { + n_keep += 1; + } + + n_keep = std::min(slot.n_ctx - 4, n_keep); + const int n_left = slot.n_past - n_keep; - const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); + const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); @@ -3305,14 +3598,14 @@ struct server_context { // add generated tokens to cache { - llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy + llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { new_tokens[i - n_discard] = new_tokens[i]; } - new_tokens.resize(slot.cache_tokens.size() - n_discard); - slot.cache_tokens.clear(); - slot.cache_tokens.insert(new_tokens); + new_tokens.resize(slot.prompt.tokens.size() - n_discard); + slot.prompt.tokens.clear(); + slot.prompt.tokens.insert(new_tokens); } slot.n_past -= n_discard; @@ -3328,7 +3621,8 @@ struct server_context { server_slot * slot_batched = nullptr; auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end(); + return params_base.special || + slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); }; // frist, add sampled tokens from any ongoing sequences @@ -3349,10 +3643,10 @@ struct server_context { common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); slot.n_past += 1; - slot.cache_tokens.push_back(slot.sampled); + slot.prompt.tokens.push_back(slot.sampled); SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); + slot.n_ctx, slot.n_past, (int) slot.prompt.tokens.size(), slot.truncated); } // process in chunks of params.n_batch @@ -3375,7 +3669,7 @@ struct server_context { // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { - auto & prompt_tokens = slot.prompt_tokens; + const auto & input_tokens = slot.task->tokens; // TODO: maybe move branch to outside of this loop in the future if (slot.state == SLOT_STATE_STARTED) { @@ -3383,104 +3677,64 @@ struct server_context { slot.t_start_generation = 0; slot.n_past = 0; - slot.n_prompt_tokens = prompt_tokens.size(); slot.state = SLOT_STATE_PROCESSING_PROMPT; - SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", + slot.n_ctx, slot.task->params.n_keep, slot.n_prompt_tokens()); // print prompt tokens (for debugging) /*if (1) { // first 16 tokens (avoid flooding logs) - for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + for (int i = 0; i < std::min(16, input_tokens.size()); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); } } else { // all - for (int i = 0; i < (int) prompt_tokens.size(); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + for (int i = 0; i < (int) input_tokens.size(); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); } }*/ // empty prompt passed -> release the slot and send empty response - if (prompt_tokens.empty()) { + if (input_tokens.empty()) { SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); - slot.release(); slot.print_timings(); send_final_response(slot); + slot.release(); + continue; } // TODO: support memory-less logits computation if (slot.need_logits() && !llama_get_memory(ctx)) { - slot.release(); send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); + slot.release(); continue; } if (!slot.can_split()) { - if (slot.n_prompt_tokens > n_ubatch) { - slot.release(); + if (slot.n_prompt_tokens() > n_ubatch) { send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); + slot.release(); continue; } - if (slot.n_prompt_tokens > slot.n_ctx) { - slot.release(); + if (slot.n_prompt_tokens() > slot.n_ctx) { send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + slot.release(); continue; } } else { - if (!params_base.ctx_shift) { - // if context shift is disabled, we make sure prompt size is smaller than KV size - // TODO: there should be a separate parameter that control prompt truncation - // context shift should be applied only during the generation phase - if (slot.n_prompt_tokens >= slot.n_ctx) { - slot.release(); - send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE); - continue; - } - } - if (slot.params.n_keep < 0) { - slot.params.n_keep = slot.n_prompt_tokens; - } - slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - - // if input prompt is too big, truncate it - if (slot.n_prompt_tokens >= slot.n_ctx) { - if (mctx) { - // we should never reach this - GGML_ABORT("not supported by multimodal"); - } - const int n_left = slot.n_ctx - slot.params.n_keep; - - const int n_block_size = n_left / 2; - const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - - const llama_tokens & curr_tokens = slot.prompt_tokens.get_text_tokens(); - llama_tokens new_tokens( - curr_tokens.begin(), - curr_tokens.begin() + slot.params.n_keep); - - new_tokens.insert( - new_tokens.end(), - curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, - curr_tokens.end()); - - prompt_tokens.clear(); - prompt_tokens.insert(new_tokens); - - slot.truncated = true; - slot.n_prompt_tokens = prompt_tokens.size(); - - SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); - - GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); + if (slot.n_prompt_tokens() >= slot.n_ctx) { + send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + slot.release(); + continue; } - if (slot.params.cache_prompt) { + if (slot.task->params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt - slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); + slot.n_past = slot.prompt.tokens.get_common_prefix(input_tokens); // if there is an alora invoked, don't cache after the invocation start if (slot.alora_invocation_start >= 0) { @@ -3500,13 +3754,13 @@ struct server_context { SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); - while (head_c < slot.cache_tokens.size() && - head_p < prompt_tokens.size()) { + while (head_c < slot.prompt.tokens.size() && + head_p < input_tokens.size()) { size_t n_match = 0; - while (head_c + n_match < slot.cache_tokens.size() && - head_p + n_match < prompt_tokens.size() && - slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { + while (head_c + n_match < slot.prompt.tokens.size() && + head_p + n_match < input_tokens.size() && + slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) { n_match++; } @@ -3523,7 +3777,7 @@ struct server_context { llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); for (size_t i = 0; i < n_match; i++) { - slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]); + slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); slot.n_past++; } @@ -3547,41 +3801,83 @@ struct server_context { // the largest pos_min required for a checkpoint to be useful const auto pos_min_thold = std::max(0, slot.n_past - n_swa); - if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { + if (slot.n_past > 0 && slot.n_past < (int) slot.prompt.tokens.size()) { const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); if (pos_min == -1) { - SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min); + SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); } + // when the prompt prefix does not match, print the tokens around the mismatch + // this is useful for debugging prompt caching + { + const int np0 = std::max(slot.n_past - 4, 0); + const int np1 = std::min(slot.n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size())); + + std::stringstream ss0; + std::stringstream ss1; + + std::stringstream st0; + std::stringstream st1; + + ss0 << "old: ... "; + ss1 << "new: ... "; + + for (int i = np0; i < np1; i++) { + if (i == slot.n_past) { + ss0 << " | "; + ss1 << " | "; + } + + { + const auto token = slot.prompt.tokens[i]; + const auto piece = common_token_to_piece(ctx, token); + ss0 << piece; + st0 << std::setw(8) << token; + } + + { + const auto token = slot.task->tokens[i]; + const auto piece = common_token_to_piece(ctx, token); + ss1 << piece; + st1 << std::setw(8) << token; + } + } + + SLT_WRN(slot, "%s\n", ss0.str().c_str()); + SLT_WRN(slot, "%s\n", ss1.str().c_str()); + + SLT_WRN(slot, "%s\n", st0.str().c_str()); + SLT_WRN(slot, "%s\n", st1.str().c_str()); + } + if (pos_min > pos_min_thold) { - SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); + SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); // search for a context checkpoint const auto it = std::find_if( - slot.ctx_checkpoints.rbegin(), - slot.ctx_checkpoints.rend(), + slot.prompt.checkpoints.rbegin(), + slot.prompt.checkpoints.rend(), [&](const auto & cur) { // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] return cur.pos_min < pos_min_thold; } ); - bool do_reset = it == slot.ctx_checkpoints.rend(); - //printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false"); + bool do_reset = it == slot.prompt.checkpoints.rend(); if (!do_reset) { // restore the context checkpoint - const size_t ctx_checkpoint_size = it->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + const size_t checkpoint_size = it->data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - if (n != ctx_checkpoint_size) { - SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024); + if (n != checkpoint_size) { + SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); do_reset = true; //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); } else { slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max)); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); } } @@ -3595,19 +3891,21 @@ struct server_context { { // erase any checkpoints with pos_min > pos_min_thold - for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) { - const auto & cur = slot.ctx_checkpoints[i]; + for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { + const auto & cur = *it; if (cur.pos_min > pos_min_thold) { SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); - slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i); + it = slot.prompt.checkpoints.erase(it); + } else { + ++it; } } } } // [TAG_PROMPT_LOGITS] - if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { - SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens); + if (slot.n_past == slot.n_prompt_tokens() && slot.n_past > 0) { + SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens()); slot.n_past--; SLT_WRN(slot, "n_past was set to %d\n", slot.n_past); } @@ -3618,7 +3916,7 @@ struct server_context { if (!slot.can_split()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { + if (batch.n_tokens + slot.n_prompt_tokens() > n_batch) { continue; } } @@ -3636,28 +3934,28 @@ struct server_context { SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past); // remove the non-common part from the cache - slot.cache_tokens.keep_first(slot.n_past); + slot.prompt.tokens.keep_first(slot.n_past); // check if we should process the image - if (slot.n_past < slot.n_prompt_tokens && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { + if (slot.n_past < slot.n_prompt_tokens() && input_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { // process the image int32_t new_n_past; - int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); - int32_t n_pos = new_n_past - slot.n_past; - + int32_t res = input_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); if (res != 0) { SLT_ERR(slot, "failed to process image, res = %d\n", res); - slot.release(); send_error(slot, "failed to process image", ERROR_TYPE_SERVER); + slot.release(); continue; } // add the image chunk to cache { - const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past); - slot.cache_tokens.push_back(chunk.get()); // copy + const auto & chunk = input_tokens.find_chunk(slot.n_past); + slot.prompt.tokens.push_back(chunk.get()); // copy } + const int32_t n_pos = new_n_past - slot.n_past; + slot.n_past += n_pos; slot.n_prompt_tokens_processed += n_pos; } @@ -3676,10 +3974,27 @@ struct server_context { alora_disabled_id = enabled_loras[0]; } + bool do_checkpoint = params_base.n_ctx_checkpoints > 0; + + // make checkpoints only for completion tasks + do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; + + // make a checkpoint of the parts of the memory that cannot be rolled back. + // checkpoints are created only if: + // - the model uses SWA and we are not using `swa_full` + // - the model architecture is marked as recurrent or hybrid + // + // TODO: try to make this conditional on the context or the memory module, instead of the model type + do_checkpoint = do_checkpoint && ( + llama_model_is_recurrent(model) || + llama_model_is_hybrid(model) || + (llama_model_n_swa(model) > 0 && !params_base.swa_full) + ); + // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + while (slot.n_past < slot.n_prompt_tokens() && batch.n_tokens < n_batch) { // get next token to process - llama_token cur_tok = slot.prompt_tokens[slot.n_past]; + llama_token cur_tok = input_tokens[slot.n_past]; if (cur_tok == LLAMA_TOKEN_NULL) { break; // end of text chunk } @@ -3693,31 +4008,33 @@ struct server_context { } // embedding requires all tokens in the batch to be output - const bool need_embd = server_task_type_need_embd(slot.task_type); - - common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); - slot.cache_tokens.push_back(cur_tok); + common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, slot.need_embd()); + slot.prompt.tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; slot.n_past++; + + // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. + if (do_checkpoint && slot.n_prompt_tokens() - slot.n_past == 64) { + break; + } } // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / slot.n_prompt_tokens()); // entire prompt has been processed - if (slot.n_past == slot.n_prompt_tokens) { + if (slot.n_past == slot.n_prompt_tokens()) { slot.state = SLOT_STATE_DONE_PROMPT; GGML_ASSERT(batch.n_tokens > 0); - GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size()); common_sampler_reset(slot.smpl); // Process all prompt tokens through sampler system - for (int i = 0; i < slot.n_prompt_tokens; ++i) { - llama_token id = slot.prompt_tokens[i]; + for (int i = 0; i < slot.n_prompt_tokens(); ++i) { + llama_token id = input_tokens[i]; if (id != LLAMA_TOKEN_NULL) { common_sampler_accept(slot.smpl, id, false); } @@ -3730,6 +4047,40 @@ struct server_context { slot.i_batch = batch.n_tokens - 1; SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); + + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); + + // no need for empty or small checkpoints + do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); + + // no need to create checkpoints that are too close together + do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); + + if (do_checkpoint) { + while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + // make room for the new checkpoint, if needed + const auto & cur = slot.prompt.checkpoints.front(); + + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + + slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); + } + + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ + /*.pos_min = */ pos_min, + /*.pos_max = */ pos_max, + /*.data = */ std::vector(checkpoint_size), + }); + + llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + } } } @@ -3802,8 +4153,8 @@ struct server_context { if (!err.empty()) { SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); for (auto & slot : slots) { - slot.release(); send_error(slot, err); + slot.release(); } break; } @@ -3826,7 +4177,7 @@ struct server_context { for (auto & slot : slots) { // optionally send prompt processing progress if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.params.stream && slot.params.return_progress) { + if (slot.task->params.stream && slot.task->params.return_progress) { send_partial_response(slot, {}, true); } } @@ -3836,7 +4187,7 @@ struct server_context { } if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { + if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { // prompt evaluated for embedding send_embedding(slot, batch_view); slot.release(); @@ -3844,7 +4195,7 @@ struct server_context { continue; // continue loop of slots } - if (slot.task_type == SERVER_TASK_TYPE_RERANK) { + if (slot.task->type == SERVER_TASK_TYPE_RERANK) { send_rerank(slot, batch_view); slot.release(); slot.i_batch = -1; @@ -3853,40 +4204,6 @@ struct server_context { // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; - - // make a checkpoint of the parts of the memory that cannot be rolled back. - // checkpoints are created only if: - // - the model uses SWA and we are not using `swa_full` - // - the model architecture is marked as recurrent or hybrid - // - // TODO: try to make this conditional on the context or the memory module, instead of the model type - const bool do_checkpoint = - (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) || - (llama_model_n_swa(model) > 0 && !params_base.swa_full); - - if (do_checkpoint && params_base.n_ctx_checkpoints > 0) { - while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { - // make room for the new checkpoint, if needed - const auto & cur = slot.ctx_checkpoints.front(); - SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - - slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin()); - } - - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{ - /*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id), - /*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id), - /*.data = */ std::vector(checkpoint_size), - }); - - llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots } @@ -3909,23 +4226,24 @@ struct server_context { metrics.on_prompt_eval(slot); } - slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3; + slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; completion_token_output result; result.tok = id; result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - if (slot.params.sampling.n_probs > 0) { - populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx); + if (slot.task->params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); } if (!process_token(result, slot)) { // release slot because of stop condition - slot.release(); slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); + slot.release(); + continue; } } @@ -3946,7 +4264,7 @@ struct server_context { } // determine the max draft that fits the current slot state - int n_draft_max = slot.params.speculative.n_max; + int n_draft_max = slot.task->params.speculative.n_max; // note: n_past is not yet increased for the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts @@ -3958,8 +4276,8 @@ struct server_context { SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); - if (n_draft_max < slot.params.speculative.n_min) { - SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min); + if (n_draft_max < slot.task->params.speculative.n_min) { + SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min); continue; } @@ -3967,16 +4285,16 @@ struct server_context { llama_token id = slot.sampled; struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; - params_spec.p_min = slot.params.speculative.p_min; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; + params_spec.p_min = slot.task->params.speculative.p_min; - const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens(); + const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); // ignore small drafts - if (slot.params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); + if (slot.task->params.speculative.n_min > (int) draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); continue; } @@ -4005,8 +4323,8 @@ struct server_context { // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; - slot.cache_tokens.push_back(id); - slot.cache_tokens.insert({ids.begin(), ids.end() - 1}); + slot.prompt.tokens.push_back(id); + slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1); @@ -4020,11 +4338,11 @@ struct server_context { // TODO: set result.probs if (!process_token(result, slot)) { - // release slot because of stop condition - slot.release(); slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); + slot.release(); + break; } } @@ -4050,7 +4368,7 @@ struct server_context { static void log_server_request(const httplib::Request & req, const httplib::Response & res) { // skip GH copilot requests when using default port - if (req.path == "/v1/health" || req.path == "/v1/completions") { + if (req.path == "/v1/health") { return; } @@ -4184,6 +4502,7 @@ int main(int argc, char ** argv) { auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { static const std::unordered_set public_endpoints = { "/health", + "/v1/health", "/models", "/v1/models", "/api/tags" @@ -4291,18 +4610,18 @@ int main(int argc, char ** argv) { } // TODO: get rid of this dynamic_cast - auto res_metrics = dynamic_cast(result.get()); - GGML_ASSERT(res_metrics != nullptr); + auto res_task = dynamic_cast(result.get()); + GGML_ASSERT(res_task != nullptr); // optionally return "fail_on_no_slot" error if (req.has_param("fail_on_no_slot")) { - if (res_metrics->n_idle_slots == 0) { + if (res_task->n_idle_slots == 0) { res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); return; } } - res_ok(res, res_metrics->slots_data); + res_ok(res, res_task->slots_data); }; const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { @@ -4330,56 +4649,56 @@ int main(int argc, char ** argv) { } // TODO: get rid of this dynamic_cast - auto res_metrics = dynamic_cast(result.get()); - GGML_ASSERT(res_metrics != nullptr); + auto res_task = dynamic_cast(result.get()); + GGML_ASSERT(res_task != nullptr); // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names json all_metrics_def = json { {"counter", {{ {"name", "prompt_tokens_total"}, {"help", "Number of prompt tokens processed."}, - {"value", (uint64_t) res_metrics->n_prompt_tokens_processed_total} + {"value", (uint64_t) res_task->n_prompt_tokens_processed_total} }, { {"name", "prompt_seconds_total"}, {"help", "Prompt process time"}, - {"value", (uint64_t) res_metrics->t_prompt_processing_total / 1.e3} + {"value", (uint64_t) res_task->t_prompt_processing_total / 1.e3} }, { {"name", "tokens_predicted_total"}, {"help", "Number of generation tokens processed."}, - {"value", (uint64_t) res_metrics->n_tokens_predicted_total} + {"value", (uint64_t) res_task->n_tokens_predicted_total} }, { {"name", "tokens_predicted_seconds_total"}, {"help", "Predict process time"}, - {"value", (uint64_t) res_metrics->t_tokens_generation_total / 1.e3} + {"value", (uint64_t) res_task->t_tokens_generation_total / 1.e3} }, { {"name", "n_decode_total"}, {"help", "Total number of llama_decode() calls"}, - {"value", res_metrics->n_decode_total} + {"value", res_task->n_decode_total} }, { {"name", "n_past_max"}, {"help", "Largest observed n_past."}, - {"value", res_metrics->n_past_max} + {"value", res_task->n_past_max} }, { {"name", "n_busy_slots_per_decode"}, {"help", "Average number of busy slots per llama_decode() call"}, - {"value", (float) res_metrics->n_busy_slots_total / std::max((float) res_metrics->n_decode_total, 1.f)} + {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)} }}}, {"gauge", {{ {"name", "prompt_tokens_seconds"}, {"help", "Average prompt throughput in tokens/s."}, - {"value", res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.} + {"value", res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.} },{ {"name", "predicted_tokens_seconds"}, {"help", "Average generation throughput in tokens/s."}, - {"value", res_metrics->n_tokens_predicted ? 1.e3 / res_metrics->t_tokens_generation * res_metrics->n_tokens_predicted : 0.} + {"value", res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.} },{ {"name", "requests_processing"}, {"help", "Number of requests processing."}, - {"value", (uint64_t) res_metrics->n_processing_slots} + {"value", (uint64_t) res_task->n_processing_slots} },{ {"name", "requests_deferred"}, {"help", "Number of requests deferred."}, - {"value", (uint64_t) res_metrics->n_tasks_deferred} + {"value", (uint64_t) res_task->n_tasks_deferred} }}} }; @@ -4400,7 +4719,7 @@ int main(int argc, char ** argv) { } } - res.set_header("Process-Start-Time-Unix", std::to_string(res_metrics->t_start)); + res.set_header("Process-Start-Time-Unix", std::to_string(res_task->t_start)); res.set_content(prometheus.str(), "text/plain; version=0.0.4"); res.status = 200; // HTTP OK @@ -4524,9 +4843,22 @@ int main(int argc, char ** argv) { }; const auto handle_props = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + json default_generation_settings_for_props; + + { + slot_params params; + + params.sampling = ctx_server.params_base.sampling; + + default_generation_settings_for_props = json { + {"params", params.to_json(true)}, + {"n_ctx", ctx_server.slots[0].n_ctx}, + }; + } + // this endpoint is publicly available, please only return what is safe to be exposed json data = { - { "default_generation_settings", ctx_server.default_generation_settings_for_props }, + { "default_generation_settings", default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model.path }, { "modalities", json { @@ -4623,20 +4955,28 @@ int main(int argc, char ** argv) { // Everything else, including multimodal completions. inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); } - + const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel; tasks.reserve(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { + auto n_prompt_tokens = inputs[i].size(); + if (n_prompt_tokens >= n_ctx_slot) { + json error_data = format_error_response("the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + error_data["n_prompt_tokens"] = n_prompt_tokens; + error_data["n_ctx"] = n_ctx_slot; + res_error(res, error_data); + return; + } server_task task = server_task(type); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; - task.prompt_tokens = std::move(inputs[i]); - task.params = server_task::params_from_json_cmpl( + task.tokens = std::move(inputs[i]); + task.params = server_task::params_from_json_cmpl( ctx_server.ctx, ctx_server.params_base, data); - task.id_selected_slot = json_value(data, "id_slot", -1); + task.id_slot = json_value(data, "id_slot", -1); // OAI-compat task.params.oaicompat = oaicompat; @@ -5005,9 +5345,9 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < tokenized_prompts.size(); i++) { server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.prompt_tokens = std::move(tokenized_prompts[i]); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tokenized_prompts[i]); // OAI-compat task.params.oaicompat = oaicompat; @@ -5103,10 +5443,10 @@ int main(int argc, char ** argv) { tasks.reserve(documents.size()); for (size_t i = 0; i < documents.size(); i++) { auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); - server_task task = server_task(SERVER_TASK_TYPE_RERANK); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.prompt_tokens = std::move(tmp); + server_task task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tmp); tasks.push_back(std::move(task)); } @@ -5232,6 +5572,7 @@ int main(int argc, char ** argv) { // register API routes svr->Get (params.api_prefix + "/health", handle_health); // public endpoint (no API key check) + svr->Get (params.api_prefix + "/v1/health", handle_health); // public endpoint (no API key check) svr->Get (params.api_prefix + "/metrics", handle_metrics); svr->Get (params.api_prefix + "/props", handle_props); svr->Post(params.api_prefix + "/props", handle_props_change); @@ -5363,7 +5704,7 @@ int main(int argc, char ** argv) { #endif LOG_INF("%s: server is listening on %s - starting the main loop\n", __func__, - is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() : + is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() : string_format("http://%s:%d", params.hostname.c_str(), params.port).c_str()); // this call blocks the main thread until queue_tasks.terminate() is called diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index 829af2ebe7..720b136b05 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -66,8 +66,7 @@ def test_server_slots(): assert len(res.body) == server.n_slots assert server.n_ctx is not None and server.n_slots is not None assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots - assert "params" in res.body[0] - assert res.body[0]["params"]["seed"] == server.seed + assert "params" not in res.body[0] def test_load_split_model(): diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 2979ed4bb7..d56d3d5f17 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -19,8 +19,8 @@ def create_server(): (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None), (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'), (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None), (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None), (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None), ] @@ -54,7 +54,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", [ ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), - ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), + ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"), ] ) def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): @@ -408,6 +408,28 @@ def test_context_size_exceeded(): assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots +def test_context_size_exceeded_stream(): + global server + server.start() + try: + for _ in server.make_stream_request("POST", "/chat/completions", data={ + "messages": [ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ] * 100, # make the prompt too long + "stream": True}): + pass + assert False, "Should have failed" + except ServerError as e: + assert e.code == 400 + assert "error" in e.body + assert e.body["error"]["type"] == "exceed_context_size_error" + assert e.body["error"]["n_prompt_tokens"] > 0 + assert server.n_ctx is not None + assert server.n_slots is not None + assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots + + @pytest.mark.parametrize( "n_batch,batch_count,reuse_cache", [ diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index 11483e679a..00ba78cf67 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -16,7 +16,7 @@ def create_server(): @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [ ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False), - ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True), + ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True), ]) def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool): global server @@ -41,7 +41,7 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), - ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), + ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), ]) def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): global server diff --git a/tools/server/tests/unit/test_ctx_shift.py b/tools/server/tests/unit/test_ctx_shift.py index 92e49f2bb0..4adbbde64f 100644 --- a/tools/server/tests/unit/test_ctx_shift.py +++ b/tools/server/tests/unit/test_ctx_shift.py @@ -4,6 +4,12 @@ from utils import * server = ServerPreset.tinyllama2() +SHORT_TEXT = """ +Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. +Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. +Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. +""".strip() + LONG_TEXT = """ Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. @@ -21,19 +27,18 @@ def create_server(): def test_ctx_shift_enabled(): - # the prompt is 301 tokens + # the prompt is 226 tokens # the slot context is 512/2 = 256 tokens - # the prompt is truncated to keep the last (301 - 256/2) = 173 tokens # 96 tokens are generated thanks to shifting the context when it gets full global server server.enable_ctx_shift = True server.start() res = server.make_request("POST", "/completion", data={ "n_predict": 96, - "prompt": LONG_TEXT, + "prompt": SHORT_TEXT, }) assert res.status_code == 200 - assert res.body["timings"]["prompt_n"] == 173 + assert res.body["timings"]["prompt_n"] == 226 assert res.body["timings"]["predicted_n"] == 96 assert res.body["truncated"] is True diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index abd6fff10d..4ba3d43c33 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -35,6 +35,12 @@ class ServerResponse: body: dict | Any +class ServerError(Exception): + def __init__(self, code, body): + self.code = code + self.body = body + + class ServerProcess: # default options debug: bool = False @@ -297,6 +303,8 @@ class ServerProcess: response = requests.post(url, headers=headers, json=data, stream=True) else: raise ValueError(f"Unimplemented method: {method}") + if response.status_code != 200: + raise ServerError(response.status_code, response.json()) for line_bytes in response.iter_lines(): line = line_bytes.decode("utf-8") if '[DONE]' in line: diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 4ca1423aaf..f175115f4f 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -31,10 +31,10 @@ using json = nlohmann::ordered_json; -#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) +#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) #define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) #define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) @@ -1102,6 +1102,7 @@ public: ~server_tokens() = default; // Prevent copying + // TODO: server_tokens should be copyable - remove this: server_tokens(const server_tokens&) = delete; server_tokens& operator=(const server_tokens&) = delete; @@ -1119,7 +1120,7 @@ public: } } - server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} + server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} // for debugging std::string str() const { @@ -1144,9 +1145,8 @@ public: auto it = map_pos_to_media.find(pos); if (it != map_pos_to_media.end()) { return it->second; - } else { - throw std::runtime_error("Chunk not found"); } + throw std::runtime_error("Chunk not found"); } void push_back(llama_token tok) { @@ -1170,7 +1170,7 @@ public: map_pos_to_media[start_pos] = std::move(new_chunk); } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { size_t n_tokens; - auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); + const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); for (size_t i = 0; i < n_tokens; ++i) { push_back(text_tokens[i]); } @@ -1190,7 +1190,7 @@ public: // We could also just check, but this will prevent silently dropping MTMD data. GGML_ASSERT(has_mtmd); for (auto it = tokens.map_pos_to_media.begin(); it != tokens.map_pos_to_media.end(); ) { - auto chunk = tokens.map_pos_to_media[it->first].get(); + auto * chunk = tokens.map_pos_to_media[it->first].get(); mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); map_pos_to_media[start_pos+it->first] = std::move(new_chunk); } @@ -1271,33 +1271,52 @@ public: } size_t get_common_prefix(const server_tokens & b) const { - size_t max_idx = std::min(tokens.size(), b.tokens.size()); - for (size_t i = 0; i < max_idx; ++i) { - auto & ai = tokens[i]; - auto & bi = b.tokens[i]; + const size_t max_idx = std::min(tokens.size(), b.tokens.size()); - if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { - GGML_ASSERT(has_mtmd); - const auto & a_chunk = find_chunk(i); - const auto & b_chunk = b.find_chunk(i); - GGML_ASSERT(a_chunk && b_chunk); - std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get()); - std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get()); - size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get()); - size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get()); - if (ai_id == bi_id && a_pos == b_pos) { - GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen - i += a_pos - 1; // will be +1 by the for loop + if (!has_mtmd) { + for (size_t i = 0; i < max_idx; ++i) { + if (tokens[i] == b.tokens[i]) { continue; - } else { - return i; } - } else if (ai == bi) { - continue; - } else { + return i; } + + return max_idx; } + + for (size_t i = 0; i < max_idx; ++i) { + const llama_token ai = tokens[i]; + const llama_token bi = b.tokens[i]; + + if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { + const auto & a_chunk = find_chunk(i); + const auto & b_chunk = b.find_chunk(i); + + GGML_ASSERT(a_chunk && b_chunk); + + const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); + const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); + + const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get()); + const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get()); + + if (id_ai == id_bi && pos_a == pos_b) { + GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen + i += pos_a - 1; // will be +1 by the for loop + continue; + } + + return i; + } + + if (ai == bi) { + continue; + } + + return i; + } + return max_idx; // all tokens are equal } @@ -1308,7 +1327,7 @@ public: const int32_t n_vocab = llama_vocab_n_tokens(vocab); for (size_t i = 0; i < tokens.size(); ++i) { - auto & t = tokens[i]; + const auto & t = tokens[i]; if (t == LLAMA_TOKEN_NULL) { try { const auto & chunk = find_chunk(i); @@ -1330,8 +1349,8 @@ public: mtmd_context * mctx, llama_pos n_past, int32_t seq_id, - llama_pos & n_pos_out) { - auto & chunk = find_chunk(n_past); + llama_pos & n_pos_out) const { + const auto & chunk = find_chunk(n_past); const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio"; SRV_INF("processing %s...\n", name); diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte index c923bf9e04..fed0cf7126 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte @@ -1,7 +1,6 @@ { + const { updateConfig } = await import('$lib/stores/settings.svelte'); + updateConfig('disableReasoningFormat', false); + }} /> { + const { updateConfig } = await import('$lib/stores/settings.svelte'); + updateConfig('disableReasoningFormat', false); + }} /> { + const { updateConfig } = await import('$lib/stores/settings.svelte'); + updateConfig('disableReasoningFormat', false); + }} +/> + + { + const { updateConfig } = await import('$lib/stores/settings.svelte'); + updateConfig('disableReasoningFormat', true); + }} +/> + + { + const { updateConfig } = await import('$lib/stores/settings.svelte'); + updateConfig('disableReasoningFormat', false); // Phase 1: Stream reasoning content in chunks let reasoningText = 'I need to think about this carefully. Let me break down the problem:\n\n1. The user is asking for help with something complex\n2. I should provide a thorough and helpful response\n3. I need to consider multiple approaches\n4. The best solution would be to explain step by step\n\nThis approach will ensure clarity and understanding.'; @@ -187,126 +192,16 @@ message: processingMessage }} play={async () => { + const { updateConfig } = await import('$lib/stores/settings.svelte'); + updateConfig('disableReasoningFormat', false); // Import the chat store to simulate loading state const { chatStore } = await import('$lib/stores/chat.svelte'); - + // Set loading state to true to trigger the processing UI chatStore.isLoading = true; - + // Simulate the processing state hook behavior // This will show the "Generating..." text and parameter details - await new Promise(resolve => setTimeout(resolve, 100)); + await new Promise((resolve) => setTimeout(resolve, 100)); }} /> - - - - - - { - // Phase 1: Stream reasoning content - const thinkingContent = - 'Let me work through this problem systematically:\n\n1. First, I need to understand what the user is asking\n2. Then I should consider different approaches\n3. I need to evaluate the pros and cons\n4. Finally, I should provide a clear recommendation\n\nThis step-by-step approach will ensure accuracy.'; - - let currentContent = '\n'; - streamingThinkMessage.content = currentContent; - - for (let i = 0; i < thinkingContent.length; i++) { - currentContent += thinkingContent[i]; - streamingThinkMessage.content = currentContent; - await new Promise((resolve) => setTimeout(resolve, 5)); - } - - // Close the thinking block - currentContent += '\n\n\n'; - streamingThinkMessage.content = currentContent; - await new Promise((resolve) => setTimeout(resolve, 200)); - - // Phase 2: Stream main response content - const responseContent = - "Based on my analysis above, here's the solution:\n\n**Key Points:**\n- The approach should be systematic\n- We need to consider all factors\n- Implementation should be step-by-step\n\nThis ensures the best possible outcome."; - - for (let i = 0; i < responseContent.length; i++) { - currentContent += responseContent[i]; - streamingThinkMessage.content = currentContent; - await new Promise((resolve) => setTimeout(resolve, 10)); - } - - streamingThinkMessage.timestamp = Date.now(); - }} -> -
- -
-
- - { - // Phase 1: Stream [THINK] reasoning content - const thinkingContent = - 'Using the DeepSeek format now:\n\n- This demonstrates the [THINK] bracket format\n- Should parse identically to <think> tags\n- The UI should display this in the thinking section\n- Main content should be separate\n\nBoth formats provide the same functionality.'; - - let currentContent = '[THINK]\n'; - streamingBracketMessage.content = currentContent; - - for (let i = 0; i < thinkingContent.length; i++) { - currentContent += thinkingContent[i]; - streamingBracketMessage.content = currentContent; - await new Promise((resolve) => setTimeout(resolve, 5)); - } - - // Close the thinking block - currentContent += '\n[/THINK]\n\n'; - streamingBracketMessage.content = currentContent; - await new Promise((resolve) => setTimeout(resolve, 200)); - - // Phase 2: Stream main response content - const responseContent = - "Here's my response after using the [THINK] format:\n\n**Observations:**\n- Both <think> and [THINK] formats work seamlessly\n- The parsing logic handles both cases\n- UI display is consistent across formats\n\nThis demonstrates the enhanced thinking content support."; - - for (let i = 0; i < responseContent.length; i++) { - currentContent += responseContent[i]; - streamingBracketMessage.content = currentContent; - await new Promise((resolve) => setTimeout(resolve, 10)); - } - - streamingBracketMessage.timestamp = Date.now(); - }} -> -
- -
-