diff --git a/.devops/intel.Dockerfile b/.devops/intel.Dockerfile index 9ce80a71eb..cd2f9aa79b 100644 --- a/.devops/intel.Dockerfile +++ b/.devops/intel.Dockerfile @@ -1,8 +1,8 @@ -ARG ONEAPI_VERSION=2025.1.1-0-devel-ubuntu24.04 +ARG ONEAPI_VERSION=2025.2.2-0-devel-ubuntu24.04 ## Build Image -FROM intel/oneapi-basekit:$ONEAPI_VERSION AS build +FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS build ARG GGML_SYCL_F16=OFF RUN apt-get update && \ @@ -31,7 +31,7 @@ RUN mkdir -p /app/full \ && cp requirements.txt /app/full \ && cp .devops/tools.sh /app/full/tools.sh -FROM intel/oneapi-basekit:$ONEAPI_VERSION AS base +FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS base RUN apt-get update \ && apt-get install -y libgomp1 curl\ diff --git a/.devops/nix/package.nix b/.devops/nix/package.nix index 651a54db4c..41748e89d5 100644 --- a/.devops/nix/package.nix +++ b/.devops/nix/package.nix @@ -128,10 +128,6 @@ effectiveStdenv.mkDerivation (finalAttrs: { }; postPatch = '' - substituteInPlace ./ggml/src/ggml-metal/ggml-metal.m \ - --replace '[bundle pathForResource:@"ggml-metal" ofType:@"metal"];' "@\"$out/bin/ggml-metal.metal\";" - substituteInPlace ./ggml/src/ggml-metal/ggml-metal.m \ - --replace '[bundle pathForResource:@"default" ofType:@"metallib"];' "@\"$out/bin/default.metallib\";" ''; # With PR#6015 https://github.com/ggml-org/llama.cpp/pull/6015, diff --git a/.devops/rocm.Dockerfile b/.devops/rocm.Dockerfile index 106c62b4dc..df9058d946 100644 --- a/.devops/rocm.Dockerfile +++ b/.devops/rocm.Dockerfile @@ -1,8 +1,8 @@ ARG UBUNTU_VERSION=24.04 # This needs to generally match the container host's environment. -ARG ROCM_VERSION=6.4 -ARG AMDGPU_VERSION=6.4 +ARG ROCM_VERSION=7.0 +ARG AMDGPU_VERSION=7.0 # Target the ROCm build image ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete @@ -13,9 +13,8 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build # Unless otherwise specified, we make a fat build. # List from https://github.com/ggml-org/llama.cpp/pull/1087#issuecomment-1682807878 # This is mostly tied to rocBLAS supported archs. -# gfx803, gfx900, gfx1032, gfx1101, gfx1102,not officialy supported -# gfx906 is deprecated -#check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.1/reference/system-requirements.html +# gfx803, gfx900, gfx906, gfx1032, gfx1101, gfx1102,not officialy supported +# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.1/reference/system-requirements.html ARG ROCM_DOCKER_ARCH='gfx803;gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1010;gfx1030;gfx1032;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx1151' #ARG ROCM_DOCKER_ARCH='gfx1151' @@ -36,13 +35,10 @@ WORKDIR /app COPY . . -RUN git clone https://github.com/rocm/rocwmma --branch develop --depth 1 - RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \ cmake -S . -B build \ -DGGML_HIP=ON \ -DGGML_HIP_ROCWMMA_FATTN=ON \ - -DCMAKE_HIP_FLAGS="-I$(pwd)/rocwmma/library/include/" \ -DAMDGPU_TARGETS="$ROCM_DOCKER_ARCH" \ -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON \ -DCMAKE_BUILD_TYPE=Release -DLLAMA_BUILD_TESTS=OFF \ diff --git a/.github/actions/install-exe/action.yml b/.github/actions/install-exe/action.yml new file mode 100644 index 0000000000..002bec83c7 --- /dev/null +++ b/.github/actions/install-exe/action.yml @@ -0,0 +1,36 @@ +name: "Install exe" +description: "Download and install exe" +inputs: + url: + description: "URL of the exe installer" + required: true + args: + description: "Installer arguments" + required: true + timeout: + description: "Timeout (in ms)" + required: false + default: "600000" + +runs: + using: "composite" + steps: + - name: Install EXE + shell: pwsh + run: | + $ErrorActionPreference = "Stop" + write-host "Downloading Installer EXE" + Invoke-WebRequest -Uri "${{ inputs.url }}" -OutFile "${env:RUNNER_TEMP}\temp-install.exe" + write-host "Installing" + $proc = Start-Process "${env:RUNNER_TEMP}\temp-install.exe" -ArgumentList '${{ inputs.args }}' -NoNewWindow -PassThru + $completed = $proc.WaitForExit(${{ inputs.timeout }}) + if (-not $completed) { + Write-Error "Installer timed out. Killing the process" + $proc.Kill() + exit 1 + } + if ($proc.ExitCode -ne 0) { + Write-Error "Installer failed with exit code $($proc.ExitCode)" + exit 1 + } + write-host "Completed installation" diff --git a/.github/actions/linux-setup-spacemit/action.yml b/.github/actions/linux-setup-spacemit/action.yml new file mode 100644 index 0000000000..e2193e8931 --- /dev/null +++ b/.github/actions/linux-setup-spacemit/action.yml @@ -0,0 +1,20 @@ +name: "Linux - Setup SpacemiT Toolchain" +description: "Setup SpacemiT Toolchain for Linux" +inputs: + path: + description: "Installation path" + required: true + version: + description: "SpacemiT toolchain version" + required: true + +runs: + using: "composite" + steps: + - name: Setup SpacemiT Toolchain + id: setup + uses: ./.github/actions/unarchive-tar + with: + url: https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v${{ inputs.version }}.tar.xz + path: ${{ inputs.path }} + strip: 1 diff --git a/.github/actions/linux-setup-vulkan/action.yml b/.github/actions/linux-setup-vulkan/action.yml new file mode 100644 index 0000000000..4d29837feb --- /dev/null +++ b/.github/actions/linux-setup-vulkan/action.yml @@ -0,0 +1,20 @@ +name: "Linux - Setup Vulkan SDK" +description: "Setup Vulkan SDK for Linux" +inputs: + path: + description: "Installation path" + required: true + version: + description: "Vulkan SDK version" + required: true + +runs: + using: "composite" + steps: + - name: Setup Vulkan SDK + id: setup + uses: ./.github/actions/unarchive-tar + with: + url: https://sdk.lunarg.com/sdk/download/${{ inputs.version }}/linux/vulkan_sdk.tar.xz + path: ${{ inputs.path }} + strip: 1 diff --git a/.github/actions/unarchive-tar/action.yml b/.github/actions/unarchive-tar/action.yml new file mode 100644 index 0000000000..b97e402f46 --- /dev/null +++ b/.github/actions/unarchive-tar/action.yml @@ -0,0 +1,27 @@ +name: "Unarchive tar" +description: "Download and unarchive tar into directory" +inputs: + url: + description: "URL of the tar archive" + required: true + path: + description: "Directory to unarchive into" + required: true + type: + description: "Compression type (tar option)" + required: false + default: "J" + strip: + description: "Strip components" + required: false + default: "0" + +runs: + using: "composite" + steps: + - name: Unarchive into directory + shell: bash + run: | + mkdir -p ${{ inputs.path }} + cd ${{ inputs.path }} + curl --no-progress-meter ${{ inputs.url }} | tar -${{ inputs.type }}x --strip-components=${{ inputs.strip }} diff --git a/.github/actions/windows-setup-rocm/action.yml b/.github/actions/windows-setup-rocm/action.yml new file mode 100644 index 0000000000..b83e6e295b --- /dev/null +++ b/.github/actions/windows-setup-rocm/action.yml @@ -0,0 +1,15 @@ +name: "Windows - Setup ROCm" +description: "Setup ROCm for Windows" +inputs: + version: + description: "ROCm version" + required: true + +runs: + using: "composite" + steps: + - name: Setup ROCm + uses: ./.github/actions/install-exe + with: + url: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ inputs.version }}-WinSvr2022-For-HIP.exe + args: -install diff --git a/.github/workflows/build-cache.yml b/.github/workflows/build-cache.yml new file mode 100644 index 0000000000..6a22e41c3b --- /dev/null +++ b/.github/workflows/build-cache.yml @@ -0,0 +1,89 @@ +name: Build Actions Cache + +on: + workflow_dispatch: # allows manual triggering + schedule: + - cron: '0 * * * *' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + ubuntu-24-vulkan-cache: + runs-on: ubuntu-24.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Get latest Vulkan SDK version + id: vulkan_sdk_version + run: | + echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV" + + - name: Setup Cache + uses: actions/cache@v4 + id: cache-sdk + with: + path: ./vulkan_sdk + key: vulkan-sdk-${{ env.VULKAN_SDK_VERSION }}-${{ runner.os }} + + - name: Setup Vulkan SDK + if: steps.cache-sdk.outputs.cache-hit != 'true' + uses: ./.github/actions/linux-setup-vulkan + with: + path: ./vulkan_sdk + version: ${{ env.VULKAN_SDK_VERSION }} + + ubuntu-24-spacemit-cache: + runs-on: ubuntu-24.04 + + env: + # Make sure this is in sync with build-linux-cross.yml + SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2" + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Setup Cache + uses: actions/cache@v4 + id: cache-toolchain + with: + path: ./spacemit_toolchain + key: spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}-${{ runner.os }} + + - name: Setup SpacemiT Toolchain + if: steps.cache-toolchain.outputs.cache-hit != 'true' + uses: ./.github/actions/linux-setup-spacemit + with: + path: ./spacemit_toolchain + version: ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} + + windows-2022-rocm-cache: + runs-on: windows-2022 + + env: + # Make sure this is in sync with build.yml + HIPSDK_INSTALLER_VERSION: "25.Q3" + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Setup Cache + uses: actions/cache@v4 + id: cache-rocm + with: + path: C:\Program Files\AMD\ROCm + key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }} + + - name: Setup ROCm + if: steps.cache-rocm.outputs.cache-hit != 'true' + uses: ./.github/actions/windows-setup-rocm + with: + version: ${{ env.HIPSDK_INSTALLER_VERSION }} diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml index 2b101876c5..937306f7af 100644 --- a/.github/workflows/build-linux-cross.yml +++ b/.github/workflows/build-linux-cross.yml @@ -258,31 +258,29 @@ jobs: runs-on: ubuntu-24.04 env: + # Make sure this is in sync with build-cache.yml SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2" - SPACEMIT_IME_TOOLCHAIN_PATH: "spacemit-toolchain-linux-glibc-x86_64" steps: - uses: actions/checkout@v4 - - name: Cache Toolchain + - name: Use SpacemiT Toolchain Cache uses: actions/cache@v4 - id: cache-spacemit-ime-cross-toolchain + id: cache-toolchain with: - path: ./${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} - key: ${{ runner.os }}-spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} + path: ./spacemit_toolchain + key: spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}-${{ runner.os }} - - name: Setup Toolchain - if: steps.cache-spacemit-ime-cross-toolchain.outputs.cache-hit != 'true' - run: | - wget --quiet --no-check-certificate https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}.tar.xz -O ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz - rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} - mkdir -p ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} - tar xf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz -C ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} --strip-components=1 - rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz + - name: Setup SpacemiT Toolchain + if: steps.cache-toolchain.outputs.cache-hit != 'true' + uses: ./.github/actions/linux-setup-spacemit + with: + path: ./spacemit_toolchain + version: ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} - name: Build run: | - export RISCV_ROOT_PATH=${PWD}/${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} + export RISCV_ROOT_PATH=${PWD}/spacemit_toolchain cmake -B build -DLLAMA_CURL=OFF \ -DCMAKE_BUILD_TYPE=Release \ -DGGML_OPENMP=OFF \ diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9a1a074cbe..ec2ab5a58d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -97,7 +97,7 @@ jobs: ctest -L 'main|curl' --verbose --timeout 900 macOS-latest-cmake-x64: - runs-on: macos-13 + runs-on: macos-15-intel steps: - name: Clone @@ -207,7 +207,7 @@ jobs: - name: ccache uses: ggml-org/ccache-action@v1.2.16 with: - key: ubuntu-cpu-cmake + key: ubuntu-cpu-cmake-${{ matrix.build }} evict-old-files: 1d - name: Build Dependencies @@ -362,11 +362,11 @@ jobs: id: checkout uses: actions/checkout@v4 - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ubuntu-latest-cmake-rpc - evict-old-files: 1d + # - name: ccache + # uses: ggml-org/ccache-action@v1.2.16 + # with: + # key: ubuntu-latest-cmake-rpc + # evict-old-files: 1d - name: Dependencies id: depends @@ -387,8 +387,8 @@ jobs: cd build ctest -L main --verbose - ubuntu-22-cmake-vulkan: - runs-on: ubuntu-22.04 + ubuntu-24-cmake-vulkan: + runs-on: ubuntu-24.04 steps: - name: Clone @@ -398,20 +398,39 @@ jobs: - name: ccache uses: ggml-org/ccache-action@v1.2.16 with: - key: ubuntu-22-cmake-vulkan + key: ubuntu-24-cmake-vulkan evict-old-files: 1d - name: Dependencies id: depends run: | - wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add - - sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list + sudo add-apt-repository -y ppa:kisak/kisak-mesa sudo apt-get update -y - sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk libcurl4-openssl-dev + sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libcurl4-openssl-dev + + - name: Get latest Vulkan SDK version + id: vulkan_sdk_version + run: | + echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV" + + - name: Use Vulkan SDK Cache + uses: actions/cache@v4 + id: cache-sdk + with: + path: ./vulkan_sdk + key: vulkan-sdk-${{ env.VULKAN_SDK_VERSION }}-${{ runner.os }} + + - name: Setup Vulkan SDK + if: steps.cache-sdk.outputs.cache-hit != 'true' + uses: ./.github/actions/linux-setup-vulkan + with: + path: ./vulkan_sdk + version: ${{ env.VULKAN_SDK_VERSION }} - name: Build id: cmake_build run: | + source ./vulkan_sdk/setup-env.sh cmake -B build \ -DGGML_VULKAN=ON cmake --build build --config Release -j $(nproc) @@ -421,11 +440,12 @@ jobs: run: | cd build export GGML_VK_VISIBLE_DEVICES=0 + export GGML_VK_DISABLE_F16=1 # This is using llvmpipe and runs slower than other backends ctest -L main --verbose --timeout 4200 - ubuntu-22-cmake-webgpu: - runs-on: ubuntu-22.04 + ubuntu-24-cmake-webgpu: + runs-on: ubuntu-24.04 steps: - name: Clone @@ -435,16 +455,34 @@ jobs: - name: ccache uses: ggml-org/ccache-action@v1.2.16 with: - key: ubuntu-22-cmake-webgpu + key: ubuntu-24-cmake-webgpu evict-old-files: 1d - - name: Vulkan SDK Dependencies - id: vulkan-depends + - name: Dependencies + id: depends run: | - wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add - - sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list + sudo add-apt-repository -y ppa:kisak/kisak-mesa sudo apt-get update -y - sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk libcurl4-openssl-dev + sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libcurl4-openssl-dev + + - name: Get latest Vulkan SDK version + id: vulkan_sdk_version + run: | + echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV" + + - name: Use Vulkan SDK Cache + uses: actions/cache@v4 + id: cache-sdk + with: + path: ./vulkan_sdk + key: vulkan-sdk-${{ env.VULKAN_SDK_VERSION }}-${{ runner.os }} + + - name: Setup Vulkan SDK + if: steps.cache-sdk.outputs.cache-hit != 'true' + uses: ./.github/actions/linux-setup-vulkan + with: + path: ./vulkan_sdk + version: ${{ env.VULKAN_SDK_VERSION }} - name: Dawn Dependency id: dawn-depends @@ -487,7 +525,7 @@ jobs: id: depends run: | sudo apt-get update - sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libcurl4-openssl-dev + sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libcurl4-openssl-dev rocwmma-dev - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1059,7 +1097,7 @@ jobs: shell: bash env: - WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7cd9bba0-7aab-4e30-b3ae-2221006a4a05/intel-oneapi-base-toolkit-2025.1.1.34_offline.exe + WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/24751ead-ddc5-4479-b9e6-f9fe2ff8b9f2/intel-deep-learning-essentials-2025.2.1.25_offline.exe WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI" steps: @@ -1090,6 +1128,7 @@ jobs: env: # The ROCm version must correspond to the version used in the HIP SDK. ROCM_VERSION: "6.4.2" + # Make sure this is in sync with build-cache.yml HIPSDK_INSTALLER_VERSION: "25.Q3" steps: @@ -1097,38 +1136,25 @@ jobs: id: checkout uses: actions/checkout@v4 - - name: Clone rocWMMA repository - id: clone_rocwmma + - name: Grab rocWMMA package + id: grab_rocwmma run: | - git clone https://github.com/rocm/rocwmma --branch rocm-${{ env.ROCM_VERSION }} --depth 1 + curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/${{ env.ROCM_VERSION }}/pool/main/r/rocwmma-dev/rocwmma-dev_1.7.0.60402-120~24.04_amd64.deb" + 7z x rocwmma.deb + 7z x data.tar - - name: Cache ROCm Installation - id: cache-rocm + - name: Use ROCm Installation Cache uses: actions/cache@v4 + id: cache-rocm with: path: C:\Program Files\AMD\ROCm key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }} - - name: Install ROCm + - name: Setup ROCm if: steps.cache-rocm.outputs.cache-hit != 'true' - id: depends - run: | - $ErrorActionPreference = "Stop" - write-host "Downloading AMD HIP SDK Installer" - Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ env.HIPSDK_INSTALLER_VERSION }}-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" - write-host "Installing AMD HIP SDK" - $proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru - $completed = $proc.WaitForExit(600000) - if (-not $completed) { - Write-Error "ROCm installation timed out after 10 minutes. Killing the process" - $proc.Kill() - exit 1 - } - if ($proc.ExitCode -ne 0) { - Write-Error "ROCm installation failed with exit code $($proc.ExitCode)" - exit 1 - } - write-host "Completed AMD HIP SDK installation" + uses: ./.github/actions/windows-setup-rocm + with: + version: ${{ env.HIPSDK_INSTALLER_VERSION }} - name: Verify ROCm id: verify @@ -1161,8 +1187,9 @@ jobs: cmake -G "Unix Makefiles" -B build -S . ` -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" ` -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` - -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/" ` + -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-${{ env.ROCM_VERSION }}/include/" ` -DCMAKE_BUILD_TYPE=Release ` + -DROCM_DIR="${env:HIP_PATH}" ` -DGGML_HIP=ON ` -DGGML_HIP_ROCWMMA_FATTN=ON ` -DGGML_RPC=ON ` diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 446c666b90..f73a2bc9f4 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -89,12 +89,15 @@ jobs: TYPE="-${{ matrix.config.tag }}" fi PREFIX="ghcr.io/${REPO_OWNER}/${REPO_NAME}:" + CACHETAGS="${PREFIX}buildcache${TYPE}" FULLTAGS="${PREFIX}full${TYPE},${PREFIX}full${TYPE}-${{ steps.srctag.outputs.name }}" LIGHTTAGS="${PREFIX}light${TYPE},${PREFIX}light${TYPE}-${{ steps.srctag.outputs.name }}" SERVERTAGS="${PREFIX}server${TYPE},${PREFIX}server${TYPE}-${{ steps.srctag.outputs.name }}" + echo "cache_output_tags=$CACHETAGS" >> $GITHUB_OUTPUT echo "full_output_tags=$FULLTAGS" >> $GITHUB_OUTPUT echo "light_output_tags=$LIGHTTAGS" >> $GITHUB_OUTPUT echo "server_output_tags=$SERVERTAGS" >> $GITHUB_OUTPUT + echo "cache_output_tags=$CACHETAGS" # print out for debugging echo "full_output_tags=$FULLTAGS" # print out for debugging echo "light_output_tags=$LIGHTTAGS" # print out for debugging echo "server_output_tags=$SERVERTAGS" # print out for debugging @@ -131,11 +134,14 @@ jobs: target: full provenance: false # using github experimental cache - cache-from: type=gha - cache-to: type=gha,mode=max + #cache-from: type=gha + #cache-to: type=gha,mode=max # return to this if the experimental github cache is having issues #cache-to: type=local,dest=/tmp/.buildx-cache #cache-from: type=local,src=/tmp/.buildx-cache + # using registry cache (no storage limit) + cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }} + cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max - name: Build and push Light Docker image (tagged + versioned) if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.light == true }} @@ -150,11 +156,14 @@ jobs: target: light provenance: false # using github experimental cache - cache-from: type=gha - cache-to: type=gha,mode=max + #cache-from: type=gha + #cache-to: type=gha,mode=max # return to this if the experimental github cache is having issues #cache-to: type=local,dest=/tmp/.buildx-cache #cache-from: type=local,src=/tmp/.buildx-cache + # using registry cache (no storage limit) + cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }} + cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max - name: Build and push Server Docker image (tagged + versioned) if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.server == true }} @@ -169,11 +178,14 @@ jobs: target: server provenance: false # using github experimental cache - cache-from: type=gha - cache-to: type=gha,mode=max + #cache-from: type=gha + #cache-to: type=gha,mode=max # return to this if the experimental github cache is having issues #cache-to: type=local,dest=/tmp/.buildx-cache #cache-from: type=local,src=/tmp/.buildx-cache + # using registry cache (no storage limit) + cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }} + cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max create_tag: name: Create and push git tag diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f461456edf..2ad3811594 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -75,7 +75,7 @@ jobs: name: llama-bin-macos-arm64.zip macOS-x64: - runs-on: macos-13 + runs-on: macos-15-intel steps: - name: Clone @@ -150,7 +150,7 @@ jobs: - name: ccache uses: ggml-org/ccache-action@v1.2.16 with: - key: ubuntu-cpu-cmake + key: ubuntu-cpu-cmake-${{ matrix.build }} evict-old-files: 1d - name: Dependencies @@ -462,7 +462,7 @@ jobs: shell: bash env: - WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7cd9bba0-7aab-4e30-b3ae-2221006a4a05/intel-oneapi-base-toolkit-2025.1.1.34_offline.exe + WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/24751ead-ddc5-4479-b9e6-f9fe2ff8b9f2/intel-deep-learning-essentials-2025.2.1.25_offline.exe WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI" @@ -505,6 +505,7 @@ jobs: cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_tbb_thread.2.dll" ./build/bin cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_adapter_level_zero.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_adapter_level_zero_v2.dll" ./build/bin cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_adapter_opencl.dll" ./build/bin cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_loader.dll" ./build/bin cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_win_proxy_loader.dll" ./build/bin @@ -513,10 +514,15 @@ jobs: cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/svml_dispmd.dll" ./build/bin cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libmmd.dll" ./build/bin cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libiomp5md.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/sycl-ls.exe" ./build/bin cp "${{ env.ONEAPI_ROOT }}/dnnl/latest/bin/dnnl.dll" ./build/bin cp "${{ env.ONEAPI_ROOT }}/tbb/latest/bin/tbb12.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/tcm/latest/bin/tcm.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/tcm/latest/bin/libhwloc-15.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/umf/latest/bin/umf.dll" ./build/bin + echo "cp oneAPI running time dll files to ./build/bin done" 7z a llama-bin-win-sycl-x64.zip ./build/bin/* @@ -543,10 +549,12 @@ jobs: id: checkout uses: actions/checkout@v4 - - name: Clone rocWMMA repository - id: clone_rocwmma + - name: Grab rocWMMA package + id: grab_rocwmma run: | - git clone https://github.com/rocm/rocwmma --branch develop --depth 1 + curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.0.1/pool/main/r/rocwmma-dev/rocwmma-dev_2.0.0.70001-42~24.04_amd64.deb" + 7z x rocwmma.deb + 7z x data.tar - name: Cache ROCm Installation id: cache-rocm @@ -601,7 +609,7 @@ jobs: cmake -G "Unix Makefiles" -B build -S . ` -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" ` -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` - -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/ -Wno-ignored-attributes -Wno-nested-anon-types" ` + -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.0.1/include/ -Wno-ignored-attributes -Wno-nested-anon-types" ` -DCMAKE_BUILD_TYPE=Release ` -DGGML_BACKEND_DL=ON ` -DGGML_NATIVE=OFF ` diff --git a/CODEOWNERS b/CODEOWNERS index 6a6468fc27..3b696bf94a 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -2,7 +2,7 @@ # multiplie collaborators per item can be specified /.devops/*.Dockerfile @ngxson -/.github/actions/ @slaren +/.github/actions/ @slaren @CISC /.github/workflows/ @CISC /.github/workflows/release.yml @slaren /.github/workflows/winget.yml @slaren @@ -14,6 +14,7 @@ /common/build-info.* @ggerganov /common/common.* @ggerganov /common/console.* @ggerganov +/common/http.* @angt /common/llguidance.* @ggerganov /common/log.* @ggerganov /common/sampling.* @ggerganov @@ -58,6 +59,9 @@ /ggml/src/ggml-cuda/mmq.* @JohannesGaessler /ggml/src/ggml-cuda/mmvf.* @JohannesGaessler /ggml/src/ggml-cuda/mmvq.* @JohannesGaessler +/ggml/src/ggml-cuda/fattn-wmma* @IMbackK +/ggml/src/ggml-hip/ @IMbackK +/ggml/src/ggml-cuda/vendors/hip.h @IMbackK /ggml/src/ggml-impl.h @ggerganov @slaren /ggml/src/ggml-metal/ @ggerganov /ggml/src/ggml-opencl/ @lhez @max-krasnyansky @@ -66,6 +70,7 @@ /ggml/src/ggml-rpc/ @rgerganov /ggml/src/ggml-threading.* @ggerganov @slaren /ggml/src/ggml-vulkan/ @0cc4m +/ggml/src/ggml-webgpu/ @reeselevine /ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM /ggml/src/ggml.c @ggerganov @slaren /ggml/src/ggml.cpp @ggerganov @slaren diff --git a/ci/run.sh b/ci/run.sh index b0af51723b..f925956a84 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -34,9 +34,9 @@ mkdir -p "$2" OUT=$(realpath "$1") MNT=$(realpath "$2") -rm -f "$OUT/*.log" -rm -f "$OUT/*.exit" -rm -f "$OUT/*.md" +rm -f $OUT/*.log +rm -f $OUT/*.exit +rm -f $OUT/*.md sd=`dirname $0` cd $sd/../ @@ -512,12 +512,7 @@ function gg_run_rerank_tiny { gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer_config.json gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/special_tokens_map.json gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/resolve/main/pytorch_model.bin - gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/sentence_bert_config.json - gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.txt - gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/modules.json - gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json - - gg_wget models-mnt/rerank-tiny/1_Pooling https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/1_Pooling/config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.json path_models="../models-mnt/rerank-tiny" @@ -607,6 +602,7 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then fi ret=0 + test $ret -eq 0 && gg_run ctest_debug test $ret -eq 0 && gg_run ctest_release @@ -624,4 +620,6 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then test $ret -eq 0 && gg_run ctest_with_model_release fi +cat $OUT/README.md + exit $ret diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 8ab3d44510..fe290bf8fd 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -56,6 +56,7 @@ add_library(${TARGET} STATIC common.h console.cpp console.h + http.h json-partial.cpp json-partial.h json-schema-to-grammar.cpp diff --git a/common/arg.cpp b/common/arg.cpp index 3c932264d0..ecc296485c 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -32,13 +32,11 @@ #include #include -//#define LLAMA_USE_CURL - #if defined(LLAMA_USE_CURL) #include #include #else -#include +#include "http.h" #endif #ifdef __linux__ @@ -54,6 +52,13 @@ #endif #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 +// isatty +#if defined(_WIN32) +#include +#else +#include +#endif + using json = nlohmann::ordered_json; std::initializer_list mmproj_examples = { @@ -100,6 +105,14 @@ static void write_file(const std::string & fname, const std::string & content) { } } +static bool is_output_a_tty() { +#if defined(_WIN32) + return _isatty(_fileno(stdout)); +#else + return isatty(1); +#endif +} + common_arg & common_arg::set_examples(std::initializer_list examples) { this->examples = std::move(examples); return *this; @@ -581,78 +594,11 @@ std::pair> common_remote_get_content(const std::string & #else -struct common_url { - std::string scheme; - std::string user; - std::string password; - std::string host; - std::string path; -}; - -static common_url parse_url(const std::string & url) { - common_url parts; - auto scheme_end = url.find("://"); - - if (scheme_end == std::string::npos) { - throw std::runtime_error("invalid URL: no scheme"); - } - parts.scheme = url.substr(0, scheme_end); - - if (parts.scheme != "http" && parts.scheme != "https") { - throw std::runtime_error("unsupported URL scheme: " + parts.scheme); +static void print_progress(size_t current, size_t total) { + if (!is_output_a_tty()) { + return; } - auto rest = url.substr(scheme_end + 3); - auto at_pos = rest.find('@'); - - if (at_pos != std::string::npos) { - auto auth = rest.substr(0, at_pos); - auto colon_pos = auth.find(':'); - if (colon_pos != std::string::npos) { - parts.user = auth.substr(0, colon_pos); - parts.password = auth.substr(colon_pos + 1); - } else { - parts.user = auth; - } - rest = rest.substr(at_pos + 1); - } - - auto slash_pos = rest.find('/'); - - if (slash_pos != std::string::npos) { - parts.host = rest.substr(0, slash_pos); - parts.path = rest.substr(slash_pos); - } else { - parts.host = rest; - parts.path = "/"; - } - return parts; -} - -static std::pair http_client(const std::string & url) { - common_url parts = parse_url(url); - - if (parts.host.empty()) { - throw std::runtime_error("error: invalid URL format"); - } - - if (!parts.user.empty()) { - throw std::runtime_error("error: user:password@ not supported yet"); // TODO - } - - httplib::Client cli(parts.scheme + "://" + parts.host); - cli.set_follow_location(true); - - // TODO cert - - return { std::move(cli), std::move(parts) }; -} - -static std::string show_masked_url(const common_url & parts) { - return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path; -} - -static void print_progress(size_t current, size_t total) { // TODO isatty if (!total) { return; } @@ -740,7 +686,7 @@ static bool common_download_file_single_online(const std::string & url, static const int max_attempts = 3; static const int retry_delay_seconds = 2; - auto [cli, parts] = http_client(url); + auto [cli, parts] = common_http_client(url); httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}}; if (!bearer_token.empty()) { @@ -820,7 +766,7 @@ static bool common_download_file_single_online(const std::string & url, // start the download LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n", - __func__, show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str()); + __func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str()); const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size); if (!was_pull_successful) { if (i + 1 < max_attempts) { @@ -848,7 +794,7 @@ static bool common_download_file_single_online(const std::string & url, std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params) { - auto [cli, parts] = http_client(url); + auto [cli, parts] = common_http_client(url); httplib::Headers headers = {{"User-Agent", "llama-cpp"}}; for (const auto & header : params.headers) { @@ -1669,18 +1615,14 @@ static void add_rpc_devices(const std::string & servers) { if (!rpc_reg) { throw std::invalid_argument("failed to find RPC backend"); } - typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint); - ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); - if (!ggml_backend_rpc_add_device_fn) { - throw std::invalid_argument("failed to find RPC device add function"); + typedef ggml_backend_reg_t (*ggml_backend_rpc_add_server_t)(const char * endpoint); + ggml_backend_rpc_add_server_t ggml_backend_rpc_add_server_fn = (ggml_backend_rpc_add_server_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server"); + if (!ggml_backend_rpc_add_server_fn) { + throw std::invalid_argument("failed to find RPC add server function"); } for (const auto & server : rpc_servers) { - ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str()); - if (dev) { - ggml_backend_device_register(dev); - } else { - throw std::invalid_argument("failed to register RPC device"); - } + auto reg = ggml_backend_rpc_add_server_fn(server.c_str()); + ggml_backend_register(reg); } } @@ -1986,13 +1928,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_env("LLAMA_ARG_SWA_FULL")); add_opt(common_arg( - {"--swa-checkpoints"}, "N", - string_format("max number of SWA checkpoints per slot to create (default: %d)\n" - "[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints), + {"--ctx-checkpoints", "--swa-checkpoints"}, "N", + string_format("max number of context checkpoints to create per slot (default: %d)\n" + "[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_ctx_checkpoints), [](common_params & params, int value) { - params.n_swa_checkpoints = value; + params.n_ctx_checkpoints = value; } - ).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER})); + ).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--kv-unified", "-kvu"}, string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n" @@ -2642,6 +2584,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.no_extra_bufts = true; } ).set_env("LLAMA_ARG_NO_REPACK")); + add_opt(common_arg( + {"--no-host"}, + "bypass host buffer allowing extra buffers to be used", + [](common_params & params) { + params.no_host = true; + } + ).set_env("LLAMA_ARG_NO_HOST")); add_opt(common_arg( {"-ctk", "--cache-type-k"}, "TYPE", string_format( @@ -3910,7 +3859,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF"; params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf"; - params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.embd_normalize = 2; params.n_ctx = 512; params.verbose_prompt = true; @@ -3924,7 +3872,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF"; params.model.hf_file = "e5-small-v2-q8_0.gguf"; - params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.embd_normalize = 2; params.n_ctx = 512; params.verbose_prompt = true; @@ -3938,7 +3885,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF"; params.model.hf_file = "gte-small-q8_0.gguf"; - params.pooling_type = LLAMA_POOLING_TYPE_NONE; params.embd_normalize = 2; params.n_ctx = 512; params.verbose_prompt = true; diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index 96ba8f533e..b3362519a6 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -75,6 +75,35 @@ bool common_chat_msg_parser::add_tool_calls(const json & arr) { } return true; } + +bool common_chat_msg_parser::add_tool_call_short_form(const json & tool_call) { + if (!tool_call.is_object() || tool_call.size() != 1) { + return false; + } + + // Get the tool name (the single key in the object) + auto it = tool_call.begin(); + std::string name = it.key(); + + if (name.empty()) { + return false; + } + + // Get the arguments (the nested object) + const json & args_json = it.value(); + std::string arguments = ""; + + if (args_json.is_object()) { + arguments = args_json.dump(); + } else if (args_json.is_string()) { + arguments = args_json; + } else if (!args_json.is_null()) { + // For other types, convert to string representation + arguments = args_json.dump(); + } + + return add_tool_call(name, "", arguments); +} void common_chat_msg_parser::finish() { if (!is_partial_ && pos_ != input_.size()) { throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_)); diff --git a/common/chat-parser.h b/common/chat-parser.h index 0e64c341a5..c8cdc63fb5 100644 --- a/common/chat-parser.h +++ b/common/chat-parser.h @@ -64,6 +64,9 @@ class common_chat_msg_parser { // Adds an array of tool calls using their "name", "id" and "arguments" fields. bool add_tool_calls(const nlohmann::ordered_json & arr); + // Adds a tool call using the short form: { "tool_name": { "arg1": val, "arg2": val } } + bool add_tool_call_short_form(const nlohmann::ordered_json & tool_call); + void finish(); bool consume_spaces(); diff --git a/common/chat.cpp b/common/chat.cpp index e2bacdcf52..afbb2a2bdd 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -625,6 +625,7 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only"; case COMMON_CHAT_FORMAT_GENERIC: return "Generic"; case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo"; + case COMMON_CHAT_FORMAT_MAGISTRAL: return "Magistral"; case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x"; case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools"; case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1"; @@ -638,6 +639,7 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS"; case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS"; case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2"; + case COMMON_CHAT_FORMAT_APERTUS: return "Apertus"; default: throw std::runtime_error("Unknown chat format"); } @@ -801,6 +803,7 @@ static std::string apply( } tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt; tmpl_inputs.extra_context = inputs.extra_context; + tmpl_inputs.extra_context["enable_thinking"] = inputs.enable_thinking; if (additional_context) { tmpl_inputs.extra_context.merge_patch(*additional_context); } @@ -982,6 +985,65 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO; return data; } + +static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + data.prompt = apply(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_MAGISTRAL; + data.preserved_tokens = { + "[THINK]", + "[/THINK]", + }; + + if (inputs.tools.is_array() && !inputs.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + schemas.push_back({ + {"type", "object"}, + {"properties", { + {"name", { + {"type", "string"}, + {"const", function.at("name")}, + }}, + {"arguments", function.at("parameters")}, + {"id", { + {"type", "string"}, + {"pattern", "^[a-zA-Z0-9]{9}$"}, + }}, + }}, + {"required", json::array({"name", "arguments", "id"})}, + }); + }); + auto schema = json { + {"type", "array"}, + {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, + {"minItems", 1}, + }; + if (!inputs.parallel_tool_calls) { + schema["maxItems"] = 1; + } + builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); + }); + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"}); + data.preserved_tokens.push_back("[TOOL_CALLS]"); + } else { + data.grammar_lazy = false; + if (!inputs.json_schema.is_null()) { + if (!inputs.grammar.empty()) { + throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); + } + data.grammar = json_schema_to_grammar(inputs.json_schema); + } else { + data.grammar = inputs.grammar; + } + } + + return data; +} + static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { if (!builder.syntax().parse_tool_calls) { builder.add_content(builder.consume_rest()); @@ -992,6 +1054,18 @@ static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { parse_prefixed_json_tool_call_array(builder, prefix); } +static void common_chat_parse_magistral(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("[THINK]", "[/THINK]"); + + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex prefix(regex_escape("[TOOL_CALLS]")); + parse_prefixed_json_tool_call_array(builder, prefix); +} + static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -1264,6 +1338,75 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_ } return data; } + +static common_chat_params common_chat_params_init_apertus(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + + // Generate the prompt using the apply() function with the template + data.prompt = apply(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_APERTUS; + + // Handle thinking tags appropriately based on inputs.enable_thinking + if (string_ends_with(data.prompt, "<|inner_prefix|>")) { + if (!inputs.enable_thinking) { + data.prompt += "<|inner_suffix|>"; + } else { + data.thinking_forced_open = true; + } + } + + // When tools are present, build grammar for the <|tools_prefix|> format + if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) { + data.grammar_lazy = true; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + schemas.push_back({ + { "type", "object" }, + { "properties", + { + { function.at("name"), function.at("parameters") } + } }, + { "required", json::array({ function.at("name") }) }, + }); + }); + auto schema = json{ + { "type", "array" }, + { "items", schemas.size() == 1 ? schemas[0] : json{ { "anyOf", schemas } } }, + { "minItems", 1 }, + }; + if (!inputs.parallel_tool_calls) { + schema["maxItems"] = 1; + } + builder.add_rule("root", + std::string(data.thinking_forced_open ? "( \"<|inner_suffix|>\" space )? " : "") + + "\"<|tools_prefix|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tools_suffix|>\""); + }); + data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + // If thinking_forced_open, then we capture the <|inner_suffix|> tag in the grammar, + // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) + std::string(data.thinking_forced_open ? + "[\\s\\S]*?(<\\|inner_suffix\\|>\\s*)" : + "(?:<\\|inner_prefix\\|>[\\s\\S]*?<\\|inner_suffix\\|>\\s*)?") + + "(<\\|tools_prefix\\|>)[\\s\\S]*" }); + data.preserved_tokens = { + "<|system_start|>", + "<|system_end|>", + "<|developer_start|>", + "<|developer_end|>", + "<|user_start|>", + "<|user_end|>", + "<|assistant_start|>", + "<|assistant_end|>", + "<|inner_prefix|>", + "<|inner_suffix|>", + "<|tools_prefix|>", + "<|tools_suffix|>", + }; + } + return data; +} static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { if (!builder.syntax().parse_tool_calls) { builder.add_content(builder.consume_rest()); @@ -2323,6 +2466,37 @@ static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) { builder.add_content(builder.consume_rest()); } +static void common_chat_parse_apertus(common_chat_msg_parser & builder) { + // Parse thinking tags + builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>"); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // Look for tool calls + static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + auto tool_calls_data = builder.consume_json(); + if (tool_calls_data.json.is_array()) { + builder.consume_spaces(); + if (!builder.try_consume_literal("<|tools_suffix|>")) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + for (const auto & value : tool_calls_data.json) { + if (value.is_object()) { + builder.add_tool_call_short_form(value); + } + } + } else { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + builder.add_content(builder.consume_rest()); +} + static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { // Parse thinking tags first - this handles the main reasoning content builder.try_parse_reasoning("", ""); @@ -2567,6 +2741,11 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_nemotron_v2(tmpl, params); } + // Apertus format detection + if (src.find("<|system_start|>") != std::string::npos && src.find("<|tools_prefix|>") != std::string::npos) { + return common_chat_params_init_apertus(tmpl, params); + } + // Use generic handler when mixing tools + JSON schema. // TODO: support that mix in handlers below. if ((params.tools.is_array() && params.json_schema.is_object())) { @@ -2595,6 +2774,10 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools); } + if (src.find("[THINK]") != std::string::npos && src.find("[/THINK]") != std::string::npos) { + return common_chat_params_init_magistral(tmpl, params); + } + // Plain handler (no tools) if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { return common_chat_params_init_without_tools(tmpl, params); @@ -2695,6 +2878,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_MISTRAL_NEMO: common_chat_parse_mistral_nemo(builder); break; + case COMMON_CHAT_FORMAT_MAGISTRAL: + common_chat_parse_magistral(builder); + break; case COMMON_CHAT_FORMAT_LLAMA_3_X: common_chat_parse_llama_3_1(builder); break; @@ -2734,6 +2920,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_NEMOTRON_V2: common_chat_parse_nemotron_v2(builder); break; + case COMMON_CHAT_FORMAT_APERTUS: + common_chat_parse_apertus(builder); + break; default: throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } diff --git a/common/chat.h b/common/chat.h index 5170fc14f4..a1afe574bd 100644 --- a/common/chat.h +++ b/common/chat.h @@ -101,6 +101,7 @@ enum common_chat_format { COMMON_CHAT_FORMAT_CONTENT_ONLY, COMMON_CHAT_FORMAT_GENERIC, COMMON_CHAT_FORMAT_MISTRAL_NEMO, + COMMON_CHAT_FORMAT_MAGISTRAL, COMMON_CHAT_FORMAT_LLAMA_3_X, COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, COMMON_CHAT_FORMAT_DEEPSEEK_R1, @@ -114,6 +115,7 @@ enum common_chat_format { COMMON_CHAT_FORMAT_GPT_OSS, COMMON_CHAT_FORMAT_SEED_OSS, COMMON_CHAT_FORMAT_NEMOTRON_V2, + COMMON_CHAT_FORMAT_APERTUS, COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; diff --git a/common/common.cpp b/common/common.cpp index c1e736c44c..b0591e84b0 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1133,6 +1133,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; mparams.use_extra_bufts = !params.no_extra_bufts; + mparams.no_host = params.no_host; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; diff --git a/common/common.h b/common/common.h index 40c6847f32..8a8ecd667f 100644 --- a/common/common.h +++ b/common/common.h @@ -392,6 +392,7 @@ struct common_params { bool check_tensors = false; // validate tensor data bool no_op_offload = false; // globally disable offload host tensor operations to device bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking) + bool no_host = false; // bypass host buffer allowing extra buffers to be used bool single_turn = false; // single turn chat conversation @@ -424,7 +425,7 @@ struct common_params { int32_t timeout_write = timeout_read; // http write timeout in seconds int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting - int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot + int32_t n_ctx_checkpoints = 3; // max number of context checkpoints per slot std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT diff --git a/common/http.h b/common/http.h new file mode 100644 index 0000000000..8e29787dcc --- /dev/null +++ b/common/http.h @@ -0,0 +1,73 @@ +#pragma once + +#include + +struct common_http_url { + std::string scheme; + std::string user; + std::string password; + std::string host; + std::string path; +}; + +static common_http_url common_http_parse_url(const std::string & url) { + common_http_url parts; + auto scheme_end = url.find("://"); + + if (scheme_end == std::string::npos) { + throw std::runtime_error("invalid URL: no scheme"); + } + parts.scheme = url.substr(0, scheme_end); + + if (parts.scheme != "http" && parts.scheme != "https") { + throw std::runtime_error("unsupported URL scheme: " + parts.scheme); + } + + auto rest = url.substr(scheme_end + 3); + auto at_pos = rest.find('@'); + + if (at_pos != std::string::npos) { + auto auth = rest.substr(0, at_pos); + auto colon_pos = auth.find(':'); + if (colon_pos != std::string::npos) { + parts.user = auth.substr(0, colon_pos); + parts.password = auth.substr(colon_pos + 1); + } else { + parts.user = auth; + } + rest = rest.substr(at_pos + 1); + } + + auto slash_pos = rest.find('/'); + + if (slash_pos != std::string::npos) { + parts.host = rest.substr(0, slash_pos); + parts.path = rest.substr(slash_pos); + } else { + parts.host = rest; + parts.path = "/"; + } + return parts; +} + +static std::pair common_http_client(const std::string & url) { + common_http_url parts = common_http_parse_url(url); + + if (parts.host.empty()) { + throw std::runtime_error("error: invalid URL format"); + } + + httplib::Client cli(parts.scheme + "://" + parts.host); + + if (!parts.user.empty()) { + cli.set_basic_auth(parts.user, parts.password); + } + + cli.set_follow_location(true); + + return { std::move(cli), std::move(parts) }; +} + +static std::string common_http_show_masked_url(const common_http_url & parts) { + return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path; +} diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 411e36f8cf..a59ebfc0da 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -891,6 +891,9 @@ class TextModel(ModelBase): if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206": # ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base res = "llada-moe" + if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e": + # ref: https://huggingface.co/ibm-granite/granite-docling-258M + res = "granite-docling" if res is None: logger.warning("\n") @@ -1325,6 +1328,7 @@ class MmprojModel(ModelBase): self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) # load preprocessor config + self.preprocessor_config = {} if not self.is_mistral_format: with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: self.preprocessor_config = json.load(f) @@ -1347,7 +1351,8 @@ class MmprojModel(ModelBase): self.gguf_writer.add_vision_projection_dim(self.n_embd_text) # vision config - self.gguf_writer.add_vision_image_size(self.find_vparam(["image_size"])) + self.image_size = self.find_vparam(["image_size"]) + self.gguf_writer.add_vision_image_size(self.image_size) self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"])) self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"])) self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"])) @@ -2378,6 +2383,10 @@ class SmolVLMModel(MmprojModel): self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2)) self.gguf_writer.add_vision_use_gelu(True) + # Add the preprocessor longest edge size + preproc_image_size = self.preprocessor_config.get("size", {}).get("longest_edge", self.image_size) + self.gguf_writer.add_vision_preproc_image_size(preproc_image_size) + def tensor_force_quant(self, name, new_name, bid, n_dims): if ".embeddings." in name: return gguf.GGMLQuantizationType.F32 @@ -4250,7 +4259,8 @@ class Plamo2Model(TextModel): # This logic matches modeling_plamo.py's is_mamba function mamba_step = hparams.get("mamba_step", 2) mamba_enabled = hparams.get("mamba_enabled", True) - mamba_layers = [] + num_key_value_heads = [] + num_attention_heads = [] if mamba_enabled: for i in range(block_count): @@ -4260,17 +4270,21 @@ class Plamo2Model(TextModel): else: is_mamba = (i % mamba_step) != (mamba_step // 2) if is_mamba: - mamba_layers.append(0) + num_key_value_heads.append(0) + num_attention_heads.append(0) else: - mamba_layers.append(hparams.get("num_key_value_heads", 4)) + num_key_value_heads.append(hparams.get("num_key_value_heads", 4)) + num_attention_heads.append(hparams.get("num_attention_heads", 32)) - if mamba_layers: - self.gguf_writer.add_head_count_kv(mamba_layers) + if num_key_value_heads and num_attention_heads: + self.gguf_writer.add_head_count_kv(num_key_value_heads) + self.gguf_writer.add_head_count(num_attention_heads) self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048)) self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096)) + self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128)) + self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128)) self.gguf_writer.add_block_count(block_count) - self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32)) self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06)) self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000)) @@ -8822,6 +8836,75 @@ class LFM2Model(TextModel): return [(self.map_tensor_name(name), data_torch)] +@ModelBase.register("Lfm2MoeForCausalLM") +class LFM2MoeModel(TextModel): + model_arch = gguf.MODEL_ARCH.LFM2MOE + + def set_gguf_parameters(self): + # set num_key_value_heads only for attention layers + self.hparams["num_key_value_heads"] = [ + self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0 + for layer_type in self.hparams["layer_types"] + ] + + super().set_gguf_parameters() + + self.gguf_writer.add_expert_count(self.hparams["num_experts"]) + self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"]) + self.gguf_writer.add_leading_dense_block_count(self.hparams["num_dense_layers"]) + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + self.gguf_writer.add_shortconv_l_cache(self.hparams["conv_L_cache"]) + + # cache for experts weights for merging + _experts_cache: dict[int, dict[str, Tensor]] = {} + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # conv op requires 2d tensor + if 'conv.conv' in name: + data_torch = data_torch.squeeze(1) + + if name.endswith(".expert_bias"): + name = name.replace(".expert_bias", ".expert_bias.bias") + + # merge expert weights + if 'experts' in name: + n_experts = self.hparams["num_experts"] + assert bid is not None + + expert_cache = self._experts_cache.setdefault(bid, {}) + expert_cache[name] = data_torch + expert_weights = ["w1", "w2", "w3"] + + # not enough expert weights to merge + if len(expert_cache) < n_experts * len(expert_weights): + return [] + + tensors: list[tuple[str, Tensor]] = [] + for w_name in expert_weights: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{w_name}.weight" + datas.append(expert_cache[ename]) + del expert_cache[ename] + + data_torch = torch.stack(datas, dim=0) + merged_name = f"layers.{bid}.feed_forward.experts.{w_name}.weight" + new_name = self.map_tensor_name(merged_name) + tensors.append((new_name, data_torch)) + + del self._experts_cache[bid] + return tensors + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + assert not self._experts_cache + + @ModelBase.register("Lfm2VlForConditionalGeneration") class LFM2VLModel(MmprojModel): def __init__(self, *args, **kwargs): @@ -8940,6 +9023,43 @@ class SmallThinkerModel(TextModel): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("ApertusForCausalLM") +class ApertusModel(LlamaModel): + model_arch = gguf.MODEL_ARCH.APERTUS + undo_permute = False + + _alpha_n = {} + _alpha_p = {} + _beta = {} + _eps = {} + + def modify_tensors(self, data_torch, name, bid): + # Handle xIELU activation parameters + n_layers = self.hparams["num_hidden_layers"] + if name.endswith(".act_fn.alpha_n"): + self._alpha_n[bid] = data_torch.to("cpu").float().item() + if (len(self._alpha_n) == n_layers): + self.gguf_writer.add_xielu_alpha_n([self._alpha_n[k] for k in sorted(self._alpha_n)]) + return [] + if name.endswith(".act_fn.alpha_p"): + self._alpha_p[bid] = data_torch.to("cpu").float().item() + if (len(self._alpha_p) == n_layers): + self.gguf_writer.add_xielu_alpha_p([self._alpha_p[k] for k in sorted(self._alpha_p)]) + return [] + if name.endswith(".act_fn.beta"): + self._beta[bid] = data_torch.to("cpu").float().item() + if (len(self._beta) == n_layers): + self.gguf_writer.add_xielu_beta([self._beta[k] for k in sorted(self._beta)]) + return [] + if name.endswith(".act_fn.eps"): + self._eps[bid] = data_torch.to("cpu").float().item() + if (len(self._eps) == n_layers): + self.gguf_writer.add_xielu_eps([self._eps[k] for k in sorted(self._eps)]) + return [] + + return super().modify_tensors(data_torch, name, bid) + + class MistralModel(LlamaModel): model_arch = gguf.MODEL_ARCH.LLAMA model_name = "Mistral" @@ -9107,7 +9227,7 @@ class LazyTorchTensor(gguf.LazyBase): def from_safetensors_slice(cls, st_slice: Any) -> Tensor: dtype = cls._dtype_str_map[st_slice.get_dtype()] shape: tuple[int, ...] = tuple(st_slice.get_shape()) - lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:]) + lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[...] if len(s.get_shape()) == 0 else s[:]) return cast(torch.Tensor, lazy) @classmethod diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 21bb4a9f3e..28002f766e 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -140,6 +140,7 @@ models = [ {"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", }, {"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", }, {"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", }, + {"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", }, ] # some models are known to be broken upstream, so we will skip them as exceptions diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index 6e9b88935d..92ab27066b 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -145,12 +145,13 @@ The docker build option is currently limited to *Intel GPU* targets. ```sh # Using FP16 docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f .devops/intel.Dockerfile . + +# Using FP32 +docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=OFF" --target light -f .devops/intel.Dockerfile . ``` *Notes*: -To build in default FP32 *(Slower than FP16 alternative)*, set `--build-arg="GGML_SYCL_F16=OFF"` in the previous command. - You can also use the `.devops/llama-server-intel.Dockerfile`, which builds the *"server"* alternative. Check the [documentation for Docker](../docker.md) to see the available images. @@ -160,7 +161,7 @@ Check the [documentation for Docker](../docker.md) to see the available images. # First, find all the DRI cards ls -la /dev/dri # Then, pick the card that you want to use (here for e.g. /dev/dri/card1). -docker run -it --rm -v "$(pwd):/app:Z" --device /dev/dri/renderD128:/dev/dri/renderD128 --device /dev/dri/card1:/dev/dri/card1 llama-cpp-sycl -m "/app/models/YOUR_MODEL_FILE" -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 +docker run -it --rm -v "/path/to/models:/models" --device /dev/dri/renderD128:/dev/dri/renderD128 --device /dev/dri/card0:/dev/dri/card0 llama-cpp-sycl -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -c 4096 -s 0 ``` *Notes:* @@ -215,9 +216,19 @@ To target AMD GPUs with SYCL, the ROCm stack must be installed first. 2. **Install Intelยฎ oneAPI Base toolkit** +SYCL backend depends on: + - Intelยฎ oneAPI DPC++/C++ compiler/running-time. + - Intelยฎ oneAPI DPC++/C++ library (oneDPL). + - Intelยฎ oneAPI Deep Neural Network Library (oneDNN). + - Intelยฎ oneAPI Math Kernel Library (oneMKL). + - **For Intel GPU** -The base toolkit can be obtained from the official [Intelยฎ oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) page. +All above are included in both **Intelยฎ oneAPI Base toolkit** and **Intelยฎ Deep Learning Essentials** packages. + +It's recommended to install **Intelยฎ Deep Learning Essentials** which only provides the necessary libraries with less size. + +The **Intelยฎ oneAPI Base toolkit** and **Intelยฎ Deep Learning Essentials** can be obtained from the official [Intelยฎ oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) page. Please follow the instructions for downloading and installing the Toolkit for Linux, and preferably keep the default installation values unchanged, notably the installation path *(`/opt/intel/oneapi` by default)*. @@ -225,6 +236,12 @@ Following guidelines/code snippets assume the default installation values. Other Upon a successful installation, SYCL is enabled for the available intel devices, along with relevant libraries such as oneAPI oneDNN for Intel GPUs. +|Verified release| +|-| +|2025.2.1| +|2025.1| +|2024.1| + - **Adding support to Nvidia GPUs** **oneAPI Plugin**: In order to enable SYCL support on Nvidia GPUs, please install the [Codeplay oneAPI Plugin for Nvidia GPUs](https://developer.codeplay.com/products/oneapi/nvidia/download). User should also make sure the plugin version matches the installed base toolkit one *(previous step)* for a seamless "oneAPI on Nvidia GPU" setup. @@ -255,10 +272,11 @@ sycl-ls When targeting an intel GPU, the user should expect one or more devices among the available SYCL devices. Please make sure that at least one GPU is present via `sycl-ls`, for instance `[level_zero:gpu]` in the sample output below: ``` -[opencl:acc][opencl:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.10.0.17_160000] -[opencl:cpu][opencl:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i7-13700K OpenCL 3.0 (Build 0) [2023.16.10.0.17_160000] -[opencl:gpu][opencl:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO [23.30.26918.50] -[level_zero:gpu][level_zero:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26918] +[level_zero:gpu][level_zero:0] Intel(R) oneAPI Unified Runtime over Level-Zero, Intel(R) Arc(TM) A770 Graphics 12.55.8 [1.3.29735+27] +[level_zero:gpu][level_zero:1] Intel(R) oneAPI Unified Runtime over Level-Zero, Intel(R) UHD Graphics 730 12.2.0 [1.3.29735+27] +[opencl:cpu][opencl:0] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i5-13400 OpenCL 3.0 (Build 0) [2025.20.8.0.06_160000] +[opencl:gpu][opencl:1] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO [24.39.31294] +[opencl:gpu][opencl:2] Intel(R) OpenCL Graphics, Intel(R) UHD Graphics 730 OpenCL 3.0 NEO [24.39.31294] ``` - **Nvidia GPU** @@ -353,7 +371,7 @@ cmake --build build --config Release -j -v #### Retrieve and prepare model -You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf). +You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf). ##### Check device @@ -466,7 +484,17 @@ If you already have a recent version of Microsoft Visual Studio, you can skip th 3. Install Intelยฎ oneAPI Base toolkit -The base toolkit can be obtained from the official [Intelยฎ oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) page. +SYCL backend depends on: + - Intelยฎ oneAPI DPC++/C++ compiler/running-time. + - Intelยฎ oneAPI DPC++/C++ library (oneDPL). + - Intelยฎ oneAPI Deep Neural Network Library (oneDNN). + - Intelยฎ oneAPI Math Kernel Library (oneMKL). + +All above are included in both **Intelยฎ oneAPI Base toolkit** and **Intelยฎ Deep Learning Essentials** packages. + +It's recommended to install **Intelยฎ Deep Learning Essentials** which only provides the necessary libraries with less size. + +The **Intelยฎ oneAPI Base toolkit** and **Intelยฎ Deep Learning Essentials** can be obtained from the official [Intelยฎ oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) page. Please follow the instructions for downloading and installing the Toolkit for Windows, and preferably keep the default installation values unchanged, notably the installation path *(`C:\Program Files (x86)\Intel\oneAPI` by default)*. diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 56420587a9..73032be68e 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -209,7 +209,6 @@ option(GGML_HIP "ggml: use HIP" option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) -option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF) option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON) option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF) option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF) @@ -223,6 +222,9 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF) option(GGML_WEBGPU "ggml: use WebGPU" OFF) option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF) +option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)" OFF) +option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF) + option(GGML_ZDNN "ggml: use zDNN" OFF) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 62b6d65e51..f1b7407859 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -215,6 +215,8 @@ extern "C" { // Backend registry // + GGML_API void ggml_backend_register(ggml_backend_reg_t reg); + GGML_API void ggml_backend_device_register(ggml_backend_dev_t device); // Backend (reg) enumeration diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 1e67411276..72eff00273 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -7,26 +7,25 @@ extern "C" { #endif -#define RPC_PROTO_MAJOR_VERSION 2 +#define RPC_PROTO_MAJOR_VERSION 3 #define RPC_PROTO_MINOR_VERSION 0 #define RPC_PROTO_PATCH_VERSION 0 #define GGML_RPC_MAX_SERVERS 16 // backend API -GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint); +GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device); GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend); -GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device); -GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total); +GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total); -GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, - const char * cache_dir, - size_t free_mem, size_t total_mem); +GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, + size_t n_threads, size_t n_devices, + ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); - -GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint); +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint); #ifdef __cplusplus } diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 34bd320458..60c6b63d05 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -576,6 +576,7 @@ extern "C" { GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_EXP, GGML_UNARY_OP_GELU_ERF, + GGML_UNARY_OP_XIELU, GGML_UNARY_OP_COUNT, }; @@ -1150,6 +1151,18 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // xIELU activation function + // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0) + // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions + // that constrain the positive and negative source alpha values respectively + GGML_API struct ggml_tensor * ggml_xielu( + struct ggml_context * ctx, + struct ggml_tensor * a, + float alpha_n, + float alpha_p, + float beta, + float eps); + // gated linear unit ops // A: n columns, r rows, // result is n / 2 columns, r rows, diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index fa46f3b491..929bc44881 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -392,12 +392,8 @@ static void ggml_dyn_tallocr_free(struct ggml_dyn_tallocr * alloc) { free(alloc); } -static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc) { - size_t max_size = 0; - for (int i = 0; i < alloc->n_chunks; i++) { - max_size += alloc->chunks[i]->max_size; - } - return max_size; +static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc, int chunk) { + return chunk < alloc->n_chunks ? alloc->chunks[chunk]->max_size : 0; } @@ -417,10 +413,8 @@ static void ggml_vbuffer_free(struct vbuffer * buf) { free(buf); } -static int ggml_vbuffer_n_chunks(struct vbuffer * buf) { - int n = 0; - while (n < GGML_VBUFFER_MAX_CHUNKS && buf->chunks[n]) n++; - return n; +static size_t ggml_vbuffer_chunk_size(struct vbuffer * buf, int chunk) { + return buf->chunks[chunk] ? ggml_backend_buffer_get_size(buf->chunks[chunk]) : 0; } static size_t ggml_vbuffer_size(struct vbuffer * buf) { @@ -885,12 +879,20 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c } } - size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0; - size_t new_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]); - // even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views - if (new_size > cur_size || galloc->buffers[i] == NULL) { + bool realloc = galloc->buffers[i] == NULL; + size_t new_size = 0; + for (int c = 0; c < galloc->buf_tallocs[i]->n_chunks; c++) { + size_t cur_chunk_size = galloc->buffers[i] ? ggml_vbuffer_chunk_size(galloc->buffers[i], c) : 0; + size_t new_chunk_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i], c); + new_size += new_chunk_size; + if (new_chunk_size > cur_chunk_size) { + realloc = true; + } + } + if (realloc) { #ifndef NDEBUG + size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0; GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); #endif diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index 07784d6f66..6792ba986e 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -209,9 +209,6 @@ extern "C" { void * context; }; - // Internal backend registry API - GGML_API void ggml_backend_register(ggml_backend_reg_t reg); - // Add backend dynamic loading support to the backend // Initialize the backend diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp index 867e158dca..895a571375 100644 --- a/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ggml/src/ggml-cpu/amx/amx.cpp @@ -149,6 +149,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous is_contiguous_2d(op->src[1]) && // src1 must be contiguous op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() && + op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315) op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) { // src1 must be host buffer diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index dbc07301b2..eded6eb77e 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2187,6 +2187,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_XIELU: { n_tasks = n_threads; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 14f7dcf4f4..8e1a2de14f 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8135,7 +8135,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( } // V /= S - const float S_inv = 1.0f/S; + const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; ggml_vec_scale_f32(DV, VKQ32, S_inv); // dst indices @@ -8637,7 +8637,7 @@ static void ggml_compute_forward_ssm_scan_f32( // n_head for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 - const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dt_soft_plus = ggml_softplus(dt[h]); const float dA = expf(dt_soft_plus * A[h]); const int g = h / (nh / ng); // repeat_interleave @@ -8734,7 +8734,7 @@ static void ggml_compute_forward_ssm_scan_f32( // n_head for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 - const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dt_soft_plus = ggml_softplus(dt[h]); const int g = h / (nh / ng); // repeat_interleave // dim @@ -8997,6 +8997,10 @@ void ggml_compute_forward_unary( { ggml_compute_forward_exp(params, dst); } break; + case GGML_UNARY_OP_XIELU: + { + ggml_compute_forward_xielu(params, dst); + } break; default: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index 4fce569b3b..cf1a4615d0 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -52,6 +52,15 @@ static inline float op_sqrt(float x) { return sqrtf(x); } +static inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) { + if (x > 0.0f) { + return alpha_p * x * x + beta * x; + } else { + const float min_x_eps = fminf(x, eps); + return (expm1f(min_x_eps) - x) * alpha_n + beta * x; + } +} + static inline float op_sin(float x) { return sinf(x); } @@ -121,6 +130,86 @@ static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) { } } +template +static void unary_op_params(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) { + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + apply_unary_op(params, dst); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type)); + GGML_ABORT("fatal error"); + } +} + +// Extend vec_unary_op to support functors +template +static inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) { + constexpr auto src0_to_f32 = type_conversion_table::to_f32; + constexpr auto f32_to_dst = type_conversion_table::from_f32; + + for (int i = 0; i < n; i++) { + y[i] = f32_to_dst(op(src0_to_f32(x[i]))); + } +} + +// Extend apply_unary_op to support functors +template +static void apply_unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(dst_t)); + GGML_ASSERT(nb00 == sizeof(src0_t)); + + const auto [ir0, ir1] = get_thread_range(params, src0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op); + } +} + +// Generic dispatcher for functors +template +static void unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) { + const ggml_tensor * src0 = dst->src[0]; + + /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32 + apply_unary_op_functor(params, dst, op); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16 + apply_unary_op_functor(params, dst, op); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16 + apply_unary_op_functor(params, dst, op); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) { + apply_unary_op_functor(params, dst, op); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + apply_unary_op_functor(params, dst, op); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type)); + GGML_ABORT("fatal error"); + } +} + void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) { unary_op(params, dst); } @@ -184,3 +273,17 @@ void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor * void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) { unary_op(params, dst); } + +void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) { + const float alpha_n = ggml_get_op_params_f32(dst, 1); + const float alpha_p = ggml_get_op_params_f32(dst, 2); + const float beta = ggml_get_op_params_f32(dst, 3); + const float eps = ggml_get_op_params_f32(dst, 4); + + const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) { + return op_xielu(f, alpha_n, alpha_p, beta, eps); + }; + + unary_op_functor(params, dst, xielu_op_params); +} + diff --git a/ggml/src/ggml-cpu/unary-ops.h b/ggml/src/ggml-cpu/unary-ops.h index b1ade2c8e3..697c1e0da0 100644 --- a/ggml/src/ggml-cpu/unary-ops.h +++ b/ggml/src/ggml-cpu/unary-ops.h @@ -22,6 +22,7 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst); #ifdef __cplusplus } diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 341e64e64f..f95ca94e54 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -654,11 +654,11 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { } // leftovers // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only - if (np < n) { - svbool_t pg = svwhilelt_b32(np, n); - ay1 = svld1_f32(pg, y + np); + for (int i = np; i < n; i += ggml_f32_epr) { + svbool_t pg = svwhilelt_b32(i, n); + ay1 = svld1_f32(pg, y + i); ay1 = svmul_f32_m(pg, ay1, vx); - svst1_f32(pg, y + np, ay1); + svst1_f32(pg, y + i, ay1); } #elif defined(__riscv_v_intrinsic) for (int i = 0, avl; i < n; i += avl) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index c4246b65eb..d51abbeafa 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -220,14 +220,6 @@ static const char * cu_get_error_str(CUresult err) { #define FAST_FP16_AVAILABLE #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 -#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) -#define FP16_MMA_AVAILABLE -#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) - -#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) -#define FP16_MMA_AVAILABLE -#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) - #if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) #define AMD_MFMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) @@ -262,27 +254,6 @@ static bool fast_fp16_hardware_available(const int cc) { (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); } -// Any FP16 tensor core instructions are available for ggml code. -static bool fp16_mma_available(const int cc) { -#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) - return false; -#else - if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || - GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || - GGML_CUDA_CC_IS_MTHREADS(cc)) { - return true; - } else if (GGML_CUDA_CC_IS_RDNA4(cc)) { -#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12) - return true; -#else - return false; -#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12) - } else { - return false; - } -#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) -} - // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fp16_mma_hardware_available(const int cc) { return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index 131a5099a3..68de623d80 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -1,6 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" #include "fattn-tile.cuh" +#include "fattn-wmma-f16.cuh" // kq_stride == number of KQ rows to process per iteration // kq_nbatch == number of K columns to load in parallel for KQ calculation @@ -190,10 +191,10 @@ static __global__ void flash_attn_tile( #ifdef FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: -#ifdef FP16_MMA_AVAILABLE +#ifdef GGML_USE_WMMA_FATTN NO_DEVICE_CODE; return; -#endif // FP16_MMA_AVAILABLE +#endif // GGML_USE_WMMA_FATTN if (use_logit_softcap && !(D == 128 || D == 256)) { GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 59c62553b0..89ab0f1638 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -535,8 +535,6 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - if (Q->ne[1] == 1) { constexpr int cols_per_block = 1; if (logit_softcap == 0.0f) { diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 2219191fd9..6c90d6d52b 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -6,19 +6,19 @@ #include "fattn-common.cuh" #include "fattn-wmma-f16.cuh" -#ifdef FP16_MMA_AVAILABLE +#ifdef GGML_USE_WMMA_FATTN #if !defined(GGML_USE_HIP) #include -#ifdef GGML_USE_MUSA +#if defined(GGML_USE_MUSA) namespace wmma = mtmusa::wmma; #else // GGML_USE_MUSA namespace wmma = nvcuda::wmma; #endif // GGML_USE_MUSA -#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) +#elif defined(GGML_USE_HIP) #include namespace wmma = rocwmma; #endif // !defined(GGML_USE_HIP) -#endif // FP16_MMA_AVAILABLE +#endif // GGML_USE_WMMA_FATTN // D == head size, VKQ_stride == num VKQ rows calculated in parallel: template @@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) +#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -481,7 +481,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) +#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) } constexpr int get_max_power_of_2(int x) { diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index beeea95eb1..1848d08836 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -1,3 +1,49 @@ #include "common.cuh" +#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) +#define GGML_USE_WMMA_FATTN +#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) + +#if defined(GGML_HIP_ROCWMMA_FATTN) +#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) +#define GGML_USE_WMMA_FATTN +#elif defined(CDNA) +#warning "rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance" +#endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) +#if defined(RDNA3) +#define GGML_USE_WMMA_FATTN +#endif // defined(RDNA3) +#if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 +#define GGML_USE_WMMA_FATTN +#elif defined(RDNA4) +#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance" +#endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 +#endif // defined(GGML_HIP_ROCWMMA_FATTN) + +// WMMA flash attention requires FP16 matrix instructions to be available for ggml code. +static bool ggml_cuda_should_use_wmma_fattn(const int cc) { +#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) + return false; +#else + if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA) || + GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) { + return true; + } else if (GGML_CUDA_CC_IS_CDNA(cc)){ +#if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) + return true; +#else + return false; +#endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) + } else if (GGML_CUDA_CC_IS_RDNA4(cc)) { +#if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1 + return true; +#else + return false; +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1 + } else { + return false; + } +#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) +} + void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 1cbd4f5bd6..0c8e7b3e41 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -208,6 +208,12 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const const int cc = ggml_cuda_info().devices[device].cc; + // TODO: temporary until support is extended + // https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206 + if (K->ne[1] % FATTN_KQ_STRIDE != 0) { + return BEST_FATTN_KERNEL_NONE; + } + switch (K->ne[0]) { case 64: case 128: @@ -222,7 +228,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (V->ne[0] != K->ne[0]) { return BEST_FATTN_KERNEL_NONE; } - if (!fp16_mma_available(cc) && !turing_mma_available(cc)) { + if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) { return BEST_FATTN_KERNEL_NONE; } break; @@ -300,7 +306,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // For large batch sizes, use the WMMA kernel if possible: - if (fp16_mma_available(cc)) { + if (ggml_cuda_should_use_wmma_fattn(cc)) { return BEST_FATTN_KERNEL_WMMA_F16; } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b7e81b21bc..26e72bbc2b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2334,6 +2334,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_ELU: ggml_cuda_op_elu(ctx, dst); break; + case GGML_UNARY_OP_XIELU: + ggml_cuda_op_xielu(ctx, dst); + break; default: return false; } diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 039f284719..afe4aee240 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -13,7 +13,7 @@ It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models */ -template +template __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, float * weights, int32_t * ids, @@ -204,8 +204,6 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); - cudaStream_t stream = ctx.stream(); - const int n_expert_used = weights->ne[1]; if (with_norm) { diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 5aff8a876a..3c564566a5 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -1,4 +1,5 @@ #include "unary.cuh" +#include "convert.cuh" static __device__ __forceinline__ float op_abs(float x) { return fabsf(x); @@ -375,6 +376,59 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream); } +/* CUDA kernel + launcher for xIELU */ + +template +static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + const float xi = ggml_cuda_cast(x[i]); + + const float gate_pos = (xi > 0.0f); + const float y_pos = alpha_p * xi * xi + beta * xi; + const float min_v_eps = fminf(xi, eps); + const float y_neg = (expm1f(min_v_eps) - xi) * alpha_n + beta * xi; + const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg; + + dst[i] = ggml_cuda_cast(out); +} + +template +static void xielu_cuda(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps, cudaStream_t stream) { + const int num_blocks = (k + CUDA_XIELU_BLOCK_SIZE) / CUDA_XIELU_BLOCK_SIZE; + xielu_kernel<<>>(x, dst, k, alpha_n, alpha_p, beta, eps); +} + +void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const void * src0_d = src0->data; + void * dst_d = dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + + const float alpha_n = ggml_get_op_params_f32(dst, 1); + const float alpha_p = ggml_get_op_params_f32(dst, 2); + const float beta = ggml_get_op_params_f32(dst, 3); + const float eps = ggml_get_op_params_f32(dst, 4); + + if (src0->type == GGML_TYPE_F16) { + xielu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream); + } else { + xielu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream); + } +} + + + /* silu_back */ static __device__ __forceinline__ float op_silu_back(float grad, float x) { diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index da3caf1d89..8e7644fcd9 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -16,6 +16,7 @@ #define CUDA_SIN_BLOCK_SIZE 256 #define CUDA_COS_BLOCK_SIZE 256 #define CUDA_GLU_BLOCK_SIZE 256 +#define CUDA_XIELU_BLOCK_SIZE 256 void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst); @@ -72,3 +73,5 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 37386afcd4..890c103649 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -6,6 +6,10 @@ #include #include +#if defined(GGML_HIP_ROCWMMA_FATTN) +#include +#endif // defined(GGML_HIP_ROCWMMA_FATTN) + #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #define CUBLAS_OP_N HIPBLAS_OP_N diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index d327b90cce..0e2b1847e0 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -39,12 +39,6 @@ endif() find_package(hip REQUIRED) find_package(hipblas REQUIRED) find_package(rocblas REQUIRED) -if (GGML_HIP_ROCWMMA_FATTN) - CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA) - if (NOT ${FOUND_ROCWMMA}) - message(FATAL_ERROR "rocwmma has not been found") - endif() -endif() if (${hip_VERSION} VERSION_LESS 6.1) message(FATAL_ERROR "At least ROCM/HIP V6.1 is required") @@ -117,10 +111,6 @@ if (NOT GGML_HIP_MMQ_MFMA) add_compile_definitions(GGML_HIP_NO_MMQ_MFMA) endif() -if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0) - add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12) -endif() - if (GGML_HIP_EXPORT_METRICS) set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps") endif() diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 86a1ebf62b..d0fb3bccad 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -102,6 +102,9 @@ static bool ggml_op_is_empty(enum ggml_op op) { } } +static inline float ggml_softplus(float input) { + return (input > 20.0f) ? input : logf(1 + expf(input)); +} // // logging // diff --git a/ggml/src/ggml-metal/ggml-metal-common.cpp b/ggml/src/ggml-metal/ggml-metal-common.cpp index dc7d241c3a..95627d3866 100644 --- a/ggml/src/ggml-metal/ggml-metal-common.cpp +++ b/ggml/src/ggml-metal/ggml-metal-common.cpp @@ -112,7 +112,7 @@ static bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * t } bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { - for (int i = 0; i < GGML_MAX_DIMS; i++) { + for (int i = 0; i < GGML_MAX_SRC; i++) { if (tensor->src[i]) { ggml_mem_ranges_add_src(mrs, tensor->src[i]); } @@ -173,7 +173,7 @@ static bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor * } bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { - for (int i = 0; i < GGML_MAX_DIMS; i++) { + for (int i = 0; i < GGML_MAX_SRC; i++) { if (tensor->src[i]) { if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) { return false; diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 819f31c8a3..46cc513459 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -338,7 +338,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar char base[256]; char name[256]; - snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); + const char * suffix = ""; + + if (op->src[1]->ne[0] % 4 == 0) { + suffix = "_4"; + } + + snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix); snprintf(name, 256, "%s", base); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); @@ -352,15 +358,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar } ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + char base[256]; char name[256]; - if (op->src[3]->ne[0] == 1) { - snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type)); - } else { - snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); - } - snprintf(name, 256, "%s", base); + const int nsg = (ne00 + 31)/32; + + snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s_nsg=%d", base, nsg); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); if (res) { @@ -369,7 +375,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg); return res; } @@ -918,6 +924,50 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + bool has_mask, + int32_t ncpsg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + GGML_UNUSED(op); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_%s", + "flash_attn_ext_pad"); + + snprintf(name, 256, "%s_mask=%d_ncpsg=%d", + base, + has_mask, + ncpsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const ggml_tensor * op, @@ -925,6 +975,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); @@ -937,18 +988,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; + // do bounds checks for the mask? + const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0); + snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", "flash_attn_ext", ggml_type_name(op->src[1]->type), dk, dv); - snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d", + snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d", base, has_mask, has_sinks, has_bias, has_scap, + has_kvpad, + bc_mask, ns10, ns20, nsg); @@ -964,6 +1020,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1); ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2); ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4); + + ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10); ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20); ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21); @@ -983,6 +1042,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg, int32_t nwg) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); @@ -1002,12 +1062,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( dk, dv); - snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d", + snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d", base, has_mask, has_sinks, has_bias, has_scap, + has_kvpad, ns10, ns20, nsg, nwg); @@ -1023,6 +1084,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1); ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2); ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4); ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20); ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index f6ebf90a00..ef04950738 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -135,6 +135,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + bool has_mask, + int32_t ncpsg); + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const struct ggml_tensor * op, @@ -142,6 +148,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( @@ -151,6 +158,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg, int32_t nwg); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 523f9d71ba..9527973015 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -776,9 +776,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te }; } case GGML_OP_GET_ROWS: - { - return op->ne[3] == 1; - } + return true; case GGML_OP_SET_ROWS: { if (op->src[0]->type != GGML_TYPE_F32) { diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 88c98423eb..1524b3ab51 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -69,11 +69,12 @@ #define N_SG_IQ4_XS 2 // function constants offsets -#define FC_FLASH_ATTN_EXT 100 -#define FC_FLASH_ATTN_EXT_VEC 200 -#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300 -#define FC_MUL_MV 400 -#define FC_MUL_MM 500 +#define FC_FLASH_ATTN_EXT_PAD 100 +#define FC_FLASH_ATTN_EXT 200 +#define FC_FLASH_ATTN_EXT_VEC 300 +#define FC_FLASH_ATTN_EXT_VEC_REDUCE 400 +#define FC_MUL_MV 500 +#define FC_MUL_MM 600 // kernel argument structs // @@ -178,6 +179,7 @@ typedef struct { } ggml_metal_kargs_clamp; typedef struct { + int64_t nk0; int64_t ne00; int64_t ne01; int64_t ne02; @@ -243,6 +245,24 @@ typedef struct { int32_t sect_3; } ggml_metal_kargs_rope; +typedef struct { + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; +} ggml_metal_kargs_flash_attn_ext_pad; + typedef struct { int32_t ne01; int32_t ne02; @@ -261,6 +281,7 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne31; int32_t ne32; int32_t ne33; uint64_t nb31; @@ -295,6 +316,7 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne31; int32_t ne32; int32_t ne33; uint64_t nb31; @@ -572,32 +594,45 @@ typedef struct { int64_t n_seq_tokens; int64_t n_seqs; uint64_t s_off; + uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; + uint64_t nb10; uint64_t nb11; uint64_t nb12; + uint64_t ns12; uint64_t nb13; + uint64_t nb20; uint64_t nb21; + uint64_t ns21; uint64_t nb22; + int64_t ne30; uint64_t nb31; uint64_t nb41; uint64_t nb42; + uint64_t ns42; uint64_t nb43; uint64_t nb51; uint64_t nb52; + uint64_t ns52; uint64_t nb53; + uint64_t nb0; } ggml_metal_kargs_ssm_scan; typedef struct { - int64_t ne00; + int32_t ne00t; + int32_t ne00; uint64_t nb01; uint64_t nb02; - int64_t ne10; + uint64_t nb03; + int32_t ne10; uint64_t nb10; uint64_t nb11; + uint64_t nb12; uint64_t nb1; uint64_t nb2; + uint64_t nb3; } ggml_metal_kargs_get_rows; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e85a223c01..125cc64dc5 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -226,6 +226,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb); GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne); GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb); + GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb); + GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb); GGML_TENSOR_LOCALS( int64_t, ne, node, ne); GGML_TENSOR_LOCALS(uint64_t, nb, node, nb); @@ -237,6 +241,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, ggml_is_contiguous(node->src[1]), node->src[1]->name); } + if (node->src[2]) { + GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23, + ggml_is_contiguous(node->src[2]), node->src[2]->name); + } + if (node->src[3]) { + GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33, + ggml_is_contiguous(node->src[3]), node->src[3]->name); + } if (node) { GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, node->name); @@ -577,6 +589,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); ggml_metal_kargs_cpy args = { + /*.nk0 =*/ ne00, /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, @@ -906,23 +919,31 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); ggml_metal_kargs_get_rows args = { - /*.ne00 =*/ ne00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.ne10 =*/ ne10, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, + /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00, + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, }; + const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const int nw0 = (args.ne00t + nth - 1)/nth; + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1); return 1; } @@ -1117,7 +1138,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); @@ -1172,25 +1193,36 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { /*.n_seq_tokens =*/ n_seq_tokens, /*.n_seqs =*/ n_seqs, /*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float), + /*.nb00 =*/ nb00, /*.nb01 =*/ nb01, /*.nb02 =*/ nb02, /*.nb03 =*/ nb03, + /*.nb10 =*/ nb10, /*.nb11 =*/ nb11, /*.nb12 =*/ nb12, + /*.ns12 =*/ nb12/nb10, /*.nb13 =*/ nb13, + /*.nb20 =*/ nb20, /*.nb21 =*/ nb21, + /*.ns21 =*/ nb21/nb20, /*.nb22 =*/ nb22, + /*.ne30 =*/ ne30, /*.nb31 =*/ nb31, /*.nb41 =*/ nb41, /*.nb42 =*/ nb42, + /*.ns42 =*/ nb42/nb40, /*.nb43 =*/ nb43, /*.nb51 =*/ nb51, /*.nb52 =*/ nb52, + /*.ns52 =*/ nb52/nb50, /*.nb53 =*/ nb53, + /*.nb0 =*/ nb0, }; ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); + GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + const size_t sms = ggml_metal_pipeline_get_smem(pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline); @@ -1206,13 +1238,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0); - if (ne30 == 1) { - // Mamba-2 - ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); - } else { - GGML_ASSERT(d_inner == 1); - ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1); - } + ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); return 1; } @@ -1273,26 +1299,23 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0); - // TODO: support - //const int32_t nk00 = ne00/ggml_blck_size(op->type); - const int32_t nk00 = ne00; - - int nth = 32; // SIMD width - - while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { - nth *= 2; + int64_t nk0 = ne00; + if (ggml_is_quantized(op->src[0]->type)) { + nk0 = ne00/16; + } else if (ggml_is_quantized(op->type)) { + nk0 = ne00/ggml_blck_size(op->type); } - nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); // when rows are small, we can batch them together in a single threadgroup int nrptg = 1; // TODO: relax this constraint in the future if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) { - if (nth > nk00) { - nrptg = (nth + nk00 - 1)/nk00; - nth = nk00; + if (nth > nk0) { + nrptg = (nth + nk0 - 1)/nk0; + nth = nk0; if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nrptg--; @@ -1300,10 +1323,11 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { } } - nth = std::min(nth, nk00); + nth = std::min(nth, nk0); ggml_metal_kargs_cpy args = { - /*.ne00 =*/ nk00, + /*.nk0 =*/ nk0, + /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, /*.ne03 =*/ ne03, @@ -1321,12 +1345,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { /*.nb3 =*/ nb3, }; + const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1; + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1); return 1; } @@ -1875,20 +1901,69 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) { return (ne01 < 20) && (ne00 % 32 == 0); } +size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + + size_t res = 0; + + const bool has_mask = op->src[3] != nullptr; + + if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + const bool has_kvpad = ne11 % 32 != 0; + + if (has_kvpad) { + res += 32*( + nb11*ne12*ne13 + + nb21*ne22*ne23 + + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); + } + } else { + const bool has_kvpad = ne11 % 64 != 0; + + if (has_kvpad) { + res += 64*( + nb11*ne12*ne13 + + nb21*ne22*ne23 + + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); + } + } + + return res; +} + size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); - const int64_t nwg = 32; + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); - const int64_t ne01 = op->src[0]->ne[1]; - const int64_t ne02 = op->src[0]->ne[2]; - const int64_t ne03 = op->src[0]->ne[3]; - const int64_t ne20 = op->src[2]->ne[0]; + size_t res = 0; - // temp buffer for writing the results from each workgroup - // - ne20: the size of the Value head - // - + 2: the S and M values for each intermediate result - return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2)); + if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + const int64_t nwg = 32; + + // temp buffer for writing the results from each workgroup + // - ne20: the size of the Value head + // - + 2: the S and M values for each intermediate result + res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2)); + } + + return res; } int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { @@ -1910,8 +1985,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS( int32_t, nb, op, nb); - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ne11 % 32 == 0); + GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(op->src[1]->type == op->src[2]->type); @@ -1921,8 +1995,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ne12 == ne22); GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16); - GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) && - "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] && + "the Flash-Attention Metal kernel requires the mask to be at least n_queries big"); float scale; float max_bias; @@ -1949,6 +2023,20 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ne01 < 65536); + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]); + ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0; + ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0; + + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_buffer_id bid_pad = bid_dst; + bid_pad.offs += ggml_nbytes(op); + + ggml_metal_buffer_id bid_tmp = bid_pad; + bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad(op); + if (!ggml_metal_op_flash_attn_ext_use_vec(op)) { // half8x8 kernel const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! @@ -1958,6 +2046,48 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); + const bool has_kvpad = ne11 % ncpsg != 0; + + if (has_kvpad) { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0); + + ggml_metal_kargs_flash_attn_ext_pad args0 = { + /*.ne11 =*/ne11, + /*.ne_12_2 =*/ne12, + /*.ne_12_3 =*/ne13, + /*.nb11 =*/nb11, + /*.nb12 =*/nb12, + /*.nb13 =*/nb13, + /*.nb21 =*/nb21, + /*.nb22 =*/nb22, + /*.nb23 =*/nb23, + /*.ne31 =*/ne31, + /*.ne32 =*/ne32, + /*.ne33 =*/ne33, + /*.nb31 =*/nb31, + /*.nb32 =*/nb32, + /*.nb33 =*/nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_src2, 2); + ggml_metal_encoder_set_buffer (enc, bid_src3, 3); + ggml_metal_encoder_set_buffer (enc, bid_pad, 4); + + assert(ne12 == ne22); + assert(ne13 == ne23); + + ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); + + ggml_metal_op_concurrency_reset(ctx); + } else { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0); + } + const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0; // 2*(2*ncpsg) @@ -2007,6 +2137,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb21 =*/ nb21, /*.nb22 =*/ nb22, /*.nb23 =*/ nb23, + /*.ne31 =*/ ne31, /*.ne32 =*/ ne32, /*.ne33 =*/ ne33, /*.nb31 =*/ nb31, @@ -2023,24 +2154,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg); + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); - if (op->src[3]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4); - } - if (op->src[4]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5); - } - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_src2, 3); + ggml_metal_encoder_set_buffer (enc, bid_src3, 4); + ggml_metal_encoder_set_buffer (enc, bid_src4, 5); + ggml_metal_encoder_set_buffer (enc, bid_pad, 6); + ggml_metal_encoder_set_buffer (enc, bid_dst, 7); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); @@ -2056,6 +2180,48 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(nqptg % 1 == 0); GGML_ASSERT(ncpsg % 32 == 0); + const bool has_kvpad = ne11 % ncpsg != 0; + + if (has_kvpad) { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0); + + ggml_metal_kargs_flash_attn_ext_pad args0 = { + /*.ne11 =*/ne11, + /*.ne_12_2 =*/ne12, + /*.ne_12_3 =*/ne13, + /*.nb11 =*/nb11, + /*.nb12 =*/nb12, + /*.nb13 =*/nb13, + /*.nb21 =*/nb21, + /*.nb22 =*/nb22, + /*.nb23 =*/nb23, + /*.ne31 =*/ne31, + /*.ne32 =*/ne32, + /*.ne33 =*/ne33, + /*.nb31 =*/nb31, + /*.nb32 =*/nb32, + /*.nb33 =*/nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_src2, 2); + ggml_metal_encoder_set_buffer (enc, bid_src3, 3); + ggml_metal_encoder_set_buffer (enc, bid_pad, 4); + + assert(ne12 == ne22); + assert(ne13 == ne23); + + ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); + + ggml_metal_op_concurrency_reset(ctx); + } else { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0); + } + // ne00 + 2*ncpsg*(nsg) // for each query, we load it as f16 in shared memory (ne00) // and store the soft_max values and the mask @@ -2120,6 +2286,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb21 =*/ nb21, /*.nb22 =*/ nb22, /*.nb23 =*/ nb23, + /*.ne31 =*/ ne31, /*.ne32 =*/ ne32, /*.ne33 =*/ ne33, /*.nb31 =*/ nb31, @@ -2136,25 +2303,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg); + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); - if (op->src[3]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4); - } - if (op->src[4]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5); - } + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_src2, 3); + ggml_metal_encoder_set_buffer (enc, bid_src3, 4); + ggml_metal_encoder_set_buffer (enc, bid_src4, 5); const size_t smem = FATTN_SMEM(nsg); @@ -2162,23 +2321,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size); if (nwg == 1) { + assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0); + // using 1 workgroup -> write the result directly into dst - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6); + ggml_metal_encoder_set_buffer(enc, bid_pad, 6); + ggml_metal_encoder_set_buffer(enc, bid_dst, 7); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); } else { // sanity checks + assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0); + GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3); GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31)); - ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); - // write the results from each workgroup into a temp buffer - ggml_metal_buffer_id bid_tmp = bid_dst; - bid_tmp.offs += ggml_nbytes(op); - ggml_metal_encoder_set_buffer(enc, bid_tmp, 6); + ggml_metal_encoder_set_buffer(enc, bid_pad, 6); + ggml_metal_encoder_set_buffer(enc, bid_tmp, 7); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 8df4c72e7c..6a6d8a7977 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -39,6 +39,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op); // return true if we should use the FA vector kernel for this op bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op); +size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op); size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op); int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index e11555a78f..e53f37b29c 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -193,9 +193,8 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ } break; case GGML_OP_FLASH_ATTN_EXT: { - if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) { - res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor); - } + res += ggml_metal_op_flash_attn_ext_extra_pad(tensor); + res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor); } break; default: break; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 96df6f0ce6..c52c6b48ad 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2032,7 +2032,38 @@ kernel void kernel_ssm_conv_f32_f32( x[0] = sumf; } -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part +kernel void kernel_ssm_conv_f32_f32_4( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i2 = tgpig.y; + const int64_t i3 = tgpig.z; + + const int64_t nc = args.ne10; + //const int64_t ncs = args.ne00; + //const int64_t nr = args.ne01; + //const int64_t n_t = args.ne1; + //const int64_t n_s = args.ne2; + + device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11); + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc/4; ++i0) { + sumf += dot(s[i0], c[i0]); + } + + x[0] = sumf; +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part kernel void kernel_ssm_scan_f32( constant ggml_metal_kargs_ssm_scan & args, device const void * src0, @@ -2044,219 +2075,88 @@ kernel void kernel_ssm_scan_f32( device const void * src6, device float * dst, threadgroup float * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgptg[[simdgroups_per_threadgroup]], - uint3 tgpg[[threadgroups_per_grid]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + constexpr short NW = N_SIMDWIDTH; - const int64_t i0 = tpitg.x; - const int64_t i1 = 0; - const int64_t ir = tgpig.x; // current head - const int64_t i3 = tgpig.y; // current seq + shared[tpitg.x] = 0.0f; - const uint64_t nb00 = sizeof(float); - const uint64_t nb10 = sizeof(float); - const uint64_t nb20 = sizeof(float); + const int32_t i0 = tpitg.x; + const int32_t i1 = tgpig.x; + const int32_t ir = tgpig.y; // current head + const int32_t i3 = tgpig.z; // current seq - const int64_t nc = args.d_state; - const int64_t nr = args.d_inner; - const int64_t nh = args.n_head; - const int64_t ng = args.n_group; - const int64_t n_t = args.n_seq_tokens; + const int32_t nc = args.d_state; + const int32_t nr = args.d_inner; + const int32_t nh = args.n_head; + const int32_t ng = args.n_group; + const int32_t n_t = args.n_seq_tokens; - const int64_t s_off = args.s_off; + const int32_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); - const int64_t i = i0 + i1*nc; - const int64_t g = ir / (nh / ng); // repeat_interleave + + const int32_t i = i0 + i1*nc; + const int32_t g = ir / (nh / ng); // repeat_interleave + float s0 = s0_buff[i]; - float s = s_buff[i]; + float s = 0.0f; - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); - device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); - device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); - device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh} - for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} - device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} + const float A0 = A[i0%args.ne30]; - const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - const float x_dt = x[0] * dt_soft_plus; + device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns} + device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns} - const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); - s = state; + device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns} - // Parallel sum: This relies on the fact that this kernel will be - // dispatched with each threadgroup having (d_state, 1, 1) threads which - // are subdivided into SIMD groups of size `sgptg`. The goal is to - // compute y = sum({state * C[i] for i in range(d_state)}). - // To parallelize this effectively, we first use simd_sum over each SIMD - // group to compute the sum of each SIMD group, then place the result in - // the SIMD group's indexed bucket in the shared memory. We then sum - // over the individual group sums to compute the final sum. - - // Computed for each thread - float sumf = state * C[i0]; - - // Sum the threads in the simd group => simd sum - sumf = simd_sum(sumf); - - if (sgptg > 1) { - - // Once per simd group, place the group sum into the shared buffer - if (tiisg == 0) { - shared[sgitg] = sumf; - } - - // Wait for all threads in the threadgroup to reach this point. This - // ensures that all elements of the shared buffer are populated with the - // sum of the individual simd groups. - threadgroup_barrier(mem_flags::mem_threadgroup); - - // For simd group 0 at indices < num simd groups, extract the shared - // simd sum - sumf = 0.0f; - if (sgitg == 0) { - if (tiisg < sgptg) { - sumf = shared[tiisg]; - } - sumf = simd_sum(sumf); - if (tiisg == 0) { - y[0] = sumf; - } - } - } else if (tiisg == 0) { - y[0] = sumf; - } - - // recurse - s0 = s; - } - - // Assign the final state to the output buffer - s_buff[i] = s; -} - -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part -kernel void kernel_ssm_scan_group_f32( - constant ggml_metal_kargs_ssm_scan & args, - device const void * src0, - device const void * src1, - device const void * src2, - device const void * src3, - device const void * src4, - device const void * src5, - device const void * src6, - device float * dst, - threadgroup float * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgptg[[simdgroups_per_threadgroup]], - uint3 tgpg[[threadgroups_per_grid]]) { - - const int64_t i0 = tpitg.x; - const int64_t i1 = tgpig.x; - const int64_t ir = tgpig.y; // current head - const int64_t i3 = tgpig.z; // current seq - - const uint64_t nb00 = sizeof(float); - const uint64_t nb10 = sizeof(float); - const uint64_t nb20 = sizeof(float); - - const int64_t nc = args.d_state; - const int64_t nr = args.d_inner; - const int64_t nh = args.n_head; - const int64_t ng = args.n_group; - const int64_t n_t = args.n_seq_tokens; - - const int64_t s_off = args.s_off; - - device const int32_t * ids = (device const int32_t *) src6; - - device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); - device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); - const int64_t i = i0 + i1*nc; - const int64_t g = ir / (nh / ng); // repeat_interleave - float s0 = s0_buff[i]; - float s = s_buff[i]; - - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} - device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); - device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); - device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); - - for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} - device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} - - const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - const float x_dt = x[0] * dt_soft_plus; - const float dA = exp(dt_soft_plus * A[0]); - - const float state = (s0 * dA) + (B[i0] * x_dt); - s = state; - - // Parallel sum: This relies on the fact that this kernel will be - // dispatched with each threadgroup having (d_state, 1, 1) threads which - // are subdivided into SIMD groups of size `sgptg`. The goal is to - // compute y = sum({state * C[i] for i in range(d_state)}). - // To parallelize this effectively, we first use simd_sum over each SIMD - // group to compute the sum of each SIMD group, then place the result in - // the SIMD group's indexed bucket in the shared memory. We then sum - // over the individual group sums to compute the final sum. - - // Computed for each thread - float sumf = state * C[i0]; - - // Sum the threads in the simd group => simd sum - sumf = simd_sum(sumf); - - // Once per simd group, place the group sum into the shared buffer - if (tiisg == 0) { - shared[sgitg] = sumf; - } - - // Wait for all threads in the threadgroup to reach this point. This - // ensures that all elements of the shared buffer are populated with the - // sum of the individual simd groups. + for (int i2 = 0; i2 < n_t; i2 += sgptg) { threadgroup_barrier(mem_flags::mem_threadgroup); - // For simd group 0 at indices < num simd groups, extract the shared - // simd sum - sumf = 0.0f; - if (sgitg == 0) { - if (tiisg < sgptg) { - sumf = shared[tiisg]; - } - sumf = simd_sum(sumf); + for (int t = 0; t < sgptg && i2 + t < n_t; t++) { + const float dt0 = dt[0]; + const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0; + const float x_dt = x[0] * dtsp; + const float dA = exp(dtsp * A0); + + s = (s0 * dA) + (B[i0] * x_dt); + + const float sumf = simd_sum(s * C[i0]); + if (tiisg == 0) { - y[0] = sumf; + shared[t*NW + sgitg] = sumf; } + + // recurse + s0 = s; + + x += args.ns12; + dt += args.ns21; + B += args.ns42; + C += args.ns52; } - // recurse - s0 = s; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float sumf = simd_sum(shared[sgitg*NW + tiisg]); + + if (tiisg == 0 && i2 + sgitg < n_t) { + y[sgitg*nh*nr] = sumf; + } + + y += sgptg*nh*nr; } - // Assign the final state to the output buffer s_buff[i] = s; } @@ -4449,10 +4349,83 @@ kernel void kernel_leaky_relu_f32_4( dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope); } +constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]]; + +constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 24)]]; + +// pad the last chunk of C elements of k and v into a an extra pad buffer +kernel void kernel_flash_attn_ext_pad( + constant ggml_metal_kargs_flash_attn_ext_pad & args, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int32_t C = FC_flash_attn_ext_pad_ncpsg; + + device char * k_pad = dst; + device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3; + device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const int32_t icp = args.ne11 % C; + const int32_t ic0 = args.ne11 - icp; + + const int32_t i1 = tgpig[0]; + const int32_t i2 = tgpig[1]; + const int32_t i3 = tgpig[2]; + + if (i2 < args.ne_12_2 && i3 < args.ne_12_3) { + device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3; + device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3; + + device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3; + device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3; + + if (i1 >= icp) { + // here it is not important the exact value that will be used as we rely on masking out the scores in the attention + for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { + k_dst[i] = 0; + } + for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { + v_dst[i] = 0; + } + } else { + for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { + k_dst[i] = k_src[i]; + } + for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { + v_dst[i] = v_src[i]; + } + } + } + + if (FC_flash_attn_ext_pad_has_mask) { + if (i2 < args.ne32 && i3 < args.ne33) { + for (int ib = i1; ib < args.ne31; ib += C) { + device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0; + device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3; + + for (int i = tiitg; i < C; i += ntg.x) { + if (i >= icp) { + mask_dst[i] = -MAXHALF; + } else { + mask_dst[i] = mask_src[i]; + } + } + } + } + } +} + constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]]; constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]]; constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]]; constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]]; +constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]]; + +constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; //constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; //constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]]; @@ -4499,6 +4472,7 @@ void kernel_flash_attn_ext_impl( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16, uint3 tgpig, @@ -4623,13 +4597,58 @@ void kernel_flash_attn_ext_impl( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic = 0; ic < args.ne11; ic += C) { + for (int ic0 = 0; ic0 < args.ne11; ic0 += C) { + int ic = ic0; + + // the last partial chunk uses the pad buffer as source + if (FC_flash_attn_ext_has_kvpad && ic0 + C > args.ne11) { + k = pad; + v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; + mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; + v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; + + if (!FC_flash_attn_ext_has_mask) { + threadgroup half * sm = (threadgroup half *) (sm2); + + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + for (short i = tiisg; i < C; i += NW) { + if (ic + i >= args.ne11) { + sm[2*j*SH + i] = -MAXHALF; + } + } + } + } else { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + pm2[jj] = (device const half2 *) ((device const half *) mask + + (iq1 + j)*C + + (iq2%args.ne32)*(C*args.ne31) + + (iq3%args.ne33)*(C*args.ne31*args.ne32)); + } + } + + ic = 0; + } + // read the mask into shared mem if (FC_flash_attn_ext_has_mask) { FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { const short j = jj*NSG + sgitg; - sm2[j*SH + tiisg] = pm2[jj][tiisg]; + if (FC_flash_attn_ext_bc_mask) { + sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF); + } else { + sm2[j*SH + tiisg] = pm2[jj][tiisg]; + } + pm2[jj] += NW; } @@ -4657,7 +4676,7 @@ void kernel_flash_attn_ext_impl( // this is compile-time check, so it does not have runtime overhead if (is_same::value) { // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11); + device const k_t * pk = (device const k_t *) (k + ic*args.nb11); threadgroup const q_t * pq = sq; threadgroup s_t * ps = ss; @@ -4729,7 +4748,7 @@ void kernel_flash_attn_ext_impl( qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); for (short ii = 0; ii < DK16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11)); + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11)); if (DK16%4 == 0) { // the head is evenly divisible by 4*16 = 64, so no need for bound checks @@ -4851,7 +4870,7 @@ void kernel_flash_attn_ext_impl( { auto sst = ss; - device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21); + device const v_t * pv = (device const v_t *) (v + ic*args.nb21); pv += 8*sgitg; @@ -4893,7 +4912,7 @@ void kernel_flash_attn_ext_impl( simdgroup_load(vs, ss + 8*cc, SH, 0, false); for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21)); + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21)); if (DV16%4 == 0) { // no need for bound checks @@ -5037,13 +5056,14 @@ kernel void kernel_flash_attn_ext( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { #define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg +#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg switch (FC_flash_attn_ext_nsg) { // note: disabled cases to reduce library load time //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; @@ -5163,6 +5183,7 @@ constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_ constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]]; constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]]; constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]]; +constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]]; //constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]]; //constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]]; @@ -5200,6 +5221,7 @@ void kernel_flash_attn_ext_vec_impl( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -5306,11 +5328,37 @@ void kernel_flash_attn_ext_vec_impl( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) { - const int ic = ic0 + C*sgitg; + int ic = ic0 + C*sgitg; if (ic >= args.ne11) { break; } + // the last partial chunk uses the pad buffer as source + if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) { + k = pad; + v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; + mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; + v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; + + if (!FC_flash_attn_ext_vec_has_mask) { + if (ic + tiisg >= args.ne11) { + sm[tiisg] = -MAXHALF; + } + } else { + pm = (device const half *) (mask) + + iq1*C + + (iq2%args.ne32)*(C*args.ne31) + + (iq3%args.ne33)*(C*args.ne31*args.ne32); + } + + ic = 0; + } + if (FC_flash_attn_ext_vec_has_mask) { sm[tiisg] = pm[ic + tiisg]; } @@ -5322,7 +5370,7 @@ void kernel_flash_attn_ext_vec_impl( // Q*K^T { - device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11); + device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11); threadgroup const q4_t * pq4 = sq4; pk4 += ty*NS10/4 + tx; @@ -5337,7 +5385,7 @@ void kernel_flash_attn_ext_vec_impl( mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]); } } else { - device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11)); + device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11)); k4_t mk; @@ -5435,7 +5483,7 @@ void kernel_flash_attn_ext_vec_impl( } if (is_same::value) { - device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21); + device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21); pv4 += ty*NS20/4 + tx; @@ -5448,7 +5496,7 @@ void kernel_flash_attn_ext_vec_impl( } } else { FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { - device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21)); + device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21)); FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { const short i = ii*NL + tx; @@ -5620,13 +5668,14 @@ kernel void kernel_flash_attn_ext_vec( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { #define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg +#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg switch (FC_flash_attn_ext_vec_nsg) { // note: disabled cases to reduce library load time case 1: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; @@ -5770,21 +5819,17 @@ kernel void kernel_flash_attn_ext_vec_reduce( } template -kernel void kernel_cpy( +kernel void kernel_cpy_t_t( constant ggml_metal_kargs_cpy & args, device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 tptg[[threads_per_threadgroup]]) { + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x; - - if (i01 >= args.ne01) { - return; - } + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; @@ -5795,190 +5840,70 @@ kernel void kernel_cpy( device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) { + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) { device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; + break; } } -typedef decltype(kernel_cpy) kernel_cpy_t; +typedef decltype(kernel_cpy_t_t) kernel_cpy_t; -template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t; #endif -template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t; #endif -// TODO: templetify these kernels -kernel void kernel_cpy_f32_q8_0( +template +kernel void kernel_cpy_f32_q( constant ggml_metal_kargs_cpy & args, device const char * src0, - device char * dst, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], + ushort tiitg[[thread_index_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]; + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; - device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00); - quantize_q8_0(src, dst_data[i00/QK8_0]); + quantize_func(src, dst_data[i00]); + + break; } } -kernel void kernel_cpy_f32_q4_0( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; +typedef decltype(kernel_cpy_f32_q) cpy_f_q_t; - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0; - - device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q4_0(src, dst_data[i00/QK4_0]); - } -} - -kernel void kernel_cpy_f32_q4_1( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1; - - device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q4_1(src, dst_data[i00/QK4_1]); - } -} - -kernel void kernel_cpy_f32_q5_0( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0; - - device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q5_0(src, dst_data[i00/QK5_0]); - } -} - -kernel void kernel_cpy_f32_q5_1( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1; - - device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q5_1(src, dst_data[i00/QK5_1]); - } -} - -kernel void kernel_cpy_f32_iq4_nl( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL; - - device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_iq4_nl(src, dst_data[i00/QK4_NL]); - } -} +template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q; template kernel void kernel_cpy_q_f32( @@ -5986,11 +5911,12 @@ kernel void kernel_cpy_q_f32( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], + ushort tiitg[[thread_index_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]; + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; @@ -6002,10 +5928,12 @@ kernel void kernel_cpy_q_f32( device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) { + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { T4x4 temp; dequantize_func(src_data + i00/nl, i00%nl, temp); dst_data[i00] = temp; + + break; } } @@ -7765,66 +7693,60 @@ kernel void kernel_mul_mv_mxfp4_f32( template kernel void kernel_get_rows_q( constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + device const void * src0, + device const void * src1, + device void * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg [[threads_per_threadgroup]]) { + const int32_t iw0 = tgpig.x/args.ne10; + const int32_t i10 = tgpig.x%args.ne10; + const int32_t i11 = tgpig.y; + const int32_t i12 = tgpig.z; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; - const int64_t i02 = i11; + const int32_t i02 = i11; + const int32_t i03 = i12; - for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) { + auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); + auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); + + for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { float4x4 temp; - dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp; + dequantize_func(psrc + ind/nl, ind%nl, temp); + pdst[ind] = temp; + + break; } } -template +template kernel void kernel_get_rows_f( constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + device const void * src0, + device const void * src1, + device void * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg [[threads_per_threadgroup]]) { + const int32_t iw0 = tgpig.x/args.ne10; + const int32_t i10 = tgpig.x%args.ne10; + const int32_t i11 = tgpig.y; + const int32_t i12 = tgpig.z; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; - const int64_t i02 = i11; + const int32_t i02 = i11; + const int32_t i03 = i12; - for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { - (( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = - ((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; - } -} + auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); + auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); -kernel void kernel_get_rows_i32( - constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device int32_t * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { + pdst[ind] = psrc[ind]; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { - (( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = - ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; + break; } } @@ -8310,12 +8232,13 @@ kernel void kernel_mul_mm_id( // get rows // -typedef decltype(kernel_get_rows_f) get_rows_f_t; +typedef decltype(kernel_get_rows_f) get_rows_f_t; -template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; -template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; #endif typedef decltype(kernel_get_rows_q) get_rows_q_t; diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index cdb3818c78..f8477a2ef3 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -56,7 +56,7 @@ if (MUSAToolkit_FOUND) set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX) foreach(SOURCE ${GGML_SOURCES_MUSA}) - set(COMPILE_FLAGS "-fsigned-char -x musa -mtgpu") + set(COMPILE_FLAGS "-Od3 -fno-strict-aliasing -ffast-math -fsigned-char -x musa -mtgpu -fmusa-flush-denormals-to-zero") foreach(ARCH ${MUSA_ARCHITECTURES}) set(COMPILE_FLAGS "${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}") endforeach() diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index a9405ab012..79d2148744 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2889,10 +2889,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_REPEAT: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded case GGML_OP_PAD: - return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && - op->src[0]->ne[3] == 1 && op->ne[3] == 1 && - (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && - (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_CONV_2D: @@ -5881,7 +5878,6 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t GGML_ASSERT(dst->extra); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -5899,28 +5895,67 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t const int s_ne0 = src0->ne[0]; const int s_ne1 = src0->ne[1]; const int s_ne2 = src0->ne[2]; + const int s_ne3 = src0->ne[3]; + + const int s_nb0 = src0->nb[0]; + const int s_nb1 = src0->nb[1]; + const int s_nb2 = src0->nb[2]; + const int s_nb3 = src0->nb[3]; const int d_ne0 = dst->ne[0]; const int d_ne1 = dst->ne[1]; const int d_ne2 = dst->ne[2]; + const int d_ne3 = dst->ne[3]; + + const int d_nb0 = dst->nb[0]; + const int d_nb1 = dst->nb[1]; + const int d_nb2 = dst->nb[2]; + const int d_nb3 = dst->nb[3]; + + const int lp0 = ((const int*)(dst->op_params))[0]; + const int rp0 = ((const int*)(dst->op_params))[1]; + const int lp1 = ((const int*)(dst->op_params))[2]; + const int rp1 = ((const int*)(dst->op_params))[3]; + const int lp2 = ((const int*)(dst->op_params))[4]; + const int rp2 = ((const int*)(dst->op_params))[5]; + const int lp3 = ((const int*)(dst->op_params))[6]; + const int rp3 = ((const int*)(dst->op_params))[7]; cl_kernel kernel = backend_ctx->kernel_pad; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne0)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne1)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne2)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &s_ne3)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &s_nb0)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &s_nb1)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &s_nb2)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &s_nb3)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &d_ne3)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &d_nb0)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &d_nb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &d_nb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &d_nb3)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &lp0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &rp0)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &lp1)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &rp1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &lp2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &rp2)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &lp3)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(int), &rp3)); size_t lws0 = 64; size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0; - size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2 }; + size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2*d_ne3 }; size_t local_work_size[] = { lws0, 1, 1 }; size_t * local_work_size_ptr = local_work_size; diff --git a/ggml/src/ggml-opencl/kernels/pad.cl b/ggml/src/ggml-opencl/kernels/pad.cl index 747fa7febc..31fb7ccd3b 100644 --- a/ggml/src/ggml-opencl/kernels/pad.cl +++ b/ggml/src/ggml-opencl/kernels/pad.cl @@ -1,30 +1,39 @@ kernel void kernel_pad( - global const void * src0_ptr, - ulong src0_offset, - global void * dst_ptr, - ulong dst_offset, - int s_ne0, int s_ne1, int s_ne2, - int d_ne0, int d_ne1, int d_ne2 + global void * src0, + ulong offset0, + global void * dst, + ulong offsetd, + int ne00, int ne01, int ne02, int ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, + int ne0, int ne1, int ne2, int ne3, + ulong nb0, ulong nb1, ulong nb2, ulong nb3, + int lp0, int rp0, + int lp1, int rp1, + int lp2, int rp2, + int lp3, int rp3 ) { - global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset); - global float * dst = (global float *)((global char *)dst_ptr + dst_offset); + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); - int nidx = get_global_id(0); - int idx_d1 = get_group_id(1); - int idx_d2 = get_group_id(2); + int i0 = get_global_id(0); + int i1 = get_group_id(1); + int i2 = get_group_id(2) % ne2; + int i3 = get_group_id(2) / ne2; - if (nidx >= d_ne0) { + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { return; } - int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1; + uint src0_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00; + uint dst_idx = i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; - bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2); + global float * src0_ptr = (global float *)((global char *)src0 + src0_idx); + global float * dst_ptr = (global float *)((global char *)dst + dst_idx); - if (in_src_bounds) { - int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1; - dst[dst_el_offset] = src0[src_el_offset]; - } else { - dst[dst_el_offset] = 0.0f; - } + bool in_src_bounds = (i0 >= lp0 && i0 < ne0 - rp0) && + (i1 >= lp1 && i1 < ne1 - rp1) && + (i2 >= lp2 && i2 < ne2 - rp2) && + (i3 >= lp3 && i3 < ne3 - rp3); + + *dst_ptr = in_src_bounds ? *src0_ptr : 0.0f; } diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index f99681c84c..aad48d62a8 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -105,9 +105,12 @@ enum rpc_cmd { RPC_CMD_INIT_TENSOR, RPC_CMD_GET_ALLOC_SIZE, RPC_CMD_HELLO, + RPC_CMD_DEVICE_COUNT, RPC_CMD_COUNT, }; +static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14"); + // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold const size_t HASH_THRESHOLD = 10 * 1024 * 1024; @@ -117,7 +120,12 @@ struct rpc_msg_hello_rsp { uint8_t patch; }; +struct rpc_msg_device_count_rsp { + uint32_t device_count; +}; + struct rpc_msg_get_alloc_size_req { + uint32_t device; rpc_tensor tensor; }; @@ -130,6 +138,7 @@ struct rpc_msg_init_tensor_req { }; struct rpc_msg_alloc_buffer_req { + uint32_t device; uint64_t size; }; @@ -138,10 +147,18 @@ struct rpc_msg_alloc_buffer_rsp { uint64_t remote_size; }; +struct rpc_msg_get_alignment_req { + uint32_t device; +}; + struct rpc_msg_get_alignment_rsp { uint64_t alignment; }; +struct rpc_msg_get_max_size_req { + uint32_t device; +}; + struct rpc_msg_get_max_size_rsp { uint64_t max_size; }; @@ -192,6 +209,10 @@ struct rpc_msg_graph_compute_rsp { uint8_t result; }; +struct rpc_msg_get_device_memory_req { + uint32_t device; +}; + struct rpc_msg_get_device_memory_rsp { uint64_t free_mem; uint64_t total_mem; @@ -207,13 +228,15 @@ static ggml_guid_t ggml_backend_rpc_guid() { struct ggml_backend_rpc_buffer_type_context { std::string endpoint; + uint32_t device; std::string name; - size_t alignment; - size_t max_size; + size_t alignment; + size_t max_size; }; struct ggml_backend_rpc_context { std::string endpoint; + uint32_t device; std::string name; }; @@ -608,23 +631,30 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con RPC_STATUS_ASSERT(status); } +static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer; +} + static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { - // check if src and dst are on the same server - ggml_backend_buffer_t src_buffer = src->buffer; - ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context; - ggml_backend_buffer_t dst_buffer = dst->buffer; - ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context; - if (src_ctx->sock != dst_ctx->sock) { - return false; + if (ggml_backend_buffer_is_rpc(src->buffer)) { + // check if src and dst are on the same server + ggml_backend_buffer_t src_buffer = src->buffer; + ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context; + ggml_backend_buffer_t dst_buffer = dst->buffer; + ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context; + if (src_ctx->sock != dst_ctx->sock) { + return false; + } + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + rpc_msg_copy_tensor_req request; + request.src = serialize_tensor(src); + request.dst = serialize_tensor(dst); + rpc_msg_copy_tensor_rsp response; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.result; } - ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; - rpc_msg_copy_tensor_req request; - request.src = serialize_tensor(src); - request.dst = serialize_tensor(dst); - rpc_msg_copy_tensor_rsp response; - bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response)); - RPC_STATUS_ASSERT(status); - return response.result; + return false; } static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { @@ -653,7 +683,7 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; - rpc_msg_alloc_buffer_req request = {size}; + rpc_msg_alloc_buffer_req request = {buft_ctx->device, size}; rpc_msg_alloc_buffer_rsp response; auto sock = get_socket(buft_ctx->endpoint); bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response)); @@ -669,9 +699,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back } } -static size_t get_alignment(const std::shared_ptr & sock) { +static size_t get_alignment(const std::shared_ptr & sock, uint32_t device) { + rpc_msg_get_alignment_req request = {device}; rpc_msg_get_alignment_rsp response; - bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response)); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response)); RPC_STATUS_ASSERT(status); return response.alignment; } @@ -681,9 +712,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ return buft_ctx->alignment; } -static size_t get_max_size(const std::shared_ptr & sock) { +static size_t get_max_size(const std::shared_ptr & sock, uint32_t device) { + rpc_msg_get_max_size_req request = {device}; rpc_msg_get_max_size_rsp response; - bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response)); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response)); RPC_STATUS_ASSERT(status); return response.max_size; } @@ -700,7 +732,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty auto sock = get_socket(buft_ctx->endpoint); rpc_msg_get_alloc_size_req request; - + request.device = buft_ctx->device; request.tensor = serialize_tensor(tensor); rpc_msg_get_alloc_size_rsp response; @@ -754,7 +786,7 @@ static void add_tensor(ggml_tensor * tensor, std::vector & tensors, tensors.push_back(serialize_tensor(tensor)); } -static void serialize_graph(const ggml_cgraph * cgraph, std::vector & output) { +static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector & output) { uint32_t n_nodes = cgraph->n_nodes; std::vector tensors; std::unordered_set visited; @@ -762,24 +794,29 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector & o add_tensor(cgraph->nodes[i], tensors, visited); } // serialization format: - // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | uint32_t n_tensors = tensors.size(); - int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor); + int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor); output.resize(output_size, 0); - memcpy(output.data(), &n_nodes, sizeof(n_nodes)); + uint8_t * dest = output.data(); + memcpy(dest, &device, sizeof(device)); + dest += sizeof(device); + memcpy(dest, &n_nodes, sizeof(n_nodes)); + dest += sizeof(n_nodes); for (uint32_t i = 0; i < n_nodes; i++) { - memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t)); + memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t)); } - uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t)); - *out_ntensors = n_tensors; - rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t)); + dest += n_nodes * sizeof(uint64_t); + memcpy(dest, &n_tensors, sizeof(n_tensors)); + dest += sizeof(n_tensors); + rpc_tensor * out_tensors = (rpc_tensor *)dest; memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor)); } static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; std::vector input; - serialize_graph(cgraph, input); + serialize_graph(rpc_ctx->device, cgraph, input); rpc_msg_graph_compute_rsp response; auto sock = get_socket(rpc_ctx->endpoint); bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response)); @@ -804,12 +841,13 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .graph_optimize = */ NULL, }; -ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { +ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) { static std::mutex mutex; std::lock_guard lock(mutex); + std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]"; // NOTE: buffer types are allocated and never freed; this is by design static std::unordered_map buft_map; - auto it = buft_map.find(endpoint); + auto it = buft_map.find(buft_name); if (it != buft_map.end()) { return it->second; } @@ -818,34 +856,37 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { GGML_LOG_ERROR("Failed to connect to %s\n", endpoint); return nullptr; } - size_t alignment = get_alignment(sock); - size_t max_size = get_max_size(sock); + size_t alignment = get_alignment(sock, device); + size_t max_size = get_max_size(sock, device); ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { /* .endpoint = */ endpoint, - /* .name = */ "RPC[" + std::string(endpoint) + "]", + /* .device = */ device, + /* .name = */ buft_name, /* .alignment = */ alignment, /* .max_size = */ max_size }; - + auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { /* .iface = */ ggml_backend_rpc_buffer_type_interface, - /* .device = */ ggml_backend_rpc_add_device(endpoint), + /* .device = */ ggml_backend_reg_dev_get(reg, device), /* .context = */ buft_ctx }; - buft_map[endpoint] = buft; + buft_map[buft_name] = buft; return buft; } -ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { +ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { + std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]"; ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { - /* .endpoint = */ endpoint, - /* .name = */ "RPC[" + std::string(endpoint) + "]", + /* .endpoint = */ endpoint, + /* .device = */ device, + /* .name = */ dev_name }; - + auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_t backend = new ggml_backend { /* .guid = */ ggml_backend_rpc_guid(), /* .iface = */ ggml_backend_rpc_interface, - /* .device = */ ggml_backend_rpc_add_device(endpoint), + /* .device = */ ggml_backend_reg_dev_get(reg, device), /* .context = */ ctx }; return backend; @@ -855,37 +896,39 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) { return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid()); } -static void get_device_memory(const std::shared_ptr & sock, size_t * free, size_t * total) { +static void get_device_memory(const std::shared_ptr & sock, uint32_t device, size_t * free, size_t * total) { + rpc_msg_get_device_memory_req request; + request.device = device; rpc_msg_get_device_memory_rsp response; - bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response)); + bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response)); RPC_STATUS_ASSERT(status); *free = response.free_mem; *total = response.total_mem; } -void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { +void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) { auto sock = get_socket(endpoint); if (sock == nullptr) { *free = 0; *total = 0; return; } - get_device_memory(sock, free, total); + get_device_memory(sock, device, free, total); } // RPC server-side implementation class rpc_server { public: - rpc_server(ggml_backend_t backend, const char * cache_dir) - : backend(backend), cache_dir(cache_dir) { + rpc_server(std::vector backends, const char * cache_dir) + : backends(std::move(backends)), cache_dir(cache_dir) { } ~rpc_server(); void hello(rpc_msg_hello_rsp & response); - void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); - void get_alignment(rpc_msg_get_alignment_rsp & response); - void get_max_size(rpc_msg_get_max_size_rsp & response); + bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); + bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response); + bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response); bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response); bool free_buffer(const rpc_msg_free_buffer_req & request); bool buffer_clear(const rpc_msg_buffer_clear_req & request); @@ -906,7 +949,7 @@ private: std::unordered_map & tensor_map); - ggml_backend_t backend; + std::vector backends; const char * cache_dir; std::unordered_set buffers; }; @@ -919,6 +962,10 @@ void rpc_server::hello(rpc_msg_hello_rsp & response) { } bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } ggml_backend_buffer_type_t buft; struct ggml_init_params params { /*.mem_size =*/ ggml_tensor_overhead(), @@ -935,10 +982,10 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_ GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n"); return false; } - LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data); + LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data); if (tensor->buffer == nullptr) { //No buffer allocated. - buft = ggml_backend_get_default_buffer_type(backend); + buft = ggml_backend_get_default_buffer_type(backends[dev_id]); } else { buft = tensor->buffer->buft; } @@ -948,33 +995,49 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_ return true; } -void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) { - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); +bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size); response.remote_ptr = 0; response.remote_size = 0; if (buffer != nullptr) { response.remote_ptr = reinterpret_cast(buffer); response.remote_size = buffer->size; - LOG_DBG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size); + LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", + __func__, dev_id, request.size, response.remote_ptr, response.remote_size); buffers.insert(buffer); } else { - LOG_DBG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size); + LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size); } + return true; } -void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) { - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); +bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); size_t alignment = ggml_backend_buft_get_alignment(buft); - LOG_DBG("[%s] alignment: %lu\n", __func__, alignment); + LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment); response.alignment = alignment; + return true; } -void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) { - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); +bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) { + uint32_t dev_id = request.device; + if (dev_id >= backends.size()) { + return false; + } + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]); size_t max_size = ggml_backend_buft_get_max_size(buft); - LOG_DBG("[%s] max_size: %lu\n", __func__, max_size); + LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size); response.max_size = max_size; + return true; } bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) { @@ -1332,23 +1395,33 @@ ggml_tensor * rpc_server::create_node(uint64_t id, bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response) { // serialization format: - // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | - if (input.size() < sizeof(uint32_t)) { + // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | + if (input.size() < 2*sizeof(uint32_t)) { + return false; + } + const uint8_t * src = input.data(); + uint32_t device; + memcpy(&device, src, sizeof(device)); + src += sizeof(device); + if (device >= backends.size()) { return false; } uint32_t n_nodes; - memcpy(&n_nodes, input.data(), sizeof(n_nodes)); - if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) { + memcpy(&n_nodes, src, sizeof(n_nodes)); + src += sizeof(n_nodes); + if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) { return false; } - const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes)); + const uint64_t * nodes = (const uint64_t *)src; + src += n_nodes*sizeof(uint64_t); uint32_t n_tensors; - memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors)); - if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) { + memcpy(&n_tensors, src, sizeof(n_tensors)); + src += sizeof(n_tensors); + if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) { return false; } - const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors)); - LOG_DBG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors); + const rpc_tensor * tensors = (const rpc_tensor *)src; + LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors); size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); @@ -1380,7 +1453,7 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph return false; } } - ggml_status status = ggml_backend_graph_compute(backend, graph); + ggml_status status = ggml_backend_graph_compute(backends[device], graph); response.result = status; return true; } @@ -1391,9 +1464,9 @@ rpc_server::~rpc_server() { } } -static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, - sockfd_t sockfd, size_t free_mem, size_t total_mem) { - rpc_server server(backend, cache_dir); +static void rpc_serve_client(const std::vector & backends, const char * cache_dir, + sockfd_t sockfd, const std::vector & free_mem, const std::vector & total_mem) { + rpc_server server(backends, cache_dir); uint8_t cmd; if (!recv_data(sockfd, &cmd, 1)) { return; @@ -1425,13 +1498,26 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, // HELLO command is handled above return; } + case RPC_CMD_DEVICE_COUNT: { + if (!recv_msg(sockfd, nullptr, 0)) { + return; + } + rpc_msg_device_count_rsp response; + response.device_count = backends.size(); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } case RPC_CMD_ALLOC_BUFFER: { rpc_msg_alloc_buffer_req request; if (!recv_msg(sockfd, &request, sizeof(request))) { return; } rpc_msg_alloc_buffer_rsp response; - server.alloc_buffer(request, response); + if (!server.alloc_buffer(request, response)) { + return; + } if (!send_msg(sockfd, &response, sizeof(response))) { return; } @@ -1452,22 +1538,28 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, break; } case RPC_CMD_GET_ALIGNMENT: { - if (!recv_msg(sockfd, nullptr, 0)) { + rpc_msg_get_alignment_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { return; } rpc_msg_get_alignment_rsp response; - server.get_alignment(response); + if (!server.get_alignment(request, response)) { + return; + } if (!send_msg(sockfd, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_MAX_SIZE: { - if (!recv_msg(sockfd, nullptr, 0)) { + rpc_msg_get_max_size_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { return; } rpc_msg_get_max_size_rsp response; - server.get_max_size(response); + if (!server.get_max_size(request, response)) { + return; + } if (!send_msg(sockfd, &response, sizeof(response))) { return; } @@ -1593,12 +1685,19 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, break; } case RPC_CMD_GET_DEVICE_MEMORY: { - if (!recv_msg(sockfd, nullptr, 0)) { + rpc_msg_get_device_memory_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + auto dev_id = request.device; + if (dev_id >= backends.size()) { return; } rpc_msg_get_device_memory_rsp response; - response.free_mem = free_mem; - response.total_mem = total_mem; + response.free_mem = free_mem[dev_id]; + response.total_mem = total_mem[dev_id]; + LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id, + response.free_mem, response.total_mem); if (!send_msg(sockfd, &response, sizeof(response))) { return; } @@ -1612,16 +1711,41 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, } } -void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, - const char * cache_dir, - size_t free_mem, size_t total_mem) { +void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, + size_t n_threads, size_t n_devices, + ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) { + if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) { + fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n"); + return; + } + std::vector backends; + std::vector free_mem_vec(free_mem, free_mem + n_devices); + std::vector total_mem_vec(total_mem, total_mem + n_devices); printf("Starting RPC server v%d.%d.%d\n", RPC_PROTO_MAJOR_VERSION, RPC_PROTO_MINOR_VERSION, RPC_PROTO_PATCH_VERSION); printf(" endpoint : %s\n", endpoint); printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a"); - printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024)); + printf("Devices:\n"); + for (size_t i = 0; i < n_devices; i++) { + auto dev = devices[i]; + printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024); + auto backend = ggml_backend_dev_init(dev, nullptr); + if (!backend) { + fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev)); + return; + } + backends.push_back(backend); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + if (reg) { + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + ggml_backend_set_n_threads_fn(backend, n_threads); + } + } + } std::string host; int port; @@ -1649,22 +1773,27 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint fprintf(stderr, "Failed to accept client connection\n"); return; } - printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); + printf("Accepted client connection\n"); fflush(stdout); - rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem); + rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec); printf("Client connection closed\n"); fflush(stdout); } #ifdef _WIN32 WSACleanup(); #endif + for (auto backend : backends) { + ggml_backend_free(backend); + } } // device interface struct ggml_backend_rpc_device_context { std::string endpoint; + uint32_t device; std::string name; + std::string description; }; static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) { @@ -1676,15 +1805,13 @@ static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) { static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; - return ctx->name.c_str(); + return ctx->description.c_str(); } static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; - ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total); - - GGML_UNUSED(dev); + ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total); } static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) { @@ -1710,7 +1837,7 @@ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggm static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; - return ggml_backend_rpc_init(ctx->endpoint.c_str()); + return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device); GGML_UNUSED(params); } @@ -1718,7 +1845,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; - return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str()); + return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device); GGML_UNUSED(dev); } @@ -1736,7 +1863,7 @@ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_b } ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context; - return buft_ctx->endpoint == dev_ctx->endpoint; + return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device; } static const struct ggml_backend_device_i ggml_backend_rpc_device_i = { @@ -1759,28 +1886,34 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = { // backend reg interface -static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) { - return "RPC"; +struct ggml_backend_rpc_reg_context { + std::string name; + std::vector devices; +}; - GGML_UNUSED(reg); +static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) { + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + return ctx ? ctx->name.c_str() : "RPC"; } static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) { - return 0; - - GGML_UNUSED(reg); + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + return ctx ? ctx->devices.size() : 0; } static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) { - GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead"); - - GGML_UNUSED(reg); - GGML_UNUSED(index); + ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context; + if (ctx == nullptr) { + GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead"); + } else { + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; + } } static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) { - if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) { - return (void *)ggml_backend_rpc_add_device; + if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) { + return (void *)ggml_backend_rpc_add_server; } if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) { return (void *)ggml_backend_rpc_start_server; @@ -1807,30 +1940,61 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) { return &ggml_backend_rpc_reg; } -ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) { - static std::unordered_map dev_map; - - static std::mutex mutex; - std::lock_guard lock(mutex); - - if (dev_map.find(endpoint) != dev_map.end()) { - return dev_map[endpoint]; - } - - ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context { - /* .endpoint = */ endpoint, - /* .name = */ "RPC[" + std::string(endpoint) + "]", - }; - - ggml_backend_dev_t dev = new ggml_backend_device { - /* .iface = */ ggml_backend_rpc_device_i, - /* .reg = */ ggml_backend_rpc_reg(), - /* .context = */ ctx, - }; - - dev_map[endpoint] = dev; - - return dev; +static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) { + auto sock = get_socket(endpoint); + rpc_msg_device_count_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + return response.device_count; } +static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = { + /* .get_name = */ ggml_backend_rpc_reg_get_name, + /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count, + /* .get_device = */ ggml_backend_rpc_reg_get_device, + /* .get_proc_address = */ ggml_backend_rpc_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) { + static std::unordered_map reg_map; + static std::mutex mutex; + static uint32_t dev_id = 0; + std::lock_guard lock(mutex); + if (reg_map.find(endpoint) != reg_map.end()) { + return reg_map[endpoint]; + } + uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint); + if (dev_count == 0) { + return nullptr; + } + ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context; + ctx->name = "RPC[" + std::string(endpoint) + "]"; + for (uint32_t ind = 0; ind < dev_count; ind++) { + std::string dev_name = "RPC" + std::to_string(dev_id); + std::string dev_desc = std::string(endpoint); + ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context { + /* .endpoint = */ endpoint, + /* .device = */ ind, + /* .name = */ dev_name, + /* .description = */ dev_desc + }; + + ggml_backend_dev_t dev = new ggml_backend_device { + /* .iface = */ ggml_backend_rpc_device_i, + /* .reg = */ ggml_backend_rpc_reg(), + /* .context = */ dev_ctx, + }; + ctx->devices.push_back(dev); + dev_id++; + } + ggml_backend_reg_t reg = new ggml_backend_reg { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_rpc_reg_interface, + /* .context = */ ctx + }; + reg_map[endpoint] = reg; + return reg; +} + + GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg) diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index b97e7bf995..83a83887b5 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -1,5 +1,6 @@ cmake_minimum_required(VERSION 3.19) cmake_policy(SET CMP0114 NEW) +cmake_policy(SET CMP0116 NEW) find_package(Vulkan COMPONENTS glslc REQUIRED) @@ -54,25 +55,25 @@ if (Vulkan_FOUND) # Test all shader extensions test_shader_extension_support( "GL_KHR_cooperative_matrix" - "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat.comp" "GGML_VULKAN_COOPMAT_GLSLC_SUPPORT" ) test_shader_extension_support( "GL_NV_cooperative_matrix2" - "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2.comp" "GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT" ) test_shader_extension_support( "GL_EXT_integer_dot_product" - "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp" "GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT" ) test_shader_extension_support( "GL_EXT_bfloat16" - "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/bfloat16.comp" "GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT" ) @@ -160,7 +161,6 @@ if (Vulkan_FOUND) set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$") set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}") set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp") - set (_ggml_vk_source "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp") set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders") set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv") @@ -176,24 +176,35 @@ if (Vulkan_FOUND) add_custom_command( OUTPUT ${_ggml_vk_header} - ${_ggml_vk_source} - COMMAND ${_ggml_vk_genshaders_cmd} - --glslc ${Vulkan_GLSLC_EXECUTABLE} - --input-dir ${_ggml_vk_input_dir} --output-dir ${_ggml_vk_output_dir} --target-hpp ${_ggml_vk_header} - --target-cpp ${_ggml_vk_source} - --no-clean - - DEPENDS ${_ggml_vk_shader_files} - ${_ggml_vk_shaders_gen_sources} + DEPENDS ${_ggml_vk_shaders_gen_sources} vulkan-shaders-gen - - COMMENT "Generate vulkan shaders" + COMMENT "Generate vulkan shaders header" ) + target_sources(ggml-vulkan PRIVATE ${_ggml_vk_header}) - target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header}) + foreach (file_full ${_ggml_vk_shader_files}) + get_filename_component(file ${file_full} NAME) + set (_ggml_vk_target_cpp "${CMAKE_CURRENT_BINARY_DIR}/${file}.cpp") + + add_custom_command( + OUTPUT ${_ggml_vk_target_cpp} + DEPFILE ${_ggml_vk_target_cpp}.d + COMMAND ${_ggml_vk_genshaders_cmd} + --glslc ${Vulkan_GLSLC_EXECUTABLE} + --source ${file_full} + --output-dir ${_ggml_vk_output_dir} + --target-hpp ${_ggml_vk_header} + --target-cpp ${_ggml_vk_target_cpp} + DEPENDS ${file_full} + ${_ggml_vk_shaders_gen_sources} + vulkan-shaders-gen + COMMENT "Generate vulkan shaders for ${file}" + ) + target_sources(ggml-vulkan PRIVATE ${_ggml_vk_target_cpp}) + endforeach() else() message(WARNING "Vulkan not found") diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 2608cbd068..3cd89c7116 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -9,8 +9,14 @@ #define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1 // We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE // to avoid conflicts with applications or other libraries who might use it. +#if VK_HEADER_VERSION >= 301 namespace vk::detail { class DispatchLoaderDynamic; } -vk::detail::DispatchLoaderDynamic & ggml_vk_default_dispatcher(); +using vk::detail::DispatchLoaderDynamic; +#else +namespace vk { class DispatchLoaderDynamic; } +using vk::DispatchLoaderDynamic; +#endif +DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher() #include @@ -387,6 +393,7 @@ struct vk_device_struct { vk::PhysicalDeviceProperties properties; std::string name; uint64_t max_memory_allocation_size; + uint64_t max_buffer_size; uint64_t suballocation_block_size; bool fp16; bool bf16; @@ -1557,6 +1564,12 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx static void ggml_backend_vk_free(ggml_backend_t backend); +static VkDeviceSize ggml_vk_get_max_buffer_range(const ggml_backend_vk_context * ctx, const vk_buffer &buf, const VkDeviceSize offset) { + const VkDeviceSize range = std::min(VkDeviceSize{buf->size - offset}, + VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange}); + return range; +} + // Wait for ctx->fence to be signaled. static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) { // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep @@ -2006,8 +2019,8 @@ static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_pr static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list) { VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")"); - if (size > device->max_memory_allocation_size) { - throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit"); + if (size > device->max_buffer_size) { + throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit"); } vk_buffer buf = std::make_shared(); @@ -2153,8 +2166,8 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf) { buf.reset(); } -static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) { - return { buf, 0, VK_WHOLE_SIZE }; +static vk_subbuffer ggml_vk_subbuffer(const ggml_backend_vk_context* ctx, const vk_buffer& buf, size_t offset = 0) { + return { buf, offset, ggml_vk_get_max_buffer_range(ctx, buf, offset) }; } static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) { @@ -2608,8 +2621,6 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t D_lsb = D ^ (D & (D-1)); uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); - // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads - GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split}; }; @@ -3849,17 +3860,27 @@ static vk_device ggml_vk_get_device(size_t idx) { const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) { - device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE); + device->max_memory_allocation_size = std::stoull(GGML_VK_FORCE_MAX_ALLOCATION_SIZE); } else if (maintenance4_support) { device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize); } else { device->max_memory_allocation_size = props3.maxMemoryAllocationSize; } + const char* GGML_VK_FORCE_MAX_BUFFER_SIZE = getenv("GGML_VK_FORCE_MAX_BUFFER_SIZE"); + + if (GGML_VK_FORCE_MAX_BUFFER_SIZE != nullptr) { + device->max_buffer_size = std::stoull(GGML_VK_FORCE_MAX_BUFFER_SIZE); + } else if (maintenance4_support) { + device->max_buffer_size = props4.maxBufferSize; + } else { + device->max_buffer_size = device->max_memory_allocation_size; + } + const char* GGML_VK_SUBALLOCATION_BLOCK_SIZE = getenv("GGML_VK_SUBALLOCATION_BLOCK_SIZE"); if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) { - device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE); + device->suballocation_block_size = std::stoull(GGML_VK_SUBALLOCATION_BLOCK_SIZE); } else { // Limit batching of allocations to 1GB by default to avoid fragmentation issues device->suballocation_block_size = 1024*1024*1024; @@ -4538,9 +4559,8 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve static bool ggml_vk_instance_debug_utils_ext_available(const std::vector & instance_extensions); static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev); -static vk::detail::DispatchLoaderDynamic ggml_vk_default_dispatcher_instance; - -vk::detail::DispatchLoaderDynamic & ggml_vk_default_dispatcher() { +static DispatchLoaderDynamic ggml_vk_default_dispatcher_instance; +DispatchLoaderDynamic & ggml_vk_default_dispatcher() { return ggml_vk_default_dispatcher_instance; } @@ -6145,9 +6165,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0; if ( - (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || - (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) || - (split_k > 1 && split_k_size > ctx->device->max_memory_allocation_size)) { + (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (split_k > 1 && split_k_size > ctx->device->properties.limits.maxStorageBufferRange)) { GGML_ABORT("Requested preallocation size is too large"); } if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { @@ -6222,7 +6242,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } if (x_non_contig) { - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0)); } else if (qx_needs_dequant) { const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); @@ -6234,7 +6254,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; } @@ -6245,7 +6265,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true); + ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; } @@ -6267,14 +6287,11 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; } - // No bounds checking is needed for dst. This is basically VK_WHOLE_SIZE but clamped to maxStorageBufferRange. - VkDeviceSize d_range = std::min(VkDeviceSize{d_D->size - d_buf_offset}, VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange}); - // compute ggml_vk_matmul( ctx, subctx, pipeline, { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total }, - { d_D, d_buf_offset, d_range }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, + ggml_vk_subbuffer(ctx, d_D, d_buf_offset), { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d, split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n @@ -6441,8 +6458,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144; } if ( - (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || - (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { GGML_ABORT("Requested preallocation size is too large"); } if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { @@ -6507,7 +6524,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& } GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0)); } if (y_non_contig) { GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); @@ -6516,7 +6533,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; } @@ -6527,7 +6544,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true); + ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; } @@ -6926,8 +6943,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t x_sz_upd = x_sz * ne02 * ne03; const uint64_t y_sz_upd = y_sz * ne12 * ne13; if ( - (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || - (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { GGML_ABORT("Requested preallocation size is too large"); } if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { @@ -6994,7 +7011,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& } if (x_non_contig) { - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0)); } else if (qx_needs_dequant) { const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, @@ -7007,7 +7024,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; } @@ -7140,8 +7157,8 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const uint64_t x_sz_upd = x_sz * ne02 * ne03; const uint64_t y_sz_upd = y_sz * ne12 * ne13; if ( - (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || - (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) { + (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { GGML_ABORT("Requested preallocation size is too large"); } if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { @@ -7207,7 +7224,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (x_non_contig) { GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0)); } if (y_non_contig) { GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); @@ -7216,7 +7233,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; } @@ -7452,8 +7469,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) { aligned = false; } - // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads - GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0); bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; @@ -7493,7 +7508,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0; - if (split_k_size > ctx->device->max_memory_allocation_size) { + if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) { GGML_ABORT("Requested preallocation size is too large"); } if (ctx->prealloc_size_split_k < split_k_size) { @@ -7615,12 +7630,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { - vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, + ggml_vk_subbuffer(ctx, d_Q, q_buf_offset), + ggml_vk_subbuffer(ctx, d_K, k_buf_offset), + ggml_vk_subbuffer(ctx, d_V, v_buf_offset), + ggml_vk_subbuffer(ctx, d_M, m_buf_offset), + ggml_vk_subbuffer(ctx, d_S, s_buf_offset), + ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0), }, // We only use split_k when group query attention is enabled, which means // there's no more than one tile of rows (i.e. workgroups_x would have been @@ -7632,21 +7647,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, { - vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, - vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, + ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0), + ggml_vk_subbuffer(ctx, d_S, s_buf_offset), + ggml_vk_subbuffer(ctx, d_D, d_buf_offset), }, pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); ctx->prealloc_split_k_need_sync = true; } else { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { - vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, + ggml_vk_subbuffer(ctx, d_Q, q_buf_offset), + ggml_vk_subbuffer(ctx, d_K, k_buf_offset), + ggml_vk_subbuffer(ctx, d_V, v_buf_offset), + ggml_vk_subbuffer(ctx, d_M, m_buf_offset), + ggml_vk_subbuffer(ctx, d_S, s_buf_offset), + ggml_vk_subbuffer(ctx, d_D, d_buf_offset), }, pc, { workgroups_x, workgroups_y, workgroups_z }); } @@ -8355,18 +8370,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } } - uint64_t x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0; - uint64_t y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 : 0; - uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0; - uint64_t d_sz = ggml_type_size(dst->type) * ned; - vk_buffer d_D = dst_buf_ctx->dev_buffer; - // Workaround for tiny tensor inputs on ROPE - if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) { - y_sz = VK_WHOLE_SIZE; - } - GGML_ASSERT(d_D != nullptr); uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; if(!src0_uma) { @@ -8391,26 +8396,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); - if (op_supports_incontiguous) { - x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0); - y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0; - z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0; - d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst); - - if (x_buf_offset + x_sz >= d_X->size) { - x_sz = VK_WHOLE_SIZE; - } - if (use_src1 && y_buf_offset + y_sz >= d_Y->size) { - y_sz = VK_WHOLE_SIZE; - } - if (use_src2 && z_buf_offset + z_sz >= d_Z->size) { - z_sz = VK_WHOLE_SIZE; - } - if (d_buf_offset + d_sz >= d_D->size) { - d_sz = VK_WHOLE_SIZE; - } - } - std::array elements; // Single call if dimension 2 is contiguous @@ -8601,19 +8586,31 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co break; } - if (!op_supports_incontiguous) { - if (x_sz != VK_WHOLE_SIZE) { - x_sz *= ne02 * ne03; + uint64_t x_sz, y_sz, z_sz, d_sz; + + if (op_supports_incontiguous) { + x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0); + y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0; + z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0; + d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst); + + if (x_buf_offset + x_sz >= d_X->size) { + x_sz = ggml_vk_get_max_buffer_range(ctx, d_X, x_buf_offset); } - if (use_src1 && y_sz != VK_WHOLE_SIZE) { - y_sz *= ne12 * ne13; + if (use_src1 && y_buf_offset + y_sz >= d_Y->size) { + y_sz = ggml_vk_get_max_buffer_range(ctx, d_Y, y_buf_offset); } - if (use_src2 && z_sz != VK_WHOLE_SIZE) { - z_sz *= ne22 * ne23; + if (use_src2 && z_buf_offset + z_sz >= d_Z->size) { + z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset); } - if (d_sz != VK_WHOLE_SIZE) { - d_sz *= ned2 * ned3; + if (d_buf_offset + d_sz >= d_D->size) { + d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset); } + } else { + x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03; + y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0; + z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0; + d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3; } if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) { @@ -8623,7 +8620,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz }, - vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE }, + ggml_vk_subbuffer(ctx, d_A, a_buf_offset), }, pc, elements); } else if (op == GGML_OP_GLU) { // Empty src1 is possible in glu, but the shader needs a buffer @@ -8816,18 +8813,18 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, static_assert(MAX_PARAMETER_COUNT == 12); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { - vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[1], offset[1], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[2], offset[2], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[3], offset[3], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[4], offset[4], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[8], offset[8], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[9], offset[9], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[10], offset[10], VK_WHOLE_SIZE }, - vk_subbuffer{ buf[11], offset[11], VK_WHOLE_SIZE }, + ggml_vk_subbuffer(ctx, buf[0], offset[0]), + ggml_vk_subbuffer(ctx, buf[1], offset[1]), + ggml_vk_subbuffer(ctx, buf[2], offset[2]), + ggml_vk_subbuffer(ctx, buf[3], offset[3]), + ggml_vk_subbuffer(ctx, buf[4], offset[4]), + ggml_vk_subbuffer(ctx, buf[5], offset[5]), + ggml_vk_subbuffer(ctx, buf[6], offset[6]), + ggml_vk_subbuffer(ctx, buf[7], offset[7]), + ggml_vk_subbuffer(ctx, buf[8], offset[8]), + ggml_vk_subbuffer(ctx, buf[9], offset[9]), + ggml_vk_subbuffer(ctx, buf[10], offset[10]), + ggml_vk_subbuffer(ctx, buf[11], offset[11]), }, pc, elements); } @@ -10001,7 +9998,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ggml_vk_ctx_begin(ctx->device, subctx); for (size_t i = 0; i < num_it; i++) { ggml_vk_matmul( - ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), + ctx, subctx, p, ggml_vk_subbuffer(ctx, d_X), ggml_vk_subbuffer(ctx, d_Y), ggml_vk_subbuffer(ctx, d_D), ggml_vk_subbuffer(ctx, ctx->prealloc_split_k), m, n, k, k, k, m, k*m, k*n, m*n, split_k, batch, batch, batch, 1, 1, n @@ -10312,7 +10309,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ // // vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); // ggml_vk_ctx_begin(ctx->device, subctx); -// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne); +// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, x_buf), ggml_vk_subbuffer(ctx, qx_buf), ne); // ggml_vk_ctx_end(subctx); // // auto begin = std::chrono::high_resolution_clock::now(); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp index d896f1ef0b..5084a70ed4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp index 00cf2dd62f..3bcfe6908e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp @@ -6,8 +6,8 @@ #extension GL_KHR_shader_subgroup_basic : enable #endif -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" const uint num_threads = 256; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp index 3ae8f0116c..495249d5f6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp @@ -2,7 +2,7 @@ #extension GL_EXT_control_flow_attributes : require -#include "types.comp" +#include "types.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp index a1d4c240dd..7c12877671 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp index dc53a401e0..c81b84452e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_control_flow_attributes : enable -#include "types.comp" +#include "types.glsl" layout(constant_id = 0) const int BLOCK_SIZE = 1024; layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp index 1e5cb8dae4..653431895e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp b/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp index 9ee2f1fae2..e404698382 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp index 6567a8c54c..ca1a3ac25b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" #extension GL_EXT_control_flow_attributes : require diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp index 938c74da50..70a301488e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp @@ -1,6 +1,6 @@ #version 450 -#include "types.comp" +#include "types.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index 44a64ddc80..0367e80bbf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -11,7 +11,7 @@ # extension GL_KHR_shader_subgroup_shuffle : enable #endif -#include "types.comp" +#include "types.glsl" // shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j layout(binding = 0) readonly buffer A { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp index b17b4e83ee..5217e18bdd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp @@ -1,6 +1,6 @@ #version 450 -#include "types.comp" +#include "types.glsl" layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin] layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin] diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp index f476a2e3dd..9f8bfd3c18 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp index 978d430030..06df509525 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -1,8 +1,8 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" -#include "dequant_funcs.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" +#include "dequant_funcs.glsl" #if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) // 16 invocations needed for init_iq_shmem diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index bc2e1f2df3..b8c40eec10 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -1,7 +1,7 @@ #version 450 -#include "rte.comp" -#include "types.comp" +#include "rte.glsl" +#include "types.glsl" #if defined(SET_ROWS) && QUANT_K == 1 layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; @@ -14,7 +14,7 @@ const uint BLOCK_SIZE = 32; layout (binding = 0) readonly buffer S {float data_s[];}; #if defined(SET_ROWS) -#include "generic_binary_head.comp" +#include "generic_binary_head.glsl" layout (binding = 1) readonly buffer C {B_TYPE data_i[];}; layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];}; @@ -25,7 +25,7 @@ layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];}; #endif #else -#include "generic_unary_head.comp" +#include "generic_unary_head.glsl" layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];}; #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp index 0b8d02f58f..db6865db98 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp b/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp index d9345497c7..e75df66756 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp @@ -2,8 +2,8 @@ #extension GL_EXT_control_flow_attributes : enable -#include "types.comp" -#include "generic_head.comp" +#include "types.glsl" +#include "generic_head.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp index a4d3fca556..765afffa80 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl similarity index 99% rename from ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 73fef4fa65..0d98f5a9d6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require #endif -#include "types.comp" +#include "types.glsl" #if defined(A_TYPE_PACKED16) layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl similarity index 99% rename from ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 706540fd85..6a5bb4574d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -1,5 +1,5 @@ -#include "types.comp" +#include "types.glsl" layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { block_q4_0_packed16 block; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl similarity index 91% rename from ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl index 8d806435b7..addceafade 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl @@ -10,4 +10,4 @@ layout (push_constant) uniform parameter uint nel; } p; -#include "types.comp" +#include "types.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp index b604c1881a..637c95fa35 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp index fd1e4e30d2..d1cbc5e9d0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp index 127c7b6424..78490162cd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp index a08331c40d..9b8ce0a7f8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp index 0ae9acd02a..aacf07d0f8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp index e4f42be94c..f2c20b1d2c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp index 19c7fdeefc..671c1f4a0d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp index 46d9ad15eb..8f7833eab2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp index f930852a48..a313699775 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp index ee496e9d56..ffba5a77dd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp index d4e4e6bae6..58dc2e5dfd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp index 3661f771c7..0c90be8b4e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp index 4081853272..b92b292135 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp index 2f27eee686..6b63cbe583 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp index 1370db3654..8b7be557e9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp index b20b805292..f1b0bac872 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp index dc59fe3b77..c495b31f17 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp index 3f3b839e11..6bc04670fc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp index 9cf34256e8..c8d6fcb49f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp index bd1344a88d..10844ddf78 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp @@ -1,6 +1,6 @@ #version 450 -#include "dequant_head.comp" +#include "dequant_head.glsl" layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp index 26d8bc22ad..9cef8a8ec3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp @@ -10,7 +10,7 @@ layout (push_constant) uniform parameter uint n_past; } p; -#include "types.comp" +#include "types.glsl" layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/div.comp b/ggml/src/ggml-vulkan/vulkan-shaders/div.comp index 9fb69c6c15..572472f8a9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/div.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" const uint num_threads = 256; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp index a3941372a7..b69d4ddb09 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp @@ -1,8 +1,8 @@ #version 450 -#include "rte.comp" -#include "generic_head.comp" -#include "types.comp" +#include "rte.glsl" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/bfloat16.comp similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/bfloat16.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat.comp similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2.comp similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/integer_dot.comp similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/integer_dot.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 43b906e5ed..62acbf107a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -8,8 +8,8 @@ #extension GL_KHR_shader_subgroup_shuffle : enable -#include "types.comp" -#include "flash_attn_base.comp" +#include "types.glsl" +#include "flash_attn_base.glsl" const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; @@ -153,12 +153,13 @@ void main() { } if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br) { - if (!KV_bounds_check || j * Bc + c < KV) { + if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); } else { masksh[c][r] = float(0); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index ddb1246e0b..2066a05b34 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -10,8 +10,8 @@ #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable -#include "types.comp" -#include "flash_attn_base.comp" +#include "types.glsl" +#include "flash_attn_base.glsl" const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; @@ -201,11 +201,13 @@ void main() { } if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - if (!KV_bounds_check || j * Bc + c < KV) { + if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)])); } } @@ -356,8 +358,8 @@ void main() { } if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2); float ms = 1.0f; float vs = 1.0f; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index ab647e9bc8..910da1ab0c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -16,9 +16,9 @@ #extension GL_KHR_shader_subgroup_vote : enable #extension GL_EXT_null_initializer : enable -#include "types.comp" -#include "dequant_funcs_cm2.comp" -#include "flash_attn_base.comp" +#include "types.glsl" +#include "dequant_funcs_cm2.glsl" +#include "flash_attn_base.glsl" layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; @@ -154,15 +154,31 @@ void main() { } if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { - tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); - tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); - tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - coopmat mv; + if (nem1_bounds_check) { + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); - coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + coopmat mv; - S += slopeMat*coopmat(mv); + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + + S += slopeMat*coopmat(mv); + } else { + tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); + // Don't clamp against nem1 when GQA is enabled + uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1; + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + + coopmat mv; + + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + + S += slopeMat*coopmat(mv); + } } // Clear padding elements to -inf, so they don't contribute to rowmax diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp index f4268ed24f..e017b50368 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp @@ -1,6 +1,6 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" const float GELU_COEF_A = 0.044715f; const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; @@ -10,4 +10,4 @@ float op(float a, float b) { return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp index cbd4cb36bf..759a1848fa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp @@ -1,6 +1,6 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation // ref: https://www.johndcook.com/blog/python_erf/ @@ -24,4 +24,4 @@ float op(float a, float b) { return 0.5f * a * (1.0f + erf_approx) * b; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp index 3a2a6897bf..c4032ab21d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp @@ -1,6 +1,6 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" const float GELU_QUICK_COEF = -1.702f; @@ -8,4 +8,4 @@ float op(float a, float b) { return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp index 4cc7a68ca1..a95c2525c8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp index 5fd5a5e703..58375aba09 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp index e6e6fcfd20..bfdfe2182d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl similarity index 97% rename from ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl index 750e785753..99595fc688 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl @@ -1,8 +1,8 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_control_flow_attributes : require -#include "rte.comp" -#include "utils.comp" +#include "rte.glsl" +#include "utils.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp index 7ef75cd7a4..76d83041ce 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp index 339f905fc7..9dba437edb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -2,9 +2,9 @@ #extension GL_EXT_control_flow_attributes : enable -#include "types.comp" -#include "generic_binary_head.comp" -#include "dequant_funcs.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" +#include "dequant_funcs.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl similarity index 95% rename from ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl index 51d70869d9..2168989340 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl @@ -1,6 +1,6 @@ #extension GL_EXT_shader_16bit_storage : require -#include "rte.comp" +#include "rte.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp index b6a0d56454..bdf97dbb5d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp b/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp index 1da252cc66..b4dbdf3141 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp b/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp index 3afc588274..1ec315915e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index f0f19a019c..1827d647a2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -3,9 +3,8 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_control_flow_attributes : require -#include "rte.comp" - -#include "types.comp" +#include "rte.glsl" +#include "types.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp index 9faa636ac2..4bf8b4ca04 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp @@ -4,9 +4,8 @@ #extension GL_EXT_control_flow_attributes : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "rte.comp" - -#include "types.comp" +#include "rte.glsl" +#include "types.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp index deba8c3985..83ef2f8795 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp index d90a99aea5..b281e855cb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp index 43de19df8e..02ef1eace1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" const uint num_threads = 256; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index bb429dd594..9a03925cfd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl similarity index 99% rename from ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index f761391eae..450dee0408 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -11,7 +11,7 @@ #define EXPERT_COUNT 8 #endif -#include "types.comp" +#include "types.glsl" #ifndef MMQ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; @@ -32,7 +32,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; layout (binding = 3) readonly buffer IDS {int data_ids[];}; #endif -#include "dequant_funcs.comp" +#include "dequant_funcs.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp index e4acbd4f96..4cb292380c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp index 309da0991a..0b74b33212 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp index 8d01536fa6..e424af12c5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp index c496043241..0cd906dbbf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp index 94d4b92e1e..71bd72d17e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp index f021e40476..a4b9ab1f94 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp index 3fe9dc3a41..40849c691f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 423ceb8a3d..03ed25d3bf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index e91724a28d..528f224d86 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -1,7 +1,7 @@ #version 450 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index f9cde06488..21d07d2e50 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 6c84ef3cde..9e46c89a11 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index d53d9ee0a2..d7a7f6426e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp index 8fb314fa0a..64293f6eca 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -6,13 +6,13 @@ #define MMQ #define B_TYPE block_q8_1_x4 -#include "mul_mat_vec_base.comp" +#include "mul_mat_vec_base.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; #define K_PER_ITER 8 -#include "mul_mmq_funcs.comp" +#include "mul_mmq_funcs.glsl" uint a_offset, b_offset, d_offset; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 3cb24412d5..85400ac5fc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -28,7 +28,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #endif -#include "types.comp" +#include "types.glsl" #ifndef LOAD_VEC_A #define LOAD_VEC_A 1 @@ -195,7 +195,7 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif -#include "mul_mm_funcs.comp" +#include "mul_mm_funcs.glsl" void main() { #ifdef NEEDS_INIT_IQ_SHMEM diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 0e3065e014..2e04baa44e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -18,8 +18,8 @@ #extension GL_EXT_bfloat16 : enable #endif -#include "types.comp" -#include "utils.comp" +#include "types.glsl" +#include "utils.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; @@ -71,7 +71,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #if QUANT_K > 1 #define DECODEFUNCA , dequantFuncA -#include "dequant_funcs_cm2.comp" +#include "dequant_funcs_cm2.glsl" #else #define DECODEFUNCA diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index f36add62a9..b5d761c0ba 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -20,7 +20,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #endif -#include "types.comp" +#include "types.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; @@ -110,7 +110,7 @@ shared u16vec2 row_ids[4096]; shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif -#include "mul_mmq_funcs.comp" +#include "mul_mmq_funcs.glsl" void main() { #ifdef NEEDS_INIT_IQ_SHMEM diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl similarity index 99% rename from ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index cdfb230f4e..fe71eb131c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -2,7 +2,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require -#include "types.comp" +#include "types.glsl" // Each iqs value maps to a 32-bit integer diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp index 854a2ad818..1e8f694a72 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp @@ -8,9 +8,9 @@ #extension GL_KHR_shader_subgroup_basic : enable #endif -#include "rte.comp" -#include "types.comp" -#include "utils.comp" +#include "rte.glsl" +#include "types.glsl" +#include "utils.glsl" layout (push_constant) uniform parameter2 { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp index 6627a50bd9..cc3ea0b760 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp index e0214fe764..1f05f922cc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp index 6426dedee5..1251f9cc64 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp @@ -1,6 +1,6 @@ #version 450 -#include "generic_head.comp" +#include "generic_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp index 0d81220c71..f3c8176872 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp @@ -1,6 +1,6 @@ #version 450 -#include "types.comp" +#include "types.glsl" layout (push_constant) uniform parameter { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp index b6124411a0..d9d7166e36 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp @@ -1,6 +1,6 @@ #version 450 -#include "types.comp" +#include "types.glsl" #extension GL_EXT_shader_16bit_storage : require diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp index 145c9fbdc9..0f3c6ca871 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp @@ -17,7 +17,7 @@ layout (push_constant) uniform parameter uint ne; } p; -#include "types.comp" +#include "types.glsl" layout(constant_id = 0) const uint GROUP_SIZE = 32; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp index 0073d8f766..86be2669a1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp @@ -1,9 +1,9 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" float op(float a, float b) { return max(a, 0.0f) * b; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp index 4f806270c7..5725cef236 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp b/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp index 1568b141de..8f4b9a8684 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp index d86279934f..87df782944 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index 41197e9301..d5b211ffaa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_binary_head.comp" -#include "types.comp" +#include "generic_binary_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp index 76009f3df6..87707fc149 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp index ba4677c293..4618b2c7e8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_binary_head.comp" -#include "types.comp" +#include "generic_binary_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #extension GL_KHR_shader_subgroup_arithmetic : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp b/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp index b9abe8dedc..68fbd0c7be 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl similarity index 97% rename from ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl index 00e203e73b..50fc1f1e2d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl @@ -1,8 +1,8 @@ -#include "types.comp" +#include "types.glsl" #extension GL_EXT_shader_16bit_storage : require -#include "rte.comp" +#include "rte.glsl" layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index 5808710ccf..111286b498 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -1,6 +1,6 @@ #version 450 -#include "rope_head.comp" +#include "rope_head.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index 366a7b1c47..06e095bef9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -1,6 +1,6 @@ #version 450 -#include "rope_head.comp" +#include "rope_head.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index 9643bca96a..6ba9575409 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -1,6 +1,6 @@ #version 450 -#include "rope_head.comp" +#include "rope_head.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp index cedacc4d14..d37d1c1043 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -1,6 +1,6 @@ #version 450 -#include "rope_head.comp" +#include "rope_head.glsl" void main() { const uint i0 = 2*gl_GlobalInvocationID.y; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/rte.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp index f10b0a02b5..35ec726a01 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" const uint num_threads = 128; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp index 5c9e5c3503..32298d43c6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp index 4d36f88e08..7d1cc6f45a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp index f9afa9b13c..e5d949ff18 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp index d7c15a1695..61f17b2f00 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp index 5f20a1ee7d..dca0d896bc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp @@ -23,7 +23,7 @@ layout (push_constant) uniform parameter uint has_sinks; } p; -#include "types.comp" +#include "types.glsl" layout(constant_id = 0) const uint BLOCK_SIZE = 32; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp index 144ea58e6f..d873332eeb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp @@ -2,8 +2,8 @@ #extension GL_EXT_control_flow_attributes : enable -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" layout(constant_id = 0) const uint BLOCK_SIZE = 32; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp index 4bc697b9b9..70daad6c5d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/square.comp b/ggml/src/ggml-vulkan/vulkan-shaders/square.comp index ef43598baf..4eb56afcb1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/square.comp @@ -1,7 +1,7 @@ #version 450 -#include "types.comp" -#include "generic_unary_head.comp" +#include "types.glsl" +#include "generic_unary_head.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp index 72353cc329..bc924b520a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp @@ -2,8 +2,8 @@ #extension GL_EXT_shader_16bit_storage : require -#include "types.comp" -#include "generic_binary_head.comp" +#include "types.glsl" +#include "generic_binary_head.glsl" const uint num_threads = 256; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp index 759204afaf..bc22aa7bd7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp @@ -1,6 +1,6 @@ #version 450 -#include "types.comp" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp index a28e7c6cc8..4fee433a12 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp @@ -1,9 +1,9 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" float op(float a, float b) { return a / (1.0f + exp(-a)) * b; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp index 970750eec0..bda9dea21c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp @@ -1,6 +1,6 @@ #version 450 -#include "glu_head.comp" +#include "glu_head.glsl" float op(float a, float b) { float xi = min(a, p.limit); @@ -11,4 +11,4 @@ float op(float a, float b) { return out_glu; } -#include "glu_main.comp" +#include "glu_main.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp b/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp index 8a6f868f58..7b5eb413bf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp @@ -1,7 +1,7 @@ #version 450 -#include "generic_head.comp" -#include "types.comp" +#include "generic_head.glsl" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp b/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp index ce8e09442d..1605565457 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp @@ -9,7 +9,7 @@ layout (push_constant) uniform parameter uint max_period; } p; -#include "types.comp" +#include "types.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 256 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/types.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/types.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp index 74771def0f..154a2172d8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp @@ -9,7 +9,7 @@ layout (push_constant) uniform parameter float sf0; float sf1; float sf2; float sf3; } p; -#include "types.comp" +#include "types.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp b/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/utils.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 84bb9df9a0..f0cc24ff31 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -1,5 +1,3 @@ - - #include #include #include @@ -22,6 +20,7 @@ #include #ifdef _WIN32 + #define NOMINMAX #include #include // For _mkdir on Windows #else @@ -34,13 +33,13 @@ std::mutex lock; std::vector> shader_fnames; +std::locale c_locale("C"); std::string GLSLC = "glslc"; -std::string input_dir = "vulkan-shaders"; +std::string input_filepath = ""; std::string output_dir = "/tmp"; -std::string target_hpp = "ggml-vulkan-shaders.hpp"; -std::string target_cpp = "ggml-vulkan-shaders.cpp"; -bool no_clean = false; +std::string target_hpp = ""; +std::string target_cpp = ""; const std::vector type_names = { "f32", @@ -75,6 +74,7 @@ enum MatMulIdType { }; namespace { + void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { #ifdef _WIN32 HANDLE stdout_read, stdout_write; @@ -232,16 +232,87 @@ std::string basename(const std::string &path) { return path.substr(path.find_last_of("/\\") + 1); } +std::stringstream make_generic_stringstream() { + std::stringstream ss; + ss.imbue(c_locale); + return ss; +} + +std::string read_binary_file(const std::string& path, bool may_not_exist = false) { + FILE* f = fopen(path.c_str(), "rb"); + if (!f) { + if (!may_not_exist) { + std::cerr << "Error opening file: " << path << " (" << strerror(errno) << ")\n"; + } + return {}; + } + + fseek(f, 0, SEEK_END); + size_t size = ftell(f); + fseek(f, 0, SEEK_SET); + + std::string data(size, '\0'); + size_t read_size = fread(data.data(), 1, size, f); + fclose(f); + if (read_size != size) { + std::cerr << "Error reading file: " << path << " (" << strerror(errno) << ")\n"; + return {}; + } + + return data; +} + +void write_binary_file(const std::string& path, const std::string& content) { + FILE* f = fopen(path.c_str(), "wb"); + if (!f) { + std::cerr << "Error opening file for writing: " << path << " (" << strerror(errno) << ")\n"; + return; + } + + size_t write_size = fwrite(content.data(), 1, content.size(), f); + fclose(f); + if (write_size != content.size()) { + std::cerr << "Error writing file: " << path << " (" << strerror(errno) << ")\n"; + return; + } +} + +void write_file_if_changed(const std::string& path, const std::string& content) { + std::string existing = read_binary_file(path, true); + if (existing != content) { + write_binary_file(path, content); + } +} + + // variables to track number of compiles in progress static uint32_t compile_count = 0; static std::mutex compile_count_mutex; static std::condition_variable compile_count_cond; +static bool generate_dep_file = true; -void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { - std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); - std::string out_fname = join_paths(output_dir, name + ".spv"); - std::string in_path = join_paths(input_dir, in_fname); +void decrement_compile_count(uint32_t * count) { + if (count) { + std::lock_guard guard(compile_count_mutex); + assert(compile_count > 0); + compile_count--; + compile_count_cond.notify_all(); + } +} +using compile_count_guard = std::unique_ptr; + +compile_count_guard acquire_compile_slot() { + // wait until fewer than N compiles are in progress. + // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. + uint32_t N = std::max(1u, std::min(16u, std::thread::hardware_concurrency())); + std::unique_lock guard(compile_count_mutex); + compile_count_cond.wait(guard, [N] { return compile_count < N; }); + compile_count++; + return compile_count_guard(&compile_count, &decrement_compile_count); +} + +void string_to_spv_func(std::string name, std::string in_path, std::string out_path, std::map defines, bool coopmat, bool dep_file, compile_count_guard slot) { std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2"; // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734 @@ -249,11 +320,17 @@ void string_to_spv_func(const std::string& _name, const std::string& in_fname, c std::string opt_level = (coopmat || name.find("bf16") != std::string::npos) ? "" : "-O"; #ifdef _WIN32 - std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""}; + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_path + "\""}; #else - std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_fname}; + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_path}; #endif + if (dep_file) { + cmd.push_back("-MD"); + cmd.push_back("-MF"); + cmd.push_back("\"" + target_cpp + ".d\""); + } + #ifdef GGML_VULKAN_SHADER_DEBUG_INFO cmd.push_back("-g"); #endif @@ -281,17 +358,23 @@ void string_to_spv_func(const std::string& _name, const std::string& in_fname, c return; } + if (dep_file) { + // replace .spv output path with the embed .cpp path which is used as output in CMakeLists.txt + std::string dep = read_binary_file(target_cpp + ".d", true); + if (!dep.empty()) { + size_t pos = dep.find(out_path); + if (pos != std::string::npos) { + dep.replace(pos, out_path.length(), target_cpp); + } + write_binary_file(target_cpp + ".d", dep); + } + } + std::lock_guard guard(lock); - shader_fnames.push_back(std::make_pair(name, out_fname)); + shader_fnames.push_back(std::make_pair(name, out_path)); } catch (const std::exception& e) { std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl; } - { - std::lock_guard guard(compile_count_mutex); - assert(compile_count > 0); - compile_count--; - } - compile_count_cond.notify_all(); } std::map merge_maps(const std::map& a, const std::map& b) { @@ -301,18 +384,24 @@ std::map merge_maps(const std::map> compiles; -void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { - { - // wait until fewer than N compiles are in progress. - // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. - uint32_t N = 16; - std::unique_lock guard(compile_count_mutex); - while (compile_count >= N) { - compile_count_cond.wait(guard); - } - compile_count++; +void string_to_spv(std::string name, const std::string& source, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { + name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); + std::string out_path = join_paths(output_dir, name + ".spv"); + + if (input_filepath == "") { + // No input source to compile, only generate header for all shaders + shader_fnames.push_back(std::pair(name, out_path)); + return; + } else if (basename(input_filepath) != source) { + // Only compile shader variants matching the input filename + return; } - compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc)); + + compile_count_guard slot = acquire_compile_slot(); + compiles.push_back(std::async( + string_to_spv_func, name, input_filepath, out_path, defines, coopmat, generate_dep_file, std::move(slot))); + // Don't write the same dep file from multiple processes + generate_dep_file = false; } void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) { @@ -485,7 +574,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } void process_shaders() { - std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl; std::map base_dict = {{"FLOAT_TYPE", "float"}}; // matmul @@ -837,11 +925,11 @@ void process_shaders() { } void write_output_files() { - FILE* hdr = fopen(target_hpp.c_str(), "w"); - FILE* src = fopen(target_cpp.c_str(), "w"); + std::stringstream hdr = make_generic_stringstream(); + std::stringstream src = make_generic_stringstream(); - fprintf(hdr, "#include \n\n"); - fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str()); + hdr << "#include \n\n"; + src << "#include \"" << basename(target_hpp) << "\"\n\n"; std::sort(shader_fnames.begin(), shader_fnames.end()); for (const auto& pair : shader_fnames) { @@ -853,91 +941,85 @@ void write_output_files() { const std::string& path = pair.second; #endif - FILE* spv = fopen(path.c_str(), "rb"); - if (!spv) { - std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; - continue; - } + hdr << "extern const uint64_t " << name << "_len;\n"; + hdr << "extern const unsigned char " << name << "_data[];\n\n"; - fseek(spv, 0, SEEK_END); - size_t size = ftell(spv); - fseek(spv, 0, SEEK_SET); + if (input_filepath != "") { + std::string data = read_binary_file(path); + if (data.empty()) { + continue; + } - std::vector data(size); - size_t read_size = fread(data.data(), 1, size, spv); - fclose(spv); - if (read_size != size) { - std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; - continue; - } - - fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size); - fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size); - - fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size); - for (size_t i = 0; i < size; ++i) { - fprintf(src, "0x%02x,", data[i]); - if ((i + 1) % 12 == 0) fprintf(src, "\n"); - } - fprintf(src, "\n};\n\n"); - - if (!no_clean) { - std::remove(path.c_str()); + src << "const uint64_t " << name << "_len = " << data.size() << ";\n"; + src << "const unsigned char " << name << "_data[" << data.size() << "] = {\n" << std::hex; + auto bytes = reinterpret_cast(data.data()); + for (size_t i = 0; i < data.size(); ++i) { + src << "0x" << static_cast(bytes[i]) << ","; + if ((i + 1) % 12 == 0) src << "\n"; + } + src << std::dec << "\n};\n\n"; } } std::string suffixes[2] = {"_f32", "_f16"}; - for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) { - fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op); - fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op); - std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = "; - std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = "; + for (auto op : {"add", "sub", "mul", "div", "add_rms"}) { + hdr << "extern const void * " << op << "_data[2][2][2][2];\n"; + hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n"; + + std::string op_file = op == "add_rms" ? "add.comp" : std::string(op) + ".comp"; + if (basename(input_filepath) != op_file) { + continue; + } + std::stringstream data = make_generic_stringstream(); + std::stringstream len = make_generic_stringstream(); + data << "const void * " << op << "_data[2][2][2][2] = "; + len << "const uint64_t " << op << "_len[2][2][2][2] = "; for (uint32_t t0 = 0; t0 < 2; ++t0) { if (t0 == 0) { - data += "{"; - len += "{"; + data << "{"; + len << "{"; } for (uint32_t t1 = 0; t1 < 2; ++t1) { if (t1 == 0) { - data += "{"; - len += "{"; + data << "{"; + len << "{"; } for (uint32_t t2 = 0; t2 < 2; ++t2) { if (t2 == 0) { - data += "{"; - len += "{"; + data << "{"; + len << "{"; } for (uint32_t rte = 0; rte < 2; ++rte) { if (rte == 0) { - data += "{"; - len += "{"; + data << "{"; + len << "{"; } - data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : ""); - len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : ""); - data += "_data,"; - len += "_len,"; + data << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); + len << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); + data << "_data,"; + len << "_len,"; if (rte == 1) { - data += "}, "; - len += "}, "; + data << "}, "; + len << "}, "; } } if (t2 == 1) { - data += "}, "; - len += "}, "; + data << "}, "; + len << "}, "; } } if (t1 == 1) { - data += "}, "; - len += "}, "; + data << "}, "; + len << "}, "; } } if (t0 == 1) { - data += "};\n"; - len += "};\n"; + data << "};\n"; + len << "};\n"; } } - fputs(data.c_str(), src); - fputs(len.c_str(), src); + src << data.str(); + src << len.str(); } std::vector btypes = {"f16", "f32"}; @@ -951,20 +1033,25 @@ void write_output_files() { if (btype == "q8_1" && !is_legacy_quant(tname)) { continue; } - fprintf(hdr, "extern unsigned char *arr_dmmv_%s_%s_f32_data[3];\n", tname.c_str(), btype.c_str()); - fprintf(hdr, "extern uint64_t arr_dmmv_%s_%s_f32_len[3];\n", tname.c_str(), btype.c_str()); - std::string data = "unsigned char *arr_dmmv_" + tname + "_" + btype + "_f32_data[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_data};\n"; - std::string len = "uint64_t arr_dmmv_" + tname + "_" + btype + "_f32_len[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_len};\n"; - fputs(data.c_str(), src); - fputs(len.c_str(), src); + hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n"; + hdr << "extern const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3];\n"; + if (basename(input_filepath) == "mul_mat_vec.comp") { + src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n"; + src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n"; + } } } - fclose(hdr); - fclose(src); -} + if (input_filepath == "") { + write_file_if_changed(target_hpp, hdr.str()); + } + if (target_cpp != "") { + write_binary_file(target_cpp, src.str()); + } } +} // namespace + int main(int argc, char** argv) { std::map args; for (int i = 1; i < argc; ++i) { @@ -982,8 +1069,8 @@ int main(int argc, char** argv) { if (args.find("--glslc") != args.end()) { GLSLC = args["--glslc"]; // Path to glslc } - if (args.find("--input-dir") != args.end()) { - input_dir = args["--input-dir"]; // Directory containing shader sources + if (args.find("--source") != args.end()) { + input_filepath = args["--source"]; // The shader source file to compile } if (args.find("--output-dir") != args.end()) { output_dir = args["--output-dir"]; // Directory for containing SPIR-V output @@ -994,14 +1081,6 @@ int main(int argc, char** argv) { if (args.find("--target-cpp") != args.end()) { target_cpp = args["--target-cpp"]; // Path to generated cpp file } - if (args.find("--no-clean") != args.end()) { - no_clean = true; // Keep temporary SPIR-V files in output-dir after build - } - - if (!directory_exists(input_dir)) { - std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl; - return EXIT_FAILURE; - } if (!directory_exists(output_dir)) { if (!create_directory(output_dir)) { diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt index 78a985a4d1..c6a95d5151 100644 --- a/ggml/src/ggml-webgpu/CMakeLists.txt +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -50,5 +50,13 @@ if (GGML_WEBGPU_DEBUG) target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1) endif() +if (GGML_WEBGPU_CPU_PROFILE) + target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_CPU_PROFILE=1) +endif() + +if (GGML_WEBGPU_GPU_PROFILE) + target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_GPU_PROFILE=1) +endif() + target_include_directories(ggml-webgpu PRIVATE ${SHADER_OUTPUT_DIR}) target_link_libraries(ggml-webgpu PRIVATE ${DawnWebGPU_TARGET}) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 2d1b6bd518..57d085b31d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -11,10 +11,12 @@ #include +#include #include #include #include #include +#include #include #include @@ -25,12 +27,45 @@ # define WEBGPU_LOG_DEBUG(msg) ((void) 0) #endif // GGML_WEBGPU_DEBUG +#ifdef GGML_WEBGPU_CPU_PROFILE +// total timing (aggregated) +# define WEBGPU_CPU_PROFILE_TOTAL_START(id) auto cpu_total_start_##id = std::chrono::high_resolution_clock::now(); + +# define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx) \ + auto cpu_total_end_##id = std::chrono::high_resolution_clock::now(); \ + double cpu_total_time_##id = \ + std::chrono::duration(cpu_total_end_##id - cpu_total_start_##id).count(); \ + (ctx)->cpu_time_ms[#id] += cpu_total_time_##id; + +// fine-grained timing (not included in totals) +# define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now(); + +# define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx) \ + auto cpu_detail_end_##id = std::chrono::high_resolution_clock::now(); \ + double cpu_detail_time_##id = \ + std::chrono::duration(cpu_detail_end_##id - cpu_detail_start_##id).count(); \ + (ctx)->cpu_detail_ms[#id] += cpu_detail_time_##id; +#else +# define WEBGPU_CPU_PROFILE_TOTAL_START(id) +# define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx) +# define WEBGPU_CPU_PROFILE_DETAIL_START(id) +# define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx) +#endif // GGML_WEBGPU_CPU_PROFILE + +#ifdef GGML_WEBGPU_GPU_PROFILE +# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24 +# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps +#endif + /* Constants */ -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16 +#define WEBGPU_MUL_MAT_WG_SIZE 256 +#define WEBGPU_NUM_PARAM_BUFS 32u +#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u #define WEBGPU_WAIT_ANY_BATCH_SIZE 64 -#define WEBGPU_MUL_MAT_WG_SIZE 64 -#define WEBGPU_NUM_PARAM_BUFS 100 +#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 +// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool +#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 @@ -66,6 +101,11 @@ struct webgpu_pool_bufs { wgpu::Buffer dev_buf; }; +// The futures to wait on for a single queue submission +struct webgpu_submission_futures { + std::vector futures; +}; + // Holds a pool of parameter buffers for WebGPU operations struct webgpu_buf_pool { std::vector free; @@ -112,6 +152,83 @@ struct webgpu_buf_pool { } }; +#ifdef GGML_WEBGPU_GPU_PROFILE +struct webgpu_gpu_profile_bufs { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; + wgpu::QuerySet query_set; +}; + +// Holds a pool of parameter buffers for WebGPU operations +struct webgpu_gpu_profile_buf_pool { + std::vector free; + + std::mutex mutex; + + std::condition_variable cv; + + void init(wgpu::Device device, + int num_bufs, + size_t buf_size, + wgpu::BufferUsage dev_buf_usage, + wgpu::BufferUsage host_buf_usage) { + for (int i = 0; i < num_bufs; i++) { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; + ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf"); + ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf"); + // Create a query set for 2 timestamps + wgpu::QuerySetDescriptor ts_query_set_desc = {}; + + ts_query_set_desc.type = wgpu::QueryType::Timestamp; + ts_query_set_desc.count = 2; + wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc); + + free.push_back({ host_buf, dev_buf, ts_query_set }); + } + } + + webgpu_gpu_profile_bufs alloc_bufs() { + std::unique_lock lock(mutex); + cv.wait(lock, [this] { return !free.empty(); }); + webgpu_gpu_profile_bufs bufs = free.back(); + free.pop_back(); + return bufs; + } + + void free_bufs(std::vector bufs) { + std::lock_guard lock(mutex); + free.insert(free.end(), bufs.begin(), bufs.end()); + cv.notify_all(); + } + + void cleanup() { + std::lock_guard lock(mutex); + for (auto & bufs : free) { + bufs.host_buf.Destroy(); + bufs.dev_buf.Destroy(); + bufs.query_set.Destroy(); + } + free.clear(); + } +}; +#endif + +struct webgpu_pipeline { + wgpu::ComputePipeline pipeline; + std::string name; +}; + +struct webgpu_command { + wgpu::CommandBuffer commands; + webgpu_pool_bufs params_bufs; + std::optional set_rows_error_bufs; +#ifdef GGML_WEBGPU_GPU_PROFILE + webgpu_gpu_profile_bufs timestamp_query_bufs; + std::string pipeline_name; +#endif +}; + // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { wgpu::Instance instance; @@ -125,63 +242,52 @@ struct webgpu_context_struct { uint32_t max_wg_size_x; std::recursive_mutex mutex; + std::atomic_uint inflight_threads = 0; webgpu_buf_pool param_buf_pool; webgpu_buf_pool set_rows_error_buf_pool; - wgpu::ComputePipeline memset_pipeline; - wgpu::ComputePipeline mul_mat_pipeline[30][2]; - wgpu::ComputePipeline set_rows_pipeline; - wgpu::ComputePipeline get_rows_pipeline[30]; - wgpu::ComputePipeline get_rows_f32_no_vec_pipeline; - wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type - wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace - wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace - wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split - wgpu::ComputePipeline scale_pipeline[2]; // inplace + webgpu_pipeline memset_pipeline; + webgpu_pipeline mul_mat_pipeline[30][2]; + webgpu_pipeline set_rows_pipeline; + webgpu_pipeline get_rows_pipeline[30]; + webgpu_pipeline get_rows_f32_no_vec_pipeline; + webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type + webgpu_pipeline add_pipeline[2][2]; // type, inplace + webgpu_pipeline sub_pipeline[2][2]; // type, inplace + webgpu_pipeline mul_pipeline[2][2]; // type, inplace + webgpu_pipeline div_pipeline[2][2]; // type, inplace + webgpu_pipeline rms_norm_pipeline[2]; // inplace + webgpu_pipeline rope_pipeline[2][2][2]; // type, ff, inplace + webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split + webgpu_pipeline scale_pipeline[2]; // inplace wgpu::ComputePipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace wgpu::ComputePipeline unary_pipeline[16][2][2]; -/* wgpu::ComputePipeline abs_pipeline[2][2]; // abs - wgpu::ComputePipeline sgn_pipeline[2][2]; // sgn - wgpu::ComputePipeline neg_pipeline[2][2]; // neg - wgpu::ComputePipeline step_pipeline[2][2]; // step - wgpu::ComputePipeline tanh_pipeline[2][2]; // tanh - wgpu::ComputePipeline elu_pipeline[2][2]; // elu - wgpu::ComputePipeline relu_pipeline[2][2]; // relu - wgpu::ComputePipeline sigmoid_pipeline[2][2]; // sigmoid - wgpu::ComputePipeline gelu_pipeline[2][2]; // gelu - wgpu::ComputePipeline gelu_quick_pipeline[2][2]; // gelu_quick - wgpu::ComputePipeline silu_pipeline[2][2]; // silu (a.k.a. swish) - wgpu::ComputePipeline hardswish_pipeline[2][2]; // hardswish - wgpu::ComputePipeline hardsigmoid_pipeline[2][2]; // hardsigmoid - wgpu::ComputePipeline exp_pipeline[2][2]; // exp - wgpu::ComputePipeline gelu_erf_pipeline[2][2]; // gelu_erf */ - size_t memset_bytes_per_thread; // Staging buffer for reading data from the GPU wgpu::Buffer get_tensor_staging_buf; - // Command buffers which need to be submitted - std::vector staged_command_bufs; - - // Parameter buffers associated with the staged command buffers - std::vector staged_param_bufs; - // Buffers associated with set_rows operations, used to store potential errors - std::vector staged_set_row_error_bufs; - - std::vector callback_futures; - #ifdef GGML_WEBGPU_DEBUG wgpu::Buffer debug_host_buf; wgpu::Buffer debug_dev_buf; #endif + +#ifdef GGML_WEBGPU_CPU_PROFILE + // Profiling: labeled CPU time in ms (total) + std::unordered_map cpu_time_ms; + // Profiling: detailed CPU time in ms + std::unordered_map cpu_detail_ms; +#endif + +#ifdef GGML_WEBGPU_GPU_PROFILE + // Profiling: per-shader GPU time in ms + std::unordered_map shader_gpu_time_ms; + // Profiling: pool of timestamp query buffers (one per operation) + webgpu_gpu_profile_buf_pool timestamp_query_buf_pool; +#endif }; typedef std::shared_ptr webgpu_context; @@ -217,12 +323,10 @@ struct ggml_backend_webgpu_buffer_context { /* WebGPU object initializations */ static void ggml_webgpu_create_pipeline(wgpu::Device & device, - wgpu::ComputePipeline & pipeline, + webgpu_pipeline & pipeline, const char * shader_code, const char * label, const std::vector & constants = {}) { - WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()"); - wgpu::ShaderSourceWGSL shader_source; shader_source.code = shader_code; @@ -240,7 +344,7 @@ static void ggml_webgpu_create_pipeline(wgpu::Device & pipeline_desc.compute.constants = constants.data(); pipeline_desc.compute.constantCount = constants.size(); } - pipeline = device.CreateComputePipeline(&pipeline_desc); + pipeline = { device.CreateComputePipeline(&pipeline_desc), label }; } static void ggml_webgpu_create_buffer(wgpu::Device & device, @@ -248,8 +352,6 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, size_t size, wgpu::BufferUsage usage, const char * label) { - WEBGPU_LOG_DEBUG("ggml_webgpu_create_buffer()"); - wgpu::BufferDescriptor buffer_desc; buffer_desc.size = size; buffer_desc.usage = usage; @@ -265,85 +367,35 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** WebGPU Actions */ // Wait for the queue to finish processing all submitted work -static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) { - - std::lock_guard lock(ctx->mutex); - if (ctx->callback_futures.empty()) { - // no existing callbacks, wait on queue submission - ctx->instance.WaitAny( - ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, - [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", - std::string(message).c_str()); - } - }), - UINT64_MAX); - } else { - // WebGPU implementations may limit the number of futures that can be waited on at once, - // so wait in batches (64 is what Dawn supports). - for (size_t i = 0; i < ctx->callback_futures.size(); i += WEBGPU_WAIT_ANY_BATCH_SIZE) { - size_t end = std::min(i + WEBGPU_WAIT_ANY_BATCH_SIZE, ctx->callback_futures.size()); - ctx->instance.WaitAny(end - i, ctx->callback_futures.data() + i, UINT64_MAX); +static void ggml_backend_webgpu_wait(webgpu_context & ctx, + std::vector & futures, + bool block = true) { + // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads, + // inflight_max may be 0, meaning that we must wait on all futures. + uint64_t timeout_ms = block ? UINT64_MAX : 0; + uint inflight_threads = ctx->inflight_threads; + uint inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); + while (futures.size() >= inflight_max && futures.size() > 0) { + ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); + futures.erase(futures.begin()); + } + size_t i = 0; + while (i < futures.size()) { + auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms); + switch (waitStatus) { + case wgpu::WaitStatus::Success: + futures.erase(futures.begin() + i); + break; + case wgpu::WaitStatus::TimedOut: + i++; + break; + case wgpu::WaitStatus::Error: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); + break; + default: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); + break; } - ctx->callback_futures.clear(); - } -} - -static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) { - - std::lock_guard lock(ctx->mutex); - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_submit_queue()"); - if (ctx->staged_command_bufs.empty()) { - // Nothing to submit - return; - } - ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data()); - - // If there are SET_ROWS operations in this submission, copy their error buffers to the host. - if (ctx->staged_set_row_error_bufs.size() > 0) { - wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); - for (auto & error_bufs : ctx->staged_set_row_error_bufs) { - // Copy the error buffer to the host buffer - encoder.CopyBufferToBuffer(error_bufs.dev_buf, 0, error_bufs.host_buf, 0, error_bufs.host_buf.GetSize()); - } - wgpu::CommandBuffer commands = encoder.Finish(); - ctx->queue.Submit(1, &commands); - } - - ctx->staged_command_bufs.clear(); - std::vector staged_param_bufs = std::move(ctx->staged_param_bufs); - std::vector staged_set_row_error_bufs = std::move(ctx->staged_set_row_error_bufs); - - // Free the staged parameter buffers once the submission completes - wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, - [ctx, staged_param_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); - } - // Free the staged buffers - ctx->param_buf_pool.free_bufs(staged_param_bufs); - }); - ctx->callback_futures.push_back({ p_f }); - - // Check for errrors in SET_ROWS operations - for (auto & error_bufs : staged_set_row_error_bufs) { - wgpu::Future f = error_bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, error_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, - [ctx, error_bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str()); - } else { - const uint32_t * error_data = (const uint32_t *) error_bufs.host_buf.GetConstMappedRange(); - if (*error_data) { - GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); - } - // We can't unmap in here due to WebGPU reentrancy limitations. - ctx->set_rows_error_buf_pool.free_bufs({ error_bufs }); - } - }); - ctx->callback_futures.push_back({ f }); } } @@ -367,7 +419,6 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx, // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and // debug statements in the shader, and then call this function after encoding the commands and submitting them. static void ggml_backend_webgpu_debug(webgpu_context & ctx) { - ggml_backend_webgpu_submit_queue(ctx); wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); wgpu::CommandBuffer commands = encoder.Finish(); @@ -384,14 +435,85 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) { } #endif -static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & ctx, - wgpu::ComputePipeline & pipeline, - std::vector params, - std::vector bind_group_entries, - uint32_t wg_x, - const char * bind_group_label = nullptr, - bool submit_and_wait = false) { - +static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, std::vector commands) { + std::vector command_buffers; + std::vector params_bufs; + std::vector set_rows_error_bufs; +#ifdef GGML_WEBGPU_GPU_PROFILE + std::vector> pipeline_name_and_ts_bufs; +#endif + + for (const auto & command : commands) { + command_buffers.push_back(command.commands); + params_bufs.push_back(command.params_bufs); + if (command.set_rows_error_bufs) { + set_rows_error_bufs.push_back(command.set_rows_error_bufs.value()); + } + } + ctx->queue.Submit(command_buffers.size(), command_buffers.data()); + + std::vector futures; + + wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, + [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + if (status != wgpu::QueueWorkDoneStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); + } + // Free the staged buffers + ctx->param_buf_pool.free_bufs({ params_bufs }); + }); + futures.push_back({ p_f }); + + for (const auto & bufs : set_rows_error_bufs) { + wgpu::Future f = bufs.host_buf.MapAsync( + wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, + [ctx, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { + if (status != wgpu::MapAsyncStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str()); + } else { + const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange(); + if (*error_data) { + GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); + } + // We can't unmap in here due to WebGPU reentrancy limitations. + ctx->set_rows_error_buf_pool.free_bufs({ bufs }); + } + }); + futures.push_back({ f }); + } + +#ifdef GGML_WEBGPU_GPU_PROFILE + for (const auto & command : commands) { + auto label = command.pipeline_name; + auto ts_bufs = command.timestamp_query_bufs; + + wgpu::Future f = ts_bufs.host_buf.MapAsync( + wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, + [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) { + if (status != wgpu::MapAsyncStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str()); + } else { + const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange(); + // WebGPU timestamps are in ns; convert to ms + double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6; + ctx->shader_gpu_time_ms[label] += elapsed_ms; + // We can't unmap in here due to WebGPU reentrancy limitations. + ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); + } + }); + futures.push_back({ f }); + } +#endif + return { futures }; +} + +static webgpu_command ggml_backend_webgpu_build(webgpu_context & ctx, + webgpu_pipeline & pipeline, + std::vector params, + std::vector bind_group_entries, + uint32_t wg_x, + std::optional set_rows_error_bufs = std::nullopt) { webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize()); @@ -409,44 +531,58 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & .size = params_bufs.dev_buf.GetSize() }); wgpu::BindGroupDescriptor bind_group_desc; - bind_group_desc.layout = pipeline.GetBindGroupLayout(0); + bind_group_desc.layout = pipeline.pipeline.GetBindGroupLayout(0); bind_group_desc.entryCount = bind_group_entries.size(); bind_group_desc.entries = bind_group_entries.data(); - if (bind_group_label) { - bind_group_desc.label = bind_group_label; - } + bind_group_desc.label = pipeline.name.c_str(); wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); + +#ifdef GGML_WEBGPU_GPU_PROFILE + // --- Profiling: GPU timestamp queries --- + // Allocate a timestamp query buffer (2 timestamps: start/end) + webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); + if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { + ts_bufs.host_buf.Unmap(); + } + + wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set, + .beginningOfPassWriteIndex = 0, + .endOfPassWriteIndex = 1 }; + wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc); +#else wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); - pass.SetPipeline(pipeline); +#endif + pass.SetPipeline(pipeline.pipeline); pass.SetBindGroup(0, bind_group); pass.DispatchWorkgroups(wg_x, 1, 1); pass.End(); - wgpu::CommandBuffer commands = encoder.Finish(); - if (submit_and_wait) { - // Submit and wait immediately - ctx->queue.Submit(1, &commands); - ctx->instance.WaitAny(ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, - [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", message.data); - } - ctx->param_buf_pool.free_bufs({ params_bufs }); - }), - UINT64_MAX); - } else { - // Lock the context mutex when pushing to the staging vectors. - std::lock_guard lock(ctx->mutex); - // Enqueue commands and only submit if we have enough staged commands - ctx->staged_command_bufs.push_back(commands); - ctx->staged_param_bufs.push_back(params_bufs); - if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { - ggml_backend_webgpu_submit_queue(ctx); - } + +#ifdef GGML_WEBGPU_GPU_PROFILE + // Resolve the query set into the device buffer + encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); + encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); +#endif + + // If there are SET_ROWS operations in this submission, copy their error buffers to the host. + if (set_rows_error_bufs) { + encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0, + set_rows_error_bufs->host_buf.GetSize()); } + + wgpu::CommandBuffer commands = encoder.Finish(); + webgpu_command result = {}; + result.commands = commands; + result.params_bufs = params_bufs; + result.set_rows_error_bufs = set_rows_error_bufs; +#ifdef GGML_WEBGPU_GPU_PROFILE + result.timestamp_query_bufs = ts_bufs; + result.pipeline_name = pipeline.name; +#endif + return result; } static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx, @@ -460,7 +596,10 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx, }; size_t bytes_per_wg = ctx->max_wg_size_x * ctx->memset_bytes_per_thread; uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, "MEMSET", true); + + webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipeline, params, entries, wg_x); + std::vector futures = { ggml_backend_webgpu_submit(ctx, { command }) }; + ggml_backend_webgpu_wait(ctx, futures); } /** End WebGPU Actions */ @@ -476,8 +615,48 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")"); - // TODO: cleanup +#ifdef GGML_WEBGPU_CPU_PROFILE + std::cout << "\n[ggml_webgpu cpu profiling summary]\n"; + double total_cpu = 0.0; + for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) { + total_cpu += kv.second; + } + std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n"; + std::cout << "ggml_webgpu: cpu breakdown:\n"; + for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) { + double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0; + std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; + } + if (ctx->webgpu_ctx->cpu_detail_ms.size() > 0) { + std::cout << "ggml_webgpu: cpu detailed breakdown:\n"; + } + for (const auto & kv : ctx->webgpu_ctx->cpu_detail_ms) { + double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0; + std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; + } +#endif + +#ifdef GGML_WEBGPU_GPU_PROFILE + std::cout << "\n[ggml_webgpu gpu profiling summary]\n"; + double total_gpu = 0.0; + for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { + total_gpu += kv.second; + } + std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n"; + std::cout << "\nggml_webgpu: gpu breakdown:\n"; + for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { + double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0; + std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; + } +#endif + +#if defined(GGML_WEBGPU_CPU_PROFILE) && defined(GGML_WEBGPU_GPU_PROFILE) + std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n"; +#endif + +#if !defined(GGML_WEBGPU_CPU_PROFILE) && !defined(GGML_WEBGPU_GPU_PROFILE) GGML_UNUSED(ctx); +#endif } static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { @@ -510,10 +689,47 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); } -static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + uint32_t ne = (uint32_t) ggml_nelements(dst); + + std::vector params = { + ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + // Convert byte-strides to element-strides + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + // Logical shapes + (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; + return ggml_backend_webgpu_build(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x); + +} + + +static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { // For set rows specifically, we need to check if src and idx are empty tensors. if (ggml_is_empty(src) || ggml_is_empty(idx)) { - return; + return std::nullopt; } webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs(); @@ -556,13 +772,13 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; - std::lock_guard lock(ctx->mutex); - ctx->staged_set_row_error_bufs.push_back(error_bufs); - - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, ctx->set_rows_pipeline, params, entries, wg_x, error_bufs); } -static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), @@ -597,14 +813,17 @@ static void ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (dst->ne[1] * dst->ne[2] * dst->ne[3] + max_wg_size - 1) / max_wg_size; - wgpu::ComputePipeline pipeline = ctx->get_rows_pipeline[src->type]; + webgpu_pipeline pipeline = ctx->get_rows_pipeline[src->type]; if (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 != 0) { pipeline = ctx->get_rows_f32_no_vec_pipeline; } - ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), @@ -641,44 +860,10 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t uint32_t wg_x = (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x, - ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x); } -static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - uint32_t ne = (uint32_t) ggml_nelements(dst); - - std::vector params = { - ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - // Convert byte-strides to element-strides - (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)), - (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), - (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), - (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), - // Logical shapes - (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0], - (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] - }; - - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; - - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x, - ggml_op_name(dst->op)); -} - -static void ggml_webgpu_unary_op( webgpu_context & ctx, +static webgpu_command ggml_webgpu_unary_op( webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst, wgpu::ComputePipeline & pipeline, @@ -718,16 +903,15 @@ static void ggml_webgpu_unary_op( webgpu_context & ctx, size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); - + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); } -static void ggml_webgpu_binary_op(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst, - wgpu::ComputePipeline & pipeline, - bool inplace) { +static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst, + webgpu_pipeline & pipeline, + bool inplace) { std::vector params = { (uint32_t) ggml_nelements(dst), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -765,10 +949,10 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx, size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { int inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { @@ -800,15 +984,14 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, ggml_nrows(src), - ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, ctx->rms_norm_pipeline[inplace], params, entries, ggml_nrows(src)); } -static void ggml_webgpu_rope(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { const int inplace = ggml_webgpu_tensor_equal(src0, dst); const int has_freq_factor = (src2 != nullptr); @@ -886,13 +1069,13 @@ static void ggml_webgpu_rope(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - wgpu::ComputePipeline pipeline = ctx->rope_pipeline[dst->type][has_freq_factor][inplace]; - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + webgpu_pipeline pipeline = ctx->rope_pipeline[dst->type][has_freq_factor][inplace]; + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size; + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { const int split = (src1 != nullptr); std::vector params = { @@ -939,13 +1122,13 @@ static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tenso .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - wgpu::ComputePipeline pipeline = ctx->glu_pipeline[ggml_get_glu_op(dst)][dst->type][split]; - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + webgpu_pipeline pipeline = ctx->glu_pipeline[ggml_get_glu_op(dst)][dst->type][split]; + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { int inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { @@ -980,15 +1163,14 @@ static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tens size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->scale_pipeline[inplace], params, entries, wg_x, - ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, ctx->scale_pipeline[inplace], params, entries, wg_x); } -static void ggml_webgpu_soft_max(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { const int inplace = ggml_webgpu_tensor_equal(src0, dst); const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here const int has_sink = (src2 != nullptr); @@ -1053,18 +1235,14 @@ static void ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->soft_max_pipeline[mask_type][has_sink][inplace], params, entries, - ggml_nrows(dst), ggml_op_name(dst->op)); + return ggml_backend_webgpu_build(ctx, ctx->soft_max_pipeline[mask_type][has_sink][inplace], params, entries, + ggml_nrows(dst)); } - - - -// Returns true if node has enqueued work into the queue, false otherwise -static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { - +// Returns the encoded command, or std::nullopt if the operation is a no-op +static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { if (ggml_is_empty(node)) { - return false; + return std::nullopt; } WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")"); @@ -1081,56 +1259,46 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: case GGML_OP_RESHAPE: - return false; + return std::nullopt; case GGML_OP_CPY: case GGML_OP_CONT: - ggml_webgpu_cpy(ctx, src0, node); - break; + return ggml_webgpu_cpy(ctx, src0, node); case GGML_OP_SET_ROWS: - ggml_webgpu_set_rows(ctx, src0, src1, node); - break; + return ggml_webgpu_set_rows(ctx, src0, src1, node); case GGML_OP_GET_ROWS: - ggml_webgpu_get_rows(ctx, src0, src1, node); - break; + return ggml_webgpu_get_rows(ctx, src0, src1, node); case GGML_OP_MUL_MAT: - ggml_webgpu_mul_mat(ctx, src0, src1, node); - break; + return ggml_webgpu_mul_mat(ctx, src0, src1, node); case GGML_OP_ADD: { int inplace = ggml_webgpu_tensor_equal(src0, node); - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace); - break; + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace); } case GGML_OP_SUB: { int inplace = ggml_webgpu_tensor_equal(src0, node); - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace); - break; + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace); } case GGML_OP_MUL: { int inplace = ggml_webgpu_tensor_equal(src0, node); - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace); - break; + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace); } case GGML_OP_DIV: { int inplace = ggml_webgpu_tensor_equal(src0, node); - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace); - break; + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace); } case GGML_OP_RMS_NORM: - ggml_webgpu_rms_norm(ctx, src0, node); - break; + return ggml_webgpu_rms_norm(ctx, src0, node); case GGML_OP_ROPE: - ggml_webgpu_rope(ctx, src0, src1, src2, node); - break; + return ggml_webgpu_rope(ctx, src0, src1, src2, node); case GGML_OP_GLU: - ggml_webgpu_glu(ctx, src0, src1, node); - break; + return ggml_webgpu_glu(ctx, src0, src1, node); case GGML_OP_SCALE: - ggml_webgpu_scale(ctx, src0, node); - break; + return ggml_webgpu_scale(ctx, src0, node); + case GGML_OP_SOFT_MAX: + return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); case GGML_OP_UNARY: { const ggml_unary_op UNARY_OP = ggml_get_unary_op(node); @@ -1141,9 +1309,8 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { } default: - return false; - } - return true; + return std::nullopt; + } } static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { @@ -1153,13 +1320,35 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ggml_backend_webgpu_context * backend_ctx = static_cast(backend->context); webgpu_context ctx = backend_ctx->webgpu_ctx; + WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); + + ctx->inflight_threads++; + + std::vector commands; + std::vector futures; for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_webgpu_encode_node(ctx, cgraph->nodes[i]); + if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { + commands.push_back(*cmd); + } + // compute the batch size based on the number of inflight threads + uint inflight_threads = ctx->inflight_threads; + uint batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), + WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); + if (commands.size() >= batch_size) { + futures.push_back(ggml_backend_webgpu_submit(ctx, commands)); + // Process events and check for completed submissions + ctx->instance.ProcessEvents(); + ggml_backend_webgpu_wait(ctx, futures, false); + commands.clear(); + } } - - ggml_backend_webgpu_submit_queue(ctx); - ggml_backend_webgpu_wait_on_submission(ctx); - + if (!commands.empty()) { + webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands); + futures.push_back(new_futures); + } + ggml_backend_webgpu_wait(ctx, futures); + ctx->inflight_threads--; + WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx); return GGML_STATUS_SUCCESS; } @@ -1185,7 +1374,6 @@ static ggml_backend_i ggml_backend_webgpu_i = { /* GGML Backend Buffer Interface */ static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_free_buffer()"); ggml_backend_webgpu_buffer_context * ctx = static_cast(buffer->context); ctx->buffer.Destroy(); } @@ -1206,6 +1394,8 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe return; } + WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor); + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); @@ -1216,6 +1406,7 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe // This is a trick to set all bytes of a u32 to the same 1 byte value. uint32_t val32 = (uint32_t) value * 0x01010101; ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size); + WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->webgpu_ctx); } static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, @@ -1225,6 +1416,7 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, size_t size) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); + WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor); ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; @@ -1247,8 +1439,17 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, remaining_size); } else { // wait for WriteBuffer to complete - ggml_backend_webgpu_wait_on_submission(webgpu_ctx); + webgpu_ctx->instance.WaitAny( + webgpu_ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, + [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + if (status != wgpu::QueueWorkDoneStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", + std::string(message).c_str()); + } + }), + UINT64_MAX); } + WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, webgpu_ctx); } static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, @@ -1258,7 +1459,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, size_t size) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); - + WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor); ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; wgpu::Device device = webgpu_ctx->device; @@ -1298,12 +1499,15 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, // Copy the data from the mapped range to the output buffer std::memcpy(data, mapped_range, size); webgpu_ctx->get_tensor_staging_buf.Unmap(); + WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, webgpu_ctx); } static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")"); + WEBGPU_CPU_PROFILE_TOTAL_START(clear); ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size); + WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->webgpu_ctx); } static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { @@ -1707,160 +1911,6 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { constants); } -static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - - // ABS - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F32][0], - wgsl_abs_f32, "abs_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F16][0], - wgsl_abs_f16, "abs_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F32][1], - wgsl_abs_in_place_f32, "abs_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F16][1], - wgsl_abs_in_place_f16, "abs_in_place_f16", constants); - - // SGN - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F32][0], - wgsl_sgn_f32, "sgn_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F16][0], - wgsl_sgn_f16, "sgn_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F32][1], - wgsl_sgn_in_place_f32, "sgn_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F16][1], - wgsl_sgn_in_place_f16, "sgn_in_place_f16", constants); - - // NEG - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F32][0], - wgsl_neg_f32, "neg_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F16][0], - wgsl_neg_f16, "neg_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F32][1], - wgsl_neg_in_place_f32, "neg_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F16][1], - wgsl_neg_in_place_f16, "neg_in_place_f16", constants); - - // STEP - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F32][0], - wgsl_step_f32, "step_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F16][0], - wgsl_step_f16, "step_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F32][1], - wgsl_step_in_place_f32, "step_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F16][1], - wgsl_step_in_place_f16, "step_in_place_f16", constants); - - // TANH - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F32][0], - wgsl_tanh_f32, "tanh_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F16][0], - wgsl_tanh_f16, "tanh_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F32][1], - wgsl_tanh_in_place_f32, "tanh_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F16][1], - wgsl_tanh_in_place_f16, "tanh_in_place_f16", constants); - - // ELU - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F32][0], - wgsl_elu_f32, "elu_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F16][0], - wgsl_elu_f16, "elu_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F32][1], - wgsl_elu_in_place_f32, "elu_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F16][1], - wgsl_elu_in_place_f16, "elu_in_place_f16", constants); - - // RELU - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F32][0], - wgsl_relu_f32, "relu_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F16][0], - wgsl_relu_f16, "relu_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F32][1], - wgsl_relu_in_place_f32, "relu_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F16][1], - wgsl_relu_in_place_f16, "relu_in_place_f16", constants); - - // SIGMOID - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][0], - wgsl_sigmoid_f32, "sigmoid_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][0], - wgsl_sigmoid_f16, "sigmoid_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][1], - wgsl_sigmoid_in_place_f32, "sigmoid_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][1], - wgsl_sigmoid_in_place_f16, "sigmoid_in_place_f16", constants); - - // GELU - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F32][0], - wgsl_gelu_f32, "gelu_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F16][0], - wgsl_gelu_f16, "gelu_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F32][1], - wgsl_gelu_in_place_f32, "gelu_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F16][1], - wgsl_gelu_in_place_f16, "gelu_in_place_f16", constants); - - // GELU_QUICK - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0], - wgsl_gelu_quick_f32, "gelu_quick_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0], - wgsl_gelu_quick_f16, "gelu_quick_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1], - wgsl_gelu_quick_in_place_f32, "gelu_quick_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1], - wgsl_gelu_quick_in_place_f16, "gelu_quick_in_place_f16", constants); - - // SILU - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F32][0], - wgsl_silu_f32, "silu_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F16][0], - wgsl_silu_f16, "silu_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F32][1], - wgsl_silu_in_place_f32, "silu_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F16][1], - wgsl_silu_in_place_f16, "silu_in_place_f16", constants); - - // HARDSWISH - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0], - wgsl_hardswish_f32, "hardswish_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0], - wgsl_hardswish_f16, "hardswish_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1], - wgsl_hardswish_in_place_f32, "hardswish_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1], - wgsl_hardswish_in_place_f16, "hardswish_in_place_f16", constants); - - // HARDSIGMOID - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0], - wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0], - wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1], - wgsl_hardsigmoid_in_place_f32, "hardsigmoid_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1], - wgsl_hardsigmoid_in_place_f16, "hardsigmoid_in_place_f16", constants); - - // EXP - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F32][0], - wgsl_exp_f32, "exp_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F16][0], - wgsl_exp_f16, "exp_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F32][1], - wgsl_exp_in_place_f32, "exp_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F16][1], - wgsl_exp_in_place_f16, "exp_in_place_f16", constants); - - // GELU_ERF - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0], - wgsl_gelu_erf_f32, "gelu_erf_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0], - wgsl_gelu_erf_f16, "gelu_erf_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1], - wgsl_gelu_erf_in_place_f32, "gelu_erf_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1], - wgsl_gelu_erf_in_place_f16, "gelu_erf_in_place_f16", constants); -} - static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { GGML_UNUSED(params); @@ -2048,6 +2098,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SCALE: supports_op = op->type == GGML_TYPE_F32; break; + case GGML_OP_SOFT_MAX: + supports_op = op->type == GGML_TYPE_F32; + break; case GGML_OP_UNARY: supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) && (src1 ? (src1->type == op->type) : true); @@ -2074,8 +2127,6 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null") << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null")); } - - return supports_op; } @@ -2117,7 +2168,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t GGML_ASSERT(index == 0); WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()"); - + WEBGPU_CPU_PROFILE_TOTAL_START(reg_get_device); ggml_backend_webgpu_reg_context * reg_ctx = static_cast(reg->context); @@ -2145,7 +2196,11 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16, wgpu::FeatureName::ImplicitDeviceSynchronization }; - wgpu::DeviceDescriptor dev_desc; +#ifdef GGML_WEBGPU_GPU_PROFILE + required_features.push_back(wgpu::FeatureName::TimestampQuery); +#endif + + wgpu::DeviceDescriptor dev_desc; dev_desc.requiredLimits = &ctx->limits; dev_desc.requiredFeatures = required_features.data(); dev_desc.requiredFeatureCount = required_features.size(); @@ -2159,8 +2214,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t dev_desc.SetUncapturedErrorCallback( [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { GGML_UNUSED(device); - GGML_LOG_ERROR("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), - std::string(message).c_str()); + GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), + std::string(message).c_str()); }); ctx->instance.WaitAny(ctx->adapter.RequestDevice( &dev_desc, wgpu::CallbackMode::AllowSpontaneous, @@ -2182,6 +2237,15 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + +#ifdef GGML_WEBGPU_GPU_PROFILE + // Initialize buffer pool for timestamp queries (profiling) + ctx->timestamp_query_buf_pool.init(ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, + WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, + wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, + wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst); +#endif + ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); @@ -2200,25 +2264,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_webgpu_init_rope_pipeline(ctx); ggml_webgpu_init_glu_pipeline(ctx); ggml_webgpu_init_scale_pipeline(ctx); + ggml_webgpu_init_soft_max_pipeline(ctx); ggml_webgpu_init_unary_pipeline(ctx); - -/* ggml_webgpu_init_abs_pipeline(ctx); - ggml_webgpu_init_sgn_pipeline(ctx); - ggml_webgpu_init_neg_pipeline(ctx); - ggml_webgpu_init_step_pipeline(ctx); - ggml_webgpu_init_tanh_pipeline(ctx); - ggml_webgpu_init_elu_pipeline(ctx); - ggml_webgpu_init_relu_pipeline(ctx); - ggml_webgpu_init_sigmoid_pipeline(ctx); - ggml_webgpu_init_gelu_pipeline(ctx); - ggml_webgpu_init_gelu_quick_pipeline(ctx); - ggml_webgpu_init_silu_pipeline(ctx); - ggml_webgpu_init_hardswish_pipeline(ctx); - ggml_webgpu_init_hardsigmoid_pipeline(ctx); - ggml_webgpu_init_exp_pipeline(ctx); - ggml_webgpu_init_gelu_erf_pipeline(ctx); */ - #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), @@ -2244,8 +2292,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t /* .reg = */ reg, /* .context = */ &device_ctx, }; - + WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, ctx); return &device; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl index 25e2185de8..141db9b39d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl @@ -870,7 +870,7 @@ struct MulMatParams { @group(0) @binding(3) var params: MulMatParams; -@compute @workgroup_size(64) +@compute @workgroup_size(256) fn main(@builtin(global_invocation_id) global_id: vec3) { let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; if (global_id.x >= total) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl index 4f72bb1c85..7d99fb70de 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl @@ -84,7 +84,9 @@ fn main(@builtin(workgroup_id) wid: vec3, let i2 = i / params.ne1; let i1 = i % params.ne1; let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; - let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + let elems = (params.ne0 + wg_size - 1) / wg_size; let elems = (params.ne0 + wg_size - 1) / wg_size; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl index 64ab576c08..c74dc4cc92 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl @@ -300,6 +300,7 @@ fn main(@builtin(workgroup_id) wid: vec3, workgroupBarrier(); } let row_max = scratch[0]; + workgroupBarrier(); var sum = 0.0f; col = lid.x; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d753c9a18a..2bce1375ba 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1143,10 +1143,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "HARDSIGMOID", "EXP", "GELU_ERF", + "XIELU", }; -static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15"); - +static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16"); static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", @@ -2652,6 +2652,29 @@ struct ggml_tensor * ggml_silu_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU); } +// ggml_xielu + +struct ggml_tensor * ggml_xielu( + struct ggml_context * ctx, + struct ggml_tensor * a, + float alpha_n, + float alpha_p, + float beta, + float eps) { + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU); + ggml_set_op_params_f32(result, 1, beta + ggml_softplus(alpha_n)); + ggml_set_op_params_f32(result, 2, ggml_softplus(alpha_p)); + ggml_set_op_params_f32(result, 3, beta); + ggml_set_op_params_f32(result, 4, eps); + + result->op = GGML_OP_UNARY; + result->src[0] = a; + + return result; +} + // ggml_silu_back struct ggml_tensor * ggml_silu_back( diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 88ea9f32f8..9c99b90faa 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -261,6 +261,7 @@ class Keys: class ClipVision: IMAGE_SIZE = "clip.vision.image_size" + PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size" PATCH_SIZE = "clip.vision.patch_size" EMBEDDING_LENGTH = "clip.vision.embedding_length" FEED_FORWARD_LENGTH = "clip.vision.feed_forward_length" @@ -297,6 +298,13 @@ class Keys: class Diffusion: SHIFT_LOGITS = "diffusion.shift_logits" + class xIELU: + ALPHA_P = "xielu.alpha_p" + ALPHA_N = "xielu.alpha_n" + BETA = "xielu.beta" + EPS = "xielu.eps" + + # # recommended mapping of model tensor names for storage in gguf # @@ -399,12 +407,14 @@ class MODEL_ARCH(IntEnum): SMOLLM3 = auto() GPT_OSS = auto() LFM2 = auto() + LFM2MOE = auto() DREAM = auto() SMALLTHINKER = auto() LLADA = auto() LLADA_MOE = auto() SEED_OSS = auto() GROVEMOE = auto() + APERTUS = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -740,12 +750,14 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.SMOLLM3: "smollm3", MODEL_ARCH.GPT_OSS: "gpt-oss", MODEL_ARCH.LFM2: "lfm2", + MODEL_ARCH.LFM2MOE: "lfm2moe", MODEL_ARCH.DREAM: "dream", MODEL_ARCH.SMALLTHINKER: "smallthinker", MODEL_ARCH.LLADA: "llada", MODEL_ARCH.LLADA_MOE: "llada-moe", MODEL_ARCH.SEED_OSS: "seed_oss", MODEL_ARCH.GROVEMOE: "grovemoe", + MODEL_ARCH.APERTUS: "apertus", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -2688,6 +2700,29 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.OUTPUT, ], + MODEL_ARCH.LFM2MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.SHORTCONV_CONV, + MODEL_TENSOR.SHORTCONV_INPROJ, + MODEL_TENSOR.SHORTCONV_OUTPROJ, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.ATTN_NORM, # operator_norm + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + ], MODEL_ARCH.SMALLTHINKER: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -2706,6 +2741,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.APERTUS: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.LLADA_MOE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 3152a30d7b..dfe4bfd490 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1037,6 +1037,9 @@ class GGUFWriter: def add_vision_image_size(self, value: int) -> None: self.add_uint32(Keys.ClipVision.IMAGE_SIZE, value) + def add_vision_preproc_image_size(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.PREPROC_IMAGE_SIZE, value) + def add_vision_image_mean(self, values: Sequence[float]) -> None: self.add_array(Keys.ClipVision.IMAGE_MEAN, values) @@ -1084,6 +1087,18 @@ class GGUFWriter: def add_audio_stack_factor(self, value: int) -> None: self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value) + def add_xielu_alpha_p(self, values: Sequence[float]): + self.add_array(Keys.xIELU.ALPHA_P, values) + + def add_xielu_alpha_n(self, values: Sequence[float]): + self.add_array(Keys.xIELU.ALPHA_N, values) + + def add_xielu_beta(self, values: Sequence[float]): + self.add_array(Keys.xIELU.BETA, values) + + def add_xielu_eps(self, values: Sequence[float]): + self.add_array(Keys.xIELU.EPS, values) + # diffusion models def add_diffusion_shift_logits(self, value: bool) -> None: diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index c533b55c01..3e9a2dd8f8 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -148,6 +148,7 @@ class TensorNameMap: "model.layers.{bid}.operator_norm", # lfm2 "model.transformer.blocks.{bid}.attn_norm", # llada "layers.{bid}.input_layernorm", # qwen3-embedding + "model.layers.{bid}.attention_layernorm" # apertus ), # Attention norm 2 @@ -325,6 +326,7 @@ class TensorNameMap: "model.layers.layers.{bid}.pre_mlp_norm", # plamo2 "model.transformer.blocks.{bid}.ff_norm", # llada "layers.{bid}.post_attention_layernorm", # qwen3-embedding + "model.layers.{bid}.feedforward_layernorm", # apertus ), # Post feed-forward norm @@ -356,6 +358,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.router", # openai-moe "model.layers.{bid}.mlp.gate.wg", # hunyuan "model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker + "model.layers.{bid}.feed_forward.gate", # lfm2moe ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -365,6 +368,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_EXP_PROBS_B: ( "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1 "model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe + "model.layers.{bid}.feed_forward.expert_bias", # lfm2moe ), # Feed-forward up @@ -547,6 +551,7 @@ class TensorNameMap: "transformer.layers.{bid}.attn.q_norm", # openelm "model.layers.layers.{bid}.mixer.q", # plamo2 "layers.{bid}.self_attn.q_norm", # qwen3-embedding + "model.layers.{bid}.attention.query_layernorm", # apertus ), MODEL_TENSOR.ATTN_K_NORM: ( @@ -560,6 +565,7 @@ class TensorNameMap: "transformer.layers.{bid}.attn.k_norm", # openelm "model.layers.layers.{bid}.mixer.k", # plamo2 "layers.{bid}.self_attn.k_norm", # qwen3-embedding + "model.layers.{bid}.attention.key_layernorm", # apertus ), MODEL_TENSOR.ROPE_FREQS: ( diff --git a/include/llama.h b/include/llama.h index 452d9ec5bf..a0a660bff8 100644 --- a/include/llama.h +++ b/include/llama.h @@ -296,6 +296,7 @@ extern "C" { bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data bool use_extra_bufts; // use extra buffer types (used for weight repacking) + bool no_host; // bypass host buffer allowing extra buffers to be used }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations @@ -543,6 +544,9 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns true if the model is hybrid (like Jamba, Granite, etc.) + LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); + // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); @@ -791,8 +795,12 @@ extern "C" { size_t n_token_capacity, size_t * n_token_count_out); +// for backwards-compat #define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 +// work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba) +#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 + typedef uint32_t llama_state_seq_flags; LLAMA_API size_t llama_state_seq_get_size_ext( diff --git a/models/templates/Apertus-8B-Instruct.jinja b/models/templates/Apertus-8B-Instruct.jinja new file mode 100644 index 0000000000..10826ff690 --- /dev/null +++ b/models/templates/Apertus-8B-Instruct.jinja @@ -0,0 +1,327 @@ +{%- macro render_typescript_type(param_spec, required_params, is_nullable=false) -%} + {%- if param_spec.type == "array" -%} + {%- if param_spec['items'] -%} + {%- if param_spec['items']['type'] == "string" -%} + {{- "string[]" }} + {%- elif param_spec['items']['type'] == "number" -%} + {{- "number[]" }} + {%- elif param_spec['items']['type'] == "integer" -%} + {{- "number[]" }} + {%- elif param_spec['items']['type'] == "boolean" -%} + {{- "boolean[]" }} + {%- else -%} + {%- set inner_type = render_typescript_type(param_spec['items'], required_params) -%} + {%- if inner_type == "object | object" or inner_type|length > 50 -%} + {{- "any[]" }} + {%- else -%} + {{- inner_type + "[]" }} + {%- endif -%} + {%- endif -%} + {%- if param_spec.nullable -%} + {{- " | null" }} + {%- endif -%} + {%- else -%} + {{- "any[]" }} + {%- if param_spec.nullable -%} + {{- " | null" }} + {%- endif -%} + {%- endif -%} + {%- elif param_spec.type is defined and param_spec.type is iterable and param_spec.type is not string and param_spec.type is not mapping and param_spec.type[0] is defined -%} + {#- Handle array of types like ["object", "object"] from Union[dict, list] #} + {%- if param_spec.type | length > 1 -%} + {{- param_spec.type | join(" | ") }} + {%- else -%} + {{- param_spec.type[0] }} + {%- endif -%} + {%- elif param_spec.oneOf -%} + {#- Handle oneOf schemas - check for complex unions and fallback to any #} + {%- set has_object_variants = false -%} + {%- for variant in param_spec.oneOf -%} + {%- if variant.type == "object" -%} + {%- set has_object_variants = true -%} + {%- endif -%} + {%- endfor -%} + {%- if has_object_variants and param_spec.oneOf|length > 1 -%} + {{- "any" }} + {%- else -%} + {%- for variant in param_spec.oneOf -%} + {{- render_typescript_type(variant, required_params) -}} + {%- if variant.description %} + {{- "// " + variant.description }} + {%- endif -%} + {%- if variant.default is defined %} + {{ "// default: " + variant.default|tojson }} + {%- endif -%} + {%- if not loop.last %} + {{- " | " }} + {% endif -%} + {%- endfor -%} + {%- endif -%} + {%- elif param_spec.type == "string" -%} + {%- if param_spec.enum -%} + {{- '"' + param_spec.enum|join('" | "') + '"' -}} + {%- else -%} + {{- "string" }} + {%- if param_spec.nullable %} + {{- " | null" }} + {%- endif -%} + {%- endif -%} + {%- elif param_spec.type == "number" -%} + {{- "number" }} + {%- elif param_spec.type == "integer" -%} + {{- "number" }} + {%- elif param_spec.type == "boolean" -%} + {{- "boolean" }} + {%- elif param_spec.type == "object" -%} + {%- if param_spec.properties -%} + {{- "{\n" }} + {%- for prop_name, prop_spec in param_spec.properties.items() -%} + {{- prop_name -}} + {%- if prop_name not in (param_spec.required or []) -%} + {{- "?" }} + {%- endif -%} + {{- ": " }} + {{ render_typescript_type(prop_spec, param_spec.required or []) }} + {%- if not loop.last -%} + {{-", " }} + {%- endif -%} + {%- endfor -%} + {{- "}" }} + {%- else -%} + {{- "object" }} + {%- endif -%} + {%- else -%} + {{- "any" }} + {%- endif -%} +{%- endmacro -%} + +{%- macro render_tools(tools) -%} + {%- for tool in tools %} + {{- "// " + tool.description + "\n" }} + {{- "type "+ tool.name + " = " }} + {%- if tool.parameters and tool.parameters.properties %} + {{- "(_: {\n" }} + {%- for param_name, param_spec in tool.parameters.properties.items() %} + {%- if param_spec.description %} + {{- "// " + param_spec.description + "\n" }} + {%- endif %} + {{- param_name }} + {%- if param_name not in (tool.parameters.required or []) -%} + {{- "?" }} + {%- endif -%} + {{- ": " }} + {{- render_typescript_type(param_spec, tool.parameters.required or []) }} + {%- if param_spec.default is defined -%} + {%- if param_spec.enum %} + {{- ", // default: " + param_spec.default }} + {%- elif param_spec.oneOf %} + {{- "// default: " + param_spec.default }} + {%- else %} + {{- ", // default: " + param_spec.default|tojson }} + {%- endif -%} + {%- endif -%} + {%- if not loop.last %} + {{- ",\n" }} + {%- else %} + {{- "\n" }} + {%- endif -%} + {%- endfor %} + {{- "}) => any;" }} + {%- else -%} + {{- "() => any;" }} + {%- endif -%} + {%- if not loop.last -%} + {{- "\n" }} + {%- endif -%} + {%- endfor %} +{%- endmacro -%} + +{{ bos_token }} + +{%- set system_token = '<|system_start|>' -%} +{%- set end_system_token = '<|system_end|>' -%} +{%- set developer_token = '<|developer_start|>' -%} +{%- set end_developer_token = '<|developer_end|>' -%} +{%- set user_token = '<|user_start|>' -%} +{%- set end_user_token = '<|user_end|>' -%} +{%- set assistant_token = '<|assistant_start|>' -%} +{%- set end_assistant_token = '<|assistant_end|>' -%} +{%- set inner_token = '<|inner_prefix|>' -%} +{%- set outer_token = '<|inner_suffix|>' -%} +{%- set tool_calls_token = '<|tools_prefix|>' -%} +{%- set end_tool_calls_token = '<|tools_suffix|>' -%} + +{%- set ns = namespace(in_assistant=false, in_tool=false, in_inner=false, assistant_format=none) -%} + +{%- if messages and messages[0].role == 'system' -%} + {%- if "content" in messages[0] -%} + {%- if messages[0].content is string -%} + {{ system_token + messages[0].content + end_system_token }} + {%- elif messages[0].content is mapping and "text" in messages[0].content -%} + {{ system_token + messages[0].content.text + end_system_token }} + {%- else -%} + {{- raise_exception("Invalid system message") -}} + {%- endif -%} + {%- else -%} + {{- raise_exception("Invalid system message") -}} + {%- endif -%} + {%- set loop_messages = messages[1:] -%} +{%- else -%} + {{ system_token + 'You are Apertus, a helpful assistant created by the SwissAI initiative.\nKnowledge cutoff: 2024-04\nCurrent date: ' + strftime_now('%Y-%m-%d') + end_system_token }} + {%- set loop_messages = messages -%} +{%- endif -%} + +{{ developer_token + 'Deliberation: ' }} +{%- if enable_thinking is defined and enable_thinking -%} + {{ 'enabled\n' }} +{%- else -%} + {{ 'disabled\n' }} +{%- endif -%} +{%- if tools is defined and tools -%} + {{ 'Tool Capabilities:\n' + render_tools(tools) }} +{%- else -%} + {{ 'Tool Capabilities: disabled' }} +{%- endif -%} +{{ end_developer_token }} + +{%- for message in loop_messages -%} + {%- if message.role == 'user' -%} + {%- set ns.in_inner = false -%} + {%- if ns.in_tool -%} + {{ ']' }} + {%- set ns.in_tool = false -%} + {%- endif -%} + {%- if ns.in_assistant -%} + {{ end_assistant_token }} + {%- set ns.in_assistant = false -%} + {%- endif -%} + {%- if "content" in message -%} + {{ user_token }} + {%- if message.content is string -%} + {{ message.content }} + {%- elif message.content is mapping and "parts" in message.content -%} + {%- set parts = message.content.parts -%} + {%- for part in parts -%} + {%- if part.type == "text" -%} + {{ part.text }} + {%- else -%} + {{- raise_exception("Invalid user part: " + part.type) -}} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{- raise_exception("Invalid user message: " + message.role) -}} + {%- endif -%} + {{ end_user_token }} + {%- endif -%} + {%- elif message.role == 'assistant' -%} + {%- if not ns.in_assistant -%} + {{ assistant_token }} + {%- set ns.in_assistant = true -%} + {%- endif -%} + {%- if "content" in message and message.content is not none -%} + {%- if message.content is string and (ns.assistant_format is none or ns.assistant_format == "string") -%} + {%- if ns.in_tool -%} + {{ ']' }} + {%- set ns.in_tool = false -%} + {%- endif -%} + {%- set ns.assistant_format = "string" -%} + {{ message.content }} + {%- elif message.content is mapping and "blocks" in message.content and (ns.assistant_format is none or ns.assistant_format == "mapping") -%} + {%- set ns.assistant_format = "mapping" -%} + {%- set blocks = message.content.blocks -%} + {%- for block in blocks -%} + {%- if block.type == 'thoughts' -%} + {%- if ns.in_tool -%} + {{ ']' }} + {%- set ns.in_tool = false -%} + {%- endif -%} + {%- if not ns.in_inner -%} + {%- set ns.in_inner = true -%} + {{ inner_token }} + {%- endif -%} + {{ block.text }} + {%- elif block.type == 'tool_calls' -%} + {%- if ns.in_tool -%} + {{ ']' }} + {%- set ns.in_tool = false -%} + {%- endif -%} + {%- if ns.in_inner and not loop.first and block.calls|length == 1 and block.calls[0].name == 'display_answers' -%} + {%- set ns.in_inner = false -%} + {{ outer_token }} + {%- endif -%} + {{ tool_calls_token + '[' }} + {%- for tool_call in block.calls -%} + {{- '{"' + tool_call.name + '": ' + tool_call.arguments + '}' }} + {%- if not loop.last -%} + {{- ", " }} + {%- endif -%} + {%- endfor -%} + {{ ']' + end_tool_calls_token }} + {%- elif block.type == 'tool_outputs' -%} + {%- if ns.in_tool -%} + {{- raise_exception("Cannot have both tool outputs as separate messages and tool outputs as blocks") -}} + {%- endif -%} + {{ '[' }} + {%- for tool_output in block.outputs -%} + {{- tool_output.output }} + {%- if not loop.last -%} + {{- ", " }} + {%- endif -%} + {%- endfor -%} + {{- ']' }} + {%- elif block.type == 'response' -%} + {%- if ns.in_tool -%} + {{ ']' }} + {%- set ns.in_tool = false -%} + {%- endif -%} + {%- if (not loop.first and ns.in_inner) or (ns.in_assistant and ns.in_inner) -%} + {%- set ns.in_inner = false -%} + {{ outer_token }} + {%- endif -%} + {{ block.text }} + {%- else -%} + {{- raise_exception("Invalid assistant block type: " + block.type) -}} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{- raise_exception("Invalid assistant content '" + message.content + "', expected " + ns.assistant_format) -}} + {%- endif -%} + {%- elif "tool_calls" not in message -%} + {{- raise_exception("Invalid assistant message " + message) -}} + {%- endif -%} + {%- if "tool_calls" in message and message.tool_calls -%} + {{ tool_calls_token + '[' }} + {%- for tool_call in message.tool_calls -%} + {%- if tool_call.type == 'function' -%} + {%- set function = tool_call.function -%} + {{- '{"' + function.name + '": ' + function.arguments + '}' }} + {%- if not loop.last -%} + {{- ", " }} + {%- endif -%} + {%- else -%} + {{- raise_exception("Invalid tool call type: " + tool_call.type) -}} + {%- endif -%} + {%- endfor -%} + {{ ']' + end_tool_calls_token }} + {%- endif -%} + {%- elif message.role == 'tool' -%} + {%- if not ns.in_assistant -%} + {{- raise_exception("Tool message outside of assistant") -}} + {%- endif -%} + {%- if not ns.in_tool -%} + {{ '[' }} + {%- set ns.in_tool = true -%} + {%- else -%} + {{ ", "}} + {%- endif -%} + {{ message.content }} + {%- else -%} + {{- raise_exception("Invalid message role") -}} + {%- endif -%} +{%- endfor -%} +{%- if ns.in_tool -%} + {{ ']' }} +{%- endif -%} +{%- if add_generation_prompt -%} + {{ assistant_token }} +{%- endif -%} \ No newline at end of file diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 4e8d54c419..45f0d0e2cb 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -93,12 +93,14 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_OPENAI_MOE, "gpt-oss" }, { LLM_ARCH_LFM2, "lfm2" }, + { LLM_ARCH_LFM2MOE, "lfm2moe" }, { LLM_ARCH_DREAM, "dream" }, { LLM_ARCH_SMALLTHINKER, "smallthinker" }, { LLM_ARCH_LLADA, "llada" }, { LLM_ARCH_LLADA_MOE, "llada-moe" }, { LLM_ARCH_SEED_OSS, "seed_oss" }, { LLM_ARCH_GROVEMOE, "grovemoe" }, + { LLM_ARCH_APERTUS, "apertus" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -256,6 +258,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" }, { LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, "adapter.alora.invocation_tokens" }, + { LLM_KV_XIELU_ALPHA_N, "xielu.alpha_n" }, + { LLM_KV_XIELU_ALPHA_P, "xielu.alpha_p" }, + { LLM_KV_XIELU_BETA, "xielu.beta" }, + { LLM_KV_XIELU_EPS, "xielu.eps" }, + // deprecated { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, @@ -2098,6 +2105,32 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_OUTPUT, "output" }, } }, + { + LLM_ARCH_LFM2MOE, + { + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" }, + { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, + { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + } + }, { LLM_ARCH_SMALLTHINKER, { @@ -2119,6 +2152,25 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" } }, }, + { + LLM_ARCH_APERTUS, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_DREAM, { @@ -2468,6 +2520,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_PLAMO2: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: case LLM_ARCH_NEMOTRON_H: return true; default: diff --git a/src/llama-arch.h b/src/llama-arch.h index b5c6f3d76a..507fe5f379 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -97,12 +97,14 @@ enum llm_arch { LLM_ARCH_SMOLLM3, LLM_ARCH_OPENAI_MOE, LLM_ARCH_LFM2, + LLM_ARCH_LFM2MOE, LLM_ARCH_DREAM, LLM_ARCH_SMALLTHINKER, LLM_ARCH_LLADA, LLM_ARCH_LLADA_MOE, LLM_ARCH_SEED_OSS, LLM_ARCH_GROVEMOE, + LLM_ARCH_APERTUS, LLM_ARCH_UNKNOWN, }; @@ -260,6 +262,11 @@ enum llm_kv { LLM_KV_SHORTCONV_L_CACHE, + LLM_KV_XIELU_ALPHA_N, + LLM_KV_XIELU_ALPHA_P, + LLM_KV_XIELU_BETA, + LLM_KV_XIELU_EPS, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 66e6c6a38f..956c4e085e 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -590,7 +590,7 @@ int32_t llm_chat_apply_template( ss << message->content << "<|end_of_text|>\n"; } if (add_ass) { - ss << "<|start_of_role|>assistant<|end_of_role|>\n"; + ss << "<|start_of_role|>assistant<|end_of_role|>"; } } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { // GigaChat template diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 0fe4b56942..f29b23eeff 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -42,7 +42,7 @@ struct llama_hparams { uint32_t n_embd; uint32_t n_embd_features = 0; uint32_t n_layer; - int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache + int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache uint32_t n_rot; uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head @@ -169,6 +169,12 @@ struct llama_hparams { uint32_t laurel_rank = 64; uint32_t n_embd_altup = 256; + // xIELU + std::array xielu_alpha_n; + std::array xielu_alpha_p; + std::array xielu_beta; + std::array xielu_eps; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp index 827302e6d2..facba1d004 100644 --- a/src/llama-kv-cache-iswa.cpp +++ b/src/llama-kv-cache-iswa.cpp @@ -220,7 +220,7 @@ bool llama_kv_cache_iswa::get_can_shift() const { } void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { - if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { kv_base->state_write(io, seq_id, flags); } @@ -228,7 +228,7 @@ void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id } void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { - if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { kv_base->state_read(io, seq_id, flags); } diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index abf652483c..dfb8439e01 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -73,7 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch, false); + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); } if (ubatch.n_tokens == 0) { @@ -175,17 +177,17 @@ std::map llama_memory_hybrid::memory_breakdo } void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { - GGML_UNUSED(flags); - - mem_attn->state_write(io, seq_id); - mem_recr->state_write(io, seq_id); + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { + mem_attn->state_write(io, seq_id, flags); + } + mem_recr->state_write(io, seq_id, flags); } void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { - GGML_UNUSED(flags); - - mem_attn->state_read(io, seq_id); - mem_recr->state_read(io, seq_id); + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { + mem_attn->state_read(io, seq_id, flags); + } + mem_recr->state_read(io, seq_id, flags); } llama_kv_cache * llama_memory_hybrid::get_mem_attn() const { diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 44645fcdd2..9402d9cb8d 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -136,6 +136,7 @@ void llama_memory_recurrent::clear(bool data) { } bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1); uint32_t new_head = size; if (p0 < 0) { @@ -156,7 +157,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (tail_id >= 0) { const auto & cell = cells[tail_id]; // partial intersection is invalid - if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) { + //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n"); return false; } // invalidate tails which will be cleared @@ -167,6 +169,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } else { // seq_id is negative, then the range should include everything or nothing if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: `seq_id` is negative, so returning false\n"); return false; } } @@ -379,7 +382,9 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch, false); + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); } if (ubatch.n_tokens == 0) { diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 8182a9adf5..aa3a65f87a 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -465,6 +465,8 @@ namespace GGUFMeta { // TODO: this is not very clever - figure out something better template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + llama_model_loader::llama_model_loader( const std::string & fname, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2470f87850..03c2f49d78 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -114,6 +114,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; case LLM_TYPE_A13B: return "A13B"; + case LLM_TYPE_8B_A1B: return "8B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; @@ -310,7 +311,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara } // CPU: ACCEL -> GPU host -> CPU extra -> CPU -static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts) { +static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts, bool no_host) { buft_list_t buft_list; // add ACCEL buffer types @@ -331,11 +332,13 @@ static buft_list_t make_cpu_buft_list(const std::vector & de // generally, this will be done using the first device in the list // a better approach would be to handle this on a weight-by-weight basis using the offload_op // function of the device to determine if it would benefit from being stored in a host buffer - for (auto * dev : devices) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); - if (buft) { - buft_list.emplace_back(dev, buft); - break; + if (!no_host) { + for (auto * dev : devices) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + break; + } } } @@ -512,9 +515,13 @@ void llama_model::load_hparams(llama_model_loader & ml) { llm_arch_is_recurrent(ml.get_arch())); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); - std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0); + std::fill(hparams.xielu_alpha_n.begin(), hparams.xielu_alpha_n.end(), 0.0f); + std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f); + std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); + std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); @@ -1084,7 +1091,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; default: type = LLM_TYPE_UNKNOWN; - } + } + + // Load attention parameters + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); } break; case LLM_ARCH_GPT2: { @@ -1985,14 +1996,29 @@ void llama_model::load_hparams(llama_model_loader & ml) { for (uint32_t il = 0; il < hparams.n_layer; ++il) { hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; } + hparams.n_layer_dense_lead = hparams.n_layer; switch (hparams.n_ff()) { case 4608: type = LLM_TYPE_350M; break; case 6912: type = LLM_TYPE_700M; break; case 8192: type = LLM_TYPE_1_2B; break; case 10752: type = LLM_TYPE_2_6B; break; - default: type = LLM_TYPE_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_LFM2MOE: + { + ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + } + + type = LLM_TYPE_8B_A1B; + } break; case LLM_ARCH_SMALLTHINKER: { const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); @@ -2029,6 +2055,19 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_APERTUS: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -2062,7 +2101,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false"); // build a list of buffer types for the CPU and GPU devices - pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts); + pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); for (auto * dev : devices) { buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); // add CPU buffer types as a fallback @@ -3392,17 +3431,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_PLAMO2: { + // mamba parameters const uint32_t d_conv = hparams.ssm_d_conv; const uint32_t d_state = hparams.ssm_d_state; const uint32_t num_heads = hparams.ssm_dt_rank; const uint32_t intermediate_size = hparams.ssm_d_inner; - const uint32_t head_dim = intermediate_size / num_heads; - const uint32_t qk_dim = head_dim; - const uint32_t v_dim = head_dim; - const int64_t num_attention_heads = hparams.n_head(); - const int64_t q_num_heads = num_attention_heads; const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + // attention parameters + const uint32_t qk_dim = hparams.n_embd_head_k; + const uint32_t v_dim = hparams.n_embd_head_v; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output @@ -3436,6 +3475,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); } else { + const int64_t num_attention_heads = hparams.n_head(i); + const int64_t q_num_heads = num_attention_heads; const int64_t num_key_value_heads = hparams.n_head_kv(i); const int64_t k_num_heads = num_key_value_heads; const int64_t v_num_heads = num_key_value_heads; @@ -3444,8 +3485,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t v_proj_dim = v_num_heads * v_dim; layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim, num_attention_heads}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim, k_num_heads}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim, num_attention_heads}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim, k_num_heads}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); } @@ -4825,11 +4866,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); } } } @@ -5787,6 +5830,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); @@ -5798,11 +5842,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - // ffn is same for transformer and conv layers + + const bool is_moe_layer = i >= static_cast(hparams.n_layer_dense_lead); + + // ffn/moe is same for transformer and conv layers layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + if (is_moe_layer) { + GGML_ASSERT(n_expert && n_expert_used); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } else { // dense + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } // for operator_norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -5907,6 +5963,48 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_chexps = create_tensor(tn(LLM_TENSOR_FFN_UP_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); } } break; + case LLM_ARCH_APERTUS: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + + // Q and K layernorms for Apertus + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -6241,7 +6339,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); } - if (arch == LLM_ARCH_SMALLTHINKER) { + if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } @@ -7776,6 +7874,8 @@ struct llm_build_bert : public llm_graph_context { } if (model.layers[il].attn_q_norm) { + Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, @@ -7785,6 +7885,8 @@ struct llm_build_bert : public llm_graph_context { } if (model.layers[il].attn_k_norm) { + Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, @@ -8167,6 +8269,9 @@ struct llm_build_mpt : public llm_graph_context { // Q/K Layernorm if (model.layers[il].attn_q_norm) { + Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens); + Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, @@ -17609,6 +17714,7 @@ private: const int64_t n_embd_head_q = hparams.n_embd_head_k; const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_head_v = hparams.n_embd_head_v; + int32_t n_head = hparams.n_head(il); int32_t n_head_kv = hparams.n_head_kv(il); const int64_t q_offset = 0; @@ -18525,6 +18631,8 @@ struct llm_build_lfm2 : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + const bool is_moe_layer = il >= static_cast(hparams.n_layer_dense_lead); + auto * prev_cur = cur; cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "model.layers.{}.operator_norm", il); @@ -18539,7 +18647,16 @@ struct llm_build_lfm2 : public llm_graph_context { } cur = ggml_add(ctx0, prev_cur, cur); - cur = ggml_add(ctx0, cur, build_feed_forward(cur, il)); + + auto * ffn_norm_out = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(ffn_norm_out, "model.layers.{}.ffn_norm", il); + + ggml_tensor * ffn_out = is_moe_layer ? + build_moe_feed_forward(ffn_norm_out, il) : + build_dense_feed_forward(ffn_norm_out, il); + cb(ffn_norm_out, "model.layers.{}.ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_out); } cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1); @@ -18554,23 +18671,32 @@ struct llm_build_lfm2 : public llm_graph_context { ggml_build_forward_expand(gf, cur); } - ggml_tensor * build_feed_forward(ggml_tensor * cur, - int il) const { - cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "model.layers.{}.ffn_norm", il); + ggml_tensor * build_moe_feed_forward(ggml_tensor * cur, + int il) const { + return build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + static_cast(hparams.expert_gating_func), + il); + } + ggml_tensor * build_dense_feed_forward(ggml_tensor * cur, + int il) const { GGML_ASSERT(!model.layers[il].ffn_up_b); GGML_ASSERT(!model.layers[il].ffn_gate_b); GGML_ASSERT(!model.layers[il].ffn_down_b); - cur = build_ffn(cur, + return build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "model.layers.{}.feed_forward.w2", il); - - return cur; } ggml_tensor * build_attn_block(ggml_tensor * cur, @@ -19090,6 +19216,141 @@ struct llm_build_grovemoe : public llm_graph_context { } }; +struct llm_build_apertus : public llm_graph_context { + llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur_pos", il); + cb(Kcur, "Kcur_pos", il); + cb(Vcur, "Vcur_pos", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network with xIELU activation + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // Up projection + ggml_tensor * up = build_lora_mm(model.layers[il].ffn_up, cur); + cb(up, "ffn_up", il); + + float alpha_n_val = hparams.xielu_alpha_n[il]; + float alpha_p_val = hparams.xielu_alpha_p[il]; + float beta_val = hparams.xielu_beta[il]; + float eps_val = hparams.xielu_eps[il]; + + // Apply xIELU activation + ggml_tensor * activated = ggml_xielu(ctx0, up, alpha_n_val, alpha_p_val, beta_val, eps_val); + cb(activated, "ffn_xielu", il); + + // Down projection + cur = build_lora_mm(model.layers[il].ffn_down, activated); + cb(cur, "ffn_down", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, nullptr, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; @@ -19605,6 +19866,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: { llm = std::make_unique(*this, params); } break; @@ -19620,6 +19882,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_APERTUS: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -19651,6 +19917,7 @@ llama_model_params llama_model_default_params() { /*.use_mlock =*/ false, /*.check_tensors =*/ false, /*.use_extra_bufts =*/ true, + /*.no_host =*/ false, }; return result; @@ -19822,10 +20089,12 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_OPENAI_MOE: case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: case LLM_ARCH_SMALLTHINKER: case LLM_ARCH_GLM4_MOE: case LLM_ARCH_SEED_OSS: case LLM_ARCH_GROVEMOE: + case LLM_ARCH_APERTUS: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: @@ -19936,6 +20205,10 @@ bool llama_model_is_recurrent(const llama_model * model) { return llm_arch_is_recurrent(model->arch); } +bool llama_model_is_hybrid(const llama_model * model) { + return llm_arch_is_hybrid(model->arch); +} + bool llama_model_is_diffusion(const llama_model * model) { return llm_arch_is_diffusion(model->arch); } diff --git a/src/llama-model.h b/src/llama-model.h index d73ce96932..20b59d952b 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -107,6 +107,7 @@ enum llm_type { LLM_TYPE_17B_16E, // llama4 Scout LLM_TYPE_17B_128E, // llama4 Maverick LLM_TYPE_A13B, + LLM_TYPE_8B_A1B, // lfm2moe LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, LLM_TYPE_106B_A12B, // GLM-4.5-Air @@ -380,6 +381,12 @@ struct llama_layer { // openai-moe struct ggml_tensor * attn_sinks = nullptr; + // xIELU activation parameters for Apertus + struct ggml_tensor * ffn_act_alpha_n = nullptr; + struct ggml_tensor * ffn_act_alpha_p = nullptr; + struct ggml_tensor * ffn_act_beta = nullptr; + struct ggml_tensor * ffn_act_eps = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index da938af03b..f965752a84 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -347,6 +347,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_OLMO: case LLAMA_VOCAB_PRE_TYPE_JAIS: case LLAMA_VOCAB_PRE_TYPE_TRILLION: + case LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING: regex_exprs = { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }; @@ -1961,6 +1962,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "trillion") { pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION; clean_spaces = false; + } else if ( + tokenizer_pre == "granite-docling") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING; + clean_spaces = false; } else if ( tokenizer_pre == "bailingmoe" || tokenizer_pre == "llada-moe") { diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 0d2f28c36c..5e468675e4 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -8,46 +8,47 @@ // pre-tokenization types enum llama_vocab_pre_type { - LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, - LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, - LLAMA_VOCAB_PRE_TYPE_FALCON = 4, - LLAMA_VOCAB_PRE_TYPE_MPT = 5, - LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, - LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, - LLAMA_VOCAB_PRE_TYPE_REFACT = 8, - LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, - LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, - LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, - LLAMA_VOCAB_PRE_TYPE_OLMO = 12, - LLAMA_VOCAB_PRE_TYPE_DBRX = 13, - LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, - LLAMA_VOCAB_PRE_TYPE_PORO = 15, - LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, - LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, - LLAMA_VOCAB_PRE_TYPE_VIKING = 18, - LLAMA_VOCAB_PRE_TYPE_JAIS = 19, - LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, - LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, - LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, - LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, - LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, - LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, - LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, - LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, - LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, - LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, - LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, - LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, - LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, - LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, - LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, - LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, - LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, + LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, + LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, + LLAMA_VOCAB_PRE_TYPE_FALCON = 4, + LLAMA_VOCAB_PRE_TYPE_MPT = 5, + LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, + LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, + LLAMA_VOCAB_PRE_TYPE_REFACT = 8, + LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, + LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, + LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, + LLAMA_VOCAB_PRE_TYPE_OLMO = 12, + LLAMA_VOCAB_PRE_TYPE_DBRX = 13, + LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, + LLAMA_VOCAB_PRE_TYPE_PORO = 15, + LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, + LLAMA_VOCAB_PRE_TYPE_VIKING = 18, + LLAMA_VOCAB_PRE_TYPE_JAIS = 19, + LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, + LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, + LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, + LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, + LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, + LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, + LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, + LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, + LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, + LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, + LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, + LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, + LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, + LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, }; struct LLM_KV; diff --git a/tests/test-alloc.cpp b/tests/test-alloc.cpp index 2eb7724731..95e09c97b0 100644 --- a/tests/test-alloc.cpp +++ b/tests/test-alloc.cpp @@ -548,6 +548,41 @@ static void test_buffer_size_zero() { GGML_ASSERT(backend_b.context->allocated_total() == 0); } +// Test re-using gallocr for a different graph. The new graph has the same +// total size, but one of the chunks is larger, so reallocation is required. +static void test_reallocation() { + dummy_backend backend = dummy_backend_init(32, /*align*/ 4); + ggml_gallocr_ptr galloc; + { + auto [ctx, graph, ctx_ptr] = make_context(); + ggml_tensor * x[4]; + x[0] = make_input_with_size(ctx, 24); + x[1] = make_input_with_size(ctx, 16); + x[2] = ggml_view_1d(ctx, x[0], 4, 0); + x[3] = ggml_add(ctx, x[2], x[1]); + assign_names(ctx); + + galloc = allocate_graph(graph, x[3], &backend.buffer_type); + check_all_allocated(graph); + GGML_ASSERT(backend.context->allocated_total() == 40); + } + { + auto [ctx, graph, ctx_ptr] = make_context(); + ggml_tensor * x[3]; + x[0] = make_input_with_size(ctx, 20); + x[1] = make_input_with_size(ctx, 20); + x[2] = ggml_add(ctx, x[0], x[1]); + assign_names(ctx); + ggml_set_output(x[2]); + ggml_build_forward_expand(graph, x[2]); + + bool result = ggml_gallocr_alloc_graph(galloc.get(), graph); + GGML_ASSERT(result); + check_all_allocated(graph); + GGML_ASSERT(backend.context->allocated_total() == 40); + } +} + static void run(const char * name, void (*f)()) { printf("%s ", name); fflush(stdout); @@ -568,5 +603,6 @@ int main() { run("test_prefer_already_allocated_memory", test_prefer_already_allocated_memory); run("test_multiple_buffer_types", test_multiple_buffer_types); run("test_buffer_size_zero", test_buffer_size_zero); + run("test_reallocation", test_reallocation); return 0; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6b751db114..2fa16b497a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -131,6 +131,50 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } } +// generate an F16 mask where certain blocks are randomly masked with -INF value +static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { + GGML_ASSERT(tensor->type == GGML_TYPE_F16); + + GGML_TENSOR_LOCALS( int32_t, ne, tensor, ne); + + std::vector data_f32(ne0*ne1*ne2*ne3); + std::vector data_f16(ne0*ne1*ne2*ne3); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dis(min, max); + + for (size_t i = 0; i < data_f32.size(); i++) { + data_f32[i] = dis(gen); + } + + // block size + const int blck0 = 128; + const int blck1 = 64; + + // number of INF blocks + const int n_inf_blocks = 0.1*(ne0*ne1*ne2*ne3)/(blck0*blck1); + + for (int b = 0; b < n_inf_blocks; b++) { + const int p3 = (rd() % ne3); + const int p2 = (rd() % ne2); + const int p1 = (rd() % ne1); + const int p0 = (rd() % ne0); + + for (int i1 = 0; i1 < blck1 && p1 + i1 < ne1; i1++) { + const int idx = p3*ne2*ne1*ne0 + p2*ne1*ne0 + (p1 + i1)*ne0 + p0; + + for (int i0 = 0; i0 < blck0 && p0 + i0 < ne0; i0++) { + data_f32[idx + i0] = -INFINITY; + } + } + } + + ggml_fp32_to_fp16_row(data_f32.data(), data_f16.data(), ne0*ne1*ne2*ne3); + + ggml_backend_tensor_set(tensor, data_f16.data(), 0, data_f16.size()*sizeof(ggml_fp16_t)); +} + static std::vector tensor_to_float(const ggml_tensor * t) { std::vector tv; tv.reserve(ggml_nelements(t)); @@ -3796,7 +3840,7 @@ struct test_soft_max : public test_case { if (inplace) { out = ggml_soft_max_ext_inplace(ctx, a, mask, scale, max_bias); } else { - out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias); + out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias); } ggml_soft_max_add_sinks(out, sinks); ggml_set_name(out, "out"); @@ -5111,6 +5155,8 @@ struct test_flash_attn_ext : public test_case { if (strcmp(t->name, "s") == 0) { // make the sink values more noticable in order to trigger a test failure when the implementation is wrong init_tensor_uniform(t, -10.0f, 10.0f); + } else if (strcmp(t->name, "m") == 0) { + init_tensor_kq_mask(t); } else { init_tensor_uniform(t); } @@ -6727,7 +6773,8 @@ static std::vector> make_test_cases_eval() { if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes for (int nr2 : { 1, 4, 16 }) { if (nr2 == 16 && hsk != 128) continue; - for (int kv : { 512, 1024, }) { + //for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) { + for (int kv : { 113, 512, 1024, }) { if (nr2 != 1 && kv != 512) continue; for (int nb : { 1, 3, 32, 35, }) { for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { diff --git a/tests/test-barrier.cpp b/tests/test-barrier.cpp index d85bf912b2..04c27761dc 100644 --- a/tests/test-barrier.cpp +++ b/tests/test-barrier.cpp @@ -1,6 +1,5 @@ #include "ggml.h" #include "ggml-cpu.h" -#include "ggml-backend.h" #include #include @@ -8,12 +7,13 @@ #include #include #include +#include #define MAX_NARGS 2 int main(int argc, char *argv[]) { - int n_threads = 4; + int n_threads = std::max(1, std::min(4, (int) std::thread::hardware_concurrency())); int n_rounds = 100; if (argc > 1) { diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index b863367db6..a5382ae3a3 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -214,7 +214,7 @@ int main(void) { { /* .name= */ "ibm-granite/granite-3.0-8b-instruct", /* .template_str= */ "{%- if tools %}\n {{- '<|start_of_role|>available_tools<|end_of_role|>\n' }}\n {%- for tool in tools %}\n {{- tool | tojson(indent=4) }}\n {%- if not loop.last %}\n {{- '\n\n' }}\n {%- endif %}\n {%- endfor %}\n {{- '<|end_of_text|>\n' }}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'user' %}\n {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'tool_response' %}\n {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}", - /* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n", + /* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>", /* .expected_output_jinja= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>", }, { diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index ce0f4b0a2a..52e23b5ac6 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -411,6 +411,7 @@ const common_chat_msg message_assist_thoughts_unparsed_md = simple_assis const common_chat_msg message_assist_thoughts_unparsed_md_partial = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}"); const common_chat_msg message_assist_thoughts_unparsed_r7b = simple_assist_msg("<|START_THINKING|>I'm\nthinking<|END_THINKING|>Hello, world!\nWhat's up?"); +const common_chat_msg message_assist_thoughts_unparsed_magistral = simple_assist_msg("[THINK]raisonnement[/THINK]Rรฉponse"); const common_chat_msg message_assist_thoughts = simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"); const common_chat_msg message_assist_thoughts_unopened_unparsed = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?"); const common_chat_msg message_assist_thoughts_no_content = simple_assist_msg("", "I'm\nthinking"); @@ -745,6 +746,17 @@ static void test_template_output_parsers() { tmpls.get(), end_tokens, message_assist_call_id, tools, "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]"); } + { + assert_msg_equals( + simple_assist_msg("Rรฉponse", "raisonnement"), + common_chat_parse( + message_assist_thoughts_unparsed_magistral.content, + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_MAGISTRAL, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + })); + } { auto tmpls = read_templates("models/templates/Qwen-QwQ-32B.jinja"); std::vector end_tokens{ "<|im_end|>" }; @@ -2054,6 +2066,79 @@ static void test_template_output_parsers() { /* .parse_tool_calls = */ true, })); } + { + auto tmpls = read_templates("models/templates/Apertus-8B-Instruct.jinja"); + std::vector end_tokens{ "<|assistant_end|>" }; + + assert_equals(COMMON_CHAT_FORMAT_APERTUS, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); + assert_equals(COMMON_CHAT_FORMAT_APERTUS, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + + // Test parsing regular content + assert_msg_equals(message_assist, + common_chat_parse( + "Hello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_APERTUS})); + + // Test parsing content with thinking + assert_msg_equals(message_assist_thoughts, + common_chat_parse( + "<|inner_prefix|>I'm\nthinking<|inner_suffix|>Hello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_APERTUS, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + + // Test parsing tool calls + assert_msg_equals(message_assist_call, + common_chat_parse( + "<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_APERTUS})); + + // Test parsing tool calls with thinking + assert_msg_equals(message_assist_call_thoughts, + common_chat_parse( + "<|inner_prefix|>I'm\nthinking<|inner_suffix|><|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_APERTUS, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + })); + + // Test tool calls with extra content + assert_msg_equals(message_assist_call_content, + common_chat_parse( + "<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>Hello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_APERTUS} + )); + + // Test tool calls with extra content AND thinking + assert_msg_equals(message_assist_call_thoughts_content, + common_chat_parse( + "<|inner_prefix|>I'm\nthinking<|inner_suffix|><|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>Hello, world!\nWhat's up?", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_APERTUS, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK + })); + + // Test template generation for regular content + test_templates(tmpls.get(), end_tokens, message_assist, tools, + "Hello, world!\nWhat's up?", + /* expect_grammar_triggered= */ false); + + // Test template generation for tool calls + test_templates(tmpls.get(), end_tokens, message_assist_call, tools, + "<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>", + /* expect_grammar_triggered= */ true + ); + + assert_equals(true, common_chat_templates_support_enable_thinking(tmpls.get())); + } + } static void test_msg_diffs_compute() { diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 275ba367c0..0de07b9811 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -168,7 +168,7 @@ static std::vector parse_devices_arg(const std::string & val return devices; } -static std::vector register_rpc_device_list(const std::string & servers) { +static void register_rpc_server_list(const std::string & servers) { auto rpc_servers = string_split(servers, ','); if (rpc_servers.empty()) { throw std::invalid_argument("no RPC servers specified"); @@ -179,36 +179,15 @@ static std::vector register_rpc_device_list(const std::strin throw std::invalid_argument("failed to find RPC backend"); } - using add_rpc_device_fn = ggml_backend_dev_t (*)(const char * endpoint); - auto * ggml_backend_rpc_add_device_fn = (add_rpc_device_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); - if (!ggml_backend_rpc_add_device_fn) { - throw std::invalid_argument("failed to find RPC device add function"); + using add_rpc_server_fn = ggml_backend_reg_t (*)(const char * endpoint); + auto * ggml_backend_rpc_add_server_fn = (add_rpc_server_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server"); + if (!ggml_backend_rpc_add_server_fn) { + throw std::invalid_argument("failed to find RPC add server function"); } - - static std::unordered_set registered; - std::vector devices; for (const auto & server : rpc_servers) { - ggml_backend_dev_t dev = nullptr; - - std::string name = string_format("RPC[%s]", server.c_str()); - - if (registered.find(server) != registered.end()) { - dev = ggml_backend_dev_by_name(name.c_str()); - } - - if (!dev) { - dev = ggml_backend_rpc_add_device_fn(server.c_str()); - if (!dev) { - throw std::invalid_argument(string_format("failed to add RPC device for server '%s'", server.c_str())); - } - ggml_backend_device_register(dev); - registered.insert(server); - } - - devices.push_back(dev); + auto reg = ggml_backend_rpc_add_server_fn(server.c_str()); + ggml_backend_register(reg); } - - return devices; } static std::string devices_to_string(const std::vector & devices) { @@ -357,6 +336,7 @@ struct cmd_params { std::vector use_mmap; std::vector embeddings; std::vector no_op_offload; + std::vector no_host; ggml_numa_strategy numa; int reps; ggml_sched_priority prio; @@ -394,6 +374,7 @@ static const cmd_params cmd_params_defaults = { /* use_mmap */ { true }, /* embeddings */ { false }, /* no_op_offload */ { false }, + /* no_host */ { false }, /* numa */ GGML_NUMA_STRATEGY_DISABLED, /* reps */ 5, /* prio */ GGML_SCHED_PRIO_NORMAL, @@ -474,6 +455,8 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -ot --override-tensor =;...\n"); printf(" (default: disabled)\n"); printf(" -nopo, --no-op-offload <0|1> (default: 0)\n"); + printf(" --no-host <0|1> (default: %s)\n", + join(cmd_params_defaults.no_host, ",").c_str()); printf("\n"); printf( "Multiple values can be given for each parameter by separating them with ','\n" @@ -714,7 +697,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } try { - register_rpc_device_list(argv[i]); + register_rpc_server_list(argv[i]); } catch (const std::exception & e) { fprintf(stderr, "error: %s\n", e.what()); invalid_param = true; @@ -803,6 +786,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split(argv[i], split_delim); params.no_op_offload.insert(params.no_op_offload.end(), p.begin(), p.end()); + } else if (arg == "--no-host") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split(argv[i], split_delim); + params.no_host.insert(params.no_host.end(), p.begin(), p.end()); } else if (arg == "-ts" || arg == "--tensor-split") { if (++i >= argc) { invalid_param = true; @@ -1024,6 +1014,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.no_op_offload.empty()) { params.no_op_offload = cmd_params_defaults.no_op_offload; } + if (params.no_host.empty()) { + params.no_host = cmd_params_defaults.no_host; + } if (params.n_threads.empty()) { params.n_threads = cmd_params_defaults.n_threads; } @@ -1065,6 +1058,7 @@ struct cmd_params_instance { bool use_mmap; bool embeddings; bool no_op_offload; + bool no_host; llama_model_params to_llama_mparams() const { llama_model_params mparams = llama_model_default_params(); @@ -1077,6 +1071,7 @@ struct cmd_params_instance { mparams.main_gpu = main_gpu; mparams.tensor_split = tensor_split.data(); mparams.use_mmap = use_mmap; + mparams.no_host = no_host; if (n_cpu_moe <= 0) { if (tensor_buft_overrides.empty()) { @@ -1122,6 +1117,7 @@ struct cmd_params_instance { split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap && tensor_split == other.tensor_split && devices == other.devices && + no_host == other.no_host && vec_tensor_buft_override_equal(tensor_buft_overrides, other.tensor_buft_overrides); } @@ -1157,6 +1153,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & ts : params.tensor_split) for (const auto & ot : params.tensor_buft_overrides) for (const auto & mmp : params.use_mmap) + for (const auto & noh : params.no_host) for (const auto & embd : params.embeddings) for (const auto & nopo : params.no_op_offload) for (const auto & nb : params.n_batch) @@ -1199,6 +1196,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .no_op_offload= */ nopo, + /* .no_host = */ noh, }; instances.push_back(instance); } @@ -1232,6 +1230,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .no_op_offload= */ nopo, + /* .no_host = */ noh, }; instances.push_back(instance); } @@ -1265,6 +1264,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .no_op_offload= */ nopo, + /* .no_host = */ noh, }; instances.push_back(instance); } @@ -1303,6 +1303,7 @@ struct test { bool use_mmap; bool embeddings; bool no_op_offload; + bool no_host; int n_prompt; int n_gen; int n_depth; @@ -1339,6 +1340,7 @@ struct test { use_mmap = inst.use_mmap; embeddings = inst.embeddings; no_op_offload = inst.no_op_offload; + no_host = inst.no_host; n_prompt = inst.n_prompt; n_gen = inst.n_gen; n_depth = inst.n_depth; @@ -1368,13 +1370,23 @@ struct test { static std::string get_backend() { std::vector backends; + bool rpc_used = false; for (size_t i = 0; i < ggml_backend_reg_count(); i++) { auto * reg = ggml_backend_reg_get(i); std::string name = ggml_backend_reg_name(reg); - if (name != "CPU") { - backends.push_back(ggml_backend_reg_name(reg)); + if (string_starts_with(name, "RPC")) { + if (ggml_backend_reg_dev_count(reg) > 0) { + rpc_used = true; + } + } else { + if (name != "CPU") { + backends.push_back(ggml_backend_reg_name(reg)); + } } } + if (rpc_used) { + backends.push_back("RPC"); + } return backends.empty() ? "CPU" : join(backends, ","); } @@ -1386,8 +1398,8 @@ struct test { "type_k", "type_v", "n_gpu_layers", "n_cpu_moe", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "devices", "tensor_split", "tensor_buft_overrides", "use_mmap", "embeddings", "no_op_offload", - "n_prompt", "n_gen", "n_depth", "test_time", "avg_ns", - "stddev_ns", "avg_ts", "stddev_ts" + "no_host", "n_prompt", "n_gen", "n_depth", "test_time", + "avg_ns", "stddev_ns", "avg_ts", "stddev_ts" }; return fields; } @@ -1402,7 +1414,7 @@ struct test { return INT; } if (field == "f16_kv" || field == "no_kv_offload" || field == "cpu_strict" || field == "flash_attn" || - field == "use_mmap" || field == "embeddings") { + field == "use_mmap" || field == "embeddings" || field == "no_host") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -1477,6 +1489,7 @@ struct test { std::to_string(use_mmap), std::to_string(embeddings), std::to_string(no_op_offload), + std::to_string(no_host), std::to_string(n_prompt), std::to_string(n_gen), std::to_string(n_depth), @@ -1665,6 +1678,9 @@ struct markdown_printer : public printer { if (field == "no_op_offload") { return 4; } + if (field == "no_host") { + return 4; + } int width = std::max((int) field.length(), 10); @@ -1699,6 +1715,9 @@ struct markdown_printer : public printer { if (field == "no_op_offload") { return "nopo"; } + if (field == "no_host") { + return "noh"; + } if (field == "devices") { return "dev"; } @@ -1779,6 +1798,9 @@ struct markdown_printer : public printer { if (params.no_op_offload.size() > 1 || params.no_op_offload != cmd_params_defaults.no_op_offload) { fields.emplace_back("no_op_offload"); } + if (params.no_host.size() > 1 || params.no_host != cmd_params_defaults.no_host) { + fields.emplace_back("no_host"); + } fields.emplace_back("test"); fields.emplace_back("t/s"); diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 664b0c9ac6..7a7523851c 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -31,6 +31,7 @@ // vision-specific #define KEY_IMAGE_SIZE "clip.vision.image_size" +#define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size" #define KEY_PATCH_SIZE "clip.vision.patch_size" #define KEY_IMAGE_MEAN "clip.vision.image_mean" #define KEY_IMAGE_STD "clip.vision.image_std" diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 210ecc883f..98e68af27a 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -170,7 +170,9 @@ struct clip_hparams { int32_t projection_dim; int32_t n_head; int32_t n_layer; - int32_t proj_scale_factor = 0; // idefics3 + // idefics3 + int32_t preproc_image_size = 0; + int32_t proj_scale_factor = 0; float image_mean[3]; float image_std[3]; @@ -2250,6 +2252,7 @@ struct clip_model_loader { if (is_vision) { get_u32(KEY_IMAGE_SIZE, hparams.image_size); + get_u32(KEY_PREPROC_IMAGE_SIZE, hparams.preproc_image_size, false); get_u32(KEY_PATCH_SIZE, hparams.patch_size); get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy @@ -3551,10 +3554,51 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // res_imgs->data[0] = *res; res_imgs->entries.push_back(std::move(img_f32)); return true; - } - else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE + } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) { + // The refined size has two steps: + // 1. Resize w/ aspect-ratio preserving such that the longer side is + // the preprocessor longest size + // 2. Resize w/out preserving aspect ratio such that both sides are + // multiples of image_size (always rounding up) + // + // CITE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics3/image_processing_idefics3.py#L737 + const clip_image_size refined_size = image_manipulation::calc_size_preserved_ratio( + original_size, params.image_size, params.preproc_image_size); + + llava_uhd::slice_instructions instructions; + instructions.overview_size = clip_image_size{params.image_size, params.image_size}; + instructions.refined_size = refined_size; + instructions.grid_size = clip_image_size{ + static_cast(std::ceil(static_cast(refined_size.width) / params.image_size)), + static_cast(std::ceil(static_cast(refined_size.height) / params.image_size)), + }; + for (int y = 0; y < refined_size.height; y += params.image_size) { + for (int x = 0; x < refined_size.width; x += params.image_size) { + instructions.slices.push_back(llava_uhd::slice_coordinates{ + /* x */x, + /* y */y, + /* size */clip_image_size{ + std::min(params.image_size, refined_size.width - x), + std::min(params.image_size, refined_size.height - y) + } + }); + } + } + auto imgs = llava_uhd::slice_image(img, instructions); + + // cast and normalize to f32 + for (size_t i = 0; i < imgs.size(); ++i) { + // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp"); + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + } + + res_imgs->grid_x = instructions.grid_size.width; + res_imgs->grid_y = instructions.grid_size.height; + return true; + } else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE || ctx->proj_type() == PROJECTOR_TYPE_GEMMA3 - || ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type() == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution ) { clip_image_u8 resized_image; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index cd022c5e24..4d487581ae 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -76,7 +76,7 @@ enum mtmd_slice_tmpl { MTMD_SLICE_TMPL_MINICPMV_2_5, MTMD_SLICE_TMPL_MINICPMV_2_6, MTMD_SLICE_TMPL_LLAMA4, - // TODO @ngxson : add support for idefics (SmolVLM) + MTMD_SLICE_TMPL_IDEFICS3, }; const char * mtmd_default_marker() { @@ -114,19 +114,22 @@ struct mtmd_context { // for llava-uhd style models, we need special tokens in-between slices // minicpmv calls them "slices", llama 4 calls them "tiles" mtmd_slice_tmpl slice_tmpl = MTMD_SLICE_TMPL_NONE; - llama_token tok_ov_img_start = LLAMA_TOKEN_NULL; // overview image - llama_token tok_ov_img_end = LLAMA_TOKEN_NULL; // overview image - llama_token tok_slices_start = LLAMA_TOKEN_NULL; // start of all slices - llama_token tok_slices_end = LLAMA_TOKEN_NULL; // end of all slices - llama_token tok_sli_img_start = LLAMA_TOKEN_NULL; // single slice start - llama_token tok_sli_img_end = LLAMA_TOKEN_NULL; // single slice end - llama_token tok_sli_img_mid = LLAMA_TOKEN_NULL; // between 2 slices - llama_token tok_row_end = LLAMA_TOKEN_NULL; // end of row + std::vector tok_ov_img_start; // overview image + std::vector tok_ov_img_end; // overview image + std::vector tok_slices_start; // start of all slices + std::vector tok_slices_end; // end of all slices + std::vector tok_sli_img_start; // single slice start + std::vector tok_sli_img_end; // single slice end + std::vector tok_sli_img_mid; // between 2 slices + std::vector tok_row_end; // end of row bool tok_row_end_trail = false; bool ov_img_first = false; bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE + // string template for slice image delimiters with row/col (idefics3) + std::string sli_img_start_tmpl; + // for whisper, we pre-calculate the mel filter bank whisper_preprocessor::whisper_filters w_filters; @@ -197,13 +200,13 @@ struct mtmd_context { // minicpmv 2.5 format: // (overview) (slice) (slice) \n ... slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_5; - tok_ov_img_start = lookup_token(""); - tok_ov_img_end = lookup_token(""); - tok_slices_start = lookup_token(""); - tok_slices_end = lookup_token(""); + tok_ov_img_start = {lookup_token("")}; + tok_ov_img_end = {lookup_token("")}; + tok_slices_start = {lookup_token("")}; + tok_slices_end = {lookup_token("")}; tok_sli_img_start = tok_ov_img_start; tok_sli_img_end = tok_ov_img_end; - tok_row_end = lookup_token("\n"); + tok_row_end = {lookup_token("\n")}; tok_row_end_trail = false; // no trailing end-of-row token ov_img_first = true; @@ -211,11 +214,11 @@ struct mtmd_context { // minicpmv 2.6 format: // (overview) (slice) (slice) \n ... slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_6; - tok_ov_img_start = lookup_token(""); - tok_ov_img_end = lookup_token(""); - tok_sli_img_start = lookup_token(""); - tok_sli_img_end = lookup_token(""); - tok_row_end = lookup_token("\n"); + tok_ov_img_start = {lookup_token("")}; + tok_ov_img_end = {lookup_token("")}; + tok_sli_img_start = {lookup_token("")}; + tok_sli_img_end = {lookup_token("")}; + tok_row_end = {lookup_token("\n")}; tok_row_end_trail = false; // no trailing end-of-row token ov_img_first = true; @@ -230,9 +233,9 @@ struct mtmd_context { // <|image|> (overview) <-- overview image is last // <|image_end|> slice_tmpl = MTMD_SLICE_TMPL_LLAMA4; - tok_ov_img_start = lookup_token("<|image|>"); - tok_sli_img_mid = lookup_token("<|tile_x_separator|>"); - tok_row_end = lookup_token("<|tile_y_separator|>"); + tok_ov_img_start = {lookup_token("<|image|>")}; + tok_sli_img_mid = {lookup_token("<|tile_x_separator|>")}; + tok_row_end = {lookup_token("<|tile_y_separator|>")}; tok_row_end_trail = true; // add trailing end-of-row token ov_img_first = false; // overview image is last } @@ -245,8 +248,11 @@ struct mtmd_context { } else if (proj == PROJECTOR_TYPE_IDEFICS3) { // https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215 - img_beg = ""; - img_end = ""; + slice_tmpl = MTMD_SLICE_TMPL_IDEFICS3; + tok_ov_img_start = {lookup_token("\n\n"), lookup_token(""), lookup_token("")}; + tok_ov_img_end = {lookup_token("")}; + tok_row_end = {lookup_token("\n")}; + sli_img_start_tmpl = ""; } else if (proj == PROJECTOR_TYPE_PIXTRAL) { // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md @@ -504,6 +510,7 @@ struct mtmd_tokenizer { ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6 || ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4 + || ctx->slice_tmpl == MTMD_SLICE_TMPL_IDEFICS3 ) { const int n_col = batch_f32.grid_x; const int n_row = batch_f32.grid_y; @@ -517,53 +524,45 @@ struct mtmd_tokenizer { // add overview image (first) if (ctx->ov_img_first) { - if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_ov_img_start}); - } + add_text(ctx->tok_ov_img_start); cur.entries.emplace_back(std::move(ov_chunk)); - if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_ov_img_end}); - } + add_text(ctx->tok_ov_img_end); } // add slices (or tiles) if (!chunks.empty()) { GGML_ASSERT((int)chunks.size() == n_row * n_col); - if (ctx->tok_slices_start != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_slices_start}); - } + add_text(ctx->tok_slices_start); for (int y = 0; y < n_row; y++) { for (int x = 0; x < n_col; x++) { const bool is_last_in_row = (x == n_col - 1); - if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_sli_img_start}); + if (!ctx->tok_sli_img_start.empty()) { + add_text(ctx->tok_sli_img_start); + } else if (!ctx->sli_img_start_tmpl.empty()) { + // If using a template to preceed a slice image + const size_t sz = std::snprintf(nullptr, 0, ctx->sli_img_start_tmpl.c_str(), y+1, x+1) + 1; + std::unique_ptr buf(new char[sz]); + std::snprintf(buf.get(), sz, ctx->sli_img_start_tmpl.c_str(), y+1, x+1); + add_text(std::string(buf.get(), buf.get() + sz - 1), true); } cur.entries.emplace_back(std::move(chunks[y * n_col + x])); - if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_sli_img_end}); - } - if (!is_last_in_row && ctx->tok_sli_img_mid != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_sli_img_mid}); + add_text(ctx->tok_sli_img_end); + if (!is_last_in_row) { + add_text(ctx->tok_sli_img_mid); } } - if ((y != n_row - 1 || ctx->tok_row_end_trail) && ctx->tok_row_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_row_end}); + if ((y != n_row - 1 || ctx->tok_row_end_trail)) { + add_text(ctx->tok_row_end); } } - if (ctx->tok_slices_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_slices_end}); - } + add_text(ctx->tok_slices_end); } // add overview image (last) if (!ctx->ov_img_first) { - if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_ov_img_start}); - } + add_text(ctx->tok_ov_img_start); cur.entries.emplace_back(std::move(ov_chunk)); - if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_ov_img_end}); - } + add_text(ctx->tok_ov_img_end); } } else { @@ -780,7 +779,9 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); bool ok = false; - if (clip_is_llava(ctx_clip) || clip_is_minicpmv(ctx_clip) || clip_is_glm(ctx_clip)) { + if (clip_is_llava(ctx_clip) + || clip_is_minicpmv(ctx_clip) + || clip_is_glm(ctx_clip)) { // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode() const auto & entries = image_tokens->batch_f32.entries; for (size_t i = 0; i < entries.size(); i++) { diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index c64be03630..dbdf7656a6 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -69,6 +69,7 @@ add_test_vision "ggml-org/InternVL2_5-1B-GGUF:Q8_0" add_test_vision "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0" add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" add_test_vision "ggml-org/LFM2-VL-450M-GGUF:Q8_0" +add_test_vision "ggml-org/granite-docling-258M-GGUF:Q8_0" add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0" add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" diff --git a/tools/rpc/README.md b/tools/rpc/README.md index 561f19fda6..afbb302f4b 100644 --- a/tools/rpc/README.md +++ b/tools/rpc/README.md @@ -4,7 +4,7 @@ > This example and the RPC backend are currently in a proof-of-concept development stage. As such, the functionality is fragile and > insecure. **Never run the RPC server on an open network or in a sensitive environment!** -The `rpc-server` allows running `ggml` backend on a remote host. +The `rpc-server` allows exposing `ggml` devices on a remote host. The RPC backend communicates with one or several instances of `rpc-server` and offloads computations to them. This can be used for distributed LLM inference with `llama.cpp` in the following way: @@ -14,28 +14,34 @@ flowchart TD rpcb<-->|TCP|srvb rpcb<-.->|TCP|srvn subgraph hostn[Host N] - srvn[rpc-server]<-.->backend3["Backend (CUDA,Metal,etc.)"] + srvn[rpc-server]<-.->dev4["CUDA0"] + srvn[rpc-server]<-.->dev5["CPU"] end subgraph hostb[Host B] - srvb[rpc-server]<-->backend2["Backend (CUDA,Metal,etc.)"] + srvb[rpc-server]<-->dev3["Metal"] end subgraph hosta[Host A] - srva[rpc-server]<-->backend["Backend (CUDA,Metal,etc.)"] + srva[rpc-server]<-->dev["CUDA0"] + srva[rpc-server]<-->dev2["CUDA1"] end subgraph host[Main Host] - local["Backend (CUDA,Metal,etc.)"]<-->ggml[llama-cli] + local["Local devices"]<-->ggml[llama-cli] ggml[llama-cli]<-->rpcb[RPC backend] end style hostn stroke:#66,stroke-width:2px,stroke-dasharray: 5 5 + classDef devcls fill:#5B9BD5 + class local,dev,dev2,dev3,dev4,dev5 devcls ``` -Each host can run a different backend, e.g. one with CUDA and another with Metal. -You can also run multiple `rpc-server` instances on the same host, each with a different backend. +By default, `rpc-server` exposes all available accelerator devices on the host. +If there are no accelerators, it exposes a single `CPU` device. ## Usage -On each host, build the corresponding backend with `cmake` and add `-DGGML_RPC=ON` to the build options. -For example, to build the CUDA backend with RPC support: +### Remote hosts + +On each remote host, build the backends for each accelerator by adding `-DGGML_RPC=ON` to the build options. +For example, to build the `rpc-server` with support for CUDA accelerators: ```bash mkdir build-rpc-cuda @@ -44,33 +50,38 @@ cmake .. -DGGML_CUDA=ON -DGGML_RPC=ON cmake --build . --config Release ``` -Then, start the `rpc-server` with the backend: +When started, the `rpc-server` will detect and expose all available `CUDA` devices: ```bash -$ bin/rpc-server -p 50052 -create_backend: using CUDA backend -ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no -ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes +$ bin/rpc-server +ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no +ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no ggml_cuda_init: found 1 CUDA devices: - Device 0: NVIDIA T1200 Laptop GPU, compute capability 7.5, VMM: yes -Starting RPC server on 0.0.0.0:50052 + Device 0: NVIDIA GeForce RTX 5090, compute capability 12.0, VMM: yes +Starting RPC server v3.0.0 + endpoint : 127.0.0.1:50052 + local cache : n/a +Devices: + CUDA0: NVIDIA GeForce RTX 5090 (32109 MiB, 31588 MiB free) ``` -When using the CUDA backend, you can specify the device with the `CUDA_VISIBLE_DEVICES` environment variable, e.g.: +You can control the set of exposed CUDA devices with the `CUDA_VISIBLE_DEVICES` environment variable or the `--device` command line option. The following two commands have the same effect: ```bash $ CUDA_VISIBLE_DEVICES=0 bin/rpc-server -p 50052 +$ bin/rpc-server --device CUDA0 -p 50052 ``` -This way you can run multiple `rpc-server` instances on the same host, each with a different CUDA device. +### Main host -On the main host build `llama.cpp` for the local backend and add `-DGGML_RPC=ON` to the build options. -Finally, when running `llama-cli`, use the `--rpc` option to specify the host and port of each `rpc-server`: +On the main host build `llama.cpp` with the backends for the local devices and add `-DGGML_RPC=ON` to the build options. +Finally, when running `llama-cli` or `llama-server`, use the `--rpc` option to specify the host and port of each `rpc-server`: ```bash -$ bin/llama-cli -m ../models/tinyllama-1b/ggml-model-f16.gguf -p "Hello, my name is" --repeat-penalty 1.0 -n 64 --rpc 192.168.88.10:50052,192.168.88.11:50052 -ngl 99 +$ llama-cli -hf ggml-org/gemma-3-1b-it-GGUF -ngl 99 --rpc 192.168.88.10:50052,192.168.88.11:50052 ``` -This way you can offload model layers to both local and remote devices. +By default, llama.cpp distributes model weights and the KV cache across all available devices -- both local and remote -- in proportion to each device's available memory. +You can override this behavior with the `--tensor-split` option and set custom proportions when splitting tensor data across devices. ### Local cache @@ -83,3 +94,11 @@ $ bin/rpc-server -c ``` By default, the cache is stored in the `$HOME/.cache/llama.cpp/rpc` directory and can be controlled via the `LLAMA_CACHE` environment variable. + +### Troubleshooting + +Use the `GGML_RPC_DEBUG` environment variable to enable debug messages from `rpc-server`: +```bash +$ GGML_RPC_DEBUG=1 bin/rpc-server +``` + diff --git a/tools/rpc/rpc-server.cpp b/tools/rpc/rpc-server.cpp index dc8e077f34..0885156127 100644 --- a/tools/rpc/rpc-server.cpp +++ b/tools/rpc/rpc-server.cpp @@ -22,6 +22,7 @@ #include #include #include +#include namespace fs = std::filesystem; @@ -131,24 +132,24 @@ static std::string fs_get_cache_directory() { } struct rpc_server_params { - std::string host = "127.0.0.1"; - int port = 50052; - size_t backend_mem = 0; - bool use_cache = false; - int n_threads = std::max(1U, std::thread::hardware_concurrency()/2); - std::string device; + std::string host = "127.0.0.1"; + int port = 50052; + bool use_cache = false; + int n_threads = std::max(1U, std::thread::hardware_concurrency()/2); + std::vector devices; + std::vector dev_mem; }; static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) { fprintf(stderr, "Usage: %s [options]\n\n", argv[0]); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help show this help message and exit\n"); - fprintf(stderr, " -t, --threads number of threads for the CPU backend (default: %d)\n", params.n_threads); - fprintf(stderr, " -d DEV, --device device to use\n"); - fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str()); - fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port); - fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n"); - fprintf(stderr, " -c, --cache enable local file cache\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -t, --threads N number of threads for the CPU device (default: %d)\n", params.n_threads); + fprintf(stderr, " -d, --device comma-separated list of devices\n"); + fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str()); + fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port); + fprintf(stderr, " -m, --mem memory size for each device (in MB)\n"); + fprintf(stderr, " -c, --cache enable local file cache\n"); fprintf(stderr, "\n"); } @@ -174,17 +175,17 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & if (++i >= argc) { return false; } - params.device = argv[i]; - if (ggml_backend_dev_by_name(params.device.c_str()) == nullptr) { - fprintf(stderr, "error: unknown device: %s\n", params.device.c_str()); - fprintf(stderr, "available devices:\n"); - for (size_t i = 0; i < ggml_backend_dev_count(); i++) { - auto * dev = ggml_backend_dev_get(i); - size_t free, total; - ggml_backend_dev_memory(dev, &free, &total); - printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024); + const std::regex regex{ R"([,/]+)" }; + std::string dev_str = argv[i]; + std::sregex_token_iterator iter(dev_str.begin(), dev_str.end(), regex, -1); + std::sregex_token_iterator end; + for ( ; iter != end; ++iter) { + try { + params.devices.push_back(*iter); + } catch (const std::exception & ) { + fprintf(stderr, "error: invalid device: %s\n", iter->str().c_str()); + return false; } - return false; } } else if (arg == "-p" || arg == "--port") { if (++i >= argc) { @@ -200,7 +201,19 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & if (++i >= argc) { return false; } - params.backend_mem = std::stoul(argv[i]) * 1024 * 1024; + const std::regex regex{ R"([,/]+)" }; + std::string mem_str = argv[i]; + std::sregex_token_iterator iter(mem_str.begin(), mem_str.end(), regex, -1); + std::sregex_token_iterator end; + for ( ; iter != end; ++iter) { + try { + size_t mem = std::stoul(*iter) * 1024 * 1024; + params.dev_mem.push_back(mem); + } catch (const std::exception & ) { + fprintf(stderr, "error: invalid memory size: %s\n", iter->str().c_str()); + return false; + } + } } else if (arg == "-h" || arg == "--help") { print_usage(argc, argv, params); exit(0); @@ -213,45 +226,46 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & return true; } -static ggml_backend_t create_backend(const rpc_server_params & params) { - ggml_backend_t backend = nullptr; +static std::vector get_devices(const rpc_server_params & params) { + std::vector devices; + if (!params.devices.empty()) { + for (auto device : params.devices) { + ggml_backend_dev_t dev = ggml_backend_dev_by_name(device.c_str()); + if (dev) { + devices.push_back(dev); + } else { + fprintf(stderr, "error: unknown device: %s\n", device.c_str()); + fprintf(stderr, "available devices:\n"); + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + auto * dev = ggml_backend_dev_get(i); + size_t free, total; + ggml_backend_dev_memory(dev, &free, &total); + printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024); + } + return {}; + } + } + } - if (!params.device.empty()) { - ggml_backend_dev_t dev = ggml_backend_dev_by_name(params.device.c_str()); + // Try non-CPU devices first + if (devices.empty()) { + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) { + devices.push_back(dev); + } + } + } + + // If there are no accelerators, fallback to CPU device + if (devices.empty()) { + ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); if (dev) { - backend = ggml_backend_dev_init(dev, nullptr); - if (!backend) { - fprintf(stderr, "Failed to create backend for device %s\n", params.device.c_str()); - return nullptr; - } + devices.push_back(dev); } } - if (!backend) { - backend = ggml_backend_init_best(); - } - - if (backend) { - fprintf(stderr, "%s: using %s backend\n", __func__, ggml_backend_name(backend)); - - // set the number of threads - ggml_backend_dev_t dev = ggml_backend_get_device(backend); - ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; - if (reg) { - auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); - if (ggml_backend_set_n_threads_fn) { - ggml_backend_set_n_threads_fn(backend, params.n_threads); - } - } - } - - return backend; -} - -static void get_backend_memory(ggml_backend_t backend, size_t * free_mem, size_t * total_mem) { - ggml_backend_dev_t dev = ggml_backend_get_device(backend); - GGML_ASSERT(dev != nullptr); - ggml_backend_dev_memory(dev, free_mem, total_mem); + return devices; } int main(int argc, char * argv[]) { @@ -273,18 +287,23 @@ int main(int argc, char * argv[]) { fprintf(stderr, "\n"); } - ggml_backend_t backend = create_backend(params); - if (!backend) { - fprintf(stderr, "Failed to create backend\n"); + auto devices = get_devices(params); + if (devices.empty()) { + fprintf(stderr, "No devices found\n"); return 1; } std::string endpoint = params.host + ":" + std::to_string(params.port); - size_t free_mem, total_mem; - if (params.backend_mem > 0) { - free_mem = params.backend_mem; - total_mem = params.backend_mem; - } else { - get_backend_memory(backend, &free_mem, &total_mem); + std::vector free_mem, total_mem; + for (size_t i = 0; i < devices.size(); i++) { + if (i < params.dev_mem.size()) { + free_mem.push_back(params.dev_mem[i]); + total_mem.push_back(params.dev_mem[i]); + } else { + size_t free, total; + ggml_backend_dev_memory(devices[i], &free, &total); + free_mem.push_back(free); + total_mem.push_back(total); + } } const char * cache_dir = nullptr; std::string cache_dir_str; @@ -309,8 +328,7 @@ int main(int argc, char * argv[]) { return 1; } - start_server_fn(backend, endpoint.c_str(), cache_dir, free_mem, total_mem); - - ggml_backend_free(backend); + start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(), + devices.data(), free_mem.data(), total_mem.data()); return 0; } diff --git a/tools/run/run.cpp b/tools/run/run.cpp index 772d66c921..b90a7253c4 100644 --- a/tools/run/run.cpp +++ b/tools/run/run.cpp @@ -9,6 +9,7 @@ #include #if defined(_WIN32) +# define WIN32_LEAN_AND_MEAN # ifndef NOMINMAX # define NOMINMAX # endif @@ -22,6 +23,8 @@ #if defined(LLAMA_USE_CURL) # include +#else +# include "http.h" #endif #include @@ -397,7 +400,6 @@ class File { # endif }; -#ifdef LLAMA_USE_CURL class HttpClient { public: int init(const std::string & url, const std::vector & headers, const std::string & output_file, @@ -428,6 +430,8 @@ class HttpClient { return 0; } +#ifdef LLAMA_USE_CURL + ~HttpClient() { if (chunk) { curl_slist_free_all(chunk); @@ -532,6 +536,117 @@ class HttpClient { return curl_easy_perform(curl); } +#else // LLAMA_USE_CURL is not defined + +#define curl_off_t long long // temporary hack + + private: + // this is a direct translation of the cURL download() above + int download(const std::string & url, const std::vector & headers_vec, const std::string & output_file, + const bool progress, std::string * response_str = nullptr) { + try { + auto [cli, url_parts] = common_http_client(url); + + httplib::Headers headers; + for (const auto & h : headers_vec) { + size_t pos = h.find(':'); + if (pos != std::string::npos) { + headers.emplace(h.substr(0, pos), h.substr(pos + 2)); + } + } + + File out; + if (!output_file.empty()) { + if (!out.open(output_file, "ab")) { + printe("Failed to open file for writing\n"); + return 1; + } + if (out.lock()) { + printe("Failed to exclusively lock file\n"); + return 1; + } + } + + size_t resume_offset = 0; + if (!output_file.empty() && std::filesystem::exists(output_file)) { + resume_offset = std::filesystem::file_size(output_file); + if (resume_offset > 0) { + headers.emplace("Range", "bytes=" + std::to_string(resume_offset) + "-"); + } + } + + progress_data data; + data.file_size = resume_offset; + + long long total_size = 0; + long long received_this_session = 0; + + auto response_handler = + [&](const httplib::Response & response) { + if (resume_offset > 0 && response.status != 206) { + printe("\nServer does not support resuming. Restarting download.\n"); + out.file = freopen(output_file.c_str(), "wb", out.file); + if (!out.file) { + return false; + } + data.file_size = 0; + } + if (progress) { + if (response.has_header("Content-Length")) { + total_size = std::stoll(response.get_header_value("Content-Length")); + } else if (response.has_header("Content-Range")) { + auto range = response.get_header_value("Content-Range"); + auto slash = range.find('/'); + if (slash != std::string::npos) { + total_size = std::stoll(range.substr(slash + 1)); + } + } + } + return true; + }; + + auto content_receiver = + [&](const char * chunk, size_t length) { + if (out.file && fwrite(chunk, 1, length, out.file) != length) { + return false; + } + if (response_str) { + response_str->append(chunk, length); + } + received_this_session += length; + + if (progress && total_size > 0) { + update_progress(&data, total_size, received_this_session, 0, 0); + } + return true; + }; + + auto res = cli.Get(url_parts.path, headers, response_handler, content_receiver); + + if (data.printed) { + printe("\n"); + } + + if (!res) { + auto err = res.error(); + printe("Fetching resource '%s' failed: %s\n", url.c_str(), httplib::to_string(err).c_str()); + return 1; + } + + if (res->status >= 400) { + printe("Fetching resource '%s' failed with status code: %d\n", url.c_str(), res->status); + return 1; + } + + } catch (const std::exception & e) { + printe("HTTP request failed: %s\n", e.what()); + return 1; + } + return 0; + } + +#endif // LLAMA_USE_CURL + static std::string human_readable_time(double seconds) { int hrs = static_cast(seconds) / 3600; int mins = (static_cast(seconds) % 3600) / 60; @@ -644,8 +759,8 @@ class HttpClient { str->append(static_cast(ptr), size * nmemb); return size * nmemb; } + }; -#endif class LlamaData { public: @@ -673,7 +788,6 @@ class LlamaData { } private: -#ifdef LLAMA_USE_CURL int download(const std::string & url, const std::string & output_file, const bool progress, const std::vector & headers = {}, std::string * response_str = nullptr) { HttpClient http; @@ -683,14 +797,6 @@ class LlamaData { return 0; } -#else - int download(const std::string &, const std::string &, const bool, const std::vector & = {}, - std::string * = nullptr) { - printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); - - return 1; - } -#endif // Helper function to handle model tag extraction and URL construction std::pair extract_model_and_tag(std::string & model, const std::string & base_url) { diff --git a/tools/server/README.md b/tools/server/README.md index 9f7ab229f7..e23b122ab1 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -393,7 +393,7 @@ node index.js ### GET `/health`: Returns health check result -This endpoint is public (no API key check). +This endpoint is public (no API key check). `/v1/health` also works. **Response format** @@ -1045,6 +1045,7 @@ Available metrics: - `llamacpp:kv_cache_tokens`: KV-cache tokens. - `llamacpp:requests_processing`: Number of requests processing. - `llamacpp:requests_deferred`: Number of requests deferred. +- `llamacpp:n_past_max`: High watermark of the context size observed. ### POST `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file. diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 4575db67a1..8d57b4a167 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 6062904a8c..307653764c 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -764,7 +764,7 @@ struct completion_token_output { } }; -struct swa_checkpoint { +struct ctx_checkpoint { llama_pos pos_min; llama_pos pos_max; @@ -1460,7 +1460,7 @@ struct server_slot { std::vector generated_token_probs; - std::vector swa_checkpoints; + std::vector ctx_checkpoints; bool has_next_token = true; bool has_new_line = false; @@ -3541,7 +3541,11 @@ struct server_context { slot.n_past = 0; } - const auto n_swa = llama_model_n_swa(model); + // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1 + const auto n_swa = std::max(1, llama_model_n_swa(model)); + + // the largest pos_min required for a checkpoint to be useful + const auto pos_min_thold = std::max(0, slot.n_past - n_swa); if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) { const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); @@ -3550,66 +3554,62 @@ struct server_context { GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); } - const auto pos_min_thold = std::max(0, slot.n_past - n_swa); - if (pos_min > pos_min_thold) { SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa); - // search for a SWA checkpoint + // search for a context checkpoint const auto it = std::find_if( - slot.swa_checkpoints.rbegin(), - slot.swa_checkpoints.rend(), + slot.ctx_checkpoints.rbegin(), + slot.ctx_checkpoints.rend(), [&](const auto & cur) { - return cur.pos_min <= pos_min_thold; + // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] + return cur.pos_min < pos_min_thold; } ); - bool do_reset = it == slot.swa_checkpoints.rend(); + bool do_reset = it == slot.ctx_checkpoints.rend(); + //printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false"); if (!do_reset) { - // restore the checkpoint - const size_t swa_size = it->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY); + // restore the context checkpoint + const size_t ctx_checkpoint_size = it->data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - if (n != swa_size) { - SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024); + if (n != ctx_checkpoint_size) { + SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024); do_reset = true; + //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); } else { - slot.n_past = std::min(slot.n_past, it->pos_max); - - SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024); + slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max)); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024); } } if (do_reset) { - SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n", + SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n", "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); - slot.n_past = 0; - slot.swa_checkpoints.clear(); } } } - if (n_swa > 0) { - const auto pos_min_thold = std::max(0, slot.n_past - n_swa); - + { // erase any checkpoints with pos_min > pos_min_thold - for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) { - const auto & cur = slot.swa_checkpoints[i]; + for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) { + const auto & cur = slot.ctx_checkpoints[i]; if (cur.pos_min > pos_min_thold) { - slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i); - - SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); + slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i); } } } } + // [TAG_PROMPT_LOGITS] if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { - SLT_WRN(slot, "need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens); - + SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens); slot.n_past--; + SLT_WRN(slot, "n_past was set to %d\n", slot.n_past); } slot.n_prompt_tokens_cache = slot.n_past; @@ -3623,9 +3623,9 @@ struct server_context { } } - // keep only the common part + // truncate any tokens that are beyond n_past for this slot if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) { - // could not partially delete (likely using a non-Transformer model) + SLT_WRN(slot, "failed to truncate tokens beyond n_past = %d\n", slot.n_past); llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); // there is no common part left @@ -3633,7 +3633,7 @@ struct server_context { slot.n_prompt_tokens_cache = 0; } - SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); + SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past); // remove the non-common part from the cache slot.cache_tokens.keep_first(slot.n_past); @@ -3854,37 +3854,38 @@ struct server_context { // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; - // make a checkpoint with the SWA memory - // checkpoints are needed only if we are not using "--swa-full" - if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) { - if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) { - { - const auto & cur = slot.swa_checkpoints.back(); + // make a checkpoint of the parts of the memory that cannot be rolled back. + // checkpoints are created only if: + // - the model uses SWA and we are not using `swa_full` + // - the model architecture is marked as recurrent or hybrid + // + // TODO: try to make this conditional on the context or the memory module, instead of the model type + const bool do_checkpoint = + (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) || + (llama_model_n_swa(model) > 0 && !params_base.swa_full); - SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", - cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - } + if (do_checkpoint && params_base.n_ctx_checkpoints > 0) { + while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + // make room for the new checkpoint, if needed + const auto & cur = slot.ctx_checkpoints.front(); + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - slot.swa_checkpoints.erase(slot.swa_checkpoints.begin()); + slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin()); } - const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY); + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{ + auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{ /*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id), /*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id), - /*.data = */ std::vector(swa_size), + /*.data = */ std::vector(checkpoint_size), }); - llama_state_seq_get_data_ext(ctx, cur.data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY); + llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - float size_total = 0.0f; - for (const auto & checkpoint : slot.swa_checkpoints) { - size_total += (float) checkpoint.data.size() / 1024 / 1024; - } - - SLT_WRN(slot, "SWA checkpoint create, pos_min = %d, pos_max = %d, size = %.3f MiB, total = %d/%d (%.3f MiB)\n", - cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024, (int) slot.swa_checkpoints.size(), params_base.n_swa_checkpoints, size_total); + SLT_WRN(slot, "saved context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + (int) slot.ctx_checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots @@ -4183,6 +4184,7 @@ int main(int argc, char ** argv) { auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { static const std::unordered_set public_endpoints = { "/health", + "/v1/health", "/models", "/v1/models", "/api/tags" @@ -5231,6 +5233,7 @@ int main(int argc, char ** argv) { // register API routes svr->Get (params.api_prefix + "/health", handle_health); // public endpoint (no API key check) + svr->Get (params.api_prefix + "/v1/health", handle_health); // public endpoint (no API key check) svr->Get (params.api_prefix + "/metrics", handle_metrics); svr->Get (params.api_prefix + "/props", handle_props); svr->Post(params.api_prefix + "/props", handle_props_change); diff --git a/tools/server/webui/scripts/dev.sh b/tools/server/webui/scripts/dev.sh index e0e8b26e9a..2bda8f22c8 100644 --- a/tools/server/webui/scripts/dev.sh +++ b/tools/server/webui/scripts/dev.sh @@ -1,5 +1,14 @@ #!/bin/bash +# Development script for llama.cpp webui +# +# This script starts the webui development servers (Storybook and Vite). +# Note: You need to start llama-server separately. +# +# Usage: +# bash scripts/dev.sh +# npm run dev + cd ../../../ # Check and install git hooks if missing @@ -28,76 +37,19 @@ check_and_install_hooks() { # Install git hooks if needed check_and_install_hooks -# Check if llama-server binary already exists -if [ ! -f "build/bin/llama-server" ]; then - echo "Building llama-server..." - cmake -B build && cmake --build build --config Release -t llama-server -else - echo "llama-server binary already exists, skipping build." -fi - -# Start llama-server and capture output -echo "Starting llama-server..." -mkfifo server_output.pipe -build/bin/llama-server -hf ggml-org/gpt-oss-20b-GGUF --jinja -c 0 --no-webui > server_output.pipe 2>&1 & -SERVER_PID=$! - -# Function to wait for server to be ready -wait_for_server() { - echo "Waiting for llama-server to be ready..." - local max_wait=60 - local start_time=$(date +%s) - - # Read server output in background and look for the ready message - ( - while IFS= read -r line; do - echo "๐Ÿ” Server: $line" - if [[ "$line" == *"server is listening on http://127.0.0.1:8080 - starting the main loop"* ]]; then - echo "โœ… llama-server is ready!" - echo "READY" > server_ready.flag - break - fi - done < server_output.pipe - ) & - - # Wait for ready flag or timeout - while [ ! -f server_ready.flag ]; do - local current_time=$(date +%s) - local elapsed=$((current_time - start_time)) - - if [ $elapsed -ge $max_wait ]; then - echo "โŒ Server failed to start within $max_wait seconds" - rm -f server_ready.flag - return 1 - fi - - sleep 1 - done - - rm -f server_ready.flag - return 0 -} - # Cleanup function cleanup() { echo "๐Ÿงน Cleaning up..." - kill $SERVER_PID 2>/dev/null - rm -f server_output.pipe server_ready.flag exit } # Set up signal handlers trap cleanup SIGINT SIGTERM -# Wait for server to be ready -if wait_for_server; then - echo "๐Ÿš€ Starting development servers..." - cd tools/server/webui - storybook dev -p 6006 --ci & vite dev --host 0.0.0.0 & - - # Wait for all background processes - wait -else - echo "โŒ Failed to start development environment" - cleanup -fi +echo "๐Ÿš€ Starting development servers..." +echo "๐Ÿ“ Note: Make sure to start llama-server separately if needed" +cd tools/server/webui +storybook dev -p 6006 --ci & vite dev --host 0.0.0.0 & + +# Wait for all background processes +wait diff --git a/tools/server/webui/src/app.css b/tools/server/webui/src/app.css index c743199361..2ca1536409 100644 --- a/tools/server/webui/src/app.css +++ b/tools/server/webui/src/app.css @@ -37,8 +37,8 @@ --sidebar-accent-foreground: oklch(0.205 0 0); --sidebar-border: oklch(0.922 0 0); --sidebar-ring: oklch(0.708 0 0); - --code-background: oklch(0.225 0 0); - --code-foreground: oklch(0.875 0 0); + --code-background: oklch(0.975 0 0); + --code-foreground: oklch(0.145 0 0); --layer-popover: 1000000; } @@ -74,6 +74,8 @@ --sidebar-accent-foreground: oklch(0.985 0 0); --sidebar-border: oklch(1 0 0 / 10%); --sidebar-ring: oklch(0.556 0 0); + --code-background: oklch(0.225 0 0); + --code-foreground: oklch(0.875 0 0); } @theme inline { diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte index 013b77cbbe..ad3ffa3792 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte @@ -3,12 +3,14 @@ import { useProcessingState } from '$lib/hooks/use-processing-state.svelte'; import { isLoading } from '$lib/stores/chat.svelte'; import { fade } from 'svelte/transition'; - import { Check, X } from '@lucide/svelte'; + import { Check, Copy, Package, X } from '@lucide/svelte'; import { Button } from '$lib/components/ui/button'; import { Checkbox } from '$lib/components/ui/checkbox'; import { INPUT_CLASSES } from '$lib/constants/input-classes'; import ChatMessageActions from './ChatMessageActions.svelte'; import Label from '$lib/components/ui/label/label.svelte'; + import { config } from '$lib/stores/settings.svelte'; + import { copyToClipboard } from '$lib/utils/copy'; interface Props { class?: string; @@ -136,6 +138,23 @@ {/if} + {#if config().showModelInfo && message.model} + + + + Model used: + + + + {/if} + {#if message.timestamp && !isEditing} import { goto } from '$app/navigation'; import { page } from '$app/state'; - import { ChatSidebarConversationItem } from '$lib/components/app'; + import { Trash2 } from '@lucide/svelte'; + import { ChatSidebarConversationItem, ConfirmationDialog } from '$lib/components/app'; import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte'; import * as Sidebar from '$lib/components/ui/sidebar'; + import * as AlertDialog from '$lib/components/ui/alert-dialog'; + import Input from '$lib/components/ui/input/input.svelte'; import { conversations, deleteConversation, @@ -16,6 +19,10 @@ let currentChatId = $derived(page.params.id); let isSearchModeActive = $state(false); let searchQuery = $state(''); + let showDeleteDialog = $state(false); + let showEditDialog = $state(false); + let selectedConversation = $state(null); + let editedName = $state(''); let filteredConversations = $derived.by(() => { if (searchQuery.trim().length > 0) { @@ -27,12 +34,41 @@ return conversations(); }); - async function editConversation(id: string, name: string) { - await updateConversationName(id, name); + async function handleDeleteConversation(id: string) { + const conversation = conversations().find((conv) => conv.id === id); + if (conversation) { + selectedConversation = conversation; + showDeleteDialog = true; + } } - async function handleDeleteConversation(id: string) { - await deleteConversation(id); + async function handleEditConversation(id: string) { + const conversation = conversations().find((conv) => conv.id === id); + if (conversation) { + selectedConversation = conversation; + editedName = conversation.name; + showEditDialog = true; + } + } + + function handleConfirmDelete() { + if (selectedConversation) { + showDeleteDialog = false; + + setTimeout(() => { + deleteConversation(selectedConversation.id); + selectedConversation = null; + }, 100); // Wait for animation to finish + } + } + + function handleConfirmEdit() { + if (!editedName.trim() || !selectedConversation) return; + + showEditDialog = false; + + updateConversationName(selectedConversation.id, editedName); + selectedConversation = null; } export function handleMobileSidebarItemClick() { @@ -98,7 +134,7 @@ {handleMobileSidebarItemClick} isActive={currentChatId === conversation.id} onSelect={selectConversation} - onEdit={editConversation} + onEdit={handleEditConversation} onDelete={handleDeleteConversation} /> @@ -119,7 +155,53 @@ -
-

Conversations are stored locally in your browser.

-
+
+ + { + showDeleteDialog = false; + selectedConversation = null; + }} +/> + + + + + Edit Conversation Name + + { + if (e.key === 'Enter') { + e.preventDefault(); + handleConfirmEdit(); + } + }} + placeholder="Enter a new name" + type="text" + bind:value={editedName} + /> + + + + { + showEditDialog = false; + selectedConversation = null; + }}>Cancel + Save + + + diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebarActions.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebarActions.svelte index 30d1f9d4b7..e91673e98b 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebarActions.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSidebar/ChatSidebarActions.svelte @@ -1,8 +1,9 @@ +