diff --git a/.devops/musa.Dockerfile b/.devops/musa.Dockerfile index b0c86dccd5..ec44b22914 100644 --- a/.devops/musa.Dockerfile +++ b/.devops/musa.Dockerfile @@ -1,6 +1,6 @@ ARG UBUNTU_VERSION=22.04 # This needs to generally match the container host's environment. -ARG MUSA_VERSION=rc4.2.0 +ARG MUSA_VERSION=rc4.3.0 # Target the MUSA build image ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION}-amd64 diff --git a/.github/workflows/build-amd.yml b/.github/workflows/build-amd.yml new file mode 100644 index 0000000000..b6fe8de865 --- /dev/null +++ b/.github/workflows/build-amd.yml @@ -0,0 +1,52 @@ +name: CI (AMD) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: [ + '.github/workflows/build-amd.yml', + '**/CMakeLists.txt', + '**/.cmake', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.comp' + ] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + ggml-ci-x64-amd-vulkan: + runs-on: [self-hosted, Linux, X64, AMD] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Test + id: ggml-ci + run: | + vulkaninfo --summary + GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp + + ggml-ci-x64-amd-rocm: + runs-on: [self-hosted, Linux, X64, AMD] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Test + id: ggml-ci + run: | + amd-smi static + GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml index 04ad187d35..2b101876c5 100644 --- a/.github/workflows/build-linux-cross.yml +++ b/.github/workflows/build-linux-cross.yml @@ -141,97 +141,6 @@ jobs: # cmake --build build --config Release -j $(nproc) - ubuntu-24-ppc64el-cpu-cross: - runs-on: ubuntu-24.04 - - steps: - - uses: actions/checkout@v4 - - name: Setup PowerPC64le - run: | - sudo dpkg --add-architecture ppc64el - - # Add arch-specific repositories for non-amd64 architectures - cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list - deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe - deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe - deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe - deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe - EOF - - sudo apt-get update || true ;# Prevent failure due to missing URLs. - - sudo apt-get install -y --no-install-recommends \ - build-essential \ - gcc-14-powerpc64le-linux-gnu \ - g++-14-powerpc64le-linux-gnu - - - name: Build - run: | - cmake -B build -DLLAMA_CURL=OFF \ - -DCMAKE_BUILD_TYPE=Release \ - -DGGML_OPENMP=OFF \ - -DLLAMA_BUILD_EXAMPLES=ON \ - -DLLAMA_BUILD_TOOLS=ON \ - -DLLAMA_BUILD_TESTS=OFF \ - -DCMAKE_SYSTEM_NAME=Linux \ - -DCMAKE_SYSTEM_PROCESSOR=ppc64 \ - -DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \ - -DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \ - -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ - -DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \ - -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ - -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ - -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH - - cmake --build build --config Release -j $(nproc) - - # ubuntu-24-ppc64el-vulkan-cross: - # runs-on: ubuntu-24.04 - - # steps: - # - uses: actions/checkout@v4 - # - name: Setup PowerPC64le - # run: | - # sudo dpkg --add-architecture ppc64el - - # # Add arch-specific repositories for non-amd64 architectures - # cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list - # deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe - # deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe - # deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe - # deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe - # EOF - - # sudo apt-get update || true ;# Prevent failure due to missing URLs. - - # sudo apt-get install -y --no-install-recommends \ - # build-essential \ - # glslc \ - # gcc-14-powerpc64le-linux-gnu \ - # g++-14-powerpc64le-linux-gnu \ - # libvulkan-dev:ppc64el - - # - name: Build - # run: | - # cmake -B build -DLLAMA_CURL=OFF \ - # -DCMAKE_BUILD_TYPE=Release \ - # -DGGML_VULKAN=ON \ - # -DGGML_OPENMP=OFF \ - # -DLLAMA_BUILD_EXAMPLES=ON \ - # -DLLAMA_BUILD_TOOLS=ON \ - # -DLLAMA_BUILD_TESTS=OFF \ - # -DCMAKE_SYSTEM_NAME=Linux \ - # -DCMAKE_SYSTEM_PROCESSOR=ppc64 \ - # -DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \ - # -DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \ - # -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ - # -DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \ - # -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ - # -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ - # -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH - - # cmake --build build --config Release -j $(nproc) - debian-13-loongarch64-cpu-cross: runs-on: ubuntu-24.04 container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671 @@ -344,3 +253,47 @@ jobs: -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH cmake --build build --config Release -j $(nproc) + + ubuntu-24-riscv64-cpu-spacemit-ime-cross: + runs-on: ubuntu-24.04 + + env: + SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2" + SPACEMIT_IME_TOOLCHAIN_PATH: "spacemit-toolchain-linux-glibc-x86_64" + + steps: + - uses: actions/checkout@v4 + + - name: Cache Toolchain + uses: actions/cache@v4 + id: cache-spacemit-ime-cross-toolchain + with: + path: ./${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} + key: ${{ runner.os }}-spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} + + - name: Setup Toolchain + if: steps.cache-spacemit-ime-cross-toolchain.outputs.cache-hit != 'true' + run: | + wget --quiet --no-check-certificate https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}.tar.xz -O ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz + rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} + mkdir -p ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} + tar xf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz -C ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} --strip-components=1 + rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz + + - name: Build + run: | + export RISCV_ROOT_PATH=${PWD}/${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DGGML_CPU_RISCV64_SPACEMIT=ON \ + -DGGML_RVV=ON \ + -DGGML_RV_ZFH=ON \ + -DGGML_RV_ZICBOP=ON \ + -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 \ + -DCMAKE_TOOLCHAIN_FILE=${PWD}/cmake/riscv64-spacemit-linux-gnu-gcc.cmake + + cmake --build build --config Release -j $(nproc) diff --git a/.github/workflows/build-riscv-native.yml b/.github/workflows/build-riscv-native.yml index acad316602..a3a0b0d663 100644 --- a/.github/workflows/build-riscv-native.yml +++ b/.github/workflows/build-riscv-native.yml @@ -58,3 +58,63 @@ jobs: -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH cmake --build build --config Release -j $(nproc) + + # debian-13-riscv64-spacemit-ime-native: # Bianbu 2.2 + # runs-on: [self-hosted, RISCV64] + + # steps: + # - name: Install prerequisites + # run: | + # sudo apt-get update || true + # sudo apt-get install -y libatomic1 + # - uses: actions/checkout@v4 + # - name: Setup Riscv + # run: | + # sudo apt-get update || true + # sudo apt-get install -y --no-install-recommends \ + # build-essential \ + # gcc-14-riscv64-linux-gnu \ + # g++-14-riscv64-linux-gnu \ + # ccache \ + # cmake + # sudo apt-get upgrade binutils -y + + # - name: Setup ccache + # run: | + # mkdir -p $HOME/.ccache + # ccache -M 5G -d $HOME/.ccache + # export CCACHE_LOGFILE=/home/runneruser/ccache_debug/ccache.log + # export CCACHE_DEBUGDIR="/home/runneruser/ccache_debug" + # echo "$GITHUB_WORKSPACE" + # echo "CCACHE_LOGFILE=$CCACHE_LOGFILE" >> $GITHUB_ENV + # echo "CCACHE_DEBUGDIR=$CCACHE_DEBUGDIR" >> $GITHUB_ENV + # echo "CCACHE_BASEDIR=$GITHUB_WORKSPACE" >> $GITHUB_ENV + # echo "CCACHE_DIR=$HOME/.ccache" >> $GITHUB_ENV + + # - name: Build + # run: | + # cmake -B build \ + # -DLLAMA_CURL=OFF \ + # -DCMAKE_BUILD_TYPE=Release \ + # -DGGML_OPENMP=OFF \ + # -DLLAMA_BUILD_EXAMPLES=ON \ + # -DLLAMA_BUILD_TOOLS=ON \ + # -DLLAMA_BUILD_TESTS=OFF \ + # -DCMAKE_SYSTEM_NAME=Linux \ + # -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ + # -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + # -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ + # -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + # -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + # -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + # -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ + # -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + # -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + # -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH \ + # -DGGML_RVV=ON \ + # -DGGML_RV_ZFH=ON \ + # -DGGML_RV_ZICBOP=ON \ + # -DGGML_CPU_RISCV64_SPACEMIT=ON \ + # -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 + + # cmake --build build --config Release -j $(nproc) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4f70232b17..410552813a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -192,6 +192,10 @@ jobs: os: ubuntu-22.04 - build: 'arm64' os: ubuntu-22.04-arm + - build: 's390x' + os: ubuntu-24.04-s390x + - build: 'ppc64le' + os: ubuntu-24.04-ppc64le runs-on: ${{ matrix.os }} @@ -203,14 +207,31 @@ jobs: - name: ccache uses: ggml-org/ccache-action@v1.2.16 with: - key: ubuntu-cpu-cmake + key: ubuntu-cpu-cmake-${{ matrix.build }} evict-old-files: 1d - - name: Dependencies - id: depends + - name: Build Dependencies + id: build_depends run: | sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev + sudo apt-get install -y --no-install-recommends \ + python3 python3-pip python3-dev \ + libjpeg-dev build-essential libcurl4-openssl-dev \ + git-lfs + + - name: Python Dependencies + id: python_depends + run: | + python3 -m pip install --upgrade pip + pip3 install ./gguf-py + + - name: Swap Endianness + id: endianness + if: ${{ matrix.build == 's390x' }} + run: | + for f in models/*.gguf; do + echo YES | python3 gguf-py/gguf/scripts/gguf_convert_endian.py $f big + done - name: Build id: cmake_build @@ -228,6 +249,7 @@ jobs: - name: Test llama2c conversion id: llama2c_test + if: ${{ matrix.build != 's390x' }} run: | cd build echo "Fetch tokenizer" @@ -237,6 +259,15 @@ jobs: ./bin/llama-convert-llama2c-to-ggml --copy-vocab-from-model ./tok512.bin --llama2c-model stories260K.bin --llama2c-output-model stories260K.gguf ./bin/llama-cli -m stories260K.gguf -p "One day, Lily met a Shoggoth" -n 500 -c 256 + - name: Test llama2c (s390x) + id: llama2c_test_s390x + if: ${{ matrix.build == 's390x' }} + run: | + cd build + echo "Fetch llama2c big-endian model" + wget https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories260K-be.gguf + ./bin/llama-cli -m stories260K-be.gguf -p "One day, Lily met a Shoggoth" -n 500 -c 256 + ubuntu-latest-cmake-sanitizer: runs-on: ubuntu-latest @@ -475,7 +506,7 @@ jobs: ubuntu-22-cmake-musa: runs-on: ubuntu-22.04 - container: mthreads/musa:rc4.2.0-devel-ubuntu22.04-amd64 + container: mthreads/musa:rc4.3.0-devel-ubuntu22.04-amd64 steps: - name: Clone @@ -1191,11 +1222,12 @@ jobs: - name: Clone uses: actions/checkout@v4 - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: android-build - evict-old-files: 1d + # Disabled due to size (400MB) and always 0 cache hits + # - name: ccache + # uses: ggml-org/ccache-action@v1.2.16 + # with: + # key: android-build + # evict-old-files: 1d - name: Set up JDK uses: actions/setup-java@v3 @@ -1430,34 +1462,6 @@ jobs: run: | bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp -# ggml-ci-x64-amd-vulkan: -# runs-on: [self-hosted, Linux, X64, AMD] -# -# steps: -# - name: Clone -# id: checkout -# uses: actions/checkout@v4 -# -# - name: Test -# id: ggml-ci -# run: | -# vulkaninfo --summary -# GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp -# -# ggml-ci-x64-amd-rocm: -# runs-on: [self-hosted, Linux, X64, AMD] -# -# steps: -# - name: Clone -# id: checkout -# uses: actions/checkout@v4 -# -# - name: Test -# id: ggml-ci -# run: | -# amd-smi static -# GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp - ggml-ci-mac-metal: runs-on: [self-hosted, macOS, ARM64] diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 542621b077..f73a2bc9f4 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -28,7 +28,7 @@ jobs: push_to_registry: name: Push Docker image to Docker Hub - runs-on: ubuntu-22.04 + runs-on: ${{ matrix.config.runs_on }} env: COMMIT_SHA: ${{ github.sha }} strategy: @@ -39,12 +39,12 @@ jobs: # Note: the arm64 images are failing, which prevents the amd64 images from being built # https://github.com/ggml-org/llama.cpp/issues/11888 #- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: false } - - { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false } - - { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false } - - { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true } - - { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true } - - { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false } - - { tag: "s390x", dockerfile: ".devops/s390x.Dockerfile", platforms: "linux/s390x", full: true, light: true, server: true, free_disk_space: false } + - { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" } + - { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" } + - { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" } + - { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" } + - { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04" } + - { tag: "s390x", dockerfile: ".devops/s390x.Dockerfile", platforms: "linux/s390x", full: true, light: true, server: true, free_disk_space: false, runs_on: "ubuntu-22.04-s390x" } # Note: the rocm images are failing due to a compiler error and are disabled until this is fixed to allow the workflow to complete #- {tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: true } steps: @@ -54,6 +54,7 @@ jobs: fetch-depth: 0 # preserve git history, so we can determine the build number - name: Set up QEMU + if: ${{ matrix.config.tag != 's390x' }} uses: docker/setup-qemu-action@v3 with: image: tonistiigi/binfmt:qemu-v7.0.0-28 @@ -68,22 +69,19 @@ jobs: username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - - name: Determine tag name + - name: Determine source tag name + id: srctag + uses: ./.github/actions/get-tag-name + env: + BRANCH_NAME: ${{ github.head_ref || github.ref_name }} + + - name: Determine image tag name id: tag shell: bash run: | - BUILD_NUMBER="$(git rev-list --count HEAD)" - SHORT_HASH="$(git rev-parse --short=7 HEAD)" REPO_OWNER="${GITHUB_REPOSITORY_OWNER@L}" # to lower case REPO_NAME="${{ github.event.repository.name }}" - # determine tag name postfix (build number, commit hash) - if [[ "${{ env.GITHUB_BRANCH_NAME }}" == "master" ]]; then - TAG_POSTFIX="-b${BUILD_NUMBER}" - else - SAFE_NAME=$(echo "${{ env.GITHUB_BRANCH_NAME }}" | tr '/' '-') - TAG_POSTFIX="-${SAFE_NAME}-${SHORT_HASH}" - fi # list all tags possible if [[ "${{ matrix.config.tag }}" == "cpu" ]]; then TYPE="" @@ -91,17 +89,19 @@ jobs: TYPE="-${{ matrix.config.tag }}" fi PREFIX="ghcr.io/${REPO_OWNER}/${REPO_NAME}:" - FULLTAGS="${PREFIX}full${TYPE},${PREFIX}full${TYPE}${TAG_POSTFIX}" - LIGHTTAGS="${PREFIX}light${TYPE},${PREFIX}light${TYPE}${TAG_POSTFIX}" - SERVERTAGS="${PREFIX}server${TYPE},${PREFIX}server${TYPE}${TAG_POSTFIX}" + CACHETAGS="${PREFIX}buildcache${TYPE}" + FULLTAGS="${PREFIX}full${TYPE},${PREFIX}full${TYPE}-${{ steps.srctag.outputs.name }}" + LIGHTTAGS="${PREFIX}light${TYPE},${PREFIX}light${TYPE}-${{ steps.srctag.outputs.name }}" + SERVERTAGS="${PREFIX}server${TYPE},${PREFIX}server${TYPE}-${{ steps.srctag.outputs.name }}" + echo "cache_output_tags=$CACHETAGS" >> $GITHUB_OUTPUT echo "full_output_tags=$FULLTAGS" >> $GITHUB_OUTPUT echo "light_output_tags=$LIGHTTAGS" >> $GITHUB_OUTPUT echo "server_output_tags=$SERVERTAGS" >> $GITHUB_OUTPUT + echo "cache_output_tags=$CACHETAGS" # print out for debugging echo "full_output_tags=$FULLTAGS" # print out for debugging echo "light_output_tags=$LIGHTTAGS" # print out for debugging echo "server_output_tags=$SERVERTAGS" # print out for debugging env: - GITHUB_BRANCH_NAME: ${{ github.head_ref || github.ref_name }} GITHUB_REPOSITORY_OWNER: '${{ github.repository_owner }}' - name: Free Disk Space (Ubuntu) @@ -134,11 +134,14 @@ jobs: target: full provenance: false # using github experimental cache - cache-from: type=gha - cache-to: type=gha,mode=max + #cache-from: type=gha + #cache-to: type=gha,mode=max # return to this if the experimental github cache is having issues #cache-to: type=local,dest=/tmp/.buildx-cache #cache-from: type=local,src=/tmp/.buildx-cache + # using registry cache (no storage limit) + cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }} + cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max - name: Build and push Light Docker image (tagged + versioned) if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.light == true }} @@ -153,11 +156,14 @@ jobs: target: light provenance: false # using github experimental cache - cache-from: type=gha - cache-to: type=gha,mode=max + #cache-from: type=gha + #cache-to: type=gha,mode=max # return to this if the experimental github cache is having issues #cache-to: type=local,dest=/tmp/.buildx-cache #cache-from: type=local,src=/tmp/.buildx-cache + # using registry cache (no storage limit) + cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }} + cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max - name: Build and push Server Docker image (tagged + versioned) if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.server == true }} @@ -172,8 +178,37 @@ jobs: target: server provenance: false # using github experimental cache - cache-from: type=gha - cache-to: type=gha,mode=max + #cache-from: type=gha + #cache-to: type=gha,mode=max # return to this if the experimental github cache is having issues #cache-to: type=local,dest=/tmp/.buildx-cache #cache-from: type=local,src=/tmp/.buildx-cache + # using registry cache (no storage limit) + cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }} + cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max + + create_tag: + name: Create and push git tag + runs-on: ubuntu-22.04 + permissions: + contents: write + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Determine source tag name + id: srctag + uses: ./.github/actions/get-tag-name + env: + BRANCH_NAME: ${{ github.head_ref || github.ref_name }} + + - name: Create and push git tag + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + git tag ${{ steps.srctag.outputs.name }} || exit 0 + git push origin ${{ steps.srctag.outputs.name }} || exit 0 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f461456edf..f4eae5da11 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -150,7 +150,7 @@ jobs: - name: ccache uses: ggml-org/ccache-action@v1.2.16 with: - key: ubuntu-cpu-cmake + key: ubuntu-cpu-cmake-${{ matrix.build }} evict-old-files: 1d - name: Dependencies diff --git a/.gitignore b/.gitignore index e1ecdb7a81..c7d0009785 100644 --- a/.gitignore +++ b/.gitignore @@ -149,6 +149,6 @@ poetry.toml /run-chat.sh .ccache/ -# Code Workspace +# IDE *.code-workspace - +.windsurf/ diff --git a/.windsurf/rules/css-architecture.md b/.windsurf/rules/css-architecture.md deleted file mode 100644 index 10a183585c..0000000000 --- a/.windsurf/rules/css-architecture.md +++ /dev/null @@ -1,7 +0,0 @@ ---- -trigger: manual ---- - -#### Tailwind & CSS - -- We are using Tailwind v4 which uses oklch colors so we now want to refer to the CSS vars directly, without wrapping it with any color function like `hsla/hsl`, `rgba` etc. diff --git a/.windsurf/rules/sveltekit-architecture.md b/.windsurf/rules/sveltekit-architecture.md deleted file mode 100644 index c2c0aa07ce..0000000000 --- a/.windsurf/rules/sveltekit-architecture.md +++ /dev/null @@ -1,48 +0,0 @@ ---- -trigger: manual ---- - -# Coding rules - -## Svelte & SvelteKit - -### Services vs Stores Separation Pattern - -#### `lib/services/` - Pure Business Logic - -- **Purpose**: Stateless business logic and external communication -- **Contains**: - - API calls to external services (ApiService) - - Pure business logic functions (ChatService, etc.) -- **Rules**: - - NO Svelte runes ($state, $derived, $effect) - - NO reactive state management - - Pure functions and classes only - - Can import types but not stores - - Focus on "how" - implementation details - -#### `lib/stores/` - Reactive State Management - -- **Purpose**: Svelte-specific reactive state with runes -- **Contains**: - - Reactive state classes with $state, $derived, $effect - - Database operations (DatabaseStore) - - UI-focused state management - - Store orchestration logic -- **Rules**: - - USE Svelte runes for reactivity - - Import and use services for business logic - - NO direct database operations - - NO direct API calls (use services) - - Focus on "what" - reactive state for UI - -#### Enforcement - -- Services should be testable without Svelte -- Stores should leverage Svelte's reactivity system -- Clear separation: services handle data, stores handle state -- Services can be reused across multiple stores - -#### Misc - -- Always use `let` for $derived state variables diff --git a/.windsurf/rules/tests.md b/.windsurf/rules/tests.md deleted file mode 100644 index e388fa9544..0000000000 --- a/.windsurf/rules/tests.md +++ /dev/null @@ -1,9 +0,0 @@ ---- -trigger: manual ---- - -# Automated Tests - -## General rules - -- NEVER include any test code in the production code - we should always have it in a separate dedicated files diff --git a/.windsurf/rules/typescript-architecture.md b/.windsurf/rules/typescript-architecture.md deleted file mode 100644 index a61ff6b98f..0000000000 --- a/.windsurf/rules/typescript-architecture.md +++ /dev/null @@ -1,7 +0,0 @@ ---- -trigger: manual ---- - -## TypeScript - -- Add JSDocs for functions diff --git a/CMakeLists.txt b/CMakeLists.txt index 4720e1f1a2..4bf8b2789a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,6 +92,7 @@ option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_ # 3rd party libs option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON) +option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" OFF) option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF) # Required for relocatable CMake package diff --git a/CODEOWNERS b/CODEOWNERS index afdc11ed04..89b84ce850 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -14,6 +14,7 @@ /common/build-info.* @ggerganov /common/common.* @ggerganov /common/console.* @ggerganov +/common/http.* @angt /common/llguidance.* @ggerganov /common/log.* @ggerganov /common/sampling.* @ggerganov @@ -50,6 +51,7 @@ /ggml/src/ggml-blas/ @slaren /ggml/src/ggml-common.h @ggerganov @slaren /ggml/src/ggml-cpu/ @ggerganov @slaren +/ggml/src/ggml-cpu/spacemit/ @alex-spacemit /ggml/src/ggml-cuda/common.cuh @slaren /ggml/src/ggml-cuda/fattn* @JohannesGaessler /ggml/src/ggml-cuda/ggml-cuda.cu @slaren @@ -59,8 +61,10 @@ /ggml/src/ggml-cuda/mmvq.* @JohannesGaessler /ggml/src/ggml-impl.h @ggerganov @slaren /ggml/src/ggml-metal/ @ggerganov +/ggml/src/ggml-opencl/ @lhez @max-krasnyansky /ggml/src/ggml-opt.cpp @JohannesGaessler /ggml/src/ggml-quants.* @ggerganov +/ggml/src/ggml-rpc/ @rgerganov /ggml/src/ggml-threading.* @ggerganov @slaren /ggml/src/ggml-vulkan/ @0cc4m /ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM @@ -89,6 +93,7 @@ /tools/mtmd/ @ngxson /tools/perplexity/ @ggerganov /tools/quantize/ @ggerganov +/tools/rpc/ @rgerganov /tools/run/ @ericcurtin /tools/server/* @ngxson @ggerganov @ericcurtin # no subdir /tools/server/webui/ @allozaur @@ -103,4 +108,5 @@ /LICENSE @ggerganov /README.md @ggerganov /SECURITY.md @ggerganov +/build-xcframework.sh @danbev requirements*.txt @CISC diff --git a/build-xcframework.sh b/build-xcframework.sh index f813984db9..796f8016ca 100755 --- a/build-xcframework.sh +++ b/build-xcframework.sh @@ -422,6 +422,7 @@ echo "Building for iOS devices..." cmake -B build-ios-device -G Xcode \ "${COMMON_CMAKE_ARGS[@]}" \ -DCMAKE_OSX_DEPLOYMENT_TARGET=${IOS_MIN_OS_VERSION} \ + -DCMAKE_SYSTEM_NAME=iOS \ -DCMAKE_OSX_SYSROOT=iphoneos \ -DCMAKE_OSX_ARCHITECTURES="arm64" \ -DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=iphoneos \ diff --git a/ci/README-MUSA.md b/ci/README-MUSA.md index bb6ca2a3dc..c5e24c5d9e 100644 --- a/ci/README-MUSA.md +++ b/ci/README-MUSA.md @@ -21,7 +21,7 @@ docker run --privileged -it \ -v $HOME/llama.cpp/ci-cache:/ci-cache \ -v $HOME/llama.cpp/ci-results:/ci-results \ -v $PWD:/ws -w /ws \ - mthreads/musa:rc4.2.0-devel-ubuntu22.04-amd64 + mthreads/musa:rc4.3.0-devel-ubuntu22.04-amd64 ``` Inside the container, execute the following commands: diff --git a/ci/run.sh b/ci/run.sh index 68cbfdf2f5..b0af51723b 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -114,6 +114,7 @@ if [ ! -z ${GG_BUILD_NO_SVE} ]; then # arm 9 and newer enables sve by default, adjust these flags depending on the cpu used CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm" fi + ## helpers # download a file if it does not exist or if it is outdated diff --git a/cmake/riscv64-spacemit-linux-gnu-gcc.cmake b/cmake/riscv64-spacemit-linux-gnu-gcc.cmake new file mode 100644 index 0000000000..08fdbf5063 --- /dev/null +++ b/cmake/riscv64-spacemit-linux-gnu-gcc.cmake @@ -0,0 +1,29 @@ +set(CMAKE_SYSTEM_NAME Linux) +set(CMAKE_SYSTEM_PROCESSOR riscv64) +set(CMAKE_SYSTEM_VERSION 1) + +if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(riscv)") + message(STATUS "HOST SYSTEM ${CMAKE_HOST_SYSTEM_PROCESSOR}") +else() + set(GNU_MACHINE riscv64-unknown-linux-gnu CACHE STRING "GNU compiler triple") + if (DEFINED ENV{RISCV_ROOT_PATH}) + file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH) + else() + message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined") + endif() + + set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain") + set(CMAKE_C_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc) + set(CMAKE_CXX_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++) + set(CMAKE_STRIP ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-strip) + set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu") + set(CMAKE_SYSROOT "${RISCV_ROOT_PATH}/sysroot") +endif() + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) +set(CMAKE_C_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CMAKE_C_FLAGS}") +set(CMAKE_CXX_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CXX_FLAGS}") +set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -latomic") diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 0ae4d698f0..fe290bf8fd 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -56,6 +56,7 @@ add_library(${TARGET} STATIC common.h console.cpp console.h + http.h json-partial.cpp json-partial.h json-schema-to-grammar.cpp @@ -87,7 +88,43 @@ if (LLAMA_CURL) target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL) include_directories(${CURL_INCLUDE_DIRS}) set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARIES}) -endif () +endif() + +if (LLAMA_OPENSSL) + find_package(OpenSSL) + if (OpenSSL_FOUND) + include(CheckCSourceCompiles) + set(SAVED_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES}) + set(CMAKE_REQUIRED_INCLUDES ${OPENSSL_INCLUDE_DIR}) + check_c_source_compiles(" + #include + #if defined(OPENSSL_IS_BORINGSSL) || defined(LIBRESSL_VERSION_NUMBER) + # if OPENSSL_VERSION_NUMBER < 0x1010107f + # error bad version + # endif + #else + # if OPENSSL_VERSION_NUMBER < 0x30000000L + # error bad version + # endif + #endif + int main() { return 0; } + " OPENSSL_VERSION_SUPPORTED) + set(CMAKE_REQUIRED_INCLUDES ${SAVED_CMAKE_REQUIRED_INCLUDES}) + if (OPENSSL_VERSION_SUPPORTED) + message(STATUS "OpenSSL found: ${OPENSSL_VERSION}") + target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_OPENSSL_SUPPORT) + target_link_libraries(${TARGET} PUBLIC OpenSSL::SSL OpenSSL::Crypto) + if (APPLE AND CMAKE_SYSTEM_NAME STREQUAL "Darwin") + target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) + find_library(CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED) + find_library(SECURITY_FRAMEWORK Security REQUIRED) + target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK}) + endif() + endif() + else() + message(STATUS "OpenSSL not found, SSL support disabled") + endif() +endif() if (LLAMA_LLGUIDANCE) include(ExternalProject) diff --git a/common/arg.cpp b/common/arg.cpp index b37e21631b..690e51c12f 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -32,11 +32,11 @@ #include #include -//#define LLAMA_USE_CURL - #if defined(LLAMA_USE_CURL) #include #include +#else +#include "http.h" #endif #ifdef __linux__ @@ -52,6 +52,13 @@ #endif #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 +// isatty +#if defined(_WIN32) +#include +#else +#include +#endif + using json = nlohmann::ordered_json; std::initializer_list mmproj_examples = { @@ -98,6 +105,14 @@ static void write_file(const std::string & fname, const std::string & content) { } } +static bool is_output_a_tty() { +#if defined(_WIN32) + return _isatty(_fileno(stdout)); +#else + return isatty(1); +#endif +} + common_arg & common_arg::set_examples(std::initializer_list examples) { this->examples = std::move(examples); return *this; @@ -215,12 +230,55 @@ struct common_hf_file_res { std::string mmprojFile; }; -#ifdef LLAMA_USE_CURL - -bool common_has_curl() { - return true; +static void write_etag(const std::string & path, const std::string & etag) { + const std::string etag_path = path + ".etag"; + write_file(etag_path, etag); + LOG_DBG("%s: file etag saved: %s\n", __func__, etag_path.c_str()); } +static std::string read_etag(const std::string & path) { + std::string none; + const std::string etag_path = path + ".etag"; + + if (std::filesystem::exists(etag_path)) { + std::ifstream etag_in(etag_path); + if (!etag_in) { + LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str()); + return none; + } + std::string etag; + std::getline(etag_in, etag); + return etag; + } + + // no etag file, but maybe there is an old .json + // remove this code later + const std::string metadata_path = path + ".json"; + + if (std::filesystem::exists(metadata_path)) { + std::ifstream metadata_in(metadata_path); + try { + nlohmann::json metadata_json; + metadata_in >> metadata_json; + LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), + metadata_json.dump().c_str()); + if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) { + std::string etag = metadata_json.at("etag"); + write_etag(path, etag); + if (!std::filesystem::remove(metadata_path)) { + LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str()); + } + return etag; + } + } catch (const nlohmann::json::exception & e) { + LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); + } + } + return none; +} + +#ifdef LLAMA_USE_CURL + // // CURL utils // @@ -371,36 +429,15 @@ static bool common_download_head(CURL * curl, static bool common_download_file_single_online(const std::string & url, const std::string & path, const std::string & bearer_token) { - // If the file exists, check its JSON metadata companion file. - std::string metadata_path = path + ".json"; static const int max_attempts = 3; static const int retry_delay_seconds = 2; for (int i = 0; i < max_attempts; ++i) { - nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead - std::string etag; - std::string last_modified; + std::string etag; // Check if the file already exists locally const auto file_exists = std::filesystem::exists(path); if (file_exists) { - // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). - std::ifstream metadata_in(metadata_path); - if (metadata_in.good()) { - try { - metadata_in >> metadata; - LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), - metadata.dump().c_str()); - if (metadata.contains("etag") && metadata.at("etag").is_string()) { - etag = metadata.at("etag"); - } - if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) { - last_modified = metadata.at("lastModified"); - } - } catch (const nlohmann::json::exception & e) { - LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); - } - } - // if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again) + etag = read_etag(path); } else { LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); } @@ -438,11 +475,6 @@ static bool common_download_file_single_online(const std::string & url, headers.etag.c_str()); should_download = true; should_download_from_scratch = true; - } else if (!last_modified.empty() && last_modified != headers.last_modified) { - LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, - last_modified.c_str(), headers.last_modified.c_str()); - should_download = true; - should_download_from_scratch = true; } } @@ -473,15 +505,9 @@ static bool common_download_file_single_online(const std::string & url, } } } - - // Write the updated JSON metadata file. - metadata.update({ - { "url", url }, - { "etag", headers.etag }, - { "lastModified", headers.last_modified } - }); - write_file(metadata_path, metadata.dump(4)); - LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); + if (head_request_ok) { + write_etag(path, headers.etag); + } // start the download LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", @@ -568,21 +594,238 @@ std::pair> common_remote_get_content(const std::string & #else -bool common_has_curl() { - return false; -} - -static bool common_download_file_single_online(const std::string &, const std::string &, const std::string &) { - LOG_ERR("error: built without CURL, cannot download model from internet\n"); - return false; -} - -std::pair> common_remote_get_content(const std::string & url, const common_remote_params &) { - if (!url.empty()) { - throw std::runtime_error("error: built without CURL, cannot download model from the internet"); +static void print_progress(size_t current, size_t total) { + if (!is_output_a_tty()) { + return; } - return {}; + if (!total) { + return; + } + + size_t width = 50; + size_t pct = (100 * current) / total; + size_t pos = (width * current) / total; + + std::cout << "[" + << std::string(pos, '=') + << (pos < width ? ">" : "") + << std::string(width - pos, ' ') + << "] " << std::setw(3) << pct << "% (" + << current / (1024 * 1024) << " MB / " + << total / (1024 * 1024) << " MB)\r"; + std::cout.flush(); +} + +static bool common_pull_file(httplib::Client & cli, + const std::string & resolve_path, + const std::string & path_tmp, + bool supports_ranges, + size_t existing_size, + size_t & total_size) { + std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app); + if (!ofs.is_open()) { + LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str()); + return false; + } + + httplib::Headers headers; + if (supports_ranges && existing_size > 0) { + headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-"); + } + + std::atomic downloaded{existing_size}; + + auto res = cli.Get(resolve_path, headers, + [&](const httplib::Response &response) { + if (existing_size > 0 && response.status != 206) { + LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", __func__, response.status); + return false; + } + if (existing_size == 0 && response.status != 200) { + LOG_WRN("%s: download received non-successful status code: %d\n", __func__, response.status); + return false; + } + if (total_size == 0 && response.has_header("Content-Length")) { + try { + size_t content_length = std::stoull(response.get_header_value("Content-Length")); + total_size = existing_size + content_length; + } catch (const std::exception &e) { + LOG_WRN("%s: invalid Content-Length header: %s\n", __func__, e.what()); + } + } + return true; + }, + [&](const char *data, size_t len) { + ofs.write(data, len); + if (!ofs) { + LOG_ERR("%s: error writing to file: %s\n", __func__, path_tmp.c_str()); + return false; + } + downloaded += len; + print_progress(downloaded, total_size); + return true; + }, + nullptr + ); + + std::cout << "\n"; + + if (!res) { + LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1); + return false; + } + + return true; +} + +// download one single file from remote URL to local path +static bool common_download_file_single_online(const std::string & url, + const std::string & path, + const std::string & bearer_token) { + static const int max_attempts = 3; + static const int retry_delay_seconds = 2; + + auto [cli, parts] = common_http_client(url); + + httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}}; + if (!bearer_token.empty()) { + default_headers.insert({"Authorization", "Bearer " + bearer_token}); + } + cli.set_default_headers(default_headers); + + const bool file_exists = std::filesystem::exists(path); + + std::string last_etag; + if (file_exists) { + last_etag = read_etag(path); + } else { + LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); + } + + for (int i = 0; i < max_attempts; ++i) { + auto head = cli.Head(parts.path); + bool head_ok = head && head->status >= 200 && head->status < 300; + if (!head_ok) { + LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1); + if (file_exists) { + LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str()); + return true; + } + } + + std::string etag; + if (head_ok && head->has_header("ETag")) { + etag = head->get_header_value("ETag"); + } + + size_t total_size = 0; + if (head_ok && head->has_header("Content-Length")) { + try { + total_size = std::stoull(head->get_header_value("Content-Length")); + } catch (const std::exception& e) { + LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what()); + } + } + + bool supports_ranges = false; + if (head_ok && head->has_header("Accept-Ranges")) { + supports_ranges = head->get_header_value("Accept-Ranges") != "none"; + } + + bool should_download_from_scratch = false; + if (!last_etag.empty() && !etag.empty() && last_etag != etag) { + LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, + last_etag.c_str(), etag.c_str()); + should_download_from_scratch = true; + } + + if (file_exists) { + if (!should_download_from_scratch) { + LOG_INF("%s: using cached file: %s\n", __func__, path.c_str()); + return true; + } + LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); + if (remove(path.c_str()) != 0) { + LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); + return false; + } + } + + const std::string path_temporary = path + ".downloadInProgress"; + size_t existing_size = 0; + + if (std::filesystem::exists(path_temporary)) { + if (supports_ranges && !should_download_from_scratch) { + existing_size = std::filesystem::file_size(path_temporary); + } else if (remove(path_temporary.c_str()) != 0) { + LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str()); + return false; + } + } + + // start the download + LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n", + __func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str()); + const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size); + if (!was_pull_successful) { + if (i + 1 < max_attempts) { + const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000; + LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay); + std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); + } else { + LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts); + } + continue; + } + + if (std::rename(path_temporary.c_str(), path.c_str()) != 0) { + LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); + return false; + } + if (!etag.empty()) { + write_etag(path, etag); + } + break; + } + + return true; +} + +std::pair> common_remote_get_content(const std::string & url, + const common_remote_params & params) { + auto [cli, parts] = common_http_client(url); + + httplib::Headers headers = {{"User-Agent", "llama-cpp"}}; + for (const auto & header : params.headers) { + size_t pos = header.find(':'); + if (pos != std::string::npos) { + headers.emplace(header.substr(0, pos), header.substr(pos + 1)); + } else { + headers.emplace(header, ""); + } + } + + if (params.timeout > 0) { + cli.set_read_timeout(params.timeout, 0); + cli.set_write_timeout(params.timeout, 0); + } + + std::vector buf; + auto res = cli.Get(parts.path, headers, + [&](const char *data, size_t len) { + buf.insert(buf.end(), data, data + len); + return params.max_size == 0 || + buf.size() <= static_cast(params.max_size); + }, + nullptr + ); + + if (!res) { + throw std::runtime_error("error: cannot make GET request"); + } + + return { res->status, std::move(buf) }; } #endif // LLAMA_USE_CURL diff --git a/common/arg.h b/common/arg.h index 70bea100fd..77997c4ef3 100644 --- a/common/arg.h +++ b/common/arg.h @@ -78,7 +78,6 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e // function to be used by test-arg-parser common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); -bool common_has_curl(); struct common_remote_params { std::vector headers; diff --git a/common/chat.cpp b/common/chat.cpp index ce53f89f78..e2bacdcf52 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1616,17 +1616,36 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp ); }); - auto recipient_in_role = builder.add_rule("recipient_in_role", - "\"<|start|>assistant\"? \" to=functions.\" ( " + - string_join(tool_rules_recipient_in_role, " | ") + " )" - ); - auto recipient_in_channel = builder.add_rule("recipient_in_channel", channel + " \" to=functions.\" ( " + string_join(tool_rules_recipient_in_channel, " | ") + " )" ); - builder.add_rule("root", recipient_in_role + " | " + recipient_in_channel); + if (data.grammar_lazy) { + auto recipient_in_role = builder.add_rule("recipient_in_role", + "\"<|start|>assistant\"? \" to=functions.\" ( " + + string_join(tool_rules_recipient_in_role, " | ") + " )" + ); + + builder.add_rule("root", recipient_in_role + " | " + recipient_in_channel); + } else { + auto not_end = builder.add_rule("not-end", + "[^<] | \"<\" [^|] | \"<|\" [^e] | \"<|e\" [^n] | \"<|en\" [^d] | \"<|end\" [^|] | \"<|end|\" [^>]"); + auto analysis = builder.add_rule("analysis", + "\"<|channel|>analysis<|message|>\" ( " + not_end + " )* \"<|end|>\""); + auto commentary = builder.add_rule("commentary", + "\"<|channel|>commentary<|message|>\" ( " + not_end + " )* \"<|end|>\""); + + auto recipient_in_role = builder.add_rule("recipient_in_role", + "\" to=functions.\" ( " + string_join(tool_rules_recipient_in_role, " | ") + " )" + ); + + builder.add_rule("root", + "( " + analysis + " \"<|start|>assistant\" )? " + + "( " + commentary + " \"<|start|>assistant\" )? " + + "( " + recipient_in_role + " | " + recipient_in_channel + " )" + ); + } // Trigger on tool calls that appear in the commentary channel data.grammar_triggers.push_back({ diff --git a/common/common.cpp b/common/common.cpp index 6ebd934154..c1e736c44c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -51,6 +51,11 @@ #include #endif +#if defined(__linux__) +#include +#include +#endif + #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif @@ -865,8 +870,20 @@ std::string fs_get_cache_directory() { #if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__) if (std::getenv("XDG_CACHE_HOME")) { cache_directory = std::getenv("XDG_CACHE_HOME"); - } else { + } else if (std::getenv("HOME")) { cache_directory = std::getenv("HOME") + std::string("/.cache/"); + } else { +#if defined(__linux__) + /* no $HOME is defined, fallback to getpwuid */ + struct passwd *pw = getpwuid(getuid()); + if ((!pw) || (!pw->pw_dir)) { + throw std::runtime_error("Failed to find $HOME directory"); + } + + cache_directory = std::string(pw->pw_dir) + std::string("/.cache/"); +#else /* defined(__linux__) */ + throw std::runtime_error("Failed to find $HOME directory"); +#endif /* defined(__linux__) */ } #elif defined(__APPLE__) cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); diff --git a/common/http.h b/common/http.h new file mode 100644 index 0000000000..8e29787dcc --- /dev/null +++ b/common/http.h @@ -0,0 +1,73 @@ +#pragma once + +#include + +struct common_http_url { + std::string scheme; + std::string user; + std::string password; + std::string host; + std::string path; +}; + +static common_http_url common_http_parse_url(const std::string & url) { + common_http_url parts; + auto scheme_end = url.find("://"); + + if (scheme_end == std::string::npos) { + throw std::runtime_error("invalid URL: no scheme"); + } + parts.scheme = url.substr(0, scheme_end); + + if (parts.scheme != "http" && parts.scheme != "https") { + throw std::runtime_error("unsupported URL scheme: " + parts.scheme); + } + + auto rest = url.substr(scheme_end + 3); + auto at_pos = rest.find('@'); + + if (at_pos != std::string::npos) { + auto auth = rest.substr(0, at_pos); + auto colon_pos = auth.find(':'); + if (colon_pos != std::string::npos) { + parts.user = auth.substr(0, colon_pos); + parts.password = auth.substr(colon_pos + 1); + } else { + parts.user = auth; + } + rest = rest.substr(at_pos + 1); + } + + auto slash_pos = rest.find('/'); + + if (slash_pos != std::string::npos) { + parts.host = rest.substr(0, slash_pos); + parts.path = rest.substr(slash_pos); + } else { + parts.host = rest; + parts.path = "/"; + } + return parts; +} + +static std::pair common_http_client(const std::string & url) { + common_http_url parts = common_http_parse_url(url); + + if (parts.host.empty()) { + throw std::runtime_error("error: invalid URL format"); + } + + httplib::Client cli(parts.scheme + "://" + parts.host); + + if (!parts.user.empty()) { + cli.set_basic_auth(parts.user, parts.password); + } + + cli.set_follow_location(true); + + return { std::move(cli), std::move(parts) }; +} + +static std::string common_http_show_masked_url(const common_http_url & parts) { + return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path; +} diff --git a/docs/build-riscv64-spacemit.md b/docs/build-riscv64-spacemit.md new file mode 100644 index 0000000000..eaa6532546 --- /dev/null +++ b/docs/build-riscv64-spacemit.md @@ -0,0 +1,89 @@ +> [!IMPORTANT] +> This build documentation is specific only to RISC-V SpacemiT SOCs. + +## Build llama.cpp locally (for riscv64) + +1. Prepare Toolchain For RISCV +~~~ +wget https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v1.1.2.tar.xz +~~~ + +2. Build +Below is the build script: it requires utilizing RISC-V vector instructions for acceleration. Ensure the `GGML_CPU_RISCV64_SPACEMIT` compilation option is enabled. The currently supported optimization version is `RISCV64_SPACEMIT_IME1`, corresponding to the `RISCV64_SPACEMIT_IME_SPEC` compilation option. Compiler configurations are defined in the `riscv64-spacemit-linux-gnu-gcc.cmake` file. Please ensure you have installed the RISC-V compiler and set the environment variable via `export RISCV_ROOT_PATH={your_compiler_path}`. +```bash + +cmake -B build \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_CPU_RISCV64_SPACEMIT=ON \ + -DLLAMA_CURL=OFF \ + -DGGML_RVV=ON \ + -DGGML_RV_ZFH=ON \ + -DGGML_RV_ZICBOP=ON \ + -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 \ + -DCMAKE_TOOLCHAIN_FILE=${PWD}/cmake/riscv64-spacemit-linux-gnu-gcc.cmake \ + -DCMAKE_INSTALL_PREFIX=build/installed + +cmake --build build --parallel $(nproc) --config Release + +pushd build +make install +popd +``` + +## Simulation +You can use QEMU to perform emulation on non-RISC-V architectures. + +1. Download QEMU +~~~ +wget https://archive.spacemit.com/spacemit-ai/qemu/jdsk-qemu-v0.0.14.tar.gz +~~~ + +2. Run Simulation +After build your llama.cpp, you can run the executable file via QEMU for simulation, for example: +~~~ +export QEMU_ROOT_PATH={your QEMU file path} +export RISCV_ROOT_PATH_IME1={your RISC-V compiler path} + +${QEMU_ROOT_PATH}/bin/qemu-riscv64 -L ${RISCV_ROOT_PATH_IME1}/sysroot -cpu max,vlen=256,elen=64,vext_spec=v1.0 ${PWD}/build/bin/llama-cli -m ${PWD}/models/Qwen2.5-0.5B-Instruct-Q4_0.gguf -t 1 +~~~ +## Performance +#### Quantization Support For Matrix +~~~ +model name : Spacemit(R) X60 +isa : rv64imafdcv_zicbom_zicboz_zicntr_zicond_zicsr_zifencei_zihintpause_zihpm_zfh_zfhmin_zca_zcd_zba_zbb_zbc_zbs_zkt_zve32f_zve32x_zve64d_zve64f_zve64x_zvfh_zvfhmin_zvkt_sscofpmf_sstc_svinval_svnapot_svpbmt +mmu : sv39 +uarch : spacemit,x60 +mvendorid : 0x710 +marchid : 0x8000000058000001 +~~~ + +Q4_0 +| Model | Size | Params | backend | threads | test | t/s | +| -----------| -------- | ------ | ------- | ------- | ---- |------| +Qwen2.5 0.5B |403.20 MiB|630.17 M| cpu | 4 | pp512|64.12 ± 0.26| +Qwen2.5 0.5B |403.20 MiB|630.17 M| cpu | 4 | tg128|10.03 ± 0.01| +Qwen2.5 1.5B |1011.16 MiB| 1.78 B | cpu | 4 | pp512|24.16 ± 0.02| +Qwen2.5 1.5B |1011.16 MiB| 1.78 B | cpu | 4 | tg128|3.83 ± 0.06| +Qwen2.5 3B | 1.86 GiB | 3.40 B | cpu | 4 | pp512|12.08 ± 0.02| +Qwen2.5 3B | 1.86 GiB | 3.40 B | cpu | 4 | tg128|2.23 ± 0.02| + +Q4_1 +| Model | Size | Params | backend | threads | test | t/s | +| -----------| -------- | ------ | ------- | ------- | ---- |------| +Qwen2.5 0.5B |351.50 MiB|494.03 M| cpu | 4 | pp512|62.07 ± 0.12| +Qwen2.5 0.5B |351.50 MiB|494.03 M| cpu | 4 | tg128|9.91 ± 0.01| +Qwen2.5 1.5B |964.06 MiB| 1.54 B | cpu | 4 | pp512|22.95 ± 0.25| +Qwen2.5 1.5B |964.06 MiB| 1.54 B | cpu | 4 | tg128|4.01 ± 0.15| +Qwen2.5 3B | 1.85 GiB | 3.09 B | cpu | 4 | pp512|11.55 ± 0.16| +Qwen2.5 3B | 1.85 GiB | 3.09 B | cpu | 4 | tg128|2.25 ± 0.04| + + +Q4_K +| Model | Size | Params | backend | threads | test | t/s | +| -----------| -------- | ------ | ------- | ------- | ---- |------| +Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | pp512|9.29 ± 0.05| +Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | tg128|5.67 ± 0.04| +Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | pp512|10.38 ± 0.10| +Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | tg128|3.17 ± 0.08| +Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | pp512|4.23 ± 0.04| +Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | tg128|1.73 ± 0.00| diff --git a/docs/docker.md b/docs/docker.md index 543a51f75c..bfabf2425a 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -110,7 +110,7 @@ You may want to pass in some different `ARGS`, depending on the MUSA environment The defaults are: -- `MUSA_VERSION` set to `rc4.2.0` +- `MUSA_VERSION` set to `rc4.3.0` The resulting images, are essentially the same as the non-MUSA images: diff --git a/examples/eval-callback/CMakeLists.txt b/examples/eval-callback/CMakeLists.txt index 95915ed91c..c514e4317e 100644 --- a/examples/eval-callback/CMakeLists.txt +++ b/examples/eval-callback/CMakeLists.txt @@ -5,6 +5,11 @@ target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) set(TEST_TARGET test-eval-callback) -add_test(NAME ${TEST_TARGET} - COMMAND llama-eval-callback --hf-repo ggml-org/models --hf-file tinyllamas/stories260K.gguf --model stories260K.gguf --prompt hello --seed 42 -ngl 0) +if(NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") + add_test(NAME ${TEST_TARGET} + COMMAND llama-eval-callback --hf-repo ggml-org/models --hf-file tinyllamas/stories260K.gguf --model stories260K.gguf --prompt hello --seed 42 -ngl 0) +else() + add_test(NAME ${TEST_TARGET} + COMMAND llama-eval-callback --hf-repo ggml-org/models --hf-file tinyllamas/stories260K-be.gguf --model stories260K-be.gguf --prompt hello --seed 42 -ngl 0) +endif() set_property(TEST ${TEST_TARGET} PROPERTY LABELS eval-callback curl) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index b56096b45e..56420587a9 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,8 +4,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 0) -set(GGML_VERSION_DEV "-dev") # "-dev" for development, "" for releases +set(GGML_VERSION_PATCH 4) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) @@ -26,8 +25,8 @@ if(GIT_EXE) ) endif() -# Build the version string with optional -dev suffix and dirty flag -set(GGML_VERSION "${GGML_VERSION_BASE}${GGML_VERSION_DEV}") +# Build the version string with optional dirty flag +set(GGML_VERSION "${GGML_VERSION_BASE}") if(GGML_GIT_DIRTY AND NOT GGML_GIT_DIRTY EQUAL 0) set(GGML_VERSION "${GGML_VERSION}-dirty") endif() @@ -177,7 +176,7 @@ set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC") if (MINGW) - set(GGML_WIN_VER "0x602" CACHE STRING "ggml: Windows version") + set(GGML_WIN_VER "0xA00" CACHE STRING "ggml: Windows version") endif() # ggml core diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 36b23dc6d0..5028a9cebf 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -237,6 +237,8 @@ #define GGML_EXIT_SUCCESS 0 #define GGML_EXIT_ABORTED 1 +// TODO: convert to enum https://github.com/ggml-org/llama.cpp/pull/16187#discussion_r2388538726 +#define GGML_ROPE_TYPE_NORMAL 0 #define GGML_ROPE_TYPE_NEOX 2 #define GGML_ROPE_TYPE_MROPE 8 #define GGML_ROPE_TYPE_VISION 24 diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 7002cb07e0..136afec748 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -135,6 +135,10 @@ static void * dl_get_sym(dl_handle * handle, const char * name) { return p; } +static const char * dl_error() { + return ""; +} + #else using dl_handle = void; @@ -155,6 +159,11 @@ static void * dl_get_sym(dl_handle * handle, const char * name) { return dlsym(handle, name); } +static const char * dl_error() { + const char *rslt = dlerror(); + return rslt != nullptr ? rslt : ""; +} + #endif using dl_handle_ptr = std::unique_ptr; @@ -240,7 +249,7 @@ struct ggml_backend_registry { dl_handle_ptr handle { dl_load_library(path) }; if (!handle) { if (!silent) { - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(path).c_str()); + GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_str(path).c_str(), dl_error()); } return nullptr; } @@ -530,7 +539,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, if (filename.native().find(file_prefix) == 0 && ext == file_extension) { dl_handle_ptr handle { dl_load_library(entry) }; if (!handle && !silent) { - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(entry.path()).c_str()); + GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_str(entry.path()).c_str(), dl_error()); } if (handle) { auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); diff --git a/ggml/src/ggml-blas/CMakeLists.txt b/ggml/src/ggml-blas/CMakeLists.txt index 76064c3fd1..60ce4b1e02 100644 --- a/ggml/src/ggml-blas/CMakeLists.txt +++ b/ggml/src/ggml-blas/CMakeLists.txt @@ -74,7 +74,7 @@ if (BLAS_FOUND) target_compile_options(ggml-blas PRIVATE ${BLAS_LINKER_FLAGS}) - if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel")) + if ("${BLAS_INCLUDE_DIRS}" MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel")) add_compile_definitions(GGML_BLAS_USE_MKL) endif() diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 3699057507..42041b717a 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -439,6 +439,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ggml-cpu/arch/riscv/quants.c ggml-cpu/arch/riscv/repack.cpp ) + if (GGML_CPU_RISCV64_SPACEMIT) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT ${RISCV64_SPACEMIT_IME_SPEC}) + list(APPEND GGML_CPU_SOURCES + ggml-cpu/spacemit/ime.cpp + ggml-cpu/spacemit/ime.h + ggml-cpu/spacemit/ime1_kernels.cpp + ggml-cpu/spacemit/ime_kernels.h + ) + endif() set(MARCH_STR "rv64gc") if (GGML_RV_ZFH) string(APPEND MARCH_STR "_zfh") @@ -504,9 +513,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.13.0") + set(KLEIDIAI_COMMIT_TAG "v1.14.0") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "d82a8de939d9814621a5ba23907bdac1") + set(KLEIDIAI_ARCHIVE_MD5 "45e110675d93f99f82c23a1afcca76bc") if (POLICY CMP0135) cmake_policy(SET CMP0135 NEW) @@ -583,6 +592,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c ${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 373408a9c0..edfd791390 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -160,7 +160,6 @@ #define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K #define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K -#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c index dc1bba3a3e..19d225a483 100644 --- a/ggml/src/ggml-cpu/arch/s390/quants.c +++ b/ggml/src/ggml-cpu/arch/s390/quants.c @@ -75,7 +75,8 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i for (int j = 0; j < 8; j++) { const float32x4_t v = vec_mul(srcv[j], vec_splats(id)); - const int32x4_t vi = vec_signed(v); + /* Uses non-default rounding for vec_signed or vec_round */ + const int32x4_t vi = vec_signed(__builtin_s390_vfisb(v, 4, 1)); y[i].qs[4*j + 0] = vec_extract(vi, 0); y[i].qs[4*j + 1] = vec_extract(vi, 1); @@ -122,7 +123,8 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i for (int j = 0; j < 8; j++) { const float32x4_t v = vec_mul(srcv[j], vec_splats(id)); - const int32x4_t vi = vec_signed(v); + /* Uses non-default rounding for vec_signed or vec_round */ + const int32x4_t vi = vec_signed(__builtin_s390_vfisb(v, 4, 1)); y[i].qs[4*j + 0] = vec_extract(vi, 0); y[i].qs[4*j + 1] = vec_extract(vi, 1); @@ -260,6 +262,101 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } +void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_MXFP4 == 0); + static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same"); + + const int qk = QK_MXFP4; + const int nb = n / qk; + + const block_mxfp4 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0.0f; + +#if defined(__VXE__) || defined(__VXE2__) + const int8x16_t v_k = vec_xl(0, kvalues_mxfp4); + const uint8x16_t v_m = vec_splats((const uint8_t)0x0F); + + float32x4_t v_acc = vec_splats(0.0f); + + #pragma GCC unroll 8 + for (; ib + 1 < nb; ib += 2) { + const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0]; + const block_mxfp4 * GGML_RESTRICT x1 = &x[ib + 1]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; + const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; + + const uint8x16_t v_x0 = vec_xl(0, x0->qs); + const uint8x16_t v_x1 = vec_xl(0, x1->qs); + + int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m); + int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4); + int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m); + int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4); + + v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l); + v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h); + v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l); + v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h); + + const int8x16_t v_y0l = vec_xl(0, y0->qs); + const int8x16_t v_y0h = vec_xl(QK8_0/2, y0->qs); + const int8x16_t v_y1l = vec_xl(0, y1->qs); + const int8x16_t v_y1h = vec_xl(QK8_0/2, y1->qs); + + const int32x4_t v_xy0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0l, v_y0l), v_x0h, v_y0h); + const int32x4_t v_xy1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1l, v_y1l), v_x1h, v_y1h); + + const float32x4_t v_xy0f = vec_float(v_xy0); + const float32x4_t v_xy1f = vec_float(v_xy1); + + const float32x4_t v_d0 = vec_splats(GGML_E8M0_TO_FP32_HALF(x0->e) * GGML_CPU_FP16_TO_FP32(y0->d)); + const float32x4_t v_d1 = vec_splats(GGML_E8M0_TO_FP32_HALF(x1->e) * GGML_CPU_FP16_TO_FP32(y1->d)); + + v_acc = vec_madd(v_xy0f, v_d0, v_acc); + v_acc = vec_madd(v_xy1f, v_d1, v_acc); + } + + for (; ib < nb; ++ib) { + const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; + + const uint8x16_t v_x = vec_xl(0, x0->qs); + + int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m); + int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4); + + v_xl = vec_perm(v_k, v_k, (uchar8x16_t)v_xl); + v_xh = vec_perm(v_k, v_k, (uchar8x16_t)v_xh); + + const int8x16_t v_yl = vec_xl(0, y0->qs); + const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs); + + const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh); + const float32x4_t v_xyf = vec_float(v_xy); + + const float32x4_t v_d = vec_splats(GGML_E8M0_TO_FP32_HALF(x0->e) * GGML_CPU_FP16_TO_FP32(y0->d)); + v_acc = vec_madd(v_xyf, v_d, v_acc); + } + + sumf = vec_hsum_f32x4(v_acc); + *s = sumf; +#else + UNUSED(x); + UNUSED(y); + UNUSED(ib); + UNUSED(sumf); + ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -636,7 +733,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi uint8x16_t q3h[4]; uint8x16_t q3b[2]; int8x16_t q3bytes[4]; - int8x16_t q8bytes[4]; + int8x16_t q8bytes[8]; uint8x16_t qhbits[2]; float sum = 0; diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 81a314e4d6..3191faaa4c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -18,6 +18,10 @@ # include "kleidiai/kleidiai.h" #endif +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT +# include "spacemit/ime.h" +#endif + #if defined(_WIN32) # define WIN32_LEAN_AND_MEAN # ifndef NOMINMAX @@ -45,6 +49,12 @@ std::vector & ggml_backend_cpu_get_extra_buffer_type } #endif +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT + if (ggml_backend_cpu_riscv64_spacemit_buffer_type()) { + bufts.push_back(ggml_backend_cpu_riscv64_spacemit_buffer_type()); + } +#endif + #ifdef GGML_USE_CPU_KLEIDIAI if (ggml_backend_cpu_kleidiai_buffer_type()) { bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type()); diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 8694ee15d3..44691e5dfd 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -87,15 +87,38 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { return tensor->ne[dim]; } +template +constexpr bool variant_any_invocable_impl(std::index_sequence) { + using V = std::remove_reference_t; + return (std::is_invocable_r_v< + Ret, + std::variant_alternative_t, + Args...> || ...); +} + +template +constexpr bool variant_any_invocable_v = + variant_any_invocable_impl( + std::make_index_sequence< + std::variant_size_v>>{}); + template -static Ret variant_call(const Variant & var, Args&&... args) { - return std::visit([&](auto&& func) -> Ret { - if constexpr (std::is_invocable_r_v) { - return func(std::forward(args)...); - } else { - throw std::runtime_error("Invalid function type in variant_call"); - } - }, var); +static inline Ret variant_call(Variant && var, Args&&... args) { + static_assert(variant_any_invocable_v, Ret, Args...>, + "No alternative in Variant is invocable with the provided arguments and return type."); + + return std::visit( + [&](auto && f) -> Ret { + using F = std::decay_t; + if constexpr (std::is_invocable_r_v) { + return std::invoke(std::forward(f), std::forward(args)...); + } else { + GGML_ABORT("Invalid function type in variant_call"); + GGML_UNREACHABLE(); + } + }, + std::forward(var) + ); } namespace ggml::cpu::kleidiai { @@ -138,7 +161,10 @@ class tensor_traits : public ggml::cpu::tensor_traits { if (kernels->rhs_type == GGML_TYPE_Q4_0) { size = variant_call(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr); } else if (kernels->rhs_type == GGML_TYPE_F16) { - size = variant_call(lhs_info->packed_size, m, k, mr, kr, sr) + + const int64_t lhs_batch_size0 = op->src[1]->ne[2]; + const int64_t rhs_batch_size0 = op->src[0]->ne[2]; + const int64_t r = lhs_batch_size0 / rhs_batch_size0; + size = variant_call(lhs_info->packed_size, m * r, k, mr, kr, sr) + variant_call(kernels->rhs_info.packed_size, n, k) + k * n * sizeof(float) + n * sizeof(float); } else { @@ -148,7 +174,6 @@ class tensor_traits : public ggml::cpu::tensor_traits { return true; } - bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override { if (dst->op == GGML_OP_MUL_MAT) { if (dst->src[0]->type == GGML_TYPE_Q4_0) { @@ -165,8 +190,6 @@ class tensor_traits : public ggml::cpu::tensor_traits { } bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) { - static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT; - const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -175,7 +198,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); GGML_ASSERT(kernels); - bool is_gemv = src1->ne[1] == 1; + const bool is_gemv = src1->ne[1] == 1; kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; GGML_ASSERT(kernel); @@ -185,27 +208,30 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t lhs_batch_size0 = ne12; const int64_t rhs_batch_size0 = ne02; - const int64_t batch_size = rhs_batch_size0; + const int64_t batch_size = lhs_batch_size0; + GGML_ASSERT(rhs_batch_size0 > 0); + GGML_ASSERT(lhs_batch_size0 % rhs_batch_size0 == 0); const int64_t r = lhs_batch_size0 / rhs_batch_size0; - const int64_t m = ne11 * r; - const int64_t n = ne01; - const int64_t k = ne00; + const int64_t m_group = ne11; + const int64_t m = m_group; + const int64_t n = ne01; + const int64_t k = ne00; const size_t lhs_stride = src1->nb[1]; const size_t rhs_stride = src0->nb[1]; const size_t dst_stride = dst->nb[1]; - const int64_t mr = static_cast(kernel->get_mr()); - const int64_t nr = static_cast(kernel->get_nr()); - const int64_t kr = static_cast(kernel->get_kr()); - const int64_t sr = static_cast(kernel->get_sr()); + const int64_t mr = (int64_t) kernel->get_mr(); + const int64_t nr = (int64_t) kernel->get_nr(); + const int64_t kr = (int64_t) kernel->get_kr(); + const int64_t sr = (int64_t) kernel->get_sr(); - const size_t lhs_packed_size = variant_call(lhs_info->packed_size, m, k, mr, kr, sr); - const size_t rhs_packed_size = variant_call(kernels->rhs_info.packed_size, n, k); - const size_t kxn_size = k * n * sizeof(float); - const size_t bias_size = n * sizeof(float); + const size_t lhs_packed_size = variant_call(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); + const size_t rhs_packed_size = variant_call(kernels->rhs_info.packed_size, (size_t)n, (size_t)k); + const size_t kxn_size = (size_t)k * (size_t)n * sizeof(float); + const size_t bias_size = (size_t)n * sizeof(float); const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size; GGML_ASSERT(wsize_required <= params->wsize); @@ -216,82 +242,102 @@ class tensor_traits : public ggml::cpu::tensor_traits { uint8_t * bias = rhs_kxn + kxn_size; for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - const uint8_t * lhs_batch = static_cast(src1->data) + batch_idx * m * lhs_stride; - const uint8_t * rhs_batch = static_cast(src0->data) + batch_idx * n * rhs_stride; - uint8_t * dst_batch = static_cast(dst->data) + batch_idx * m * dst_stride; + const int64_t rhs_batch_idx = batch_idx / r; + const uint8_t * rhs_batch_base = static_cast(src0->data) + rhs_batch_idx * src0->nb[2]; + uint8_t * dst_batch_base = static_cast(dst->data) + batch_idx * dst->nb[2]; - // LHS packing + // LHS packing (threaded over m, honoring mr alignment and KV groups) { const int64_t m_roundup_mr = kai_roundup(m, mr); const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth); if (ith < num_threads) { - const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr); + const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / num_threads), (size_t)mr); const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0; - const int64_t m_start = ith * num_m_per_thread0; - const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; + const int64_t m_start = ith * num_m_per_thread0; + const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; - const size_t lhs_offset = variant_call(kernels->gemm.get_lhs_offset, m_start, lhs_stride); - const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, m_start, k, mr, kr, sr); + // Base packed offset (aligned) and per-row stride in bytes + const size_t base_packed_off = variant_call( + lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); + const size_t next_block_off = variant_call( + lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); + const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr; - const void * src_ptr = static_cast(lhs_batch) + lhs_offset; - void * dst_ptr = static_cast(lhs_packed) + lhs_packed_offset; + int64_t remaining = m_count; + int64_t cur = m_start; - variant_call(lhs_info->pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr); + while (remaining > 0) { + const int64_t row_in_group = cur; + const int64_t avail = m_group - row_in_group; + const int64_t take = std::min(avail, remaining); + + const uint8_t * lhs_batch_base = static_cast(src1->data) + batch_idx * src1->nb[2]; + const void * src_ptr = lhs_batch_base + (size_t)row_in_group * lhs_stride; + const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes; + void * dst_ptr = lhs_packed + dst_off; + + variant_call(lhs_info->pack_func, + (size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr, + /*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr); + + cur += take; + remaining -= take; + } } } - // RHS packing - if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) { - // First thread to reach this point handles RHS packing - memset(bias, 0, n * sizeof(float)); - transpose_f32kxn_f16nxk(n, k, reinterpret_cast(rhs_kxn), - reinterpret_cast(rhs_batch), rhs_stride); + // RHS packing (single thread), then synchronize + if (ith == 0) { + memset(bias, 0, (size_t)n * sizeof(float)); + transpose_f32kxn_f16nxk((size_t)n, (size_t)k, + reinterpret_cast(rhs_kxn), + reinterpret_cast(rhs_batch_base), + rhs_stride); - variant_call(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float), - rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr); + variant_call(kernels->rhs_info.pack_func, + /*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr, + /*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)), + rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr); } ggml_barrier(params->threadpool); - first_to_arrive.clear(std::memory_order_release); - - // Perform the matmul + // Matmul (threaded over n) { - const int64_t m_to_process = m; - const int64_t m_start = 0; - - const int64_t n_step = static_cast(kernel->get_n_step()); - int64_t num_threads = KAI_MIN(n / n_step, nth); - if (num_threads <= 0) { - num_threads = 1; + const int64_t n_step = (int64_t) kernel->get_n_step(); + int64_t num_threads_n = KAI_MIN(n / n_step, nth); + if (num_threads_n <= 0) { + num_threads_n = 1; } - if (ith < num_threads) { - const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step); - const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0; + if (ith < num_threads_n) { + const int64_t num_n_per_thread0 = round_down((size_t)(n / num_threads_n), (size_t)n_step); + const int64_t num_n_per_threadN_1 = n - (num_threads_n - 1) * num_n_per_thread0; const int64_t n_start = ith * num_n_per_thread0; - const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0; + const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0; - const size_t lhs_packed_offset = variant_call(kernel->get_lhs_offset, m_start, k); - const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, n_start, k); - const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride); + // LHS packed base at row 0 (consistent with packing above) + const size_t lhs_packed_offset0 = variant_call( + lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); + const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k); + const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride); - const void * lhs_ptr = lhs_packed + lhs_packed_offset; + const void * lhs_ptr = lhs_packed + lhs_packed_offset0; const void * rhs_ptr = rhs_packed + rhs_packed_offset; - float * dst_ptr = reinterpret_cast(dst_batch + dst_offset); + float * dst_ptr = reinterpret_cast(dst_batch_base + dst_offset); - variant_call(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + variant_call(kernel->run_kernel, + (size_t)m, (size_t)n_to_process, (size_t)k, + lhs_ptr, rhs_ptr, + dst_ptr, dst_stride, sizeof(float), + -FLT_MAX, FLT_MAX); } } if (batch_idx != batch_size - 1) { - // This barrier is necessary when the batch size is larger than 1. While processing a batch, - // the work data buffer (params->wdata) is used as temporary storage which means that only - // a single batch can be processed at any given time. No barrier is needed for the last - // batch since GGML inserts a barrier between the execution of every operator. ggml_barrier(params->threadpool); } } diff --git a/ggml/src/ggml-cpu/spacemit/ime.cpp b/ggml/src/ggml-cpu/spacemit/ime.cpp new file mode 100644 index 0000000000..54d3dece0e --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime.cpp @@ -0,0 +1,1024 @@ +#define GGML_COMMON_IMPL_CPP +#define GGML_COMMON_DECL_CPP + +#include "ime.h" + +#include "ggml-backend-impl.h" +#include "ggml-common.h" +#include "ggml-cpu.h" +#include "ime_kernels.h" +#include "traits.h" + +#include +#include +#include +#include // for GGML_ASSERT +#include +#include + +// clang-format off +#if defined(__riscv) + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +#error "riscv v extension or v_intrinsic not enabled" +#else +#include +#endif + +#if !defined(__riscv_zfh) +#error "riscv zfh extension not enabled" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) +#else +#error "RISCV64_SPACEMIT_IME1 not defined" +#endif + +#else + +#error "riscv not enabled in this build" + +#endif + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Woverlength-strings" +#pragma GCC diagnostic ignored "-Wcast-qual" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) +#define QGEMM_STRIDEN_THREAD_ALIGN 16 +#else +#define QGEMM_STRIDEN_THREAD_ALIGN 32 +#endif + +// clang-format on + +struct qnbitgemm_spacemit_ime_args { + const float * a_ptr = nullptr; + size_t lda = 0; + const std::byte * packed_quant_b_data = nullptr; + const float * quant_b_scale = nullptr; + const void * quant_b_zp = nullptr; + const float * quant_b_blksum = nullptr; + const float * bias = nullptr; + float * c_ptr = nullptr; + size_t ldc = 0; +}; + +constexpr size_t div_round_up(size_t up, size_t down) { + return (up + down - 1) / down; +} + +constexpr size_t q8_blk_size(size_t blk_len) { + const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t); + // Currently, the strictest alignment requirement of a block is for a float. + // Ensure contiguous blocks are suitably aligned. + assert(blk_size % alignof(float) == 0); + return blk_size; +} + +namespace ggml::cpu::riscv64_spacemit { + +const int num_ai_cores = std::thread::hardware_concurrency() / 2; + +} // namespace ggml::cpu::riscv64_spacemit + +static void sqnbitgemm_spacemit_ime_i8i4(const size_t blk_len, + const size_t gemm_k, + const qnbitgemm_spacemit_ime_args * gemm_args, + void * const per_gemm_ws, + const size_t m_start, + const size_t m_count, + const size_t n_start, + const size_t n_count) { + constexpr size_t scale_stride = sizeof(uint16_t); + constexpr size_t blk_bitwidth = 4; + + const size_t k_blks = div_round_up(gemm_k, blk_len); + + const size_t lda = k_blks * q8_blk_size(blk_len); + const size_t ldc = gemm_args->ldc; + const size_t ldb = k_blks * (blk_len * blk_bitwidth / 8); + const std::byte * quant_a_ptr = static_cast(per_gemm_ws) + m_start * lda; + + const size_t zero_point_stride = gemm_args->quant_b_zp != nullptr ? sizeof(uint8_t) : 0; + const size_t packed_b_stride = ldb + k_blks * (scale_stride + zero_point_stride); + const std::byte * packed_quant_b_data = gemm_args->packed_quant_b_data + n_start * packed_b_stride; + + float * c_ptr = gemm_args->c_ptr + m_start * ldc + n_start; + + size_t count_n = 0; + const size_t compute_block_count_n = m_count == 1 ? n_count : 16; + for (size_t n = 0; n < n_count; n += count_n) { + count_n = std::min(n_count - n, compute_block_count_n); + + const std::byte * a_row = quant_a_ptr; + const std::byte * b_col = packed_quant_b_data + n * packed_b_stride; + const std::byte * b_col_zp = (zero_point_stride != 0) ? b_col : nullptr; + float * c_blk = c_ptr + n; + + int32_t rows_remaining = m_count; + + while (rows_remaining > 0) { + const auto rows_handled = sqnbitgemm_spacemit_ime::ime1::gemm_kernel_i8i4( + blk_len, a_row, b_col, nullptr, b_col_zp, c_blk, rows_remaining, count_n, gemm_k, k_blks, ldc, nullptr, + scale_stride); + + c_blk += rows_handled * ldc; + a_row += rows_handled * lda; + + rows_remaining -= rows_handled; + } + } +} + +template constexpr int QK_0() { + if constexpr (K == 4) { + return QK4_0; + } + if constexpr (K == 8) { + return QK8_0; + } + return -1; +} + +template struct block { + ggml_half d[N]; // deltas for N qK_0 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks +}; + +template struct block_with_zp { + ggml_half d[N]; // deltas for N qK_1 blocks + uint8_t zp[N]; // zero points for N qK_1 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_1 blocks +}; + +// control size +static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding"); +static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), + "wrong block_with_zp<4,16> size/padding"); +static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding"); + +using block_q4_0x16 = block<4, 16>; +using block_q4_1x16 = block_with_zp<4, 16>; +using block_q8_0x16 = block<8, 16>; + +static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x16 out; + GGML_ASSERT(QK4_0 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); + } + } + + return out; +} + +static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) { + block_q4_1x16 out; + GGML_ASSERT(QK4_1 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + out.d[i] = GGML_FP32_TO_FP16(d); + out.zp[i] = static_cast(mid); + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); + } + } + + return out; +} + +static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_0x16 * dst = (block_q4_0x16 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static inline void get_scale_min_k4(int j, + const uint8_t * GGML_RESTRICT q, + uint8_t * GGML_RESTRICT d, + uint8_t * GGML_RESTRICT m) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j + 4] & 63; + } else { + *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); + } +} + +static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 16); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_K * src = (const block_q4_K *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +namespace ggml::cpu::riscv64_spacemit { + +template +int repack(struct ggml_tensor *, const void *, size_t); + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); +} + +class tensor_traits_base : public ggml::cpu::tensor_traits { + public: + virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; +}; + +template class tensor_traits : public tensor_traits_base { + bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4; + size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float)); + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + if (op->src[0]->type == GGML_TYPE_Q4_0 || // + op->src[0]->type == GGML_TYPE_Q4_1 || // + op->src[0]->type == GGML_TYPE_Q4_K) { + forward_mul_mat_q4(params, op); + return true; + } + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + void forward_mul_mat_q4(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + + GGML_TENSOR_BINARY_OP_LOCALS + + int ith = params->ith; + int nth = params->nth; + + [[maybe_unused]] const enum ggml_type type = src0->type; + + void * w_data = (void *) src0->data; + const float * feature = (const float *) src1->data; + float * output = (float *) dst->data; + + const size_t batch_feature = ne12 * ne13; + [[maybe_unused]] const size_t batch_weight = ne02 * ne03; + const size_t gemm_m = ne11; + const size_t gemm_k = ne10; + const size_t gemm_n = ne01; + + GGML_ASSERT(batch_weight == 1); + + const size_t block_count_k = div_round_up(gemm_k, QK4_0); + const size_t per_gemm_workspace_size = gemm_m * block_count_k * q8_blk_size(QK4_0); + const size_t per_gemm_workspace_stride = + div_round_up(per_gemm_workspace_size, alignof(uint64_t)) * alignof(uint64_t); + const size_t gemm_workspace_size = batch_feature * per_gemm_workspace_stride; + const size_t desired_wsize = gemm_workspace_size + alignof(uint64_t) - 1; + + if (ith == 0 && params->wsize < desired_wsize) { + throw std::runtime_error("wsize less than desired_wsize"); + } + + std::vector qnbitgemm_args(batch_feature); + + for (size_t i = 0; i < batch_feature; i++) { + qnbitgemm_args[i].a_ptr = feature + gemm_m * gemm_k * i; + qnbitgemm_args[i].lda = gemm_k; + qnbitgemm_args[i].packed_quant_b_data = (const std::byte *) w_data; + qnbitgemm_args[i].quant_b_scale = nullptr; + + if constexpr (std::is_same_v) { + qnbitgemm_args[i].quant_b_zp = nullptr; + } else { + qnbitgemm_args[i].quant_b_zp = w_data; + } + + qnbitgemm_args[i].bias = nullptr; + qnbitgemm_args[i].c_ptr = output + gemm_m * gemm_n * i; + qnbitgemm_args[i].ldc = gemm_n; + } + + const uintptr_t ws_ptr = reinterpret_cast(params->wdata); + void * ws = reinterpret_cast((ws_ptr + alignof(uint64_t) - 1) & (~(alignof(uint64_t) - 1))); + const size_t quant_a_stride = block_count_k * q8_blk_size(QK4_0); + + { + constexpr size_t block_size_m = 4; + size_t per_gemm_block_count_m = div_round_up(gemm_m, block_size_m); + int32_t task_count = batch_feature * per_gemm_block_count_m; + int32_t task_per_thread = (task_count + nth - 1) / nth; + int32_t start = ith * task_per_thread; + int32_t end = std::min((ith + 1) * task_per_thread, task_count); + for (int32_t compute_idx = start; compute_idx < end; compute_idx++) { + int32_t gemm_idx = compute_idx / block_size_m; + int32_t m_idx = compute_idx % block_size_m * block_size_m; + const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx]; + int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx); + + if (rows_tobe_handled == block_size_m) { + const float * a_row_ptr = data.a_ptr + m_idx * data.lda; + std::byte * quant_a_row_ptr = + static_cast(ws) + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; + sqnbitgemm_spacemit_ime::ime1::quantize_a_4row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); + } else { + while (rows_tobe_handled) { + const float * a_row_ptr = data.a_ptr + m_idx * data.lda; + std::byte * quant_a_row_ptr = static_cast(ws) + + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; + sqnbitgemm_spacemit_ime::ime1::quantize_a_row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); + rows_tobe_handled -= 1; + m_idx += 1; + } + } + } + } + + ggml_barrier(params->threadpool); + + if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) { + return; + } + nth = std::min(nth, int{ ggml::cpu::riscv64_spacemit::num_ai_cores }); + + size_t threads_per_gemm = nth / batch_feature; + constexpr size_t gemm_m_stride = 128; + size_t nc = gemm_n; + const size_t gemm_m_blocked = div_round_up(gemm_m, gemm_m_stride); + const size_t max_nc = div_round_up(gemm_n * gemm_m_blocked, threads_per_gemm); + if (max_nc < nc) { + nc = std::min(nc, div_round_up(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * QGEMM_STRIDEN_THREAD_ALIGN); + } + const size_t gemm_n_stride = nc; + const size_t thread_count_m = div_round_up(gemm_m, gemm_m_stride); + const size_t thread_count_n = div_round_up(gemm_n, gemm_n_stride); + threads_per_gemm = thread_count_m * thread_count_n; + + { + int task_count = batch_feature * threads_per_gemm; + int task_per_thread = (task_count + nth - 1) / nth; + int start = ith * task_per_thread; + int end = std::min((ith + 1) * task_per_thread, task_count); + for (int compute_idx = start; compute_idx < end; compute_idx++) { + const auto gemm_i = compute_idx / threads_per_gemm; + const auto blk_i = compute_idx % threads_per_gemm; + const auto * data = &qnbitgemm_args[gemm_i]; + + const auto tid_n = blk_i / thread_count_m; + const auto tid_m = blk_i % thread_count_m; + + const size_t m_start = tid_m * gemm_m_stride; + const size_t m_count = std::min(gemm_m - m_start, (size_t) gemm_m_stride); + + const size_t n_start = tid_n * gemm_n_stride; + const size_t n_count = std::min(gemm_n - n_start, (size_t) gemm_n_stride); + + void * per_gemm_ws = reinterpret_cast(ws) + gemm_i * per_gemm_workspace_stride; + + sqnbitgemm_spacemit_ime_i8i4(QK4_0, gemm_k, data, per_gemm_ws, m_start, m_count, n_start, n_count); + } + } + } + + int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { + GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type), + (int) NB_COLS, (int) INTER_SIZE); + return ggml::cpu::riscv64_spacemit::repack(t, data, data_size); + } +}; + +class tensor_traits_common : public tensor_traits_base { + bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + switch (op->op) { + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + size = 0; + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_NORM: + forward_norm_f32(params, op); + return true; + case GGML_OP_RMS_NORM: + forward_rms_norm_f32(params, op); + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon; + memcpy(&epsilon, dst->op_params, sizeof(float)); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (float *) src0->data; + auto * output = (float *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + auto offset = task_idx * hidden_size; + auto * p_input = const_cast(input + offset); + + auto * p_output = output + offset; + auto * p_temp_output = p_output; + auto * p_gamma_data = (const float *) nullptr; + auto * p_beta_data = (const float *) nullptr; + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + // load data + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + + sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + + float mean = 0.f; + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + vfloat32m1_t mean_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); + mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); + mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); + mean /= hidden_size; + + vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), + __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + mean_square = sqrt(mean_square - mean * mean + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + if (p_gamma_data == nullptr && p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } else if (p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } else if (p_gamma_data != nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); + src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); + p_beta_data += gvl; + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } + } + } + + void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon; + memcpy(&epsilon, dst->op_params, sizeof(float)); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (float *) src0->data; + auto * output = (float *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + auto offset = task_idx * hidden_size; + auto * p_input = const_cast(input + offset); + auto * p_output = output + offset; + auto * p_temp_output = p_output; + auto * p_gamma_data = (const float *) nullptr; + auto * p_beta_data = (const float *) nullptr; + + size_t gvl = __riscv_vsetvlmax_e32m4(); + // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + // load data + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + + // float mean = 0.f; + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + + vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), + __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + + mean_square = sqrt(mean_square + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + if (p_gamma_data == nullptr && p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } else if (p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } else if (p_gamma_data != nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); + src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); + p_beta_data += gvl; + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } + } + } + + int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { + memcpy(t->data, data, data_size); + return 0; + } +}; + +static const tensor_traits q4_0_16x8_q8_0; +static const tensor_traits q4_1_16x8_q8_0; +static const tensor_traits q4_k_16x8_q8_0; +static const tensor_traits_common rvv_impl; + +} // namespace ggml::cpu::riscv64_spacemit + +static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor * cur) { + if (cur->type == GGML_TYPE_Q4_0) { + if (cur->ne[1] % 16 == 0) { + return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0; + } + } else if (cur->type == GGML_TYPE_Q4_1) { + if (cur->ne[1] % 16 == 0) { + return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0; + } + } else if (cur->type == GGML_TYPE_Q4_K) { + if (cur->ne[1] % 16 == 0) { + return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0; + } + } else if (cur->type == GGML_TYPE_F32) { + return &ggml::cpu::riscv64_spacemit::rvv_impl; + } + + return nullptr; +} + +static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer, + struct ggml_tensor * tensor) { + tensor->extra = + (void *) const_cast(ggml_riscv64_spacemit_get_optimal_repack_type(tensor)); + + GGML_UNUSED(buffer); + + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer, + struct ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + + auto tensor_traits = (ggml::cpu::riscv64_spacemit::tensor_traits_base *) tensor->extra; + if (tensor_traits) { + auto OK = tensor_traits->repack(tensor, data, size); + GGML_ASSERT(OK == 0); + } + + GGML_UNUSED(buffer); +} + +static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU_RISCV64_SPACEMIT"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, + size_t size) { + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + + if (buffer == nullptr) { + return nullptr; + } + + buffer->buft = buft; + buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor; + buffer->iface.set_tensor = ggml_backend_riscv64_spacemit_buffer_set_tensor; + buffer->iface.get_tensor = nullptr; + buffer->iface.cpy_tensor = nullptr; + return buffer; +} + +static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return 64; + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, + const struct ggml_tensor * tensor) { + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + if (tensor->ne[i] <= 0) { + return 0; + } + } + + size_t nbytes; + const size_t blck_size = ggml_blck_size(tensor->type); + if (blck_size == 1) { + nbytes = ggml_type_size(tensor->type); + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; + } + } else { + nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; + if (tensor->type == GGML_TYPE_Q4_K) { + GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0); + nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; + } + } else { + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; + } + } + } + + GGML_UNUSED(buft); + return nbytes; +} + +namespace ggml::cpu::riscv64_spacemit { + +class extra_buffer_type : ggml::cpu::extra_buffer_type { + bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) && + op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() && + ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32) { + return true; + } + } + break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + if (op->src[0]->type == GGML_TYPE_F32) { + return true; + } + break; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) { + return (ggml::cpu::tensor_traits *) op->src[0]->extra; + } + break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + return (ggml::cpu::tensor_traits *) (&ggml::cpu::riscv64_spacemit::rvv_impl); + default: + // GGML_ABORT("fatal error"); + break; + } + + return nullptr; + } +}; + +} // namespace ggml::cpu::riscv64_spacemit + +ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { + /* .iface = */ + { + /* .get_name = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment, + /* .get_max_size = */ nullptr, + /* .get_alloc_size = */ ggml_backend_cpu_riscv64_spacemit_nbytes, + /* .is_host = */ nullptr, + }, + /* .device = */ + ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ + new ggml::cpu::riscv64_spacemit::extra_buffer_type(), + }; + + return &ggml_backend_cpu_buffer_type_riscv64_spacemit; +} diff --git a/ggml/src/ggml-cpu/spacemit/ime.h b/ggml/src/ggml-cpu/spacemit/ime.h new file mode 100644 index 0000000000..800d91acda --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime.h @@ -0,0 +1,13 @@ +#pragma once + +#include "ggml-alloc.h" + +#ifdef __cplusplus +extern "C" { +#endif + +ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp new file mode 100644 index 0000000000..cbbb6cd916 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp @@ -0,0 +1,3196 @@ +#include "ggml.h" +#include "ime_kernels.h" + +#include +#include + +// clang-format off +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Woverlength-strings" +#pragma GCC diagnostic ignored "-Wcast-qual" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif +// clang-format on +namespace sqnbitgemm_spacemit_ime { + +#define QUANTIZEM4ROW_KERNEL \ + "vmv.s.x v16, zero \n\t" \ + "vfabs.v v8, v0 \n\t" \ + "vfredmax.vs v16, v8, v16 \n\t" \ + "vfmv.f.s f10, v16 \n\t" \ + "fmul.s f10, f10, %[RMAXREC] \n\t" \ + "fsw f10, (a1) \n\t" \ + "fdiv.s f11, %[FONE], f10 \n\t" \ + "vfmul.vf v16, v0, f11 \n\t" \ + "vfcvt.x.f.v v16, v16 \n\t" \ + "vsetvli t0, zero, e16, mf2 \n\t" \ + "vnclip.wx v16, v16, zero \n\t" \ + "vnclip.wx v17, v17, zero \n\t" \ + "vnclip.wx v18, v18, zero \n\t" \ + "vnclip.wx v19, v19, zero \n\t" \ + "vnclip.wx v20, v20, zero \n\t" \ + "vnclip.wx v21, v21, zero \n\t" \ + "vnclip.wx v22, v22, zero \n\t" \ + "vnclip.wx v23, v23, zero \n\t" \ + "vsetvli t0, zero, e8, mf4 \n\t" \ + "vnclip.wx v24, v16, zero \n\t" \ + "vnclip.wx v25, v17, zero \n\t" \ + "vnclip.wx v26, v18, zero \n\t" \ + "vnclip.wx v27, v19, zero \n\t" \ + "vnclip.wx v28, v20, zero \n\t" \ + "vnclip.wx v29, v21, zero \n\t" \ + "vnclip.wx v30, v22, zero \n\t" \ + "vnclip.wx v31, v23, zero \n\t" + +#define QUANTIZEM4ROW_STORE \ + "addi t1, %[BlkLen], 0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v24, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v25, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v26, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v27, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v28, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v29, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v30, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v31, (s1) \n\t" + +namespace ime1 { +void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { + constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); + const float fone = 1.0f; + + if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) { + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float * SRC = A + row_index * CountK; + std::byte * DST = QuantA + row_index * sizeof(float); + + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" + + "LOOP%=: \n\t" + "vsetvli t0, %[BlkLen], e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "sub t2, t2, t0 \n\t" + "slli t1, t0, 2 \n\t" + "add %[SRC], %[SRC], t1 \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE + + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" + + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL + + "addi t3, %[BlkLen], 0 \n\t" + "addi s2, s1, 0 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "SET_ZERO%=: \n\t" + "vse8.v v8, (s2) \n\t" + "addi s2, s2, 32 \n\t" + "addi t3, t3, -8 \n\t" + "bnez t3, SET_ZERO%= \n\t" + + QUANTIZEM4ROW_STORE + + "QUIT%=: \n\t" + : [SRC] "+r"(SRC) + : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), + [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11"); + } + } else if (BlkLen == 128) { + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float * SRC = A + row_index * CountK; + std::byte * DST = QuantA + row_index * sizeof(float); + + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "li t6, 32 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "add s1, a1, %[OFFSET] \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" + + "LOOP%=: \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "addi t2, t2, -128 \n\t" + + "QUANTIZE%=: \n\t" + "add s1, a1, %[OFFSET] \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v24, v8 \n\t" + "vfmax.vv v16, v24, v16 \n\t" + "vfredmax.vs v24, v16, v24 \n\t" + "vfmv.f.s f10, v24 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (a1) \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfmul.vf v24, v8, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v24, zero \n\t" + "vsetvli t0, zero, e8, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e64, m4 \n\t" + "vsse64.v v16, (s1), t6 \n\t" + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" + + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "sub t2, t2, t0 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "sub t2, t2, t2 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "jal x0, QUANTIZE%= \n\t" + + "QUIT%=: \n\t" + : [SRC] "+r"(SRC) + : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), + [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); + } + } else if (BlkLen == 256) { + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float * SRC = A + row_index * CountK; + std::byte * DST = QuantA + row_index * sizeof(float); + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "li t6, 32 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "add s1, a1, %[OFFSET] \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" + + "LOOP%=: \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], -768 \n\t" + "addi t2, t2, -256 \n\t" + "vfabs.v v0, v0 \n\t" + "vfabs.v v8, v8 \n\t" + "vfabs.v v16, v16 \n\t" + "vfabs.v v24, v24 \n\t" + "vfmax.vv v8, v0, v8 \n\t" + "vfmax.vv v24, v24, v16 \n\t" + "vfmax.vv v8, v8, v24 \n\t" + "vfredmax.vs v24, v8, v24 \n\t" + "vfmv.f.s f10, v24 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + + "QUANTIZE%=: \n\t" + "add s1, a1, %[OFFSET] \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (a1) \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vfmul.vf v0, v0, f11 \n\t" + "vfmul.vf v8, v8, f11 \n\t" + "vfmul.vf v16, v16, f11 \n\t" + "vfmul.vf v24, v24, f11 \n\t" + "vfcvt.x.f.v v0, v0 \n\t" + "vfcvt.x.f.v v8, v8 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vnclip.wx v4, v8, zero \n\t" + "vnclip.wx v8, v16, zero \n\t" + "vnclip.wx v12, v24, zero \n\t" + "vsetvli t0, zero, e8, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vnclip.wx v4, v8, zero \n\t" + "vsetvli t0, zero, e64, m8 \n\t" + "vsse64.v v0, (s1), t6 \n\t" + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" + + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t1, t2, 0 \n\t" + "vsetvli t0, t1, e32, m8 \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, t1, e32, m8 \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, t1, e32, m8 \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, t1, e32, m8 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], -768 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfabs.v v0, v0 \n\t" + "vfabs.v v8, v8 \n\t" + "vfabs.v v16, v16 \n\t" + "vfabs.v v24, v24 \n\t" + "vfmax.vv v8, v0, v8 \n\t" + "vfmax.vv v24, v16, v24 \n\t" + "vfmax.vv v8, v8, v24 \n\t" + "vfredmax.vs v24, v8, v24 \n\t" + "vfmv.f.s f10, v24 \n\t" + "add s1, a1, %[OFFSET] \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (a1) \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e64, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vsse64.v v0, (s1), t6 \n\t" + + "TAIL_LOOP%=: \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vsetvli t0, t2, e32, m1 \n\t" + "sub t2, t2, t0 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 32 \n\t" + "vfmul.vf v1, v0, f11 \n\t" + "vfcvt.x.f.v v2, v1 \n\t" + "vsetvli t0, zero, e16, mf2 \n\t" + "vnclip.wx v3, v2, zero \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vnclip.wx v3, v3, zero \n\t" + "vse8.v v3, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "bnez t2, TAIL_LOOP%= \n\t" + + "QUIT%=: \n\t" + : [SRC] "+r"(SRC) + : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), + [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); + } + } +} + +void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { + const float * SRC = A; + std::byte * DST = QuantA; + constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); + const float fone = 1.0f; + std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); + size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK; + + if (CountK <= BlkLen) { + float max_abs_A = 0.0f; + for (size_t k = 0; k < CountK; k++) { + max_abs_A = std::max(max_abs_A, fabsf(A[k])); + } + float scale_A = max_abs_A * range_max_reciprocal; + + ((float *) QuantA)[0] = scale_A; + + auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float)); + + for (size_t k = 0; k < CountK; k++) { + QuantAData_offset[k] = + (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits::lowest(), + (float) std::numeric_limits::max()); + } + for (size_t k = CountK; k < BlkLen; k++) { + QuantAData_offset[k] = 0; + } + + return; + } + + if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) { + __asm__ volatile( + "vsetvli t0, zero, e8, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "LOOP%=: \n\t" + "vsetvli t0, %[CNT], e8, m8 \n\t" + "vse8.v v24, (%[DST]) \n\t" + "addi %[DST], %[DST], 128 \n\t" + "sub %[CNT], %[CNT], t0 \n\t" + "bnez %[CNT], LOOP%= \n\t" + : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset) + : + : "cc", "t0"); + } + if (BlkLen == 16) { + float buffer[64] = { 0.0f }; + __asm__ volatile( + "addi t3, zero, 16*8 \n\t" + "addi t2, zero, 16 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m2 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v2, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v4, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v6, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v10, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v12, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v14, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "addi a1, %[BUFFER], 0 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v18, v2 \n\t" + "vfabs.v v20, v4 \n\t" + "vfabs.v v22, v6 \n\t" + "vfabs.v v24, v8 \n\t" + "vfabs.v v26, v10 \n\t" + "vfabs.v v28, v12 \n\t" + "vfabs.v v30, v14 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v18, v18, v19 \n\t" + "vfmax.vv v20, v20, v21 \n\t" + "vfmax.vv v22, v22, v23 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfmax.vv v26, v26, v27 \n\t" + "vfmax.vv v28, v28, v29 \n\t" + "vfmax.vv v30, v30, v31 \n\t" + "vse32.v v16, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v18, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v20, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v22, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v24, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v26, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v28, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v30, (a1) \n\t" + "addi a1, %[BUFFER], 0 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f10, f3, f7 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f11, f3, f7 \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fsw f11, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f12, f3, f7 \n\t" + "fmul.s f12, f12, %[RMAXREC] \n\t" + "fsw f12, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f12, %[FONE], f12 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f13, f3, f7 \n\t" + "fmul.s f13, f13, %[RMAXREC] \n\t" + "fsw f13, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f13, %[FONE], f13 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f14, f3, f7 \n\t" + "fmul.s f14, f14, %[RMAXREC] \n\t" + "fsw f14, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f14, %[FONE], f14 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f15, f3, f7 \n\t" + "fmul.s f15, f15, %[RMAXREC] \n\t" + "fsw f15, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f15, %[FONE], f15 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f16, f3, f7 \n\t" + "fmul.s f16, f16, %[RMAXREC] \n\t" + "fsw f16, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f16, %[FONE], f16 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f17, f3, f7 \n\t" + "fmul.s f17, f17, %[RMAXREC] \n\t" + "fsw f17, (%[DST]) \n\t" + "addi %[DST], %[DST], -136 \n\t" + "fdiv.s f17, %[FONE], f17 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v18, v2, f11 \n\t" + "vfmul.vf v20, v4, f12 \n\t" + "vfmul.vf v22, v6, f13 \n\t" + "vfmul.vf v24, v8, f14 \n\t" + "vfmul.vf v26, v10, f15 \n\t" + "vfmul.vf v28, v12, f16 \n\t" + "vfmul.vf v30, v14, f17 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v18, v18 \n\t" + "vfcvt.x.f.v v20, v20 \n\t" + "vfcvt.x.f.v v22, v22 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vfcvt.x.f.v v26, v26 \n\t" + "vfcvt.x.f.v v28, v28 \n\t" + "vfcvt.x.f.v v30, v30 \n\t" + "vsetvli t0, zero, e16, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v18, v18, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v22, v22, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v26, v26, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vnclip.wx v30, v30, zero \n\t" + "vsetvli t0, t1, e8, mf2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v18, v18, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v22, v22, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v26, v26, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vnclip.wx v30, v30, zero \n\t" + "vse8.v v16, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v18, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v20, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v22, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v24, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v26, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v28, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v30, (%[DST]) \n\t" + "addi %[DST], %[DST], 16 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vse32.v v16, (%[BUFFER]) \n\t" + "flw f0, (%[BUFFER]) \n\t" + "flw f1, 4(%[BUFFER]) \n\t" + "flw f2, 8(%[BUFFER]) \n\t" + "flw f3, 12(%[BUFFER]) \n\t" + "flw f4, 16(%[BUFFER]) \n\t" + "flw f5, 20(%[BUFFER]) \n\t" + "flw f6, 24(%[BUFFER]) \n\t" + "flw f7, 28(%[BUFFER]) \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f10, f3, f7 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, t1, e8, mf2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (%[DST]) \n\t" + "addi %[DST], %[DST], 16 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m2 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer) + : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12", + "f13", "f14", "f15", "f16", "f17"); + } else if (BlkLen == 32) { + __asm__ volatile( + "addi t3, zero, 32*4 \n\t" + "addi t2, zero, 32 \n\t" + + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 128 \n\t" + "addi a3, %[SRC], 256 \n\t" + "addi a4, %[SRC], 384 \n\t" + + "addi s1, %[DST], 0 \n\t" + "addi s2, %[DST], 36 \n\t" + "addi s3, %[DST], 72 \n\t" + "addi s4, %[DST], 108 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m4 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v4, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "vle32.v v8, (a3) \n\t" + "addi a3, a3, 512 \n\t" + "vle32.v v12, (a4) \n\t" + "addi a4, a4, 512 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v20, v4 \n\t" + "vfabs.v v24, v8 \n\t" + "vfabs.v v28, v12 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vfmax.vv v20, v20, v22 \n\t" + "vfmax.vv v24, v24, v26 \n\t" + "vfmax.vv v28, v28, v30 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v20, v20, v21 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfmax.vv v28, v28, v29 \n\t" + + "vfredmax.vs v17, v16, v17 \n\t" + "vfredmax.vs v21, v20, v21 \n\t" + "vfredmax.vs v25, v24, v25 \n\t" + "vfredmax.vs v29, v28, v29 \n\t" + "vfmv.f.s f10, v17 \n\t" + "vfmv.f.s f11, v21 \n\t" + "vfmv.f.s f12, v25 \n\t" + "vfmv.f.s f13, v29 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fmul.s f12, f12, %[RMAXREC] \n\t" + "fmul.s f13, f13, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + + "fsw f11, (s2) \n\t" + "addi s2, s2, 4 \n\t" + "fsw f12, (s3) \n\t" + "addi s3, s3, 4 \n\t" + "fsw f13, (s4) \n\t" + "addi s4, s4, 4 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "fdiv.s f12, %[FONE], f12 \n\t" + "fdiv.s f13, %[FONE], f13 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v20, v4, f11 \n\t" + "vfmul.vf v24, v8, f12 \n\t" + "vfmul.vf v28, v12, f13 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v20, v20 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vfcvt.x.f.v v28, v28 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vsetvli t0, t1, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 140 \n\t" + "vse8.v v20, (s2) \n\t" + "addi s2, s2, 140 \n\t" + "vse8.v v24, (s3) \n\t" + "addi s3, s3, 140 \n\t" + "vse8.v v28, (s4) \n\t" + "addi s4, s4, 140 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m4 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 128 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfmv.f.s f10, v17 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m4 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST) + : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13"); + } else if (BlkLen == 64) { + __asm__ volatile( + "addi t3, zero, 64*2 \n\t" + "addi t2, zero, 64 \n\t" + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 256 \n\t" + "addi s1, %[DST], 0 \n\t" + "addi s2, %[DST], 68 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m8 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v8, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v24, v8 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmax.vv v16, v16, v20 \n\t" + "vfmax.vv v24, v24, v28 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vfmax.vv v24, v24, v26 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfredmax.vs v25, v24, v25 \n\t" + "vfmv.f.s f10, v17 \n\t" + "vfmv.f.s f11, v25 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fsw f11, (s2) \n\t" + "addi s2, s2, 4 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v24, v8, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vsetvli t0, t1, e8, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 132 \n\t" + "vse8.v v24, (s2) \n\t" + "addi s2, s2, 132 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m8 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 256 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmax.vv v16, v16, v20 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfmv.f.s f10, v17 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e8, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 64 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [K] "+r"(CountK) + : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) + : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11"); + } else if (BlkLen == 128) { + __asm__ volatile( + "addi t2, zero, 128 \n\t" + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 256 \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, zero, e32, m8 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v8, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "sub %[K], %[K], t2 \n\t" + "QUANT%=: \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v24, v8 \n\t" + "vfmax.vv v24, v16, v24 \n\t" + "vsetvli t1, zero, e32, m4 \n\t" + "vfmax.vv v28, v24, v28 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v30, v28, v30 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v30, v30, v31 \n\t" + "vfredmax.vs v31, v30, v31 \n\t" + "vfmv.f.s f10, v31 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfmul.vf v24, v8, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v24, zero \n\t" + "vsetvli t0, zero, e8, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (%[DST]) \n\t" + "addi %[DST], %[DST], 128 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t1, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "vsetvli t0, %[K], e32, m8 \n\t" + "vle32.v v0, (a1) \n\t" + "sub %[K], %[K], t0 \n\t" + "vsetvli t0, %[K], e32, m8 \n\t" + "vle32.v v8, (a2) \n\t" + "sub %[K], %[K], t0 \n\t" + "vsetvli t1, zero, e32, m8 \n\t" + "jal x0, QUANT%= \n\t" + "END%=: \n\t" + + : [DST] "+r"(DST), [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC) + : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11"); + } else { + float buffer[8] = { 0.0f }; + size_t cnt = BlkLen / 256; + + __asm__ volatile( + "slli t3, %[BLK], 2 \n\t" + "blt %[K], %[BLK], LOOP_TAIL%= \n\t" + "LOOP_MAIN%=: \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "vse32.v v31, (%[BUFFER]) \n\t" + "addi t6, %[CNT], 0 \n\t" + "LOOP_CMP%=: \n\t" + "addi t6, t6, -1 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vfabs.v v0, v0 \n\t" + "vfabs.v v8, v8 \n\t" + "vfabs.v v16, v16 \n\t" + "vfabs.v v24, v24 \n\t" + "vfmax.vv v8, v0, v8 \n\t" + "vfmax.vv v16, v16, v24 \n\t" + "vfmax.vv v0, v0, v16 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmax.vv v0, v0, v4 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v0, v0, v2 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v0, v0, v1 \n\t" + "vle32.v v30, (%[BUFFER]) \n\t" + "vfmax.vv v31, v30, v0 \n\t" + "vse32.v v31, (%[BUFFER]) \n\t" + "bnez t6, LOOP_CMP%= \n\t" + "sub %[SRC], %[SRC], t3 \n\t" + "addi t6, %[CNT], 0 \n\t" + "flw f0, (%[BUFFER]) \n\t" + "flw f1, 4(%[BUFFER]) \n\t" + "flw f2, 8(%[BUFFER]) \n\t" + "flw f3, 12(%[BUFFER]) \n\t" + "flw f4, 16(%[BUFFER]) \n\t" + "flw f5, 20(%[BUFFER]) \n\t" + "flw f6, 24(%[BUFFER]) \n\t" + "flw f7, 28(%[BUFFER]) \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f10, f3, f7 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "addi t6, %[CNT], 0 \n\t" + "LOOP_QUANT%=: \n\t" + "addi t6, t6, -1 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v0, v0, f11 \n\t" + "vfmul.vf v8, v8, f11 \n\t" + "vfmul.vf v16, v16, f11 \n\t" + "vfmul.vf v24, v24, f11 \n\t" + "vfcvt.x.f.v v0, v0 \n\t" + "vfcvt.x.f.v v8, v8 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vnclip.wx v4, v8, zero \n\t" + "vnclip.wx v8, v16, zero \n\t" + "vnclip.wx v12, v24, zero \n\t" + "vsetvli t0, zero, e8, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vnclip.wx v4, v8, zero \n\t" + "vse8.v v0, (%[DST]) \n\t" + "addi %[DST], %[DST], 128 \n\t" + "vse8.v v4, (%[DST]) \n\t" + "addi %[DST], %[DST], 128 \n\t" + "bnez t6, LOOP_QUANT%= \n\t" + "sub %[K], %[K], %[BLK] \n\t" + "bge %[K], %[BLK], LOOP_MAIN%= \n\t" + "blez %[K], END%= \n\t" + "LOOP_TAIL%=: \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "vse32.v v31, (%[BUFFER]) \n\t" + "addi t6, %[K], 0 \n\t" + "addi s1, %[SRC], 0 \n\t" + "TAIL_CMP%=: \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vsetvli t0, t6, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "sub t6, t6, t0 \n\t" + "vfabs.v v0, v0 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmax.vv v0, v0, v4 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v0, v0, v2 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v0, v0, v1 \n\t" + "vle32.v v30, (%[BUFFER]) \n\t" + "vfmax.vv v31, v30, v0 \n\t" + "vse32.v v31, (%[BUFFER]) \n\t" + "bnez t6, TAIL_CMP%= \n\t" + "addi t6, %[K], 0 \n\t" + "flw f0, (%[BUFFER]) \n\t" + "flw f1, 4(%[BUFFER]) \n\t" + "flw f2, 8(%[BUFFER]) \n\t" + "flw f3, 12(%[BUFFER]) \n\t" + "flw f4, 16(%[BUFFER]) \n\t" + "flw f5, 20(%[BUFFER]) \n\t" + "flw f6, 24(%[BUFFER]) \n\t" + "flw f7, 28(%[BUFFER]) \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f10, f3, f7 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "addi t6, %[K], 0 \n\t" + "TAIL_QUANT%=: \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vsetvli t1, t6, e32, m8 \n\t" + "vle32.v v0, (s1) \n\t" + "addi s1, s1, 256 \n\t" + "sub t6, t6, t1 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v0, v0, f11 \n\t" + "vfcvt.x.f.v v0, v0 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vsetvli t0, t1, e8, m2 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vse8.v v0, (%[DST]) \n\t" + "addi %[DST], %[DST], 64 \n\t" + "bnez t6, TAIL_QUANT%= \n\t" + "END%=: \n\t" + : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer), + [CNT] "r"(cnt) + : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6"); + } +} + +} // namespace ime1 + +namespace { +#define SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 \ + "vmadot v16, v14, v0 \n\t" \ + "vmadot v18, v14, v1 \n\t" \ + "vmadot v20, v14, v2 \n\t" \ + "vmadot v22, v14, v3 \n\t" \ + "vmadot v16, v15, v4 \n\t" \ + "vmadot v18, v15, v5 \n\t" \ + "vmadot v20, v15, v6 \n\t" \ + "vmadot v22, v15, v7 \n\t" + +#define SQ4BIT_KERNEL_ACC_1X4X4 \ + "vfcvt.f.x.v v16, v16 \n\t" \ + "vfcvt.f.x.v v18, v18 \n\t" \ + "vfcvt.f.x.v v20, v20 \n\t" \ + "vfcvt.f.x.v v22, v22 \n\t" \ + "addi s2, s1, 16 \n\t" \ + "addi s3, s1, 32 \n\t" \ + "addi s4, s1, 48 \n\t" \ + "addi s6, s5, 12 \n\t" \ + "vfmacc.vv v28, v16, v24 \n\t" \ + "vfmacc.vv v29, v18, v25 \n\t" \ + "vfmacc.vv v30, v20, v26 \n\t" \ + "vfmacc.vv v31, v22, v27 \n\t" + +#define SQ4BIT_KERNEL_ACC_F16_1X4X4 \ + "vfcvt.f.x.v v16, v16 \n\t" \ + "vfcvt.f.x.v v18, v18 \n\t" \ + "vfcvt.f.x.v v20, v20 \n\t" \ + "vfcvt.f.x.v v22, v22 \n\t" \ + "addi s2, s1, 8 \n\t" \ + "addi s3, s1, 16 \n\t" \ + "addi s4, s1, 24 \n\t" \ + "addi s6, s5, 12 \n\t" \ + "vfmacc.vv v28, v16, v24 \n\t" \ + "vfmacc.vv v29, v18, v25 \n\t" \ + "vfmacc.vv v30, v20, v26 \n\t" \ + "vfmacc.vv v31, v22, v27 \n\t" + +#define SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 \ + "vle8.v v4, (s1) \n\t" \ + "addi s1, s1, 128 \n\t" \ + "vle8.v v5, (s2) \n\t" \ + "addi s2, s2, 128 \n\t" \ + "vle8.v v6, (s3) \n\t" \ + "addi s3, s3, 128 \n\t" \ + "vle8.v v7, (s4) \n\t" \ + "addi s4, s4, 128 \n\t" \ + "vsetvli t0, zero, e8, mf4 \n\t" \ + "vle8.v v14, (s5) \n\t" \ + "addi s5, s5, 16 \n\t" \ + "vle8.v v15, (s6) \n\t" \ + "addi s6, s6, 16 \n\t" \ + "addi t5, t5, -1 \n\t" \ + "vsetvli t0, zero, e8, m1 \n\t" \ + "vand.vi v0, v4, 15 \n\t" \ + "vand.vi v1, v5, 15 \n\t" \ + "vand.vi v2, v6, 15 \n\t" \ + "vand.vi v3, v7, 15 \n\t" \ + "vsrl.vi v4, v4, 4 \n\t" \ + "vsrl.vi v5, v5, 4 \n\t" \ + "vsrl.vi v6, v6, 4 \n\t" \ + "vsrl.vi v7, v7, 4 \n\t" + +#define SQ4BIT_KERNEL_LOAD_ZP_16X1 \ + "vsetvli t0, zero, e8, mf2 \n\t" \ + "vle8.v v1, (s7) \n\t" \ + "vsetvli t0, zero, e8, m1 \n\t" \ + "vrgather.vv v8, v1, v13 \n\t" \ + "vadd.vi v13, v13, 4 \n\t" \ + "vrgather.vv v9, v1, v13 \n\t" \ + "vadd.vi v13, v13, 4 \n\t" \ + "vrgather.vv v10, v1, v13 \n\t" \ + "vadd.vi v13, v13, 4 \n\t" \ + "vrgather.vv v11, v1, v13 \n\t" \ + "vadd.vi v13, v13, -12 \n\t" + +// using for M4Kernel +#define LOAD_B_16x8x2 \ + "vsetvli t0, zero, e8, m1 \n\t" \ + "vle8.v v6, (s1) \n\t" \ + "addi s1, s1, 32*4 \n\t" \ + "vle8.v v7, (s2) \n\t" \ + "addi s2, s2, 32*4 \n\t" \ + "vle8.v v8, (s3) \n\t" \ + "addi s3, s3, 32*4 \n\t" \ + "vle8.v v9, (s4) \n\t" \ + "addi s4, s4, 32*4 \n\t" \ + \ + "vand.vi v2, v6, 15 \n\t" \ + "vand.vi v3, v7, 15 \n\t" \ + "vand.vi v4, v8, 15 \n\t" \ + "vand.vi v5, v9, 15 \n\t" \ + \ + "vsrl.vi v6, v6, 4 \n\t" \ + "vsrl.vi v7, v7, 4 \n\t" \ + "vsrl.vi v8, v8, 4 \n\t" \ + "vsrl.vi v9, v9, 4 \n\t" + +// [s2|s5, s3, s4, s6] +#define LOAD_SCALE_4x16_FP16 \ + "addi s2, s5, -8 \n\t" \ + "addi s3, s5, 8 \n\t" \ + "addi s4, s5, 16 \n\t" \ + "addi s6, s5, 24 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "vsetvli t0, zero, e16, mf4 \n\t" \ + "vle16.v v9, (s5) \n\t" \ + "vle16.v v11, (s3) \n\t" \ + "vle16.v v13, (s4) \n\t" \ + "vle16.v v15, (s6) \n\t" \ + "vsetvli t0, zero, e16, mf2 \n\t" \ + "vle16.v v9, (s2), v0.t \n\t" \ + "vle16.v v11, (s5), v0.t \n\t" \ + "vle16.v v13, (s3), v0.t \n\t" \ + "vle16.v v15, (s4), v0.t \n\t" \ + "vfwcvt.f.f.v v8, v9 \n\t" \ + "vfwcvt.f.f.v v10, v11 \n\t" \ + "vfwcvt.f.f.v v12, v13 \n\t" \ + "vfwcvt.f.f.v v14, v15 \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vmv.v.v v9, v8 \n\t" \ + "vmv.v.v v11, v10 \n\t" \ + "vmv.v.v v13, v12 \n\t" \ + "vmv.v.v v15, v14 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + "vfmul.vf v8, v8, f1 \n\t" \ + "vfmul.vf v10, v10, f1 \n\t" \ + "vfmul.vf v12, v12, f1 \n\t" \ + "vfmul.vf v14, v14, f1 \n\t" \ + "vfmul.vf v9, v9, f3 \n\t" \ + "vfmul.vf v11, v11, f3 \n\t" \ + "vfmul.vf v13, v13, f3 \n\t" \ + "vfmul.vf v15, v15, f3 \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vfmul.vf v8, v8, f2, v0.t \n\t" \ + "vfmul.vf v10, v10, f2, v0.t \n\t" \ + "vfmul.vf v12, v12, f2, v0.t \n\t" \ + "vfmul.vf v14, v14, f2, v0.t \n\t" \ + "vfmul.vf v9, v9, f4, v0.t \n\t" \ + "vfmul.vf v11, v11, f4, v0.t \n\t" \ + "vfmul.vf v13, v13, f4, v0.t \n\t" \ + "vfmul.vf v15, v15, f4, v0.t \n\t" + +// [s2|s5, s3, s4, s6] +#define LOAD_SCALE_4x16 \ + "addi s2, s5, -16 \n\t" \ + "addi s3, s5, 16 \n\t" \ + "addi s4, s5, 32 \n\t" \ + "addi s6, s5, 48 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + "vle32.v v8, (s5) \n\t" \ + "vle32.v v10, (s3) \n\t" \ + "vle32.v v12, (s4) \n\t" \ + "vle32.v v14, (s6) \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vle32.v v8, (s2), v0.t \n\t" \ + "vle32.v v10, (s5), v0.t \n\t" \ + "vle32.v v12, (s3), v0.t \n\t" \ + "vle32.v v14, (s4), v0.t \n\t" \ + "vmv.v.v v9, v8 \n\t" \ + "vmv.v.v v11, v10 \n\t" \ + "vmv.v.v v13, v12 \n\t" \ + "vmv.v.v v15, v14 \n\t" \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + "vfmul.vf v8, v8, f1 \n\t" \ + "vfmul.vf v10, v10, f1 \n\t" \ + "vfmul.vf v12, v12, f1 \n\t" \ + "vfmul.vf v14, v14, f1 \n\t" \ + "vfmul.vf v9, v9, f3 \n\t" \ + "vfmul.vf v11, v11, f3 \n\t" \ + "vfmul.vf v13, v13, f3 \n\t" \ + "vfmul.vf v15, v15, f3 \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vfmul.vf v8, v8, f2, v0.t \n\t" \ + "vfmul.vf v10, v10, f2, v0.t \n\t" \ + "vfmul.vf v12, v12, f2, v0.t \n\t" \ + "vfmul.vf v14, v14, f2, v0.t \n\t" \ + "vfmul.vf v9, v9, f4, v0.t \n\t" \ + "vfmul.vf v11, v11, f4, v0.t \n\t" \ + "vfmul.vf v13, v13, f4, v0.t \n\t" \ + "vfmul.vf v15, v15, f4, v0.t \n\t" + +//[s1| BIAS, s2, s3, s4] +#define LOAD_BIAS \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "addi s1, %[BIAS], -16 \n\t" \ + "addi s2, %[BIAS], 16 \n\t" \ + "addi s3, %[BIAS], 32 \n\t" \ + "addi s4, %[BIAS], 48 \n\t" \ + \ + "vle32.v v24, (%[BIAS]) \n\t" \ + "vle32.v v26, (s2) \n\t" \ + "vle32.v v28, (s3) \n\t" \ + "vle32.v v30, (s4) \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vle32.v v24, (s1), v0.t \n\t" \ + "vle32.v v26, (%[BIAS]), v0.t \n\t" \ + "vle32.v v28, (s2), v0.t \n\t" \ + "vle32.v v30, (s3), v0.t \n\t" \ + "vmv.v.v v25, v24 \n\t" \ + "vmv.v.v v27, v26 \n\t" \ + "vmv.v.v v29, v28 \n\t" \ + "vmv.v.v v31, v30 \n\t" + +#define SQ4BIT_KERNEL_COMP_4x16x16 \ + "vmadot v16, v10, v2 \n\t" \ + "vmadot v18, v10, v3 \n\t" \ + "vmadot v20, v10, v4 \n\t" \ + "vmadot v22, v10, v5 \n\t" \ + "vmadot v16, v11, v6 \n\t" \ + "vmadot v18, v11, v7 \n\t" \ + "vmadot v20, v11, v8 \n\t" \ + "vmadot v22, v11, v9 \n\t" + +#define SAVE_RESULT_4x16 \ + "addi a1, %[C], 0 \n\t" \ + "add a2, %[C], %[LDC] \n\t" \ + "add a3, a2, %[LDC] \n\t" \ + "add a4, a3, %[LDC] \n\t" \ + "addi a2, a2, -16 \n\t" \ + "addi a4, a4, -16 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + \ + "vse32.v v24, (a1) \n\t" \ + "addi a1, a1, 16 \n\t" \ + "vse32.v v25, (a3) \n\t" \ + "addi a3, a3, 16 \n\t" \ + \ + "vse32.v v26, (a1) \n\t" \ + "addi a1, a1, 16 \n\t" \ + "vse32.v v27, (a3) \n\t" \ + "addi a3, a3, 16 \n\t" \ + \ + "vse32.v v28, (a1) \n\t" \ + "addi a1, a1, 16 \n\t" \ + "vse32.v v29, (a3) \n\t" \ + "addi a3, a3, 16 \n\t" \ + \ + "vse32.v v30, (a1) \n\t" \ + "vse32.v v31, (a3) \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + \ + "vse32.v v24, (a2), v0.t \n\t" \ + "addi a2, a2, 16 \n\t" \ + "vse32.v v25, (a4), v0.t \n\t" \ + "addi a4, a4, 16 \n\t" \ + \ + "vse32.v v26, (a2), v0.t \n\t" \ + "addi a2, a2, 16 \n\t" \ + "vse32.v v27, (a4), v0.t \n\t" \ + "addi a4, a4, 16 \n\t" \ + \ + "vse32.v v28, (a2), v0.t \n\t" \ + "addi a2, a2, 16 \n\t" \ + "vse32.v v29, (a4), v0.t \n\t" \ + "addi a4, a4, 16 \n\t" \ + \ + "vse32.v v30, (a2), v0.t \n\t" \ + "vse32.v v31, (a4), v0.t \n\t" + +#define SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 \ + "vsetvli t0, zero, e8, mf2 \n\t" \ + "vle8.v v11, (s6) \n\t" \ + "vsetvli t0, zero, e8, m1 \n\t" \ + "vrgather.vv v12, v11, v1 \n\t" \ + "vadd.vi v1, v1, 4 \n\t" \ + "vrgather.vv v13, v11, v1 \n\t" \ + "vadd.vi v1, v1, 4 \n\t" \ + "vrgather.vv v14, v11, v1 \n\t" \ + "vadd.vi v1, v1, 4 \n\t" \ + "vrgather.vv v15, v11, v1 \n\t" \ + "vadd.vi v1, v1, -12 \n\t" + +template +void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountN, + size_t BlockCountK, + const float * Bias, + const size_t ldc) { + GGML_UNUSED(QuantBScale); + GGML_UNUSED(QuantBZeroPoint); + size_t LDC = ldc * sizeof(float); + const size_t INNER = BlkLen / 16; + float tmp[4 * 16]; + + if constexpr (HasZeroPoint) { + for (size_t n = 0; n < CountN; n += 16) { + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(_Float16); // scale + float * CPtr = C + n; + if (NBLKS < 16) { + CPtr = tmp; + LDC = 16 * sizeof(float); + } + if (Bias != nullptr) { + const float * bias = Bias + n; + if (NBLKS < 16) { + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "vse32.v v0, (%[DST]) \n\t" + : + : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) + : "cc", "t0"); + bias = tmp; + } + __asm__ volatile(LOAD_BIAS + + "addi t3, %[BlockCountK], 0 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 32 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", + "s2", "s3", "s4", "s5", "s6"); + + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 32 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", + "s4", "s5", "s6"); + } + } + } else { + for (size_t n = 0; n < CountN; n += 16) { + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(_Float16); // scale + float * CPtr = C + n; + if (NBLKS < 16) { + CPtr = tmp; + LDC = 16 * sizeof(float); + } + if (Bias != nullptr) { + const float * bias = Bias + n; + if (NBLKS < 16) { + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "vse32.v v0, (%[DST]) \n\t" + : + : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) + : "cc", "t0"); + bias = tmp; + } + __asm__ volatile(LOAD_BIAS + + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 32 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", + "s2", "s3", "s4", "s5", "s6"); + + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 32 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", + "s4", "s5", "s6"); + } + } + } + if (CountN % 16 != 0) { + // stroe output from tmp to C when NBLKS less than 16. + float * CPtr = C + CountN / 16 * 16; + const size_t N = CountN % 16; + LDC = ldc * sizeof(float); + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi s2, %[SRC], 64 \n\t" + "addi s3, %[SRC], 64*2 \n\t" + "addi s4, %[SRC], 64*3 \n\t" + "vle32.v v2, (s2) \n\t" + "vle32.v v4, (s3) \n\t" + "vle32.v v6, (s4) \n\t" + "add t2, %[DST], %[LDC] \n\t" + "add t3, t2, %[LDC] \n\t" + "add t4, t3, %[LDC] \n\t" + "vse32.v v0, (%[DST]) \n\t" + "vse32.v v2, (t2) \n\t" + "vse32.v v4, (t3) \n\t" + "vse32.v v6, (t4) \n\t" + : + : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) + : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); + } +} + +template +void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountN, + size_t BlockCountK, + const float * Bias, + const size_t ldc) { + GGML_UNUSED(QuantBScale); + GGML_UNUSED(QuantBZeroPoint); + size_t LDC = ldc * sizeof(float); + const size_t INNER = BlkLen / 16; + float tmp[4 * 16]; + + if constexpr (HasZeroPoint) { + for (size_t n = 0; n < CountN; n += 16) { + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(float); // scale + float * CPtr = C + n; + if (NBLKS < 16) { + CPtr = tmp; + LDC = 16 * sizeof(float); + } + if (Bias != nullptr) { + const float * bias = Bias + n; + if (NBLKS < 16) { + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "vse32.v v0, (%[DST]) \n\t" + : + : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) + : "cc", "t0"); + bias = tmp; + } + + __asm__ volatile(LOAD_BIAS + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 64 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", + "s2", "s3", "s4", "s5", "s6"); + + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 64 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", + "s4", "s5", "s6"); + } + } + } else { + for (size_t n = 0; n < CountN; n += 16) { + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(float); // scale + float * CPtr = C + n; + if (NBLKS < 16) { + CPtr = tmp; + LDC = 16 * sizeof(float); + } + if (Bias != nullptr) { + const float * bias = Bias + n; + if (NBLKS < 16) { + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "vse32.v v0, (%[DST]) \n\t" + : + : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) + : "cc", "t0"); + bias = tmp; + } + __asm__ volatile(LOAD_BIAS + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 64 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", + "s2", "s3", "s4", "s5", "s6"); + + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 64 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", + "s4", "s5", "s6"); + } + } + } + if (CountN % 16 != 0) { + // stroe output from tmp to C when NBLKS less than 16. + float * CPtr = C + CountN / 16 * 16; + const size_t N = CountN % 16; + LDC = ldc * sizeof(float); + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi s2, %[SRC], 64 \n\t" + "addi s3, %[SRC], 64*2 \n\t" + "addi s4, %[SRC], 64*3 \n\t" + "vle32.v v2, (s2) \n\t" + "vle32.v v4, (s3) \n\t" + "vle32.v v6, (s4) \n\t" + "add t2, %[DST], %[LDC] \n\t" + "add t3, t2, %[LDC] \n\t" + "add t4, t3, %[LDC] \n\t" + "vse32.v v0, (%[DST]) \n\t" + "vse32.v v2, (t2) \n\t" + "vse32.v v4, (t3) \n\t" + "vse32.v v6, (t4) \n\t" + : + : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) + : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); + } +} + +template +void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountN, + size_t BlockCountK, + const float * Bias) { + GGML_UNUSED(QuantBScale); + GGML_UNUSED(QuantBZeroPoint); + size_t INNER = BlkLen / 16; + + if constexpr (HasZeroPoint) { + for (size_t n = 0; n < CountN; n += 16) { + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(_Float16); // scale + float * CPtr = C + n; + size_t cnt = BlockCountK; + if (Bias != nullptr) { + const float * bias = Bias + n; + __asm__ volatile( + "addi t3, %[NBLKS], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + // zp offset + "addi s7, %[B], 32 \n\t" + // a offset + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v28, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v29, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v30, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v31, (%[BIAS]) \n\t" + + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 48 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 72 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 120 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + + "vsetvli t0, zero, e32, mf2 \n\t" + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + "addi s7, s1, 32 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s7, %[B], 32 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 48 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 72 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 120 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + "addi s7, s1, 32 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); + } + } + } else { + for (size_t n = 0; n < CountN; n += 16) { + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(_Float16); // scale + float * CPtr = C + n; + size_t cnt = BlockCountK; + if (Bias != nullptr) { + const float * bias = Bias + n; + __asm__ volatile( + "addi t3, %[NBLKS], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v28, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v29, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v30, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v31, (%[BIAS]) \n\t" + + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 56 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 80 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 104 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + + "vsetvli t0, zero, e32, mf2 \n\t" + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 56 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 80 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 104 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } + } + } +} + +template +void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountN, + size_t BlockCountK, + const float * Bias) { + GGML_UNUSED(QuantBScale); + GGML_UNUSED(QuantBZeroPoint); + const size_t INNER = BlkLen / 16; + if constexpr (HasZeroPoint) { + for (size_t n = 0; n < CountN; n += 16) { + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(float); // scale + float * CPtr = C + n; + size_t cnt = BlockCountK; + if (Bias != nullptr) { + const float * bias = Bias + n; + __asm__ volatile( + "addi t3, %[NBLKS], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + + // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0 + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 16 \n\t" + "addi s3, %[B], 32 \n\t" + "addi s4, %[B], 48 \n\t" + // zp offset + "addi s7, %[B], 64 \n\t" + // a offset + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v28, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v29, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v30, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v31, (%[BIAS]) \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + "LOOP_K%=: \n\t" + + // load scale + "vle32.v v8, (s1) \n\t" + "addi s1, s1, 80 \n\t" + "vle32.v v9, (s2) \n\t" + "addi s2, s2, 96 \n\t" + "vle32.v v10, (s3) \n\t" + "addi s3, s3, 112 \n\t" + "vle32.v v11, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + // load a scale + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + + // a scale * b scale + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_1X4X4 + "addi s7, s1, 64 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 16 \n\t" + "addi s3, %[B], 32 \n\t" + "addi s4, %[B], 48 \n\t" + + "addi s7, %[B], 64 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "LOOP_K%=: \n\t" + "vle32.v v8, (s1) \n\t" + "addi s1, s1, 80 \n\t" + "vle32.v v9, (s2) \n\t" + "addi s2, s2, 96 \n\t" + "vle32.v v10, (s3) \n\t" + "addi s3, s3, 112 \n\t" + "vle32.v v11, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_1X4X4 + "addi s7, s1, 64 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); + } + } + } else { + for (size_t n = 0; n < CountN; n += 16) { + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(float); // scale + float * CPtr = C + n; + size_t cnt = BlockCountK; + if (Bias != nullptr) { + const float * bias = Bias + n; + __asm__ volatile( + "addi t3, %[NBLKS], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 16 \n\t" + "addi s3, %[B], 32 \n\t" + "addi s4, %[B], 48 \n\t" + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v28, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v29, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v30, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v31, (%[BIAS]) \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + "LOOP_K%=: \n\t" + "vle32.v v8, (s1) \n\t" + "addi s1, s1, 64 \n\t" + "vle32.v v9, (s2) \n\t" + "addi s2, s2, 80 \n\t" + "vle32.v v10, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle32.v v11, (s4) \n\t" + "addi s4, s4, 112 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 16 \n\t" + "addi s3, %[B], 32 \n\t" + "addi s4, %[B], 48 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + "LOOP_K%=: \n\t" + "vle32.v v8, (s1) \n\t" + "addi s1, s1, 64 \n\t" + "vle32.v v9, (s2) \n\t" + "addi s2, s2, 80 \n\t" + "vle32.v v10, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle32.v v11, (s4) \n\t" + "addi s4, s4, 112 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } + } + } +} + +template +inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountM, + size_t CountN, + size_t BlockStrideQuantB, + const float * Bias, + const size_t ldc, + const size_t scalestride) { + if (scalestride == 4) { + SQ4BitGemmM4Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, + CountN, BlockStrideQuantB, Bias, ldc); + + } else if (scalestride == 2) { + SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl( + BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc); + } +} + +template +inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountM, + size_t CountN, + size_t BlockStrideQuantB, + const float * Bias, + const size_t ldc, + const size_t scalestride) { + if (scalestride == 4) { + SQ4BitGemmM1Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, + CountN, BlockStrideQuantB, Bias); + } else if (scalestride == 2) { + SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(BlkLen, QuantA, QuantBData, QuantBScale, + QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias); + } +} + +} // namespace + +namespace ime1 { +size_t gemm_kernel_i8i4(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + size_t ldc, + const float * Bias, + const size_t ScaleStride) { + GGML_UNUSED(CountM); + GGML_UNUSED(CountK); + GGML_UNUSED(ldc); + if (CountM >= 4) { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); + } else { + SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, + QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, + ldc, ScaleStride); + } + return 4; + } else { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); + } else { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, + QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, + ldc, ScaleStride); + } + return 1; + } +} +} // namespace ime1 +} // namespace sqnbitgemm_spacemit_ime diff --git a/ggml/src/ggml-cpu/spacemit/ime_kernels.h b/ggml/src/ggml-cpu/spacemit/ime_kernels.h new file mode 100644 index 0000000000..7570634150 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime_kernels.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +namespace sqnbitgemm_spacemit_ime { +namespace ime1 { +size_t gemm_kernel_i8i4(size_t blk_len, + const std::byte * quant_a_ptr, + const std::byte * quant_b_data, + const float * quant_b_scale, + const std::byte * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t count_k, + size_t block_count_k, + size_t ldc, + const float * bias, + const size_t scale_stride); + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); + +} // namespace ime1 +} // namespace sqnbitgemm_spacemit_ime diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index ef334d089d..341e64e64f 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -610,7 +610,7 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, co for (int i = 0; i < np; i += GGML_F32_STEP) { for (int j = 0; j < GGML_F32_ARR; j++) { ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb); + ay[j] = GGML_F32_VEC_FMA(vb, ay[j], vs); GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); } diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 725e1a81a1..6024010274 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -54,7 +54,7 @@ static __global__ void k_bin_bcast(const src0_t * src0, const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3); const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z); - if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) { + if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) { return; } diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 3b1349171b..c4246b65eb 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -586,17 +586,42 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, #endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA)) } +static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) { +#ifdef FAST_FP16_AVAILABLE + acc += v*u; +#else + const float2 tmpv = __half22float2(v); + const float2 tmpu = __half22float2(u); + float2 tmpacc = __half22float2(acc); + tmpacc.x += tmpv.x * tmpu.x; + tmpacc.y += tmpv.y * tmpu.y; + acc = make_half2(tmpacc.x, tmpacc.y); +#endif // FAST_FP16_AVAILABLE +} + // Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD. -template +template static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) { - if constexpr (nbytes == 4) { - *(int *) dst = *(const int *) src; - } else if constexpr (nbytes == 8) { - *(int2 *) dst = *(const int2 *) src; - } else if constexpr (nbytes == 16) { - *(int4 *) dst = *(const int4 *) src; - } else { - static_assert(nbytes == 0 && nbytes == -1, "bad nbytes"); + if constexpr (alignment != 0) { + static_assert(nbytes % alignment == 0, "bad alignment"); + } + constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment; + +#pragma unroll + for (int i = 0; i < nbytes/nb_per_cpy; ++i) { + if constexpr (nb_per_cpy == 1) { + ((char *) dst)[i] = ((const char *) src)[i]; + } else if constexpr (nb_per_cpy == 2) { + ((short *) dst)[i] = ((const short *) src)[i]; + } else if constexpr (nb_per_cpy == 4) { + ((int *) dst)[i] = ((const int *) src)[i]; + } else if constexpr (nb_per_cpy == 8) { + ((int2 *) dst)[i] = ((const int2 *) src)[i]; + } else if constexpr (nb_per_cpy == 16) { + ((int4 *) dst)[i] = ((const int4 *) src)[i]; + } else { + static_assert(nbytes == 0 && nbytes == -1, "bad nbytes"); + } } } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 1b763a6289..746f43966b 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -329,7 +329,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY { - CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + if (src0->type == GGML_TYPE_F32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else { + CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + } } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); @@ -400,7 +404,13 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - return nullptr; + // Prioritize CUDA graph compatibility over direct memory copy optimization. + // Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs. + if (src0->type == GGML_TYPE_F32) { + return (void*) cpy_flt>; + } else { + return nullptr; + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 142a3a88d1..33d2f0f49e 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -33,276 +33,230 @@ typedef void (* fattn_kernel_t)( const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33); -typedef half (*vec_dot_KQ_f16_t)( - const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); -typedef float (*vec_dot_KQ_f32_t)( +typedef float (*vec_dot_KQ_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); -template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( - const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { - - const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; - GGML_UNUSED(Q_v); - - T sum = 0.0f; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const int ib = k_KQ / QI8_1; - const int iqs4 = k_KQ % QI4_0; - const int shift = k_KQ & (QI8_1/2); - - const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; - const int u = Q_q8[k_KQ_0/warp_size]; - - const int sumi = ggml_cuda_dp4a(v, u, 0); - -#ifdef FP16_AVAILABLE - if (std::is_same::value) { - const half2 * Q_ds = (const half2 *) Q_ds_v; - - const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size]; - sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */); - } else -#endif // FP16_AVAILABLE - { - const float2 * Q_ds = (const float2 *) Q_ds_v; - - sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y)); - } - } - - return sum; -} - -template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( - const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { - - const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; - GGML_UNUSED(Q_v); - - T sum = 0.0f; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const int ib = k_KQ / QI8_1; - const int iqs4 = k_KQ % QI4_1; - const int shift = k_KQ & (QI8_1/2); - - const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; - const int u = Q_q8[k_KQ_0/warp_size]; - - const int sumi = ggml_cuda_dp4a(v, u, 0); - -#ifdef FP16_AVAILABLE - if (std::is_same::value) { - const half2 * Q_ds = (const half2 *) Q_ds_v; - - const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size]; - const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1); - sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled)); - } else -#endif // FP16_AVAILABLE - { - const float2 * Q_ds = (const float2 *) Q_ds_v; - - const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi; - const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1; - - sum += (T) (sumid4d8 + m4s8scaled); - } - } - - return sum; -} - -template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( - const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { - - const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; - GGML_UNUSED(Q_v); - - T sum = 0.0f; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const int ib = k_KQ / QI8_1; - const int iqs4 = k_KQ % QI5_0; - const int iqs8 = k_KQ % QI8_1; - const int shift = k_KQ & (QI8_1/2); - - int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; - const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0); - v |= (vh << 4) & 0x00000010; // 0 -> 4 - v |= (vh << 11) & 0x00001000; // 1 -> 12 - v |= (vh << 18) & 0x00100000; // 2 -> 20 - v |= (vh << 25) & 0x10000000; // 3 -> 28 - - const int u = Q_q8[k_KQ_0/warp_size]; - - const int sumi = ggml_cuda_dp4a(v, u, 0); - -#ifdef FP16_AVAILABLE - if (std::is_same::value) { - const half2 * Q_ds = (const half2 *) Q_ds_v; - - const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size]; - sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */; - } else -#endif // FP16_AVAILABLE - { - const float2 * Q_ds = (const float2 *) Q_ds_v; - - sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y)); - } - } - - return sum; -} - -template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( - const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { - - const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; - GGML_UNUSED(Q_v); - - T sum = 0.0f; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const int ib = k_KQ / QI8_1; - const int iqs4 = k_KQ % QI5_1; - const int iqs8 = k_KQ % QI8_1; - const int shift = k_KQ & (QI8_1/2); - - int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; - const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1); - v |= (vh << 4) & 0x00000010; // 0 -> 4 - v |= (vh << 11) & 0x00001000; // 1 -> 12 - v |= (vh << 18) & 0x00100000; // 2 -> 20 - v |= (vh << 25) & 0x10000000; // 3 -> 28 - - const int u = Q_q8[k_KQ_0/warp_size]; - - const int sumi = ggml_cuda_dp4a(v, u, 0); - -#ifdef FP16_AVAILABLE - if (std::is_same::value) { - const half2 * Q_ds = (const half2 *) Q_ds_v; - - const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size]; - const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1); - sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled)); - } else -#endif // FP16_AVAILABLE - { - const float2 * Q_ds = (const float2 *) Q_ds_v; - - const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi; - const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1; - - sum += (T) (sumid5d8 + m5s8scaled); - } - } - - return sum; -} - -template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( - const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { - - const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; - GGML_UNUSED(Q_v); - - T sum = 0.0f; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const int ib = k_KQ / QI8_0; - const int iqs = k_KQ % QI8_0; - - const int v = get_int_b2(K_q8_0[ib].qs, iqs); - - T Q_d; - if (std::is_same::value) { - const half2 * Q_ds = (const half2 *) Q_ds_v; - Q_d = __low2half(Q_ds[k_KQ_0/warp_size]); - } else { - const float2 * Q_ds = (const float2 *) Q_ds_v; - Q_d = Q_ds[k_KQ_0/warp_size].x; - } - - sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d); - } - - return sum; -} - -template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { const half2 * K_h2 = (const half2 *) K_c; GGML_UNUSED(Q_q8); GGML_UNUSED(Q_ds_v); -#ifdef FP16_AVAILABLE - if (std::is_same::value) { - const half2 * Q_h2 = (const half2 *) Q_v; - - half2 sum2 = make_half2(0.0f, 0.0f); - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const half2 K_ik = K_h2[k_KQ]; - sum2 += K_ik * Q_h2[k_KQ_0/warp_size]; - } - - return __low2half(sum2) + __high2half(sum2); - } -#endif // FP16_AVAILABLE - - const float2 * Q_f2 = (const float2 *) Q_v; + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; float sum = 0.0f; #pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const half2 K_ik = K_h2[k_KQ]; - sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x; - sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y; + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { + half2 tmp[cpy_ne]; + ggml_cuda_memcpy_1(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#else + ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#endif // FP16_AVAILABLE + } } return sum; } -template +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_0; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_cuda_memcpy_1(&v, K_q4_0[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + + const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads]; + sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x - (8/QI8_1)*Q_ds.y); + } + + return sum; +} + +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_1( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_1; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_cuda_memcpy_1(&v, K_q4_1[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + + const float2 K_dm = __half22float2(K_q4_1[ib].dm); + const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads]; + + sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1; + } + + return sum; +} + +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_0; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_cuda_memcpy_1(&v, K_q5_0[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + + { + int vh; + ggml_cuda_memcpy_1(&vh, K_q5_0[ib].qh); + vh >>= iqs8 * QI5_0; + + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + } + + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + + const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads]; + + sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x - (16/QI8_1)*Q_ds.y); + } + + return sum; +} + +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_1( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_1; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_cuda_memcpy_1(&v, K_q5_1[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + + { + int vh; + ggml_cuda_memcpy_1(&vh, K_q5_1[ib].qh); + vh >>= iqs8 * QI5_0; + + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + } + + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + + const float2 K_dm = __half22float2(K_q5_1[ib].dm); + const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads]; + + sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1; + } + + return sum; +} + +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q8_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int ib = k_KQ / QI8_0; + const int iqs = k_KQ % QI8_0; + + int v; + ggml_cuda_memcpy_1(&v, K_q8_0[ib].qs + 4*iqs); + + const float2 * Q_ds = (const float2 *) Q_ds_v; + const float Q_d = Q_ds[k_KQ_0/nthreads].x; + + sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d); + } + + return sum; +} + +template static __device__ __forceinline__ void quantize_q8_1_to_shared( const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) { float vals[sizeof(int)] = {0.0f}; #pragma unroll for (int l = 0; l < int(sizeof(int)); ++l) { - vals[l] = scale * x[4*threadIdx.x + l]; + vals[l] = (ni == WARP_SIZE || threadIdx.x < ni) ? scale * x[4*threadIdx.x + l] : 0.0f; } float amax = fabsf(vals[0]); @@ -330,7 +284,7 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared( } yq32[threadIdx.x] = q32; - if (threadIdx.x % QI8_1 == 0) { + if (threadIdx.x % QI8_1 == 0 && (ni == WARP_SIZE || threadIdx.x < ni)) { if (std::is_same::value) { ((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum); } else { @@ -339,167 +293,276 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared( } } -typedef half (*dequantize_1_f16_t)(const void *, const int64_t); -typedef float (*dequantize_1_f32_t)(const void *, const int64_t); +typedef void (*dequantize_V_t)(const void *, void *, const int64_t); -template -static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) { +template +static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + if constexpr (std::is_same_v) { + ggml_cuda_memcpy_1(dst, (const half *) vx + i0); + } else if constexpr (std::is_same_v) { + static_assert(ne % 2 == 0, "bad ne"); + half2 tmp[ne/2]; + ggml_cuda_memcpy_1(tmp, (const half *) vx + i0); + float2 * dst_f2 = (float2 *) dst; +#pragma unroll + for (int l = 0; l < ne/2; ++l) { + dst_f2[l] = __half22float2(tmp[l]); + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } +} + +template +static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q4_0 * x = (const block_q4_0 *) vx; - const int64_t ib = i / QK4_0; - const int iqs = i % (QK4_0/2); - const int shift = (i % QK4_0) / (QK4_0/2); + const int64_t ib = i0 / QK4_0; + const int iqs = i0 % (QK4_0/2); + const int shift = (i0 % QK4_0) / (QK4_0/2); - const T d = x[ib].d; - const int q0 = x[ib].qs[iqs]; - const int q = ((q0 >> (4*shift)) & 0x0F) - 8; + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_cuda_memcpy_1(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + q = __vsubss4(q, 0x08080808); + + const int8_t * q8 = (const int8_t *) &q; #ifdef FP16_AVAILABLE - if (std::is_same::value) { - return ((half) d)*((half) q); - } -#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const half2 d = __half2half2(x[ib].d); - return ((float) d)*((float) q); +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]); + } + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const float d = x[ib].d; + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = d * q8[l]; + } + } else { + static_assert(std::is_same_v, "bad type"); + } } -template -static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) { +template +static __device__ __forceinline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q4_1 * x = (const block_q4_1 *) vx; - const int64_t ib = i / QK4_1; - const int iqs = i % (QK4_1/2); - const int shift = (i % QK4_1) / (QK4_1/2); + const int64_t ib = i0 / QK4_1; + const int iqs = i0 % (QK4_1/2); + const int shift = (i0 % QK4_1) / (QK4_1/2); - const half2 dm = x[ib].dm; - const int q0 = x[ib].qs[iqs]; - const int q = ((q0 >> (4*shift)) & 0x0F); + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_cuda_memcpy_1(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + + const int8_t * q8 = (const int8_t *) &q; #ifdef FP16_AVAILABLE - if (std::is_same::value) { - return __low2half(dm)*((half) q) + __high2half(dm); - } -#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const half2 dm = x[ib].dm; + const half2 d = __half2half2( __low2half(dm)); + const half2 m = __half2half2(__high2half(dm)); - return __low2float(dm)*((float) q) + __high2float(dm); +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m; + } + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const float2 dm = __half22float2(x[ib].dm); + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = dm.x * q8[l] + dm.y; + } + } else { + static_assert(std::is_same_v, "bad type"); + } } -template -static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) { +template +static __device__ __forceinline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q5_0 * x = (const block_q5_0 *) vx; - const int64_t ib = i / QK5_0; - const int idq = i % QK5_0; - const int iqs = i % (QK5_0/2); - const int shift = (i % QK5_0) / (QK5_0/2); + const int64_t ib = i0 / QK5_0; + const int idq = i0 % QK5_0; + const int iqs = i0 % (QK5_0/2); + const int shift = (i0 % QK5_0) / (QK5_0/2); - const T d = x[ib].d; - const int ql0 = x[ib].qs[iqs]; - const int qh0 = get_int_b2(x[ib].qh, 0); - const int ql = ((ql0 >> (4*shift)) & 0x0F); - const int qh = ((qh0 >> idq) << 4) & 0x10; - const int q = (ql | qh) - 16; + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_cuda_memcpy_1(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + + { + int qh; + ggml_cuda_memcpy_1(&qh, x[ib].qh); +#pragma unroll + for (int l = 0; l < ne; ++l) { + q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4); + } + } + + q = __vsubss4(q, 0x10101010); + + const int8_t * q8 = (const int8_t *) &q; #ifdef FP16_AVAILABLE - if (std::is_same::value) { - return ((half) d)*((half) q); - } -#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const half2 d = __half2half2(x[ib].d); - return ((float) d)*((float) q); +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]); + } + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const float d = x[ib].d; + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = d * q8[l]; + } + } else { + static_assert(std::is_same_v, "bad type"); + } } -template -static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) { +template +static __device__ __forceinline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q5_1 * x = (const block_q5_1 *) vx; - const int64_t ib = i / QK5_1; - const int idq = i % QK5_1; - const int iqs = i % (QK5_1/2); - const int shift = (i % QK5_1) / (QK5_1/2); + const int64_t ib = i0 / QK5_1; + const int idq = i0 % QK5_1; + const int iqs = i0 % (QK5_1/2); + const int shift = (i0 % QK5_1) / (QK5_1/2); - const half2 dm = x[ib].dm; - const int ql0 = x[ib].qs[iqs]; - const int qh0 = get_int_b4(x[ib].qh, 0); - const int ql = ((ql0 >> (4*shift)) & 0x0F); - const int qh = ((qh0 >> idq) << 4) & 0x10; - const int q = (ql | qh); + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_cuda_memcpy_1(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + + { + int qh; + ggml_cuda_memcpy_1(&qh, x[ib].qh); +#pragma unroll + for (int l = 0; l < ne; ++l) { + q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4); + } + } + + const int8_t * q8 = (const int8_t *) &q; #ifdef FP16_AVAILABLE - if (std::is_same::value) { - return __low2half(dm)*((half) q) + __high2half(dm); - } -#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const half2 dm = x[ib].dm; + const half2 d = __half2half2( __low2half(dm)); + const half2 m = __half2half2(__high2half(dm)); - return __low2float(dm)*((float) q) + __high2float(dm); +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m; + } + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same_v) { + const float2 dm = __half22float2(x[ib].dm); + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = dm.x * q8[l] + dm.y; + } + } else { + static_assert(std::is_same_v, "bad type"); + } } -template -static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) { +template +static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q8_0 * x = (const block_q8_0 *) vx; - const int64_t ib = i / QK8_0; - const int iqs = i % QK8_0; + const int64_t ib = i0 / QK8_0; + const int iqs = i0 % QK8_0; - const T d = x[ib].d; - const int q = x[ib].qs[iqs]; + static_assert(ne % 2 == 0, "bad ne"); + int8_t qs[ne]; + ggml_cuda_memcpy_1(qs, x[ib].qs + iqs); #ifdef FP16_AVAILABLE - if (std::is_same::value) { - return ((half) d)*((half) q); - } + if constexpr (std::is_same::value) { + const half2 d = __half2half2(x[ib].d); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((half2 *) dst)[l0/2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]); + } + } else #endif // FP16_AVAILABLE + if constexpr (std::is_same::value) { + const float d = x[ib].d; - return ((float) d)*((float) q); +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = d * qs[l]; + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } } -template -static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) { - const half * x = (const half *) vx; - - return x[i]; +template +constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { + if constexpr (type_K == GGML_TYPE_F16) { + return vec_dot_fattn_vec_KQ_f16; + } else if constexpr (type_K == GGML_TYPE_Q4_0) { + return vec_dot_fattn_vec_KQ_q4_0; + } else if constexpr (type_K == GGML_TYPE_Q4_1) { + return vec_dot_fattn_vec_KQ_q4_1; + } else if constexpr (type_K == GGML_TYPE_Q5_0) { + return vec_dot_fattn_vec_KQ_q5_0; + } else if constexpr (type_K == GGML_TYPE_Q5_1) { + return vec_dot_fattn_vec_KQ_q5_1; + } else if constexpr (type_K == GGML_TYPE_Q8_0) { + return vec_dot_fattn_vec_KQ_q8_0; + } else { + static_assert(type_K == -1, "bad type"); + return nullptr; + } } -template -constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) { - return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : - type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : - type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : - type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : - type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : - type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : - nullptr; -} - -template -constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) { - return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : - type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : - type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : - type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : - type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : - type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : - nullptr; -} - -constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) { - return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0 : - type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1 : - type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0 : - type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1 : - type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0 : - type_V == GGML_TYPE_F16 ? dequantize_1_f16 : - nullptr; -} - -constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { - return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0 : - type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1 : - type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0 : - type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1 : - type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0 : - type_V == GGML_TYPE_F16 ? dequantize_1_f16 : - nullptr; +template +constexpr __device__ dequantize_V_t get_dequantize_V() { + if constexpr (type_V == GGML_TYPE_F16) { + return dequantize_V_f16; + } else if constexpr (type_V == GGML_TYPE_Q4_0) { + return dequantize_V_q4_0; + } else if constexpr (type_V == GGML_TYPE_Q4_1) { + return dequantize_V_q4_1; + } else if constexpr (type_V == GGML_TYPE_Q5_0) { + return dequantize_V_q5_0; + } else if constexpr (type_V == GGML_TYPE_Q5_1) { + return dequantize_V_q5_1; + } else if constexpr (type_V == GGML_TYPE_Q8_0) { + return dequantize_V_q8_0; + } else { + static_assert(type_V == -1, "bad type"); + return nullptr; + } } template @@ -870,7 +933,7 @@ void launch_fattn( const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead. - if (efficiency_percent_best >= 90 && nwaves > nwaves_best) { + if (efficiency_percent_best >= 95 && nwaves > nwaves_best) { break; } diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh deleted file mode 100644 index 27a2dd6ae4..0000000000 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ /dev/null @@ -1,495 +0,0 @@ -#include "common.cuh" -#include "fattn-common.cuh" - -// Currenlty llvm with the amdgcn target dose not support unrolling loops -// that contain a break that can not be resolved at compile time. -#ifdef __clang__ -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wpass-failed" -#endif // __clang__ -template // D == head size -#ifndef GGML_USE_HIP -__launch_bounds__(D, 1) -#endif // GGML_USE_HIP -static __global__ void flash_attn_vec_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) - - // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(D == 128 || D == 256)) { - NO_DEVICE_CODE; - return; - } -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - if (ncols > 1) { - NO_DEVICE_CODE; - return; - } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16(type_K); - constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; - constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V); - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb03*sequence + nb02* head + nb01*ic0; - K += nb13*sequence + nb12*(head / gqa_ratio); - V += nb23*sequence + nb22*(head / gqa_ratio); - - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float * sinksf = (const float *) (sinks); - - const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - const half slopeh = __float2half(slopef); - - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - constexpr int nwarps = D / WARP_SIZE; - const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < D); - - __shared__ half KQ[ncols*D]; - half2 * KQ2 = (half2 *) KQ; - - half kqmax[ncols]; - half kqsum[ncols]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqmax[j] = -HALF_MAX_HALF; - kqsum[j] = 0.0f; - } - - __shared__ half kqmax_shared[ncols][WARP_SIZE]; - __shared__ half kqsum_shared[ncols][WARP_SIZE]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - if (threadIdx.y == 0) { - kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF; - kqsum_shared[j][threadIdx.x] = 0.0f; - } - } - - __shared__ half maskh_shared[ncols*D]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - maskh_shared[j*D + tid] = 0.0f; - } - - __syncthreads(); - - // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers: - half2 Q_h2[ncols][D/(2*WARP_SIZE)]; - int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)]; - half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; - if (Q_q8_1) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (j0 + nwarps > ncols && j >= ncols) { - break; - } - - // Reuse KQ as temporary storage for converting Q to q8_1: - int * tmp_q_i32 = (int *) &KQ[j*D]; - half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); - - // Set memory to zero if out of bounds: - if (ncols > 2 && ic0 + j >= ne01) { -#pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - tmp_q_i32[i] = 0; - } - if (threadIdx.x < D/QK8_1) { - tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f); - } - continue; - } - - const float * Q_f = (const float *) (Q + j*nb01); -#pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { - quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); - } - } - - __syncthreads(); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - int * tmp_q_i32 = (int *) &KQ[j*D]; - half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); - -#pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; - Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1]; - } - } - - __syncthreads(); - } else { -#pragma unroll - for (int j = 0; j < ncols; ++j) { - const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); - Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); - } - } - } - - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - KQ[j*D + tid] = -HALF_MAX_HALF; - } - __syncthreads(); - - half2 VKQ[ncols] = {{0.0f, 0.0f}}; - - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - K += blockIdx.y*D * nb11; - V += blockIdx.y*D * nb21; - maskh += blockIdx.y*D; - for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D, - // Increment pointers after each loop: - K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { - - // Calculate KQ tile and keep track of new maximum KQ values: - - if (mask) { -#pragma unroll - for (int j = 0; j < ncols; ++j) { - maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid]; - } - __syncthreads(); - } - - // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, - // see https://github.com/ggerganov/llama.cpp/pull/7061 . - // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable). - half kqmax_new = kqmax[0]; - half kqmax_new_arr[ncols]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqmax_new_arr[j] = kqmax[j]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - - if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { - break; - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); - sum = warp_reduce_sum((float)sum); - - if (use_logit_softcap) { - sum = logit_softcap*tanhf(sum); - } - - sum += maskh_shared[j*D + i_KQ]; - - if (ncols == 1) { - kqmax_new = ggml_cuda_hmax(kqmax_new, sum); - } else { - kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum); - } - - if (threadIdx.x == 0) { - KQ[j*D + i_KQ] = sum; - } - } - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; - - if (threadIdx.x == 0) { - kqmax_shared[j][threadIdx.y] = kqmax_new_j; - } - } - - __syncthreads(); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - half kqmax_new_j = kqmax_shared[j][threadIdx.x]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); - kqmax[j] = kqmax_new_j; - - const half val = hexp(KQ[j*D + tid] - kqmax[j]); - kqsum[j] = kqsum[j]*KQ_max_scale + val; - KQ[j*D + tid] = val; - - VKQ[j] *= __half2half2(KQ_max_scale); - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < D; k0 += 2) { - if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { - break; - } - - half2 V_k; - reinterpret_cast(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid); - reinterpret_cast(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid); -#pragma unroll - for (int j = 0; j < ncols; ++j) { - VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; - } - } - - __syncthreads(); - } - - if (sinksf && blockIdx.y == 0) { - const half sink = __float2half(sinksf[head]); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - if (threadIdx.x == 0) { - kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink); - } - } - - __syncthreads(); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - half kqmax_new_j = kqmax_shared[j][threadIdx.x]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); - kqmax[j] = kqmax_new_j; - - const half val = hexp(sink - kqmax[j]); - kqsum[j] = kqsum[j]*KQ_max_scale; - - if (tid == 0) { - kqsum[j] += val; - } - - VKQ[j] *= __half2half2(KQ_max_scale); - } - - __syncthreads(); - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqsum[j] = warp_reduce_sum((float)kqsum[j]); - if (threadIdx.x == 0) { - kqsum_shared[j][threadIdx.y] = kqsum[j]; - } - } - - __syncthreads(); - -#pragma unroll - for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { - if (ncols > 2 && ic0 + j_VKQ >= ne01) { - break; - } - - kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; - kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]); - - half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ])); - if (gridDim.y == 1) { - dst_val /= kqsum[j_VKQ]; - } - dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val; - } - - if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); - } -#else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) -} -#ifdef __clang__ -#pragma clang diagnostic pop -#endif // __clang__ - -template -void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; - constexpr bool need_f16_K = D != 128; - constexpr bool need_f16_V = D != 128 && D != 64; - constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); -} - -template -void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; - - const int32_t precision = KQV->op_params[3]; - GGML_ASSERT(precision == GGML_PREC_DEFAULT); - - GGML_ASSERT(K->type == type_K); - GGML_ASSERT(V->type == type_V); - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - - if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) { - constexpr int cols_per_block = 1; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } - return; - } - - if (Q->ne[1] == 2) { - constexpr int cols_per_block = 2; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } - return; - } - - if (Q->ne[1] <= 4) { - constexpr int cols_per_block = 4; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 8; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); - } -} - -#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \ - template void ggml_cuda_flash_attn_ext_vec_f16_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ - -extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16); - -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); - -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); - -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); - -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); - -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); - -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); - -extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh deleted file mode 100644 index da195d0334..0000000000 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ /dev/null @@ -1,486 +0,0 @@ -#include "common.cuh" -#include "fattn-common.cuh" - -// Currenlty llvm with the amdgcn target dose not support unrolling loops -// that contain a break that can not be resolved at compile time. -#ifdef __clang__ -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wpass-failed" -#endif // __clang__ -template // D == head size -#ifndef GGML_USE_HIP -__launch_bounds__(D, 1) -#endif // GGML_USE_HIP -static __global__ void flash_attn_vec_ext_f32( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#ifdef FLASH_ATTN_AVAILABLE - - // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(D == 128 || D == 256)) { - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; - return; - } -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - if (ncols > 1) { - NO_DEVICE_CODE; - return; - } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32(type_K); - constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; - constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V); - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb03*sequence + nb02* head + nb01*ic0; - K += nb13*sequence + nb12*(head / gqa_ratio); - V += nb23*sequence + nb22*(head / gqa_ratio); - - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float * sinksf = (const float *) (sinks); - - const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - constexpr int nwarps = D / WARP_SIZE; - const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < D); - - __shared__ float KQ[ncols*D]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - KQ[j*D + tid] = -FLT_MAX/2.0f; - } - - float kqmax[ncols]; - float kqsum[ncols]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqmax[j] = -FLT_MAX/2.0f; - kqsum[j] = 0.0f; - } - - __shared__ float kqmax_shared[ncols][WARP_SIZE]; - __shared__ float kqsum_shared[ncols][WARP_SIZE]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - if (threadIdx.y == 0) { - kqmax_shared[j][threadIdx.x] = -FLT_MAX/2.0f; - kqsum_shared[j][threadIdx.x] = 0.0f; - } - } - - __shared__ float maskf_shared[ncols*D]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - maskf_shared[j*D + tid] = 0.0f; - } - - __syncthreads(); - - // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: - float2 Q_f2[ncols][D/(2*WARP_SIZE)]; - int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)]; - float2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; - if (Q_q8_1) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (j0 + nwarps > ncols && j >= ncols) { - break; - } - - // Reuse KQ as temporary storage for converting Q to q8_1: - int * tmp_q_i32 = (int *) &KQ[j*D]; - float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); - - // Set memory to zero if out of bounds: - if (ncols > 2 && ic0 + j >= ne01) { -#pragma unroll - for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - tmp_q_i32[i] = 0; - } - if (threadIdx.x < D/QK8_1) { - tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f); - } - continue; - } - - const float * Q_f = (const float *) (Q + j*nb01); -#pragma unroll - for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { - quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); - } - } - - __syncthreads(); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - int * tmp_q_i32 = (int *) &KQ[j*D]; - float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); - -#pragma unroll - for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; - Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1]; - } - } - - __syncthreads(); - } else { -#pragma unroll - for (int j = 0; j < ncols; ++j) { - const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); - Q_f2[j][i0/WARP_SIZE].x *= scale; - Q_f2[j][i0/WARP_SIZE].y *= scale; - } - } - } - - float VKQ[ncols] = {0.0f}; - - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - K += blockIdx.y*D * nb11; - V += blockIdx.y*D * nb21; - maskh += blockIdx.y*D; - for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D, - // Increment pointers after each loop: - K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { - - // Calculate KQ tile and keep track of new maximum KQ values: - - if (mask) { -#pragma unroll - for (int j = 0; j < ncols; ++j) { - maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]); - } - __syncthreads(); - } - - float kqmax_new_arr[ncols]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqmax_new_arr[j] = kqmax[j]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - - if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { - break; - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); - sum = warp_reduce_sum(sum); - - if (use_logit_softcap) { - sum = logit_softcap*tanhf(sum); - } - - sum += maskf_shared[j*D + i_KQ]; - - kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); - - if (threadIdx.x == 0) { - KQ[j*D + i_KQ] = sum; - } - } - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - float kqmax_new_j = kqmax_new_arr[j]; - - if (threadIdx.x == 0) { - kqmax_shared[j][threadIdx.y] = kqmax_new_j; - } - } - - __syncthreads(); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - float kqmax_new_j = kqmax_shared[j][threadIdx.x]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j); - kqmax[j] = kqmax_new_j; - - const float val = expf(KQ[j*D + tid] - kqmax[j]); - kqsum[j] = kqsum[j]*KQ_max_scale + val; - KQ[j*D + tid] = val; - - VKQ[j] *= KQ_max_scale; - } - - __syncthreads(); - -#pragma unroll - for (int k = 0; k < D; ++k) { - if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) { - break; - } - - const float V_ki = dequantize_1_v(V + k*nb21, tid); -#pragma unroll - for (int j = 0; j < ncols; ++j) { - VKQ[j] += V_ki*KQ[j*D + k]; - } - } - - __syncthreads(); - } - - if (sinksf && blockIdx.y == 0) { - const float sink = sinksf[head]; - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - if (threadIdx.x == 0) { - kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink); - } - } - - __syncthreads(); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - float kqmax_new_j = kqmax_shared[j][threadIdx.x]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j); - kqmax[j] = kqmax_new_j; - - const float val = expf(sink - kqmax[j]); - kqsum[j] = kqsum[j]*KQ_max_scale; - - if (tid == 0) { - kqsum[j] += val; - } - - VKQ[j] *= KQ_max_scale; - } - - __syncthreads(); - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqsum[j] = warp_reduce_sum(kqsum[j]); - if (threadIdx.x == 0) { - kqsum_shared[j][threadIdx.y] = kqsum[j]; - } - } - - __syncthreads(); - -#pragma unroll - for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { - if (ncols > 2 && ic0 + j_VKQ >= ne01) { - break; - } - - kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; - kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); - - float dst_val = VKQ[j_VKQ]; - if (gridDim.y == 1) { - dst_val /= kqsum[j_VKQ]; - } - dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val; - } - - if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); - } -#else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; -#endif // FLASH_ATTN_AVAILABLE -} -#ifdef __clang__ -#pragma clang diagnostic pop -#endif // __clang__ - -template -void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; - constexpr bool need_f16_K = D != 128; - constexpr bool need_f16_V = D != 128 && D != 64; - constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); -} - -template -void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * V = dst->src[2]; - - GGML_ASSERT(K->type == type_K); - GGML_ASSERT(V->type == type_V); - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - - if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) { - constexpr int cols_per_block = 1; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } - return; - } - - if (Q->ne[1] == 2) { - constexpr int cols_per_block = 2; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } - return; - } - - if (Q->ne[1] <= 4) { - constexpr int cols_per_block = 4; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 8; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); - } -} - -#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \ - template void ggml_cuda_flash_attn_ext_vec_f32_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ - -extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16); - -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); - -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); - -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); - -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); - -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); - -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); -extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); - -extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh new file mode 100644 index 0000000000..59c62553b0 --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -0,0 +1,593 @@ +#include "common.cuh" +#include "fattn-common.cuh" + +static int ggml_cuda_fattn_vec_get_nthreads_host(const int cc) { + return 128; + GGML_UNUSED(cc); +} + +static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() { + return 128; +} + +// Currenlty llvm with the amdgcn target dose not support unrolling loops +// that contain a break that can not be resolved at compile time. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template // D == head size +__launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1) +static __global__ void flash_attn_ext_vec( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { +#ifdef FLASH_ATTN_AVAILABLE + + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; + return; + } + + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + +#ifdef GGML_USE_HIP +#ifdef RDNA + constexpr int nthreads_KQ_q = 2; +#else + constexpr int nthreads_KQ_q = 4; +#endif // RDNA + constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32); +#else + constexpr int nthreads_KQ_q = (D/4 < 32 ? D/4 : 32); + constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32); +#endif // GGML_USE_HIP + + constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device(); + constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q; + constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q; + + static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K"); + static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V"); + + constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4; + constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; + + constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ(); + constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; +#ifdef FAST_FP16_AVAILABLE + constexpr dequantize_V_t dequantize_V = get_dequantize_V(); +#else + constexpr dequantize_V_t dequantize_V = get_dequantize_V(); +#endif // FAST_FP16_AVAILABLE + + const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + Q += nb03*sequence + nb02* head + nb01*ic0; + K += nb13*sequence + nb12*(head / gqa_ratio); + V += nb23*sequence + nb22*(head / gqa_ratio); + + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); + + static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); + constexpr int nwarps = nthreads / WARP_SIZE; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < nthreads); + + constexpr int ne_KQ = ncols*D; + constexpr int ne_combine = nwarps*V_cols_per_iter*D; +#ifdef FAST_FP16_AVAILABLE + half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; + __shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine]; +#else + float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; + __shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine]; +#endif // FAST_FP16_AVAILABLE + + float KQ_max[ncols]; + float KQ_sum[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_max[j] = -FLT_MAX/2.0f; + KQ_sum[j] = 0.0f; + } + + // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: +#ifdef FAST_FP16_AVAILABLE + half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely. +#else + float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. +#endif // FAST_FP16_AVAILABLE + int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; + float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; + if constexpr (Q_q8_1) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + + // Reuse KQ as temporary storage for converting Q to q8_1: + int * tmp_q_i32 = (int *) &KQ[j*D]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + + // Set memory to zero if out of bounds: + if (ncols > 1 && ic0 + j >= ne01) { +#pragma unroll + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) { + tmp_q_i32[i] = 0; + } + } + if (threadIdx.x < D/QK8_1) { + tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f); + } + } else { + const float * Q_f = (const float *) (Q + j*nb01); + constexpr int nthreads_quantize = D/sizeof(int) < WARP_SIZE ? D/sizeof(int) : WARP_SIZE; +#pragma unroll + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) { + quantize_q8_1_to_shared + (Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1); + } + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + int * tmp_q_i32 = (int *) &KQ[j*D]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + +#pragma unroll + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) { + const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ); + + Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i]; + Q_ds[j][i0/nthreads_KQ] = tmp_q_ds[i/QI8_1]; + } + } + + __syncthreads(); + } else { +#ifdef FAST_FP16_AVAILABLE + const half2 scale_h2 = make_half2(scale, scale); +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const float2 * Q_j = (const float2 *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { + const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; + + float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; + if (ncols == 1 || ic0 + j < ne01) { + ggml_cuda_memcpy_1(tmp, &Q_j[i]); + ggml_cuda_memcpy_1(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]); + } +#pragma unroll + for (int i1 = 0; i1 < cpy_ne; ++i1) { + Q_reg[j][i0/nthreads_KQ + i1] = make_half2(tmp[i1].x, tmp[i1].y); + } + } +#pragma unroll + for (int k = 0; k < (D/2)/nthreads_KQ; ++k) { + Q_reg[j][k] *= scale_h2; + } + } +#else +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const float2 * Q_j = (const float2 *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { + const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; + if (ncols == 1 || ic0 + j < ne01) { + ggml_cuda_memcpy_1(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]); + ggml_cuda_memcpy_1(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]); + } + } +#pragma unroll + for (int k = 0; k < (D/2)/nthreads_KQ; ++k) { + Q_reg[j][k].x *= scale; + Q_reg[j][k].y *= scale; + } + } +#endif // FAST_FP16_AVAILABLE + } + + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + K += blockIdx.y*nthreads * nb11; + V += blockIdx.y*nthreads * nb21; + maskh += blockIdx.y*nthreads; + for (int k_VKQ_0 = blockIdx.y*nthreads; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nthreads, + // Increment pointers after each loop: + K += gridDim.y*nthreads*nb11, V += gridDim.y*nthreads*nb21, maskh += gridDim.y*nthreads) { + + // Calculate KQ tile and keep track of new maximum KQ values: + float KQ_reg[ncols]; // KQ in registers. + + float KQ_max_new[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_max_new[j] = KQ_max[j]; + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) { + const int i_KQ = threadIdx.y*WARP_SIZE + (nthreads_KQ == WARP_SIZE ? 0 : (threadIdx.x & ~(nthreads_KQ-1))) + i_KQ_0; + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]); + sum = warp_reduce_sum(sum); + + if (use_logit_softcap) { + sum = logit_softcap*tanhf(sum); + } + + if (mask) { + sum += slope*__half2float(maskh[j*ne11 + i_KQ]); + } + + KQ_max_new[j] = fmaxf(KQ_max_new[j], sum); + + if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) { + KQ_reg[j] = sum; + } + } + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { +#pragma unroll + for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) { + KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE)); + } + const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]); + KQ_max[j] = KQ_max_new[j]; + + KQ_reg[j] = expf(KQ_reg[j] - KQ_max[j]); + KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j]; + KQ[j*nthreads + tid] = KQ_reg[j]; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale; + VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + +#ifndef GGML_USE_HIP + __syncwarp(); +#endif // GGML_USE_HIP + +#pragma unroll + for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) { + const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V); + +#ifdef FAST_FP16_AVAILABLE + half2 KQ_k[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_k[j] = __half2half2(KQ[j*nthreads + k]); + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + half2 tmp[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j]; + } + } + } +#else + float KQ_k[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_k[j] = KQ[j*nthreads + k]; + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + float2 tmp[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j]; + VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j]; + } + } + } +#endif // FAST_FP16_AVAILABLE + } + } + + if (sinks && blockIdx.y == 0) { + const float sink = ((const float *) sinks)[head]; + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + + const float kqmax_new_j = fmaxf(sink, KQ_max[j]); + const float KQ_max_scale = expf(KQ_max[j] - kqmax_new_j); + KQ_max[j] = kqmax_new_j; + + KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f); + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale; + VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + } + + __shared__ float KQ_max_shared[ncols][WARP_SIZE]; + __shared__ float KQ_sum_shared[ncols][WARP_SIZE]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.y == 0) { + KQ_max_shared[j][threadIdx.x] = -FLT_MAX/2.0f; + KQ_sum_shared[j][threadIdx.x] = 0.0f; + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.x == 0) { + KQ_max_shared[j][threadIdx.y] = KQ_max[j]; + } + } + __syncthreads(); + +#pragma unroll + for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + if (ncols > 1 && ic0 + j_VKQ >= ne01) { + break; + } + + float kqmax_new = KQ_max_shared[j_VKQ][threadIdx.x]; + kqmax_new = warp_reduce_max(kqmax_new); + const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new); + KQ_max[j_VKQ] = kqmax_new; + +#ifdef FAST_FP16_AVAILABLE + half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2) + + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2); + + const half2 kqmax_scale_h2 = make_half2(kqmax_scale, kqmax_scale); +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2; + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2); + + ggml_cuda_memcpy_1(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]); + } +#else + float2 * VKQ_tmp = (float2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2) + + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2); + +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j_VKQ][i_VKQ_0/nthreads_V].x *= kqmax_scale; + VKQ[j_VKQ][i_VKQ_0/nthreads_V].y *= kqmax_scale; + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2); + + ggml_cuda_memcpy_1(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]); + ggml_cuda_memcpy_1(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]); + } +#endif // FAST_FP16_AVAILABLE + + KQ_sum[j_VKQ] *= kqmax_scale; + KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]); + if (threadIdx.x == 0) { + KQ_sum_shared[j_VKQ][threadIdx.y] = KQ_sum[j_VKQ]; + } + + __syncthreads(); + + if (nthreads <= D || tid < D) { + KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ][threadIdx.x]; + KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]); + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += nthreads) { + float dst_val = 0; +#pragma unroll + for (int w = 0; w < nwarps; ++w) { +#pragma unroll + for (int v = 0; v < V_cols_per_iter; ++v) { + dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]); + } + } + if (gridDim.y == 1) { + dst_val /= KQ_sum[j_VKQ]; + } + dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val; + } + } + + if (j_VKQ < ncols-1) { + __syncthreads(); + } + + } + + if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < ne01)) { + dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; +#endif // FLASH_ATTN_AVAILABLE +} +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ + +template +void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc); + const int nwarps = nthreads / WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_ext_vec; + constexpr bool need_f16_K = false; + constexpr bool need_f16_V = false; + constexpr size_t nbytes_shared = 0; + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); +} + +template +void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(K->type == type_K); + GGML_ASSERT(V->type == type_V); + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + if (Q->ne[1] == 1) { + constexpr int cols_per_block = 1; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_case_impl(ctx, dst); + } + return; + } + + constexpr int cols_per_block = 2; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_case_impl(ctx, dst); + } +} + +#define DECL_FATTN_VEC_CASE(D, type_K, type_V) \ + template void ggml_cuda_flash_attn_ext_vec_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ + +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) + +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) + +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 7626d89ca0..1cbd4f5bd6 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -2,8 +2,7 @@ #include "fattn-common.cuh" #include "fattn-mma-f16.cuh" #include "fattn-tile.cuh" -#include "fattn-vec-f16.cuh" -#include "fattn-vec-f32.cuh" +#include "fattn-vec.cuh" #include "fattn-wmma-f16.cuh" #include "fattn.cuh" @@ -117,151 +116,68 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg } } -#define FATTN_VEC_F16_CASE(D, type_K, type_V) \ - if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ - ggml_cuda_flash_attn_ext_vec_f16_case(ctx, dst); \ - return; \ - } \ +#define FATTN_VEC_CASE(D, type_K, type_V) \ + if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ + ggml_cuda_flash_attn_ext_vec_case(ctx, dst); \ + return; \ + } \ -static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \ + FATTN_VEC_CASE( 64, type_K, type_V) \ + FATTN_VEC_CASE(128, type_K, type_V) \ + FATTN_VEC_CASE(256, type_K, type_V) \ + +static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_tensor * Q = dst->src[0]; ggml_tensor * K = dst->src[1]; ggml_tensor * V = dst->src[2]; #ifdef GGML_CUDA_FA_ALL_QUANTS - FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0) - FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1) - FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0) - FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1) - FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0) - FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 ) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) - FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0) - - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16) - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16) - FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) - - FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) #else - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) - - FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) - - FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) - FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) - FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) -#endif // GGML_CUDA_FA_ALL_QUANTS - - GGML_ABORT("fatal error"); -} - -#define FATTN_VEC_F32_CASE(D, type_K, type_V) \ - if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ - ggml_cuda_flash_attn_ext_vec_f32_case(ctx, dst); \ - return; \ - } \ - -static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_tensor * Q = dst->src[0]; - ggml_tensor * K = dst->src[1]; - ggml_tensor * V = dst->src[2]; - -#ifdef GGML_CUDA_FA_ALL_QUANTS - FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0) - FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1) - FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0) - FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1) - FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0) - FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) - - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0) - - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1) - - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0) - - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) - FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1) - - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) - FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0) - - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16) - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16) - FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) - - FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) -#else - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) - - FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) - - FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) - FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) - FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) #endif // GGML_CUDA_FA_ALL_QUANTS GGML_ABORT("fatal error"); @@ -271,8 +187,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg enum best_fattn_kernel { BEST_FATTN_KERNEL_NONE = 0, BEST_FATTN_KERNEL_TILE = 200, - BEST_FATTN_KERNEL_VEC_F32 = 100, - BEST_FATTN_KERNEL_VEC_F16 = 110, + BEST_FATTN_KERNEL_VEC = 100, BEST_FATTN_KERNEL_WMMA_F16 = 300, BEST_FATTN_KERNEL_MMA_F16 = 400, }; @@ -283,7 +198,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; #endif// FLASH_ATTN_AVAILABLE - const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; @@ -293,8 +207,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int cc = ggml_cuda_info().devices[device].cc; - const int warp_size = ggml_cuda_info().devices[device].warp_size; - const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); switch (K->ne[0]) { case 64: @@ -343,31 +255,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const #endif // GGML_CUDA_FA_ALL_QUANTS case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: -#ifdef GGML_CUDA_FA_ALL_QUANTS - if (K->ne[0] != 128 && K->ne[0] != 64) { - return BEST_FATTN_KERNEL_NONE; - } -#else - if (K->ne[0] != 128) { - return BEST_FATTN_KERNEL_NONE; - } -#endif // GGML_CUDA_FA_ALL_QUANTS - break; - default: - return BEST_FATTN_KERNEL_NONE; - } - - switch (V->type) { - case GGML_TYPE_F16: - break; - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q8_0: - if (K->ne[0] != 128) { - return BEST_FATTN_KERNEL_NONE; - } break; default: return BEST_FATTN_KERNEL_NONE; @@ -377,30 +264,39 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; } - const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0; // If Turing tensor cores available, use them except for some cases with batch size 1: if (turing_mma_available(cc)) { - const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask; // The mma-based kernels have GQA-specific optimizations - const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; - const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (gqa_ratio > 4 && K->ne[1] >= 8192); - const bool mma_faster_for_bs1 = gqa_opt_applies && !mma_needs_data_conversion && - (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000); - if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { - if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { - return BEST_FATTN_KERNEL_VEC_F16; + best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16; + + if (can_use_vector_kernel) { + if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { + if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { + best = BEST_FATTN_KERNEL_VEC; + } + } else { + if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { + if (Q->ne[1] <= 2) { + best = BEST_FATTN_KERNEL_VEC; + } + } else { + if (Q->ne[1] == 1) { + best = BEST_FATTN_KERNEL_VEC; + } + } + } + if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) { + best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply. } - return BEST_FATTN_KERNEL_VEC_F32; } - return BEST_FATTN_KERNEL_MMA_F16; + + return best; } - // Use kernels specializes for small batch sizes if possible: + // Use kernels specialized for small batch sizes if possible: if (Q->ne[1] <= 8 && can_use_vector_kernel) { - if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { - return BEST_FATTN_KERNEL_VEC_F16; - } - return BEST_FATTN_KERNEL_VEC_F32; + return BEST_FATTN_KERNEL_VEC; } // For large batch sizes, use the WMMA kernel if possible: @@ -420,11 +316,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst case BEST_FATTN_KERNEL_TILE: ggml_cuda_flash_attn_ext_tile(ctx, dst); break; - case BEST_FATTN_KERNEL_VEC_F32: - ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); - break; - case BEST_FATTN_KERNEL_VEC_F16: - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + case BEST_FATTN_KERNEL_VEC: + ggml_cuda_flash_attn_ext_vec(ctx, dst); break; case BEST_FATTN_KERNEL_WMMA_F16: ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8c8647b147..b7e81b21bc 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2031,7 +2031,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor const int cc = ggml_cuda_info().devices[id].cc; const int warp_size = ggml_cuda_info().devices[id].warp_size; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]); + use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false); use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } @@ -2039,7 +2039,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor const int cc = ggml_cuda_info().devices[ctx.device].cc; const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]); + use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1], /*mul_mat_id=*/false); use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } @@ -2111,7 +2111,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * return; } - if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) { + if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2], /*mul_mat_id=*/true)) { ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst); return; } @@ -2641,6 +2641,8 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased"; const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased"; const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased"; + const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out"; + const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d"; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2669,7 +2671,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) && strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 && strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && - strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) { + strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 && + strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 && + strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) { // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation // by means of matching node names. See // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and @@ -3639,9 +3643,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: case GGML_OP_SUM: - case GGML_OP_ARGSORT: case GGML_OP_ACC: return true; + case GGML_OP_ARGSORT: + // TODO: Support arbitrary column width + return op->src[0]->ne[0] <= 1024; case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_GROUP_NORM: diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 16331e9ecf..599e085ee9 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -84,7 +84,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr } } -bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) { +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols, bool mul_mat_id) { if (ggml_is_quantized(type)) { return false; @@ -96,8 +96,18 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { return false; } - if (src1_ncols > 16) { - return false; + + if (mul_mat_id) { + if (type == GGML_TYPE_F32 && src1_ncols > 32) { + return false; + } + if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) { + return false; + } + } else { + if (src1_ncols > 16) { + return false; + } } switch (type) { diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 61e3bf3015..a6c3adfcf1 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -9,13 +9,13 @@ using namespace ggml_cuda_mma; void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); -bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols); +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id); template __launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) static __global__ void mul_mat_f( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, - const int ncols, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, const int stride_col_id, const int stride_row_id, const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { @@ -31,9 +31,20 @@ static __global__ void mul_mat_f( const int row0 = blockIdx.x * rows_per_block; - const int expert_idx = has_ids ? blockIdx.y : 0; + int expert_idx = 0; + int col_base = 0; + const int channel_dst = has_ids ? 0 : blockIdx.y; + if constexpr (has_ids) { + // experts + tiles of ncols_dst are packed in the y dimension + int col_tiles = (ncols_dst_total + cols_per_block - 1) / cols_per_block; + const int nchannels_x = gridDim.y / col_tiles; + const int tile_idx = blockIdx.y / nchannels_x; + expert_idx = blockIdx.y - tile_idx * nchannels_x; + col_base = tile_idx * cols_per_block; + } + const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio); const int channel_y = channel_dst; const int sample_dst = blockIdx.z; @@ -44,6 +55,14 @@ static __global__ void mul_mat_f( y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y); dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst); + if constexpr (has_ids) { + constexpr int y_stride_scale = std::is_same_v ? 1 : 2; + const int64_t col_offset = col_base; + y += col_offset * stride_col_y * y_stride_scale; + dst += col_offset * stride_col_dst; + ids += col_offset * stride_row_id; + } + const float2 * y2 = (const float2 *) y; extern __shared__ char data_mmv[]; @@ -61,12 +80,17 @@ static __global__ void mul_mat_f( for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { const int j = j0 + threadIdx.y; - const int32_t * __restrict__ id_row = ids + j*stride_row_id; if (threadIdx.x == 0) { slot_map[j] = -1; } + if (col_base + j >= ncols_dst_total) { + continue; + } + + const int32_t * __restrict__ id_row = ids + j*stride_row_id; + for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) { int match = id_row[k*stride_col_id] == expert_idx; @@ -108,7 +132,8 @@ static __global__ void mul_mat_f( if constexpr (!has_ids) { tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; } else { - tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f; + const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0; + tile_xy[j0*tile_k_padded + threadIdx.x] = valid ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f; } } } else if constexpr (std::is_same_v || std::is_same_v) { @@ -120,7 +145,8 @@ static __global__ void mul_mat_f( const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; } else { - float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f); + const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0; + float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f); tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; } } @@ -183,14 +209,14 @@ static __global__ void mul_mat_f( dst[j*stride_col_dst + row0 + threadIdx.x] = sum; } else { const int slot = (j < cols_per_block) ? slot_map[j] : -1; - if (slot >= 0) { + if (slot >= 0 && (col_base + j) < ncols_dst_total) { dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum; } } } #else GGML_UNUSED_VARS(x, y, ids, dst, - ncols, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); @@ -201,20 +227,23 @@ static __global__ void mul_mat_f( template static inline void mul_mat_f_switch_ids( const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols_x, const int64_t nchannels_dst, + const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, const int64_t stride_col_id, const int64_t stride_row_id, const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { if (ids) { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block; + dim3 block_nums_ids = block_nums; + block_nums_ids.y *= col_tiles; + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } else { mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } @@ -223,7 +252,8 @@ static inline void mul_mat_f_switch_ids( template void mul_mat_f_cuda( const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, const int64_t stride_col_id, const int64_t stride_row_id, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, @@ -268,49 +298,49 @@ void mul_mat_f_cuda( switch (nwarps_best) { case 1: { mul_mat_f_switch_ids( - x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); } break; case 2: { mul_mat_f_switch_ids( - x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); } break; case 3: { mul_mat_f_switch_ids( - x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); } break; case 4: { mul_mat_f_switch_ids( - x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); } break; case 5: { mul_mat_f_switch_ids( - x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); } break; case 6: { mul_mat_f_switch_ids( - x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); } break; case 7: { mul_mat_f_switch_ids( - x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); } break; case 8: { mul_mat_f_switch_ids( - x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); } break; @@ -332,84 +362,89 @@ static void mul_mat_f_switch_cols_per_block( const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, cudaStream_t stream) { - switch (ncols_dst) { + + const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst; + + GGML_ASSERT(ids || ncols_dst <= 16); + + switch (ncols_case) { case 1: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; case 2: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; case 3: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; case 4: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; case 5: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; case 6: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; case 7: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; case 8: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; case 9: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; case 10: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; case 11: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; case 12: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; case 13: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; case 14: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; case 15: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; case 16: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } break; @@ -422,7 +457,7 @@ static void mul_mat_f_switch_cols_per_block( #define DECL_MMF_CASE_HELPER(T, ncols_dst) \ template void mul_mat_f_cuda( \ const T * x, const float * y, const int32_t * ids, float * dst, \ - const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \ + const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \ const int64_t stride_col_id, const int64_t stride_row_id, \ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\ diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 714b23f9f4..12bdc629bd 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -81,7 +81,7 @@ static __global__ void mmq_ids_helper( #pragma unroll for (int offset = neu_padded; offset < warp_size; offset += neu_padded) { const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size); - if (threadIdx.x >= offset) { + if (threadIdx.x >= static_cast(offset)) { it_compact_add_lower += tmp; } } @@ -110,7 +110,7 @@ static __global__ void mmq_ids_helper( expert_bounds[expert] = nex_prev; - if (expert < gridDim.x - 1) { + if (expert < static_cast(gridDim.x) - 1) { return; } diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 52de4e78d1..3bf0c9ed25 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -220,7 +220,7 @@ static __global__ void mul_mat_vec_q( tmp[j][i] = warp_reduce_sum(tmp[j][i]); } - if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + int(threadIdx.x) < stride_col_dst)) { + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) { dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x]; } } diff --git a/ggml/src/ggml-cuda/pad_reflect_1d.cu b/ggml/src/ggml-cuda/pad_reflect_1d.cu index 0478889da1..32993eb591 100644 --- a/ggml/src/ggml-cuda/pad_reflect_1d.cu +++ b/ggml/src/ggml-cuda/pad_reflect_1d.cu @@ -51,6 +51,8 @@ static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void } const float value = *(const float *) (src0_ptr + src_idx * nb00); *(float *) (dst_ptr + i0 * nb0) = value; + + GGML_UNUSED(p1); } void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu deleted file mode 100644 index 6696a23847..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu deleted file mode 100644 index dd070db285..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu deleted file mode 100644 index 54dcde6f52..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu deleted file mode 100644 index 4ec22f7919..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu deleted file mode 100644 index 3c15bf7f0e..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu deleted file mode 100644 index 7e61b5fdcd..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu deleted file mode 100644 index fdb15b580c..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu deleted file mode 100644 index 0f7c417d2c..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu deleted file mode 100644 index 851f33c43f..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu deleted file mode 100644 index 763809cbeb..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu deleted file mode 100644 index f2a276e50e..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu deleted file mode 100644 index cb227f6f5c..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu deleted file mode 100644 index 97ac0520c7..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu deleted file mode 100644 index c772b42634..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu deleted file mode 100644 index 5cb7430819..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu deleted file mode 100644 index 98a709d171..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu deleted file mode 100644 index 4f2f947ae8..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu deleted file mode 100644 index 11f96b6f65..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu deleted file mode 100644 index b39bdc0611..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu deleted file mode 100644 index bbd6a2c7f4..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu deleted file mode 100644 index 9d84ff2b19..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu deleted file mode 100644 index bc8a5bff68..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu deleted file mode 100644 index a679100c83..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu deleted file mode 100644 index 8f21bccf7f..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu deleted file mode 100644 index 858b00fd74..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu deleted file mode 100644 index 0fc8011fac..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu deleted file mode 100644 index 261fdf623e..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu deleted file mode 100644 index 0fb8247383..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu deleted file mode 100644 index a9d9d089bd..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu deleted file mode 100644 index 7d7b27920a..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu deleted file mode 100644 index a092ee2d50..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu deleted file mode 100644 index db55927a19..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu deleted file mode 100644 index c3c21cefae..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu deleted file mode 100644 index 35dd9f5208..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu deleted file mode 100644 index 050c22ac7c..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu deleted file mode 100644 index de4866c5e6..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu deleted file mode 100644 index 57a10bc4be..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu deleted file mode 100644 index e0f08b46a7..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu deleted file mode 100644 index 1c8e8a467a..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu deleted file mode 100644 index cefed83fb9..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu deleted file mode 100644 index aede6e3588..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu deleted file mode 100644 index 1a1a92c788..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu deleted file mode 100644 index ad667473d1..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f16.cuh" - -DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu deleted file mode 100644 index c499f455da..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu deleted file mode 100644 index 8286ebf373..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu deleted file mode 100644 index 4587868825..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu deleted file mode 100644 index d89103ce0c..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu deleted file mode 100644 index bb75fd42ff..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu deleted file mode 100644 index b1629817e7..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu deleted file mode 100644 index d8657604da..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu deleted file mode 100644 index 2e5bd2f1a3..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu deleted file mode 100644 index be5f302d9f..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu deleted file mode 100644 index 8dd91cd72e..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu deleted file mode 100644 index 4cb791502a..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu deleted file mode 100644 index 09dea42673..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu deleted file mode 100644 index 0fbb607694..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu deleted file mode 100644 index 2aeab83b20..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu deleted file mode 100644 index 599415b494..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu deleted file mode 100644 index e4f8e3083b..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu deleted file mode 100644 index 34d166527e..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu deleted file mode 100644 index 4bebef45a3..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu deleted file mode 100644 index 326468da2f..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu deleted file mode 100644 index 511b58f4ec..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu deleted file mode 100644 index d9906d142e..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu deleted file mode 100644 index f61c183abb..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu deleted file mode 100644 index c10450fd29..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu deleted file mode 100644 index 2d5cb195c4..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu deleted file mode 100644 index b384f34d7d..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu deleted file mode 100644 index 446e293b16..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu deleted file mode 100644 index 6f43029889..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu deleted file mode 100644 index 1cd8ba88fd..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu deleted file mode 100644 index 1ee2eab65a..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu deleted file mode 100644 index 2bc77816a5..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu deleted file mode 100644 index d55ced08bc..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu deleted file mode 100644 index 8361e99c4e..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu deleted file mode 100644 index 7507a67c4c..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu deleted file mode 100644 index 61f050b235..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu deleted file mode 100644 index d4a49d9c99..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu deleted file mode 100644 index d146278976..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu deleted file mode 100644 index e73f917a1f..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu deleted file mode 100644 index d40825dfc2..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu deleted file mode 100644 index b5c6869f4e..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu deleted file mode 100644 index 4e21b0ccae..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu deleted file mode 100644 index 2eac321b37..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu deleted file mode 100644 index f7d2c3b4e0..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu deleted file mode 100644 index a013f400bd..0000000000 --- a/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +++ /dev/null @@ -1,5 +0,0 @@ -// This file has been autogenerated by generate_cu_files.py, do not edit manually. - -#include "../fattn-vec-f32.cuh" - -DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu new file mode 100644 index 0000000000..c357abd80d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu new file mode 100644 index 0000000000..4b148656f9 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu new file mode 100644 index 0000000000..ef7715758c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu new file mode 100644 index 0000000000..9ae11cc542 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu new file mode 100644 index 0000000000..10ed48affa --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu new file mode 100644 index 0000000000..4fcc3f3377 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu new file mode 100644 index 0000000000..7ca50531fb --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu new file mode 100644 index 0000000000..6ef1a48fdb --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu new file mode 100644 index 0000000000..4c0532ca7e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu new file mode 100644 index 0000000000..ed3d7bad39 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu new file mode 100644 index 0000000000..687f254068 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu new file mode 100644 index 0000000000..41107c45f4 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu new file mode 100644 index 0000000000..d523ce01cc --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu new file mode 100644 index 0000000000..8b9ed358ec --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu new file mode 100644 index 0000000000..0553e464c4 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu new file mode 100644 index 0000000000..8390eaf1c8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu new file mode 100644 index 0000000000..f61e19d6a3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu new file mode 100644 index 0000000000..86a188269c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu new file mode 100644 index 0000000000..1d7af474b4 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu new file mode 100644 index 0000000000..837224d360 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu new file mode 100644 index 0000000000..0dd7dd693f --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu new file mode 100644 index 0000000000..41b859f45d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu new file mode 100644 index 0000000000..d2e5ffd0ac --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu new file mode 100644 index 0000000000..81ff740b58 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu new file mode 100644 index 0000000000..a38dae1922 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu new file mode 100644 index 0000000000..2304571e24 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu new file mode 100644 index 0000000000..84b83e5544 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu new file mode 100644 index 0000000000..39f80e218d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu new file mode 100644 index 0000000000..cf4e66112b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu new file mode 100644 index 0000000000..65654182e5 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu new file mode 100644 index 0000000000..a1bc3f5a6a --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu new file mode 100644 index 0000000000..4b76a9be23 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu new file mode 100644 index 0000000000..77d04125f7 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu new file mode 100644 index 0000000000..6e170fe36f --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu new file mode 100644 index 0000000000..b617cd73b5 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu new file mode 100644 index 0000000000..a5b768b111 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index da2d7b7c3b..d410080fab 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,13 +3,15 @@ from glob import glob import os -TYPES_KV = ["GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_F16"] +TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"] SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. -#include "../fattn-vec-f{vkq_size}.cuh" +#include "../fattn-vec.cuh" -DECL_FATTN_VEC_F{vkq_size}_CASE({head_size}, {type_k}, {type_v}); +DECL_FATTN_VEC_CASE( 64, {type_k}, {type_v}); +DECL_FATTN_VEC_CASE(128, {type_k}, {type_v}); +DECL_FATTN_VEC_CASE(256, {type_k}, {type_v}); """ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. @@ -46,23 +48,13 @@ def get_short_name(long_quant_name): return long_quant_name.replace("GGML_TYPE_", "").lower() -def get_head_sizes(type_k, type_v): - if type_k == "GGML_TYPE_F16" and type_v == "GGML_TYPE_F16": - return [64, 128, 256] - if type_k == "GGML_TYPE_F16": - return [64, 128] - return [128] - - for filename in glob("*.cu"): os.remove(filename) -for vkq_size in [16, 32]: - for type_k in TYPES_KV: - for type_v in TYPES_KV: - for head_size in get_head_sizes(type_k, type_v): - with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: - f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v)) +for type_k in TYPES_KV: + for type_v in TYPES_KV: + with open(f"fattn-vec-instance-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: + f.write(SOURCE_FATTN_VEC.format(type_k=type_k, type_v=type_v)) for ncols in [8, 16, 32, 64]: for ncols2 in [1, 2, 4, 8, 16]: diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m index af9ff21436..052efb7ace 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.m +++ b/ggml/src/ggml-metal/ggml-metal-context.m @@ -222,7 +222,28 @@ void ggml_metal_synchronize(ggml_metal_t ctx) { ctx->cmd_buf_last = nil; } - // release any completed command buffers + // check status of all command buffers + { + const int n_cb = ctx->n_cb; + + for (int cb_idx = 0; cb_idx <= n_cb; ++cb_idx) { + id cmd_buf = ctx->cmd_bufs[cb_idx].obj; + if (!cmd_buf) { + continue; + } + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, cb_idx, (int) status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + GGML_ABORT("fatal error"); + } + } + } + + // release any completed extra command buffers if (ctx->cmd_bufs_ext.count > 0) { for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) { id cmd_buf = ctx->cmd_bufs_ext[i]; @@ -260,6 +281,8 @@ void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, length:size options:MTLResourceStorageModeShared]; + GGML_ASSERT(buf_src); + struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(tensor); if (bid_dst.metal == nil) { GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name); @@ -299,6 +322,8 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te options:MTLResourceStorageModeShared deallocator:nil]; + GGML_ASSERT(buf_dst); + struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(tensor); if (bid_src.metal == nil) { GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name); @@ -542,13 +567,13 @@ void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) { ctx->debug_graph, ctx->debug_fusion); - for (int idx = idx_start; idx < idx_end;) { + for (int idx = 0; idx < ggml_metal_op_n_nodes(ctx_op); ++idx) { const int res = ggml_metal_op_encode(ctx_op, idx); if (res == 0) { break; } - idx += res; + idx += res - 1; } ggml_metal_op_free(ctx_op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 03be2c01a7..819f31c8a3 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -438,21 +438,35 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_libr return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) { +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; + const ggml_type tsrc0 = op->src[0]->type; + const ggml_type tsrc1 = op->src[1]->type; + + const bool bc_inp = op->src[0]->ne[0] % 32 != 0; + const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0; + snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); - snprintf(name, 256, "%s", base); + snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); if (res) { return res; } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_pipeline_set_smem(res, 8192); + ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); + ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes + ggml_metal_pipeline_set_smem(res, bc_out ? 8192 : 4096 + 2048); return res; } @@ -481,22 +495,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_ case GGML_TYPE_F16: case GGML_TYPE_BF16: { - if (ne00 == 4) { + if (ne00 < 32) { nsg = 1; nr0 = 32; - nr1 = 4; - suffix = "_c4"; - } else if (ne00 % 4 == 0) { - nsg = N_SG_F; - nr0 = N_R0_F; nr1 = 1; - smem = 32*sizeof(float)*N_R0_F; - suffix = "_4"; + suffix = "_short"; } else { - nsg = N_SG_F; - nr0 = N_R0_F; + nsg = std::min(4, (ne00 + 127) / 128); + nr0 = 2; nr1 = 1; - smem = 32*sizeof(float)*N_R0_F; + smem = 32*sizeof(float)*nr0; + suffix = ne00 % 4 == 0 ? "_4" : ""; } } break; case GGML_TYPE_Q4_0: @@ -659,19 +668,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_ return res; } -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) { +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; + const ggml_type tsrc0 = op->src[0]->type; + const ggml_type tsrc1 = op->src[1]->type; + + const bool bc_inp = op->src[0]->ne[0] % 32 != 0; + snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); - snprintf(name, 256, "%s", base); + snprintf(name, 256, "%s_bci=%d", base, bc_inp); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); if (res) { return res; } - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); ggml_metal_pipeline_set_smem(res, 8192); @@ -702,18 +722,11 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra case GGML_TYPE_F16: case GGML_TYPE_BF16: { - if (ne00 % 4 == 0) { - nsg = N_SG_F; - nr0 = N_R0_F; - nr1 = 1; - smem = 32*sizeof(float)*N_R0_F; - suffix = "_4"; - } else { - nsg = N_SG_F; - nr0 = N_R0_F; - nr1 = 1; - smem = 32*sizeof(float)*N_R0_F; - } + nsg = std::min(4, (ne00 + 127) / 128); + nr0 = 2; + nr1 = 1; + smem = 32*sizeof(float)*nr0; + suffix = ne00 % 4 == 0 ? "_4" : ""; } break; case GGML_TYPE_Q4_0: { diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index dda7eca85a..f6ebf90a00 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -115,10 +115,10 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); -ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 5f744d1a0b..523f9d71ba 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -683,9 +683,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); case GGML_OP_PAD_REFLECT_1D: case GGML_OP_TIMESTEP_EMBEDDING: - case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ARGSORT: + // TODO: Support arbitrary column width + return op->src[0]->ne[0] <= 1024; case GGML_OP_ARANGE: return true; case GGML_OP_FLASH_ATTN_EXT: @@ -717,8 +719,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: - return has_simdgroup_reduction && - (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32); + return has_simdgroup_reduction; case GGML_OP_CPY: case GGML_OP_DUP: case GGML_OP_CONT: @@ -1176,6 +1177,8 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * options:MTLResourceStorageModeShared deallocator:nil]; + GGML_ASSERT(buf_src); + // dst struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf, tensor); bid_dst.offs += offset; @@ -1232,6 +1235,8 @@ void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_ten options:MTLResourceStorageModeShared deallocator:nil]; + GGML_ASSERT(buf_dst); + id queue = buf->queue; id cmd_buf = [queue commandBufferWithUnretainedReferences]; diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index ab51b76c8c..88c98423eb 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -8,9 +8,6 @@ // // TODO: for optimal performance, become function of the device and work size -#define N_R0_F 2 -#define N_SG_F 4 - #define N_R0_Q4_0 4 #define N_SG_Q4_0 2 @@ -76,6 +73,7 @@ #define FC_FLASH_ATTN_EXT_VEC 200 #define FC_FLASH_ATTN_EXT_VEC_REDUCE 300 #define FC_MUL_MV 400 +#define FC_MUL_MM 500 // kernel argument structs // @@ -351,6 +349,7 @@ typedef struct { uint64_t nb13; int32_t ne0; int32_t ne1; + int32_t nr0; int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mv; @@ -426,6 +425,7 @@ typedef struct { int32_t ne0; int32_t ne1; uint64_t nb1; + int32_t nr0; } ggml_metal_kargs_mul_mv_id; // NORM diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 7b11f36adc..e85a223c01 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -24,22 +24,88 @@ static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) { } struct ggml_metal_op { + ggml_metal_op( + ggml_metal_device_t dev, + ggml_metal_cmd_buf_t cmd_buf, + ggml_cgraph * gf, + int idx_start, + int idx_end, + bool use_fusion, + bool use_concurrency, + bool use_capture, + int debug_graph, + int debug_fusion) { + this->dev = dev; + this->lib = ggml_metal_device_get_library(dev); + this->enc = ggml_metal_encoder_init(cmd_buf, use_concurrency); + this->mem_ranges = ggml_mem_ranges_init(debug_graph); + this->idx_start = idx_start; + this->idx_end = idx_end; + this->use_fusion = use_fusion; + this->use_concurrency = use_concurrency; + this->use_capture = use_capture; + this->debug_graph = debug_graph; + this->debug_fusion = debug_fusion; + this->gf = gf; + + idxs.reserve(gf->n_nodes); + + // filter empty nodes + // TODO: this can be removed when the allocator starts filtering them earlier + // https://github.com/ggml-org/llama.cpp/pull/16130#issuecomment-3327905830 + for (int i = idx_start; i < idx_end; i++) { + if (!ggml_op_is_empty(gf->nodes[i]->op) && !ggml_is_empty(gf->nodes[i])) { + idxs.push_back(i); + } + } + } + + ~ggml_metal_op() { + ggml_metal_encoder_end_encoding(this->enc); + ggml_metal_encoder_free(this->enc); + ggml_mem_ranges_free(this->mem_ranges); + } + + int n_nodes() const { + return idxs.size(); + } + + ggml_tensor * node(int i) const { + assert(i >= 0 && i < (int) idxs.size()); + return ggml_graph_node(gf, idxs[i]); + } + + bool can_fuse(int i0, const ggml_op * ops, int n_ops) const { + assert(use_fusion); + assert(i0 >= 0 && i0 < n_nodes()); + + if (i0 + n_ops > n_nodes()) { + return false; + } + + return ggml_can_fuse_ext(gf, idxs.data() + i0, ops, n_ops); + } + ggml_metal_device_t dev; ggml_metal_library_t lib; ggml_metal_encoder_t enc; ggml_mem_ranges_t mem_ranges; - ggml_cgraph * gf; - - int idx_start; - int idx_end; - bool use_fusion; bool use_concurrency; bool use_capture; int debug_graph; int debug_fusion; + +private: + ggml_cgraph * gf; + + int idx_start; + int idx_end; + + // non-empty node indices + std::vector idxs; }; ggml_metal_op_t ggml_metal_op_init( @@ -53,34 +119,29 @@ ggml_metal_op_t ggml_metal_op_init( bool use_capture, int debug_graph, int debug_fusion) { - ggml_metal_op_t res = new ggml_metal_op(); - - *res = { - /*.dev =*/ dev, - /*.lib =*/ ggml_metal_device_get_library(dev), - /*.enc =*/ ggml_metal_encoder_init(cmd_buf, use_concurrency), - /*.mem_ranges =*/ ggml_mem_ranges_init(debug_graph), - /*.gf =*/ gf, - /*.idx_start =*/ idx_start, - /*.idx_end =*/ idx_end, - /*.use_fusion =*/ use_fusion, - /*.use_concurrency =*/ use_concurrency, - /*.use_capture =*/ use_capture, - /*.debug_graph =*/ debug_graph, - /*.debug_fusion =*/ debug_fusion, - }; + ggml_metal_op_t res = new ggml_metal_op( + dev, + cmd_buf, + gf, + idx_start, + idx_end, + use_fusion, + use_concurrency, + use_capture, + debug_graph, + debug_fusion); return res; } void ggml_metal_op_free(ggml_metal_op_t ctx) { - ggml_metal_encoder_end_encoding(ctx->enc); - ggml_metal_encoder_free(ctx->enc); - ggml_mem_ranges_free(ctx->mem_ranges); - delete ctx; } +int ggml_metal_op_n_nodes(ggml_metal_op_t ctx) { + return ctx->n_nodes(); +} + static bool ggml_metal_op_concurrency_reset(ggml_metal_op_t ctx) { if (!ctx->mem_ranges) { return true; @@ -110,10 +171,7 @@ static bool ggml_metal_op_concurrency_add(ggml_metal_op_t ctx, const ggml_tensor } static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { - struct ggml_cgraph * gf = ctx->gf; - - struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx; - struct ggml_tensor * node = nodes[0]; + struct ggml_tensor * node = ctx->node(idx); //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op)); @@ -129,6 +187,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { case GGML_OP_PERMUTE: { // noop -> next node + if (ctx->debug_graph > 0) { + GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), "(noop)"); + } } return 1; default: { @@ -352,7 +413,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { // update the mem ranges in the encoding context for (int i = 0; i < n_fuse; ++i) { - if (!ggml_metal_op_concurrency_add(ctx, nodes[i])) { + if (!ggml_metal_op_concurrency_add(ctx, ctx->node(idx + i))) { ggml_metal_op_concurrency_reset(ctx); } } @@ -362,11 +423,11 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) { if (ctx->use_capture) { - ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ggml_graph_node(ctx->gf, idx))); + ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ctx->node(idx))); } int res = ggml_metal_op_encode_impl(ctx, idx); - if (idx + res > ctx->idx_end) { + if (idx + res > ctx->n_nodes()) { GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s", "https://github.com/ggml-org/llama.cpp/pull/14849"); } @@ -379,8 +440,7 @@ int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -438,8 +498,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -483,8 +542,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -594,8 +652,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -634,8 +691,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -674,8 +730,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -703,8 +758,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -774,8 +828,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -838,8 +891,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -876,8 +928,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -939,8 +990,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1030,8 +1080,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1076,8 +1125,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1170,8 +1218,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1212,8 +1259,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1286,8 +1332,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1347,8 +1392,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1476,22 +1520,20 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { !ggml_is_transposed(op->src[1]) && // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - props_dev->has_simdgroup_mm && - op->src[1]->type == GGML_TYPE_F32 && - ne00 % 32 == 0 && ne00 >= 64 && + props_dev->has_simdgroup_mm && ne00 >= 64 && (ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) { //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) - switch (op->src[0]->type) { - case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; - case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; - case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; - default: break; - } + //switch (op->src[0]->type) { + // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; + // default: break; + //} - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op->src[0]->type, op->src[1]->type); + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op); ggml_metal_kargs_mul_mm args = { /*.ne00 =*/ ne00, @@ -1523,6 +1565,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { } else { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); + const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); + const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); + const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + ggml_metal_kargs_mul_mv args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, @@ -1540,16 +1588,11 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { /*.nb13 =*/ nb13, /*.ne0 =*/ ne0, /*.ne1 =*/ ne1, + /*.nr0 =*/ nr0, /*.r2 =*/ r2, /*.r3 =*/ r3, }; - const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); - const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); - const int nsg = ggml_metal_pipeline_get_nsg(pipeline); - - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); - ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); @@ -1589,8 +1632,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const ggml_tensor * op) { } int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1612,8 +1654,6 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(!ggml_is_transposed(op->src[0])); GGML_ASSERT(!ggml_is_transposed(op->src[1])); - GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); - GGML_ASSERT(ne03 == 1); GGML_ASSERT(ne13 == 1); @@ -1631,19 +1671,15 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { // ne21 = n_rows (batch size) const int ne21_mm_id_min = 32; - if (props_dev->has_simdgroup_mm && - ne00 % 32 == 0 && ne00 >= 64 && - (ne21 >= ne21_mm_id_min)) { - GGML_ASSERT(ne00 % 4 == 0); - + if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) { // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) - switch (op->src[0]->type) { - case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; - case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; - case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; - default: break; - } + //switch (op->src[0]->type) { + // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; + // default: break; + //} // extra buffers for intermediate id mapping ggml_metal_buffer_id bid_tpe = bid_dst; @@ -1687,7 +1723,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); { - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op->src[0]->type, GGML_TYPE_F16); + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op); ggml_metal_kargs_mul_mm_id args = { /*.ne00 =*/ ne00, @@ -1723,6 +1759,14 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1); } } else { + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); + + const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); + const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); + const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + ggml_metal_kargs_mul_mv_id args = { /*.nei0 =*/ ne20, /*.nei1 =*/ ne21, @@ -1743,16 +1787,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { /*.ne0 =*/ ne0, /*.ne1 =*/ ne1, /*.nb1 =*/ nb1, + /*.nr0 =*/ nr0, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); - - const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); - const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); - const int nsg = ggml_metal_pipeline_get_nsg(pipeline); - - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); - if (ggml_is_quantized(op->src[0]->type)) { GGML_ASSERT(ne00 >= nsg*nr0); } @@ -1783,8 +1820,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -1856,8 +1892,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) { } int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2176,16 +2211,11 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); - - ggml_tensor ** ops = ggml_graph_nodes(gf) + idx; + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; - const int idx_end = ctx->idx_end; - const bool use_fusion = ctx->use_fusion; const int debug_fusion = ctx->debug_fusion; @@ -2258,22 +2288,25 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops // across splits. idx_end indicates the last node in the current split - for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) { - if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) { + for (n_fuse = 0; n_fuse <= 6; ++n_fuse) { + if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) { break; } - if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) { + ggml_tensor * f0 = ctx->node(idx + n_fuse); + ggml_tensor * f1 = ctx->node(idx + n_fuse + 1); + + if (f0 != f1->src[0]) { break; } // b[0] === b[1] === ... - if (!ggml_are_same_layout(ops[n_fuse]->src[1], ops[n_fuse + 1]->src[1])) { + if (!ggml_are_same_layout(f0->src[1], f1->src[1])) { break; } // only fuse ops if src1 is in the same Metal buffer - ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]); + ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(f1->src[1]); if (bid_fuse.metal != bid_src1.metal) { break; } @@ -2309,10 +2342,10 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { } if (n_fuse > 1) { - bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]); + bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1)); for (int i = 1; i < n_fuse; ++i) { - if (!ggml_metal_op_concurrency_check(ctx, ops[i])) { + if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) { ggml_metal_op_concurrency_reset(ctx); break; @@ -2344,8 +2377,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2393,8 +2425,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2445,20 +2476,15 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; - const int idx_end = ctx->idx_end; - const bool use_fusion = ctx->use_fusion; const int debug_fusion = ctx->debug_fusion; - ggml_tensor ** ops = ggml_graph_nodes(gf) + idx; - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); @@ -2499,38 +2525,41 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { fops[1] = GGML_OP_MUL; fops[2] = GGML_OP_ADD; - for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) { - if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) { + for (n_fuse = 0; n_fuse <= 1; ++n_fuse) { + if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) { break; } - if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) { + ggml_tensor * f0 = ctx->node(idx + n_fuse); + ggml_tensor * f1 = ctx->node(idx + n_fuse + 1); + + if (f0 != f1->src[0]) { break; } - if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) { + if (f1->src[1]->ne[0] != op->ne[0]) { break; } - if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) { + if (!ggml_is_contiguous_rows(f1->src[1])) { break; } - if (ops[n_fuse + 1]->type != GGML_TYPE_F32) { + if (f1->type != GGML_TYPE_F32) { break; } - //ctx->fuse_cnt[ops[n_fuse + 1]->op]++; + //ctx->fuse_cnt[f1->op]++; - bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]); + bid_fuse[n_fuse] = ggml_metal_get_buffer_id(f1->src[1]); - args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1]; - args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2]; - args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3]; + args.nef1[n_fuse + 1] = f1->src[1]->ne[1]; + args.nef2[n_fuse + 1] = f1->src[1]->ne[2]; + args.nef3[n_fuse + 1] = f1->src[1]->ne[3]; - args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1]; - args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2]; - args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3]; + args.nbf1[n_fuse + 1] = f1->src[1]->nb[1]; + args.nbf2[n_fuse + 1] = f1->src[1]->nb[2]; + args.nbf3[n_fuse + 1] = f1->src[1]->nb[3]; } ++n_fuse; @@ -2546,10 +2575,10 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { } if (n_fuse > 1) { - bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]); + bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1)); for (int i = 1; i < n_fuse; ++i) { - if (!ggml_metal_op_concurrency_check(ctx, ops[i])) { + if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) { ggml_metal_op_concurrency_reset(ctx); break; @@ -2585,8 +2614,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2681,8 +2709,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2752,8 +2779,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2798,8 +2824,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2852,8 +2877,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2897,8 +2921,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2944,8 +2967,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -2985,8 +3007,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -3020,8 +3041,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -3060,8 +3080,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; @@ -3103,8 +3122,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { } int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { - ggml_cgraph * gf = ctx->gf; - ggml_tensor * op = ggml_graph_node(gf, idx); + ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index a1151f881b..8df4c72e7c 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -22,6 +22,8 @@ ggml_metal_op_t ggml_metal_op_init( void ggml_metal_op_free(ggml_metal_op_t ctx); +int ggml_metal_op_n_nodes(ggml_metal_op_t ctx); + int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx); // diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 48856c79b7..96df6f0ce6 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -33,6 +33,7 @@ using namespace metal; #if defined(GGML_METAL_HAS_BF16) typedef matrix bfloat4x4; +typedef matrix bfloat2x4; #endif constexpr constant static float kvalues_iq4nl_f[16] = { @@ -3530,7 +3531,25 @@ void kernel_mul_mv_t_t_impl( helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); } -template +template +void kernel_mul_mv_t_t_disp( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + switch (args.nr0) { + //case 1: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + case 2: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 3: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 4: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + } +} + +template kernel void kernel_mul_mv_t_t( constant ggml_metal_kargs_mul_mv & args, device const char * src0, @@ -3540,17 +3559,17 @@ kernel void kernel_mul_mv_t_t( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_t_t_disp(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(kernel_mul_mv_t_t) mul_mv_t_t; +typedef decltype(kernel_mul_mv_t_t) mul_mv_t_t; -template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; -template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; -template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; -template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; #endif template @@ -3636,7 +3655,25 @@ void kernel_mul_mv_t_t_4_impl( helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); } -template +template +void kernel_mul_mv_t_t_4_disp( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + switch (args.nr0) { + //case 1: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + case 2: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 3: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 4: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + }; +} + +template kernel void kernel_mul_mv_t_t_4( constant ggml_metal_kargs_mul_mv & args, device const char * src0, @@ -3646,23 +3683,21 @@ kernel void kernel_mul_mv_t_t_4( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_t_t_4_disp(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(kernel_mul_mv_t_t_4) mul_mv_t_t_4; +typedef decltype(kernel_mul_mv_t_t_4) mul_mv_t_t_4; -template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; -template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; -template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; -template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; #endif -#define N_MV_T_T 4 - -template -void kernel_mul_mv_c4_impl( +template +void kernel_mul_mv_t_t_short_impl( args_t args, device const char * src0, device const char * src1, @@ -3670,7 +3705,7 @@ void kernel_mul_mv_c4_impl( uint3 tgpig, ushort tiisg) { const int r0 = tgpig.x*32 + tiisg; - const int rb = tgpig.y*N_MV_T_T; + const int r1 = tgpig.y; const int im = tgpig.z; if (r0 >= args.ne01) { @@ -3682,33 +3717,32 @@ void kernel_mul_mv_c4_impl( const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - device const T04 * x = (device const T04 *) (src0 + offset0); + device const T0 * x = (device const T0 *) (src0 + offset0); device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; - for (int row = 0; row < N_MV_T_T; ++row) { - int r1 = rb + row; - if (r1 >= args.ne11) { - break; - } + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + device const T1 * y = (device const T1 *) (src1 + offset1); - device const T14 * y = (device const T14 *) (src1 + offset1); + float res = 0.0f; - dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]); + for (int i = 0; i < args.ne00; ++i) { + res += (float) x[i] * (float) y[i]; } + + dst_f32[(uint64_t)r1*args.ne0 + r0] = res; } -template -kernel void kernel_mul_mv_c4( +template +kernel void kernel_mul_mv_t_t_short( constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_c4_impl( + kernel_mul_mv_t_t_short_impl( args, src0, src1, @@ -3717,14 +3751,14 @@ kernel void kernel_mul_mv_c4( tiisg); } -typedef decltype(kernel_mul_mv_c4) mul_mv_c4_t; +typedef decltype(kernel_mul_mv_t_t_short) mul_mv_t_t_short_t; -template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -template [[host_name("kernel_mul_mv_f16_f16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; +template [[host_name("kernel_mul_mv_f32_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_f16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_f16_f16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -template [[host_name("kernel_mul_mv_bf16_bf16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; +template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; #endif static float rope_yarn_ramp(const float low, const float high, const int i0) { @@ -7856,6 +7890,9 @@ kernel void kernel_set_rows_f( } } +constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; +constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; + #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B #define BLOCK_SIZE_K 32 @@ -7868,7 +7905,7 @@ kernel void kernel_set_rows_f( #define SG_MAT_ROW 8 // each block_q contains 16*nl weights -template +template kernel void kernel_mul_mm( constant ggml_metal_kargs_mul_mm & args, device const char * src0, @@ -7879,8 +7916,8 @@ kernel void kernel_mul_mm( ushort tiitg[[thread_index_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup T * sa = (threadgroup T *)(shmem); - threadgroup float * sb = (threadgroup float *)(shmem + 4096); + threadgroup S0 * sa = (threadgroup S0 *)(shmem); + threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); const int r0 = tgpig.y; const int r1 = tgpig.x; @@ -7894,8 +7931,9 @@ kernel void kernel_mul_mm( const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - simdgroup_T8x8 ma[4]; - simdgroup_float8x8 mb[2]; + S0_8x8 ma[4]; + S1_8x8 mb[2]; + simdgroup_float8x8 mc[8]; for (short i = 0; i < 8; i++){ @@ -7913,27 +7951,45 @@ kernel void kernel_mul_mm( device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; - device const float * y = (device const float *)(src1 + const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)); + + device const T1 * y = (device const T1 *)(src1 + args.nb13*i13 + args.nb12*i12 + args.nb11*(r1*BLOCK_SIZE_N + thread_col) - + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + args.nb10*iy); for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory - T4x4 temp_a; - dequantize_func(x, il, temp_a); + if (is_same::value && FC_mul_mm_bc_inp) { + threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup_barrier(mem_flags::mem_threadgroup); + // no need for dequantization + for (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0; + } + } else { + S0_4x4 temp_a; + dequantize_func(x, il, temp_a); - #pragma unroll(16) - for (short i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ - + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ - + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + } } - *(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y); + if (FC_mul_mm_bc_inp) { + for (short i = 0; i < 8; ++i) { + sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0; + } + } else { + *(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y)); + } il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2 + nl - 1)/nl : x; @@ -7942,8 +7998,8 @@ kernel void kernel_mul_mm( threadgroup_barrier(mem_flags::mem_threadgroup); // load matrices from threadgroup memory and conduct outer products - threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); - threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); + threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); + threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); #pragma unroll(4) for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { @@ -7971,7 +8027,8 @@ kernel void kernel_mul_mm( } } - if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) { + if (!FC_mul_mm_bc_out || ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1)) { + // if no bounds checks on the output are needed, we can directly write to device memory device float * C = (device float *) dst + (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \ (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; @@ -8076,7 +8133,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_ template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>; template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>; -template +template kernel void kernel_mul_mm_id( constant ggml_metal_kargs_mul_mm_id & args, device const char * src0, @@ -8090,8 +8147,8 @@ kernel void kernel_mul_mm_id( ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup T * sa = (threadgroup T *)(shmem); - threadgroup half * sb = (threadgroup half *)(shmem + 4096); + threadgroup S0 * sa = (threadgroup S0 *)(shmem); + threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); const int r0 = tgpig.y; const int r1 = tgpig.x; @@ -8114,8 +8171,9 @@ kernel void kernel_mul_mm_id( const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - simdgroup_T8x8 ma[4]; - simdgroup_half8x8 mb[2]; + S0_8x8 ma[4]; + S1_8x8 mb[2]; + simdgroup_float8x8 mc[8]; for (short i = 0; i < 8; i++){ @@ -8136,27 +8194,45 @@ kernel void kernel_mul_mm_id( device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; - device const float * y = (device const float *)(src1 + const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)); + + device const T1 * y = (device const T1 *)(src1 + args.nb13*i13 + args.nb12*i12 + args.nb11*i11 - + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + args.nb10*iy); for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory - T4x4 temp_a; - dequantize_func(x, il, temp_a); + if (is_same::value && FC_mul_mm_bc_inp) { + threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup_barrier(mem_flags::mem_threadgroup); + // no need for dequantization + for (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = loop_k + 16*il + i < args.ne00 ? ((device T0 *) x)[i] : 0; + } + } else { + S0_4x4 temp_a; + dequantize_func(x, il, temp_a); - #pragma unroll(16) - for (short i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ - + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ - + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + FOR_UNROLL (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + } } - *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (half2x4)(*((device float2x4 *) y)); + if (FC_mul_mm_bc_inp) { + for (short i = 0; i < 8; ++i) { + sb[32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0; + } + } else { + *(threadgroup S1_2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y)); + } il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2 + nl - 1)/nl : x; @@ -8165,8 +8241,8 @@ kernel void kernel_mul_mm_id( threadgroup_barrier(mem_flags::mem_threadgroup); // load matrices from threadgroup memory and conduct outer products - threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); - threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); + threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); + threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); #pragma unroll(4) for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { @@ -8299,72 +8375,123 @@ template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kerne // matrix-matrix multiplication // -typedef decltype(kernel_mul_mm) mul_mm_t; +typedef decltype(kernel_mul_mm) mul_mm_t; -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm; #endif -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm; + +template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm; +#endif +template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm; // // indirect matrix-matrix multiplication // -typedef decltype(kernel_mul_mm_id) mul_mm_id; +typedef decltype(kernel_mul_mm_id) mul_mm_id; -template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_mul_mm_id; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id; #endif -template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id; +#endif +template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; // // matrix-vector multiplication // -typedef void (kernel_mul_mv_impl_t)( +typedef void (kernel_mul_mv_disp_t)( ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, @@ -8372,7 +8499,7 @@ typedef void (kernel_mul_mv_impl_t)( uint3 tgpig, ushort tiisg); -typedef void (kernel_mul_mv2_impl_t)( +typedef void (kernel_mul_mv2_disp_t)( ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, @@ -8382,7 +8509,7 @@ typedef void (kernel_mul_mv2_impl_t)( ushort tiisg, ushort sgitg); -template +template void mmv_fn( ggml_metal_kargs_mul_mv args, device const char * src0, @@ -8393,10 +8520,10 @@ void mmv_fn( ushort tiitg, ushort tiisg, ushort sgitg) { - impl_fn(args, src0, src1, dst, tgpig, tiisg); + disp_fn(args, src0, src1, dst, tgpig, tiisg); } -template +template void mmv_fn( ggml_metal_kargs_mul_mv args, device const char * src0, @@ -8407,12 +8534,12 @@ void mmv_fn( ushort tiitg, ushort tiisg, ushort sgitg) { - impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(mmv_fn>) mul_mv_impl_fn_t; +typedef decltype(mmv_fn>) mul_mv_disp_fn_t; -template +template kernel void kernel_mul_mv_id( constant ggml_metal_kargs_mul_mv_id & args, device const char * src0s, @@ -8459,11 +8586,12 @@ kernel void kernel_mul_mv_id( /*.nb13 =*/ args.nb12, // ne12 == 1 /*.ne0 =*/ args.ne0, /*.ne1 =*/ 1, // args.ne1, + /*.nr0 =*/ args.nr0, /*.r2 =*/ 1, /*.r3 =*/ 1, }; - impl_fn( + disp_fn( args0, /* src0 */ src0_cur, /* src1 */ src1_cur, @@ -8475,19 +8603,19 @@ kernel void kernel_mul_mv_id( sgitg); } -typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; -typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_4_t; +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_4_t; -template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; #endif -template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; #endif template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0cf3b92464..79d2148744 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2889,10 +2889,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_REPEAT: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded case GGML_OP_PAD: - return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && - op->src[0]->ne[3] == 1 && op->ne[3] == 1 && - (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && - (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_CONV_2D: @@ -4222,15 +4219,19 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - const int ne00 = src0 ? src0->ne[0] : 0; - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; - const cl_ulong nb02 = src0 ? src0->nb[2] : 0; - const int ne10 = src1 ? src1->ne[0] : 0; - const cl_ulong nb10 = src1 ? src1->nb[0] : 0; - const int ne11 = src1 ? src1->ne[1] : 0; - const cl_ulong nb11 = src1 ? src1->nb[1] : 0; - const cl_ulong nb1 = dst ? dst->nb[1] : 0; - const cl_ulong nb2 = dst ? dst->nb[2] : 0; + const int ne00 = src0->ne[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + const int ne10 = src1->ne[0]; + const cl_ulong nb10 = src1->nb[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -4267,14 +4268,17 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb10)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3)); - size_t global_work_size[] = {(size_t)ne10, (size_t)ne11, 1}; - size_t local_work_size[] = {1, 1, 1}; + size_t global_work_size[] = {(size_t)ne10*64, (size_t)ne11, (size_t)ne12}; + size_t local_work_size[] = {64, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } @@ -5874,7 +5878,6 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t GGML_ASSERT(dst->extra); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -5892,28 +5895,67 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t const int s_ne0 = src0->ne[0]; const int s_ne1 = src0->ne[1]; const int s_ne2 = src0->ne[2]; + const int s_ne3 = src0->ne[3]; + + const int s_nb0 = src0->nb[0]; + const int s_nb1 = src0->nb[1]; + const int s_nb2 = src0->nb[2]; + const int s_nb3 = src0->nb[3]; const int d_ne0 = dst->ne[0]; const int d_ne1 = dst->ne[1]; const int d_ne2 = dst->ne[2]; + const int d_ne3 = dst->ne[3]; + + const int d_nb0 = dst->nb[0]; + const int d_nb1 = dst->nb[1]; + const int d_nb2 = dst->nb[2]; + const int d_nb3 = dst->nb[3]; + + const int lp0 = ((const int*)(dst->op_params))[0]; + const int rp0 = ((const int*)(dst->op_params))[1]; + const int lp1 = ((const int*)(dst->op_params))[2]; + const int rp1 = ((const int*)(dst->op_params))[3]; + const int lp2 = ((const int*)(dst->op_params))[4]; + const int rp2 = ((const int*)(dst->op_params))[5]; + const int lp3 = ((const int*)(dst->op_params))[6]; + const int rp3 = ((const int*)(dst->op_params))[7]; cl_kernel kernel = backend_ctx->kernel_pad; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne0)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne1)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne2)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &s_ne3)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &s_nb0)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &s_nb1)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &s_nb2)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &s_nb3)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &d_ne3)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &d_nb0)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &d_nb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &d_nb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &d_nb3)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &lp0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &rp0)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &lp1)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &rp1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &lp2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &rp2)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &lp3)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(int), &rp3)); size_t lws0 = 64; size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0; - size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2 }; + size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2*d_ne3 }; size_t local_work_size[] = { lws0, 1, 1 }; size_t * local_work_size_ptr = local_work_size; diff --git a/ggml/src/ggml-opencl/kernels/get_rows.cl b/ggml/src/ggml-opencl/kernels/get_rows.cl index b3fea2923d..c2962edc98 100644 --- a/ggml/src/ggml-opencl/kernels/get_rows.cl +++ b/ggml/src/ggml-opencl/kernels/get_rows.cl @@ -69,11 +69,14 @@ kernel void kernel_get_rows_f32( int ne00, ulong nb01, ulong nb02, + ulong nb03, int ne10, ulong nb10, ulong nb11, + ulong nb12, ulong nb1, - ulong nb2 + ulong nb2, + ulong nb3 ) { src0 = (global void*)((global char*)src0 + offset0); src1 = (global int*)((global char*)src1 + offset1); @@ -81,14 +84,19 @@ kernel void kernel_get_rows_f32( int i10 = get_group_id(0); int i11 = get_group_id(1); + int i12 = get_group_id(2); - int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; int i02 = i11; + int i03 = i12; for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = - ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + if (ind >= ne00) { + return; + } + ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] = + ((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind]; } } @@ -102,11 +110,14 @@ kernel void kernel_get_rows_f16( int ne00, ulong nb01, ulong nb02, + ulong nb03, int ne10, ulong nb10, ulong nb11, + ulong nb12, ulong nb1, - ulong nb2 + ulong nb2, + ulong nb3 ) { src0 = (global void*)((global char*)src0 + offset0); src1 = (global int*)((global char*)src1 + offset1); @@ -114,14 +125,19 @@ kernel void kernel_get_rows_f16( int i10 = get_group_id(0); int i11 = get_group_id(1); + int i12 = get_group_id(2); - int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; int i02 = i11; + int i03 = i12; for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = - ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + if (ind >= ne00) { + return; + } + ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] = + ((global half *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind]; } } @@ -135,11 +151,14 @@ kernel void kernel_get_rows_q4_0( int ne00, ulong nb01, ulong nb02, + ulong nb03, int ne10, ulong nb10, ulong nb11, + ulong nb12, ulong nb1, - ulong nb2 + ulong nb2, + ulong nb3 ) { src0 = (global void*)((global char*)src0 + offset0); src1 = (global int*)((global char*)src1 + offset1); @@ -149,15 +168,20 @@ kernel void kernel_get_rows_q4_0( int i10 = get_group_id(0); int i11 = get_group_id(1); + int i12 = get_group_id(2); - int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; int i02 = i11; + int i03 = i12; for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { float16 temp; + if (ind >= ne00) { + return; + } dequantize_q4_0_f32( - ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp); - *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03)) + ind/NL, ind%NL, &temp); + *(((global float16 *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1)) + ind) = temp; } } diff --git a/ggml/src/ggml-opencl/kernels/pad.cl b/ggml/src/ggml-opencl/kernels/pad.cl index 747fa7febc..31fb7ccd3b 100644 --- a/ggml/src/ggml-opencl/kernels/pad.cl +++ b/ggml/src/ggml-opencl/kernels/pad.cl @@ -1,30 +1,39 @@ kernel void kernel_pad( - global const void * src0_ptr, - ulong src0_offset, - global void * dst_ptr, - ulong dst_offset, - int s_ne0, int s_ne1, int s_ne2, - int d_ne0, int d_ne1, int d_ne2 + global void * src0, + ulong offset0, + global void * dst, + ulong offsetd, + int ne00, int ne01, int ne02, int ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, + int ne0, int ne1, int ne2, int ne3, + ulong nb0, ulong nb1, ulong nb2, ulong nb3, + int lp0, int rp0, + int lp1, int rp1, + int lp2, int rp2, + int lp3, int rp3 ) { - global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset); - global float * dst = (global float *)((global char *)dst_ptr + dst_offset); + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); - int nidx = get_global_id(0); - int idx_d1 = get_group_id(1); - int idx_d2 = get_group_id(2); + int i0 = get_global_id(0); + int i1 = get_group_id(1); + int i2 = get_group_id(2) % ne2; + int i3 = get_group_id(2) / ne2; - if (nidx >= d_ne0) { + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { return; } - int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1; + uint src0_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00; + uint dst_idx = i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; - bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2); + global float * src0_ptr = (global float *)((global char *)src0 + src0_idx); + global float * dst_ptr = (global float *)((global char *)dst + dst_idx); - if (in_src_bounds) { - int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1; - dst[dst_el_offset] = src0[src_el_offset]; - } else { - dst[dst_el_offset] = 0.0f; - } + bool in_src_bounds = (i0 >= lp0 && i0 < ne0 - rp0) && + (i1 >= lp1 && i1 < ne1 - rp1) && + (i2 >= lp2 && i2 < ne2 - rp2) && + (i3 >= lp3 && i3 < ne3 - rp3); + + *dst_ptr = in_src_bounds ? *src0_ptr : 0.0f; } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ebbb412e55..003a901067 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -7,12 +7,20 @@ // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- #define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1 +// We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE +// to avoid conflicts with applications or other libraries who might use it. +#if VK_HEADER_VERSION >= 301 +namespace vk::detail { class DispatchLoaderDynamic; } +using vk::detail::DispatchLoaderDynamic; +#else +namespace vk { class DispatchLoaderDynamic; } +using vk::DispatchLoaderDynamic; +#endif +DispatchLoaderDynamic & ggml_vk_default_dispatcher(); +#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher() #include -// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- -VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE - #include #include #include @@ -406,6 +414,8 @@ struct vk_device_struct { bool subgroup_ballot; bool subgroup_clustered; bool multi_add; + bool shader_int64; + bool buffer_device_address; bool add_rms_fusion; uint32_t partials_binding_alignment; @@ -653,6 +663,7 @@ struct vk_buffer_struct { vk::MemoryPropertyFlags memory_property_flags; void * ptr; size_t size = 0; + vk::DeviceAddress bda_addr {}; vk_device device; @@ -985,6 +996,7 @@ struct vk_op_argsort_push_constants { }; struct vk_op_im2col_push_constants { + uint64_t dst_addr; uint32_t batch_offset; uint32_t offset_delta; uint32_t IC; uint32_t IW; uint32_t IH; @@ -998,6 +1010,7 @@ struct vk_op_im2col_push_constants { }; struct vk_op_im2col_3d_push_constants { + uint64_t dst_addr; uint32_t nb10; uint32_t nb11; uint32_t nb12; @@ -2010,10 +2023,17 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std return buf; } + vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst; + vk::MemoryAllocateFlags mem_flags {}; + if (device->buffer_device_address) { + usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress; + mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress; + } + vk::BufferCreateInfo buffer_create_info{ vk::BufferCreateFlags(), size, - vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst, + usage_flags, vk::SharingMode::eExclusive, 0, nullptr, @@ -2025,6 +2045,8 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); + const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags }; + for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { const auto & req_flags = *it; @@ -2036,7 +2058,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std buf->memory_property_flags = req_flags; try { - buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); + buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info }); break; } catch (const vk::SystemError& e) { // loop and retry @@ -2064,6 +2086,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std buf->device = device; buf->size = size; + if (device->buffer_device_address) { + const vk::BufferDeviceAddressInfo addressInfo(buf->buffer); + buf->bda_addr = device->device.getBufferAddress(addressInfo); + } + #ifdef GGML_VULKAN_MEMORY_DEBUG device->memory_logger->log_allocation(buf, size); #endif @@ -3256,6 +3283,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q2_K], "get_rows_q2_k", get_rows_q2_k_len, get_rows_q2_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q3_K], "get_rows_q3_k", get_rows_q3_k_len, get_rows_q3_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_K], "get_rows_q4_k", get_rows_q4_k_len, get_rows_q4_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_K], "get_rows_q5_k", get_rows_q5_k_len, get_rows_q5_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q6_K], "get_rows_q6_k", get_rows_q6_k_len, get_rows_q6_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S], "get_rows_iq1_s", get_rows_iq1_s_len, get_rows_iq1_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M], "get_rows_iq1_m", get_rows_iq1_m_len, get_rows_iq1_m_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -3275,6 +3307,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q2_K], "get_rows_q2_k_f32", get_rows_q2_k_f32_len, get_rows_q2_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q3_K], "get_rows_q3_k_f32", get_rows_q3_k_f32_len, get_rows_q3_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_K], "get_rows_q4_k_f32", get_rows_q4_k_f32_len, get_rows_q4_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_K], "get_rows_q5_k_f32", get_rows_q5_k_f32_len, get_rows_q5_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q6_K], "get_rows_q6_k_f32", get_rows_q6_k_f32_len, get_rows_q6_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S], "get_rows_iq1_s_f32", get_rows_iq1_s_f32_len, get_rows_iq1_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M], "get_rows_iq1_m_f32", get_rows_iq1_m_f32_len, get_rows_iq1_m_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -3520,14 +3557,20 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); - if (device->float_controls_rte_fp16) { - ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); +#define IM2COL(bda) \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ + if (device->float_controls_rte_fp16) { \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ + } else { \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ + } + if (device->shader_int64 && device->buffer_device_address) { + IM2COL(_bda) } else { - ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); + IM2COL() } ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); @@ -4005,6 +4048,9 @@ static vk_device ggml_vk_get_device(size_t idx) { device->vendor_id != VK_VENDOR_ID_INTEL && getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr; + device->shader_int64 = device_features2.features.shaderInt64; + device->buffer_device_address = vk12_features.bufferDeviceAddress; + if (device->subgroup_size_control) { device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize; @@ -4498,6 +4544,11 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve static bool ggml_vk_instance_debug_utils_ext_available(const std::vector & instance_extensions); static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev); +static DispatchLoaderDynamic ggml_vk_default_dispatcher_instance; +DispatchLoaderDynamic & ggml_vk_default_dispatcher() { + return ggml_vk_default_dispatcher_instance; +} + static void ggml_vk_instance_init() { if (vk_instance_initialized) { return; @@ -4505,13 +4556,13 @@ static void ggml_vk_instance_init() { VK_LOG_DEBUG("ggml_vk_instance_init()"); // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- - VULKAN_HPP_DEFAULT_DISPATCHER.init(vkGetInstanceProcAddr); + ggml_vk_default_dispatcher_instance.init(vkGetInstanceProcAddr); uint32_t api_version = vk::enumerateInstanceVersion(); if (api_version < VK_API_VERSION_1_2) { std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl; - GGML_ABORT("fatal error"); + throw vk::SystemError(vk::Result::eErrorFeatureNotPresent, "Vulkan 1.2 required"); } vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version }; @@ -5643,8 +5694,12 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz ggml_vk_queue_command_pools_cleanup(dst->device); } -static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, const vk_pipeline& pipeline) { - VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")"); +static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) { + VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ", " << disable_split_k << ")"); + + if (disable_split_k) { + return 1; + } uint32_t split_k = 1; if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) { @@ -5969,7 +6024,7 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub ggml_vk_sync_buffers(ctx, subctx); } -static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { +static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; @@ -5987,8 +6042,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const uint64_t ne12 = src1->ne[2]; const uint64_t ne13 = src1->ne[3]; - const uint64_t ne20 = dst->ne[0]; const uint64_t ne21 = dst->ne[1]; + const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type); + const uint32_t stride_batch_d = stride_d*ne21; const uint64_t r2 = ne12 / ne02; const uint64_t r3 = ne13 / ne03; @@ -6057,7 +6113,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const int y_ne = padded_n * ne10; const int d_ne = ne11 * ne01; - const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline); + const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, disable_split_k, pipeline); const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); @@ -6216,13 +6272,16 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; } + // No bounds checking is needed for dst. This is basically VK_WHOLE_SIZE but clamped to maxStorageBufferRange. + VkDeviceSize d_range = std::min(VkDeviceSize{d_D->size - d_buf_offset}, VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange}); + // compute ggml_vk_matmul( ctx, subctx, pipeline, { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total }, - { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, + { d_D, d_buf_offset, d_range }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, - ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, + ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d, split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n ); // NOLINT @@ -6700,9 +6759,36 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 }); } -static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { +static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")"); - if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 && + + // Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases + // where the M dimension is very large. + // Split_k doesn't work with M splitting. + const size_t nbytes = ggml_nbytes(src0); + const bool needs_split = nbytes > ctx->device->properties.limits.maxStorageBufferRange; + if (needs_split) { + // Choose the number of rows that can fit (and divide by two, to allow for any additional offsets) + const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]); + uint32_t m_offset = 0; + while (m_offset < dst->ne[0]) { + const uint32_t cur_M_size = std::min(M_split, (uint32_t)(dst->ne[0] - m_offset)); + ggml_tensor dst2 = *dst; + ggml_tensor src02 = *src0; + + dst2.view_src = dst->view_src ? dst->view_src : dst; + src02.view_src = src0->view_src ? src0->view_src : src0; + + dst2.view_offs += m_offset * dst->nb[0]; + src02.view_offs += m_offset * src0->nb[1]; + dst2.ne[0] = cur_M_size; + src02.ne[1] = cur_M_size; + + ggml_vk_mul_mat_q_f16(ctx, subctx, &src02, src1, &dst2, true, dryrun); + + m_offset += cur_M_size; + } + } else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 && // detect 0213 permutation, and batch size of 1 src0->nb[0] <= src0->nb[2] && src0->nb[2] <= src0->nb[1] && @@ -6722,7 +6808,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) { ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun); } else { - ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun); + ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, false, dryrun); } } @@ -8582,6 +8668,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) { + if (ctx->device->shader_int64 && ctx->device->buffer_device_address) { + // buffer device address path doesn't use dst buffer + d_sz = 1; + } // im2col uses only src1 and dst buffers ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_COUNT_EQUAL) { @@ -9433,7 +9523,13 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t pelements = OW * KW * KH; + const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + const vk_buffer d_buf = d_buf_ctx->dev_buffer; + + const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs; + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, { + dst_addr, batch_offset, offset_delta, IC, IW, IH, OW, OH, KW, KH, pelements, @@ -9469,8 +9565,14 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const int64_t OH = ne2; const int64_t OW = ne1; + const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + const vk_buffer d_buf = d_buf_ctx->dev_buffer; + + const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs; + vk_op_im2col_3d_push_constants pc {}; + pc.dst_addr = dst_addr; pc.nb10 = nb10 / ggml_type_size(src1->type); pc.nb11 = nb11 / ggml_type_size(src1->type); pc.nb12 = nb12 / ggml_type_size(src1->type); @@ -10657,10 +10759,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")"); ctx->semaphore_idx = 0; - const ggml_tensor * src0 = node->src[0]; - const ggml_tensor * src1 = node->src[1]; - const ggml_tensor * src2 = node->src[2]; - const ggml_tensor * src3 = node->src[3]; + ggml_tensor * src0 = node->src[0]; + ggml_tensor * src1 = node->src[1]; + ggml_tensor * src2 = node->src[2]; + ggml_tensor * src3 = node->src[3]; switch (node->op) { // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor @@ -12613,6 +12715,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: @@ -12894,6 +13001,12 @@ ggml_backend_reg_t ggml_backend_vk_reg() { } catch (const vk::SystemError& e) { VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what()); return nullptr; + } catch (const std::exception &e) { + VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: " << e.what()); + return nullptr; + } catch (...) { + VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: unknown exception during Vulkan init"); + return nullptr; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index d3127fbd98..73fef4fa65 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -478,3 +478,139 @@ vec2 get_dm(uint ib, uint a_offset) { return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m)); } #endif + +#if defined(DATA_A_Q2_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 + const uint scalesi = iqs / 8; // 0..15 + const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + + const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]); + const uint scales = data_a[a_offset + ib].scales[scalesi]; + const vec2 d = vec2(data_a[a_offset + ib].d); + + return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif + +#if defined(DATA_A_Q3_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint n = iqs / 64; // 0,1 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + const uint hmi = (iqs % 16) * 2; // 0,2,4..30 + const uint j = (iqs % 64) / 4; // 0..3 + const uint is = iqs / 8; // 0..15 + const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + const int8_t us = int8_t(((data_a[a_offset + ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF) + | (((data_a[a_offset + ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); + const float dl = float(data_a[a_offset + ib].d) * float(us - 32); + + return vec2(dl * float(int8_t((data_a[a_offset + ib].qs[qsi ] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi ] & m) != 0) ? 0 : 4)), + dl * float(int8_t((data_a[a_offset + ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif + +#if defined(DATA_A_Q4_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + + const vec2 loadd = vec2(data_a[a_offset + ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF), m), + fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif + +#if defined(DATA_A_Q5_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + const uint qhi = (iqs % 16) * 2; // 0,2,4..30 + + const uint8_t hm = uint8_t(1 << (iqs / 16)); + + const vec2 loadd = vec2(data_a[a_offset + ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t(((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi ] & hm) != 0 ? 16 : 0), m), + fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif + +#if defined(DATA_A_Q6_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint n = iqs / 64; // 0,1 + const uint b = (iqs % 64) / 32; // 0,1 + const uint is_b = (iqs % 16) / 8; // 0,1 + const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + const uint is = 8 * n + qhshift + is_b; // 0..15 + const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + const float dscale = float(data_a[a_offset + ib].d) * float(data_a[a_offset + ib].scales[is]); + + return vec2(dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32), + dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 482445c6fe..43b906e5ed 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -117,6 +117,9 @@ void main() { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); @@ -155,7 +158,11 @@ void main() { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br) { - masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); + if (!KV_bounds_check || j * Bc + c < KV) { + masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); + } else { + masksh[c][r] = float(0); + } } } barrier(); @@ -172,8 +179,11 @@ void main() { float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br]; [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - rowmaxf[r] = Sf[r][0]; + rowmaxf[r] = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } rowmaxf[r] = max(rowmaxf[r], Sf[r][c]); } Moldf[r] = Mf[r]; @@ -190,6 +200,9 @@ void main() { // Compute sum across row of P rowsumf[r] = 0.0; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } rowsumf[r] += Pf[r][c]; } @@ -203,6 +216,9 @@ void main() { } [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp index f73e17e1fa..9b1f153bf7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -13,6 +13,8 @@ layout (constant_id = 6) const uint32_t D_split = 16; const uint32_t HSK_pad = (HSK + 15) & ~15; const uint32_t HSV_pad = (HSV + 15) & ~15; +const bool KV_bounds_check = Clamp != 0; + layout (push_constant) uniform parameter { uint32_t N; uint32_t KV; @@ -65,30 +67,48 @@ layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; #if defined(A_TYPE_PACKED16) #define BINDING_IDX_K 0 #define BINDING_IDX_V 1 -layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; +layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed; +layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; #endif #if defined(DATA_A_Q4_0) #define BLOCK_BYTE_SIZE 18 vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; + if (binding_idx == BINDING_IDX_K) { + uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + } else { + uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + } } #endif #if defined(DATA_A_Q8_0) #define BLOCK_BYTE_SIZE 34 vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + if (binding_idx == BINDING_IDX_K) { + const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + } else { + const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + } } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 63b32171b0..ddb1246e0b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -152,14 +152,17 @@ void main() { uint32_t d = (idx + tid) % (HSK / 4); uint32_t c = (idx + tid) / (HSK / 4); if (c < Bc && d < HSK / 4) { + f16vec4 K_Tf = f16vec4(0); + if (!KV_bounds_check || j * Bc + c < KV) { #if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); #else - f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); #endif + } ksh[c * kshstride + d] = K_Tf; } @@ -202,7 +205,9 @@ void main() { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)])); + if (!KV_bounds_check || j * Bc + c < KV) { + sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)])); + } } } barrier(); @@ -210,8 +215,11 @@ void main() { float eMf[rows_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride]; + float rowmaxf = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride])); } float Moldf = Mf[r]; @@ -233,6 +241,9 @@ void main() { } [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } float Pf[rows_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index fdbcf7eba0..f0f19a019c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -5,8 +5,11 @@ #include "rte.comp" +#include "types.comp" + layout (push_constant) uniform parameter { + BDA_STORAGE_T dst_addr; uint batch_offset; uint offset_delta; uint IC; uint IW; uint IH; @@ -19,8 +22,6 @@ layout (push_constant) uniform parameter int d0; int d1; } p; -#include "types.comp" - layout(constant_id = 0) const uint BLOCK_SIZE = 32; const uint NUM_ITER = 512 / BLOCK_SIZE; @@ -30,6 +31,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; +#if BDA +layout (buffer_reference) buffer D_ptr {D_TYPE d;}; +#endif + void main() { const uint gidx = gl_GlobalInvocationID.x; @@ -38,7 +43,7 @@ void main() { const uint ic = gl_GlobalInvocationID.z % p.IC; const uint src_base = ic * p.offset_delta + batch * p.batch_offset; - const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH); + const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH); const int oh_s1 = int(oh) * p.s1; const uint ksize = p.OW * p.KH; @@ -50,7 +55,7 @@ void main() { uint current_ix = rem % p.OW; A_TYPE values[NUM_ITER]; - uint offset_dst[NUM_ITER]; + BDA_OFFSET_T offset_dst[NUM_ITER]; [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { values[idx] = A_TYPE(0); } @@ -66,7 +71,7 @@ void main() { const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0; const uint iih = oh_s1 + current_ky * p.d1 - p.p1; - offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx; + offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx; if ((iih < p.IH) && (iiw < p.IW)) { values[idx] = data_a[src_base + iih * p.IW + iiw]; @@ -89,7 +94,11 @@ void main() { continue; } +#if BDA + D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]); + dst_addr.d = D_TYPE(values[idx]); +#else data_d[offset_dst[idx]] = D_TYPE(values[idx]); +#endif } - } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp index 3b010bdeb5..9faa636ac2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp @@ -6,8 +6,11 @@ #include "rte.comp" +#include "types.comp" + layout (push_constant) uniform parameter { + BDA_STORAGE_T dst_addr; uint32_t nb10; uint32_t nb11; uint32_t nb12; @@ -38,8 +41,6 @@ layout (push_constant) uniform parameter uint32_t misalign_offsets; } p; -#include "types.comp" - uint get_aoffset() { return p.misalign_offsets >> 16; } uint get_doffset() { return p.misalign_offsets & 0xFFFF; } @@ -50,6 +51,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; +#if BDA +layout (buffer_reference) buffer D_ptr {D_TYPE d;}; +#endif + void main() { const uint32_t i = gl_GlobalInvocationID.x; @@ -100,13 +105,22 @@ void main() { const uint32_t iih = ioh * s1 + ikh * d1 - p1; const uint32_t iid = iod * s2 + ikd * d2 - p2; - const uint32_t offset_dst = in_*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + const BDA_OFFSET_T offset_dst = BDA_OFFSET_T(in_)*OD_OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(iod)*OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(ioh)*OW_IC_KD_KH_KW + BDA_OFFSET_T(iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10; +#if BDA + D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst); + if (iih >= IH || iiw >= IW || iid >= ID) { + dst_addr.d = D_TYPE(0.0f); + } else { + dst_addr.d = D_TYPE(data_a[offset_src + get_aoffset()]); + } +#else if (iih >= IH || iiw >= IW || iid >= ID) { data_d[offset_dst + get_doffset()] = D_TYPE(0.0f); } else { - const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10; data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]); } +#endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 69ac38fd41..0e3065e014 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -265,7 +265,6 @@ void main() { tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); - tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); #if QUANT_K > 1 tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); @@ -281,6 +280,8 @@ void main() { tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k); tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k); + tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); + tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); #if !defined(MUL_MAT_ID) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index 75aa22eae4..2fa54ce51f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -245,6 +245,7 @@ struct block_q2_K_packed32 #if defined(DATA_A_Q2_K) #define QUANT_K QUANT_K_Q2_K +#define QUANT_R 1 #define A_TYPE block_q2_K #define A_TYPE_PACKED16 block_q2_K_packed16 #define A_TYPE_PACKED32 block_q2_K_packed32 @@ -270,6 +271,7 @@ struct block_q3_K_packed16 #if defined(DATA_A_Q3_K) #define QUANT_K QUANT_K_Q3_K +#define QUANT_R 1 #define A_TYPE block_q3_K #define A_TYPE_PACKED16 block_q3_K_packed16 #endif @@ -304,6 +306,7 @@ struct block_q4_K_packed128 #if defined(DATA_A_Q4_K) #define QUANT_K QUANT_K_Q4_K +#define QUANT_R 1 #define A_TYPE block_q4_K #define A_TYPE_PACKED16 block_q4_K_packed16 #define A_TYPE_PACKED32 block_q4_K_packed32 @@ -334,6 +337,7 @@ struct block_q5_K_packed128 #if defined(DATA_A_Q5_K) #define QUANT_K QUANT_K_Q5_K +#define QUANT_R 1 #define A_TYPE block_q5_K #define A_TYPE_PACKED16 block_q5_K_packed16 #endif @@ -358,6 +362,7 @@ struct block_q6_K_packed16 #if defined(DATA_A_Q6_K) #define QUANT_K QUANT_K_Q6_K +#define QUANT_R 1 #define A_TYPE block_q6_K #define A_TYPE_PACKED16 block_q6_K_packed16 #endif @@ -1442,4 +1447,19 @@ float e8m0_to_fp32(uint8_t x) { return uintBitsToFloat(bits); } +#if BDA + +#extension GL_EXT_buffer_reference : enable +#extension GL_EXT_shader_explicit_arithmetic_types_int64 : enable + +#define BDA_STORAGE_T uint64_t +#define BDA_OFFSET_T uint64_t + +#else + +#define BDA_STORAGE_T uvec2 +#define BDA_OFFSET_T uint + +#endif + #endif // !defined(GGML_TYPES_COMP) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 79701544ff..84bb9df9a0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -589,16 +589,14 @@ void process_shaders() { string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); } - if (!string_ends_with(tname, "_k")) { - shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp"; + shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp"; - if (tname == "f16") { - string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); - } else { - string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); - } - string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); + if (tname == "f16") { + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); + } else { + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); } + string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); } string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); @@ -777,13 +775,15 @@ void process_shaders() { string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); - string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); - string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}})); - - string_to_spv("im2col_3d_f32", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("im2col_3d_f32_f16", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); - string_to_spv("im2col_3d_f32_f16_rte", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}})); + for (std::string dim_str : {"", "_3d"}) { + for (bool bda : {false, true}) { + std::string bda_str = bda ? "_bda" : ""; + std::string bda_def = bda ? "1" : "0"; + string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}})); + string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}})); + string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}})); + } + } string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index cee4b08366..93200a4d29 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -130,13 +130,15 @@ struct webgpu_context_struct { wgpu::ComputePipeline set_rows_pipeline; wgpu::ComputePipeline get_rows_pipeline[30]; wgpu::ComputePipeline get_rows_f32_no_vec_pipeline; - wgpu::ComputePipeline cpy_pipeline; - wgpu::ComputePipeline add_pipeline[2]; - wgpu::ComputePipeline add_ip_pipeline[2]; - wgpu::ComputePipeline mul_pipeline[2]; - wgpu::ComputePipeline mul_ip_pipeline[2]; - wgpu::ComputePipeline rms_norm_pipeline; - wgpu::ComputePipeline rms_norm_ip_pipeline; + wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type + wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace + wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace + wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split + wgpu::ComputePipeline scale_pipeline[2]; // inplace size_t memset_bytes_per_thread; @@ -489,8 +491,9 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), - // Logical shape — same for both tensors even if permuted - (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3] + // Logical shapes + (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] }; std::vector entries = { @@ -506,7 +509,8 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x, + ggml_op_name(dst->op)); } static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { @@ -649,7 +653,7 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * dst, wgpu::ComputePipeline & pipeline, - bool in_place) { + bool inplace) { std::vector params = { (uint32_t) ggml_nelements(dst), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -678,7 +682,7 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, src1), .size = ggml_webgpu_tensor_binding_size(ctx, src1) } }; - if (!in_place) { + if (!inplace) { entries.push_back({ .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), @@ -691,30 +695,23 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx, } static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - bool in_place = ggml_webgpu_tensor_equal(src, dst); - - uint32_t eps; - memcpy(&eps, dst->op_params, sizeof(float)); + int inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + (uint32_t) src->ne[3], + *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader }; - if (!in_place) { - params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type))); - } - params.push_back((uint32_t) (src->nb[1] / ggml_type_size(src->type))); - params.push_back((uint32_t) (src->nb[2] / ggml_type_size(src->type))); - params.push_back((uint32_t) (src->nb[3] / ggml_type_size(src->type))); - if (!in_place) { - params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type))); - params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type))); - params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type))); - } - params.push_back((uint32_t) src->ne[0]); - params.push_back((uint32_t) src->ne[1]); - params.push_back((uint32_t) src->ne[2]); - params.push_back((uint32_t) src->ne[3]); - params.push_back(eps); // epsilon, will be bitcast to float in shader std::vector entries = { { .binding = 0, @@ -722,24 +719,199 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t .offset = ggml_webgpu_tensor_align_offset(ctx, src), .size = ggml_webgpu_tensor_binding_size(ctx, src) } }; - if (!in_place) { + if (!inplace) { entries.push_back({ .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - wgpu::ComputePipeline pipeline; - if (in_place) { - pipeline = ctx->rms_norm_ip_pipeline; - } else { - pipeline = ctx->rms_norm_pipeline; - } size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, wg_x, + ggml_op_name(dst->op)); +} + +static void ggml_webgpu_rope(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + const int inplace = ggml_webgpu_tensor_equal(src0, dst); + const int has_freq_factor = (src2 != nullptr); + + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + + int sections[4]; + memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int)); + + float theta_scale = powf(freq_base, -2.0f / n_dims); + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) ggml_nelements(src0) / 2, + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + (uint32_t) n_dims, + (uint32_t) mode, + *(uint32_t *) &theta_scale, + *(uint32_t *) &attn_factor, + *(uint32_t *) &freq_scale, + *(uint32_t *) &ext_factor, + *(uint32_t *) &corr_dims[0], + *(uint32_t *) &corr_dims[1], + (uint32_t) sections[0], + (uint32_t) sections[1], + (uint32_t) sections[2], + (uint32_t) sections[3] + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) } + }; + uint32_t dst_binding = 2; + if (has_freq_factor) { + dst_binding = 3; + entries.push_back({ .binding = 2, + .buffer = ggml_webgpu_tensor_buf(src2), + .offset = ggml_webgpu_tensor_align_offset(ctx, src2), + .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + } + if (!inplace) { + entries.push_back({ .binding = dst_binding, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } + + wgpu::ComputePipeline pipeline = ctx->rope_pipeline[dst->type][has_freq_factor][inplace]; + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size; ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); } +static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { + const int split = (src1 != nullptr); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) ggml_nelements(dst), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) ((int32_t *) dst->op_params)[1], // swapped + *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai + *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + }; + uint32_t dst_binding = 1; + if (split) { + dst_binding = 2; + entries.push_back({ .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + } + entries.push_back({ .binding = dst_binding, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + + wgpu::ComputePipeline pipeline = ctx->glu_pipeline[ggml_get_glu_op(dst)][dst->type][split]; + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); +} + +static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + int inplace = ggml_webgpu_tensor_equal(src, dst); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) ggml_nelements(dst), + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + *(uint32_t *) dst->op_params, // scale + *(uint32_t *) &dst->op_params[1] // bias + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) } + }; + if (!inplace) { + entries.push_back({ .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } + + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->scale_pipeline[inplace], params, entries, wg_x, + ggml_op_name(dst->op)); +} + // Returns true if node has enqueued work into the queue, false otherwise static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { if (ggml_is_empty(node)) { @@ -749,6 +921,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { ggml_tensor * src0 = node->src[0]; ggml_tensor * src1 = node->src[1]; + ggml_tensor * src2 = node->src[2]; switch (node->op) { // no-ops @@ -759,6 +932,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { case GGML_OP_RESHAPE: return false; case GGML_OP_CPY: + case GGML_OP_CONT: ggml_webgpu_cpy(ctx, src0, node); break; case GGML_OP_SET_ROWS: @@ -771,22 +945,41 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { ggml_webgpu_mul_mat(ctx, src0, src1, node); break; case GGML_OP_ADD: - if (ggml_webgpu_tensor_equal(src0, node)) { - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true); - } else { - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false); + { + int inplace = ggml_webgpu_tensor_equal(src0, node); + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace); + break; + } + case GGML_OP_SUB: + { + int inplace = ggml_webgpu_tensor_equal(src0, node); + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace); + break; } - break; case GGML_OP_MUL: - if (ggml_webgpu_tensor_equal(src0, node)) { - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true); - } else { - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false); + { + int inplace = ggml_webgpu_tensor_equal(src0, node); + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace); + break; + } + case GGML_OP_DIV: + { + int inplace = ggml_webgpu_tensor_equal(src0, node); + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace); + break; } - break; case GGML_OP_RMS_NORM: ggml_webgpu_rms_norm(ctx, src0, node); break; + case GGML_OP_ROPE: + ggml_webgpu_rope(ctx, src0, src1, src2, node); + break; + case GGML_OP_GLU: + ggml_webgpu_glu(ctx, src0, src1, node); + break; + case GGML_OP_SCALE: + ggml_webgpu_scale(ctx, src0, node); + break; default: return false; } @@ -1170,40 +1363,153 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", - ggml_webgpu_max_wg_size_entry(webgpu_ctx)); + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32], + wgsl_cpy_f32_f32, "cpy_f32_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16], + wgsl_cpy_f32_f16, "cpy_f32_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F32], + wgsl_cpy_f16_f32, "cpy_f16_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F16], + wgsl_cpy_f16_f16, "cpy_f16_f16", constants); } static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32, - "add_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16, - "add_in_place_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1], wgsl_add_f32_inplace, + "add_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1], wgsl_add_f16_inplace, + "add_f16_inplace", constants); +} + +static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1], wgsl_sub_f32_inplace, + "sub_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1], wgsl_sub_f16_inplace, + "sub_f16_inplace", constants); } static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32, - "mul_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16, - "mul_in_place_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1], wgsl_mul_f32_inplace, + "mul_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1], wgsl_mul_f16_inplace, + "mul_f16_inplace", constants); +} + +static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1], wgsl_div_f32_inplace, + "div_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1], wgsl_div_f16_inplace, + "div_f16_inplace", constants); } static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline, wgsl_rms_norm, "rms_norm", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_ip_pipeline, wgsl_rms_norm_in_place, - "rms_norm_in_place", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace, + "rms_norm_inplace", constants); +} + +static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32, + "rope_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1], + wgsl_rope_f32_inplace, "rope_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][0], wgsl_rope_f32_ff, + "rope_f32_ff", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][1], + wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][0], wgsl_rope_f16, + "rope_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][1], + wgsl_rope_f16_inplace, "rope_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][0], wgsl_rope_f16_ff, + "rope_f16_ff", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][1], + wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); +} + +static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + // reglu + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0], + wgsl_reglu_f32, "reglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0], + wgsl_reglu_f16, "reglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1], + wgsl_reglu_f32_split, "reglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1], + wgsl_reglu_f16_split, "reglu_f16_split", constants); + // geglu + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0], + wgsl_geglu_f32, "geglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0], + wgsl_geglu_f16, "geglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1], + wgsl_geglu_f32_split, "geglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1], + wgsl_geglu_f16_split, "geglu_f16_split", constants); + // swiglu + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0], + wgsl_swiglu_f32, "swiglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0], + wgsl_swiglu_f16, "swiglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1], + wgsl_swiglu_f32_split, "swiglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1], + wgsl_swiglu_f16_split, "swiglu_f16_split", constants); + // swiglu_oai + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0], + wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1], + wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); + // geglu_erf + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0], + wgsl_geglu_erf_f32, "geglu_erf_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0], + wgsl_geglu_erf_f16, "geglu_erf_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1], + wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1], + wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); + // geglu_quick + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0], + wgsl_geglu_quick_f32, "geglu_quick_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0], + wgsl_geglu_quick_f16, "geglu_quick_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1], + wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1], + wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); +} + +static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace, + "scale_f32_inplace", constants); } static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { @@ -1287,6 +1593,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * src0 = op->src[0]; ggml_tensor * src1 = op->src[1]; + // on smaller devices (or CI), tensors may be larger than the max storage buffer size if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) || @@ -1304,28 +1611,34 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const supports_op = true; break; case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_MUL: - supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) && - (op->src[1]->type == op->type); + case GGML_OP_DIV: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) && + (src1->type == op->type); break; case GGML_OP_CPY: + case GGML_OP_CONT: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + break; case GGML_OP_SET_ROWS: supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64); break; case GGML_OP_GET_ROWS: - if (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || - op->src[0]->type == GGML_TYPE_I32 || ggml_webgpu_supported_qtype(op->src[0]->type)) { + if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 || + ggml_webgpu_supported_qtype(src0->type)) { supports_op = (op->type == GGML_TYPE_F32); } break; case GGML_OP_MUL_MAT: { - switch (op->src[1]->type) { + switch (src1->type) { case GGML_TYPE_F16: - supports_op = (op->src[0]->type == GGML_TYPE_F16); + supports_op |= (src0->type == GGML_TYPE_F16); break; case GGML_TYPE_F32: - switch (op->src[0]->type) { + switch (src0->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_Q4_0: @@ -1358,7 +1671,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } case GGML_OP_RMS_NORM: - supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_ROPE: + supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16; + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16; + break; + case GGML_GLU_OP_SWIGLU_OAI: + supports_op = op->type == GGML_TYPE_F32; + break; + default: + break; + } + break; + case GGML_OP_SCALE: + supports_op = op->type == GGML_TYPE_F32; break; default: break; @@ -1484,8 +1819,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_webgpu_init_get_rows_pipeline(ctx); ggml_webgpu_init_cpy_pipeline(ctx); ggml_webgpu_init_add_pipeline(ctx); + ggml_webgpu_init_sub_pipeline(ctx); ggml_webgpu_init_mul_pipeline(ctx); + ggml_webgpu_init_div_pipeline(ctx); ggml_webgpu_init_rms_norm_pipeline(ctx); + ggml_webgpu_init_rope_pipeline(ctx); + ggml_webgpu_init_glu_pipeline(ctx); + ggml_webgpu_init_scale_pipeline(ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl deleted file mode 100644 index f261cbb553..0000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +++ /dev/null @@ -1,44 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - } - }, - { - "REPLS": { - "TYPE" : "f16", - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)]; - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl deleted file mode 100644 index 903f7bdbcc..0000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +++ /dev/null @@ -1,41 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - } - }, - { - "REPLS": { - "TYPE" : "f16", - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)]; - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl new file mode 100644 index 0000000000..1ce4d83fa8 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl @@ -0,0 +1,188 @@ +#define(VARIANTS) + +[ + { + "SHADER_NAME": "add_f32", + "REPLS": { + "TYPE" : "f32", + "OP": "+" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "add_f16", + "REPLS": { + "TYPE" : "f16", + "OP": "+" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "add_f32_inplace", + "REPLS": { + "TYPE" : "f32", + "OP": "+" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "add_f16_inplace", + "REPLS": { + "TYPE" : "f16", + "OP": "+" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "mul_f32", + "REPLS": { + "TYPE" : "f32", + "OP": "*" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "mul_f16", + "REPLS": { + "TYPE" : "f16", + "OP": "*" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "mul_f32_inplace", + "REPLS": { + "TYPE" : "f32", + "OP": "*" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "mul_f16_inplace", + "REPLS": { + "TYPE" : "f16", + "OP": "*" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sub_f32", + "REPLS": { + "TYPE" : "f32", + "OP": "-" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sub_f16", + "REPLS": { + "TYPE" : "f16", + "OP": "-" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sub_f32_inplace", + "REPLS": { + "TYPE" : "f32", + "OP": "-" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sub_f16_inplace", + "REPLS": { + "TYPE" : "f16", + "OP": "-" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "div_f32", + "REPLS": { + "TYPE" : "f32", + "OP": "/" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "div_f16", + "REPLS": { + "TYPE" : "f16", + "OP": "/" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "div_f32_inplace", + "REPLS": { + "TYPE" : "f32", + "OP": "/" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "div_f16_inplace", + "REPLS": { + "TYPE" : "f16", + "OP": "/" + }, + "DECLS": ["INPLACE"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(NOT_INPLACE) + +fn update(dst_i: u32, src0_i: u32, src1_i: u32) { + dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; +} + +@group(0) @binding(2) +var dst: array<{{TYPE}}>; + +@group(0) @binding(3) +var params: Params; + +#enddecl(NOT_INPLACE) + +#decl(INPLACE) + +fn update(dst_i: u32, src0_i: u32, src1_i: u32) { + src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; +} + +@group(0) @binding(2) +var params: Params; + +#enddecl(INPLACE) + +#end(DECLS) + + +#define(SHADER) + +enable f16; + +#include "binary_head.tmpl" + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +@group(0) @binding(1) +var src1: array<{{TYPE}}>; + +DECLS + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x)); + } +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl new file mode 100644 index 0000000000..db1aa34903 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl @@ -0,0 +1,101 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "SRC_TYPE": "f32", + "DST_TYPE": "f32" + } + }, + { + "REPLS": { + "SRC_TYPE": "f32", + "DST_TYPE": "f16" + } + }, + { + "REPLS": { + "SRC_TYPE": "f16", + "DST_TYPE": "f16" + } + }, + { + "REPLS": { + "SRC_TYPE": "f16", + "DST_TYPE": "f32" + } + } +] + +#end(VARIANTS) + +#define(SHADER) +enable f16; + +@group(0) @binding(0) +var src: array<{{SRC_TYPE}}>; + +@group(0) @binding(1) +var dst: array<{{DST_TYPE}}>; + +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32 +}; + +@group(0) @binding(2) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); + i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); + let i2 = i / (params.src_ne1 * params.src_ne0); + i = i % (params.src_ne1 * params.src_ne0); + let i1 = i / params.src_ne0; + let i0 = i % params.src_ne0; + + var j = gid.x; + let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + let j2 = j / (params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne1 * params.dst_ne0); + let j1 = j / params.dst_ne0; + let j0 = j % params.dst_ne0; + + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + + i2 * params.stride_src2 + i3 * params.stride_src3; + + let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + + j2 * params.stride_dst2 + j3 * params.stride_dst3; + + dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx])); +} +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl deleted file mode 100644 index 6fe924c554..0000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +++ /dev/null @@ -1,60 +0,0 @@ -enable f16; - -@group(0) @binding(0) -var src: array; - -@group(0) @binding(1) -var dst: array; - -struct Params { - ne: u32, // total number of elements - offset_src: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) — may be permuted - stride_src0: u32, - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_dst0: u32, - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Logical shape (same for both tensors) - ne0: u32, - ne1: u32, - ne2: u32, - ne3: u32, -}; - -@group(0) @binding(2) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - - let i1 = i / params.ne0; - let i0 = i % params.ne0; - - let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + - i2 * params.stride_src2 + i3 * params.stride_src3; - - let dst_idx = i0 * params.stride_dst0 + i1 * params.stride_dst1 + - i2 * params.stride_dst2 + i3 * params.stride_dst3; - - dst[params.offset_dst + dst_idx] = f16(src[params.offset_src + src_idx]); -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index d9dfd7d6f4..251051eaec 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -88,15 +88,20 @@ def generate_variants(fname, input_dir, output_dir, outfile): raise ValueError(f"DECLS key '{key}' not found.") decls_code += decls_map[key] + "\n\n" - shader_variant = replace_placeholders(shader_template, variant["REPLS"]) - final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant) + final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template) + if "REPLS" in variant: + final_shader = replace_placeholders(final_shader, variant["REPLS"]) final_shader = expand_includes(final_shader, input_dir) - if "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]: + if "SHADER_NAME" in variant: + output_name = variant["SHADER_NAME"] + elif "SHADER_SUFFIX" in variant: + output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"] + elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]: output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]]) - elif "TYPE_SUFFIX" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE_SUFFIX"] - elif "TYPE" in variant["REPLS"]: + elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]: + output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]]) + elif "REPLS" in variant and "TYPE" in variant["REPLS"]: output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"] else: output_name = shader_base_name diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl index e3fe311b26..f80ce1fc55 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl @@ -2,9 +2,9 @@ [ { + "SHADER_SUFFIX": "f32_vec", "REPLS": { "TYPE" : "vec4", - "TYPE_SUFFIX": "f32_vec", "DST_TYPE": "vec4", "BLOCK_SIZE": 4 }, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl new file mode 100644 index 0000000000..03fcd54868 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl @@ -0,0 +1,323 @@ +#define(VARIANTS) + +[ + { + "SHADER_NAME": "reglu_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "REGLU"] + }, + { + "SHADER_NAME": "reglu_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "REGLU"] + }, + { + "SHADER_NAME": "reglu_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "REGLU"] + }, + { + "SHADER_NAME": "reglu_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "REGLU"] + }, + { + "SHADER_NAME": "geglu_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "GEGLU"] + }, + { + "SHADER_NAME": "geglu_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "GEGLU"] + }, + { + "SHADER_NAME": "geglu_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "GEGLU"] + }, + { + "SHADER_NAME": "geglu_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "GEGLU"] + }, + { + "SHADER_NAME": "swiglu_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "SWIGLU"] + }, + { + "SHADER_NAME": "swiglu_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "SWIGLU"] + }, + { + "SHADER_NAME": "swiglu_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "SWIGLU"] + }, + { + "SHADER_NAME": "swiglu_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "SWIGLU"] + }, + { + "SHADER_NAME": "swiglu_oai_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "SWIGLU_OAI"] + }, + { + "SHADER_NAME": "swiglu_oai_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "SWIGLU_OAI"] + }, + { + "SHADER_NAME": "geglu_erf_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "GEGLU_ERF"] + }, + { + "SHADER_NAME": "geglu_erf_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "GEGLU_ERF"] + }, + { + "SHADER_NAME": "geglu_erf_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "GEGLU_ERF"] + }, + { + "SHADER_NAME": "geglu_erf_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "GEGLU_ERF"] + }, + { + "SHADER_NAME": "geglu_quick_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] + }, + { + "SHADER_NAME": "geglu_quick_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "GEGLU_QUICK"] + }, + { + "SHADER_NAME": "geglu_quick_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] + }, + { + "SHADER_NAME": "geglu_quick_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "GEGLU_QUICK"] + }, +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(REGLU) +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + return max(a, 0) * b; +} +#enddecl(REGLU) + +#decl(GEGLU) +const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876; +const GELU_COEF_A: {{TYPE}} = 0.044715; + +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a); + return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b; +} +#enddecl(GEGLU) + +#decl(SWIGLU) +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + return a / (1.0 + exp(-a)) * b; +} +#enddecl(SWIGLU) + +#decl(SWIGLU_OAI) +fn op(a: f32, b: f32) -> f32 { + let xi = min(a, params.limit); + let gi = max(min(b, params.limit), -params.limit); + var out_glu = xi / (1.0 + exp(-xi * params.alpha)); + out_glu = out_glu * (1.0 + gi); + return out_glu; +} +#enddecl(SWIGLU_OAI) + +#decl(GEGLU_ERF) +const p_erf: {{TYPE}} = 0.3275911; +const a1_erf: {{TYPE}} = 0.254829592; +const a2_erf: {{TYPE}} = -0.284496736; +const a3_erf: {{TYPE}} = 1.421413741; +const a4_erf: {{TYPE}} = -1.453152027; +const a5_erf: {{TYPE}} = 1.061405429; +const SQRT_2_INV: {{TYPE}} = 0.7071067811865476; + +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + let a_div_sqr2 = a * SQRT_2_INV; + let sign_x = sign(a_div_sqr2); + let x = abs(a_div_sqr2); + let t = 1.0 / (1.0 + p_erf * x); + let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x)); + let erf_approx = sign_x * y; + return 0.5 * a * (1.0 + erf_approx) * b; +} +#enddecl(GEGLU_ERF) + +#decl(GEGLU_QUICK) +const GELU_QUICK_COEF: {{TYPE}} = -1.702; + +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b; +} +#enddecl(GEGLU_QUICK) + +#decl(NO_SPLIT) +@group(0) @binding(1) +var dst: array<{{TYPE}}>; + +@group(0) @binding(2) +var params: Params; + +fn a_value(base: u32) -> {{TYPE}} { + let offset: u32 = select(0, params.ne0, params.swapped != 0); + return src0[base + offset]; +} + +fn b_value(base: u32) -> {{TYPE}} { + let offset: u32 = select(params.ne0, 0, params.swapped != 0); + return src0[base + offset]; +} +#enddecl(NO_SPLIT) + +#decl(SPLIT) +@group(0) @binding(1) +var src1: array<{{TYPE}}>; + +@group(0) @binding(2) +var dst: array<{{TYPE}}>; + +@group(0) @binding(3) +var params: Params; + +fn a_value(base: u32) -> {{TYPE}} { + return src0[base]; +} + +fn b_value(base: u32) -> {{TYPE}} { + return src1[base]; +} +#enddecl(SPLIT) + +#end(DECLS) + +#define(SHADER) + +enable f16; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // shape of dst + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + swapped: u32, + alpha: f32, + limit: f32, +} + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +DECLS + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0; + let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; + let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; + + dst[i_dst] = op(a_value(i_a), b_value(i_b)); +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl deleted file mode 100644 index 12506e1420..0000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +++ /dev/null @@ -1,44 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - } - }, - { - "REPLS": { - "TYPE" : "f16", - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)]; - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl deleted file mode 100644 index e467e59edb..0000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +++ /dev/null @@ -1,41 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - } - }, - { - "REPLS": { - "TYPE" : "f16", - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)]; - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl index f919a51336..a275eeb978 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl @@ -1,9 +1,48 @@ -@group(0) @binding(0) -var src: array; +#define(VARIANTS) + +[ + { + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_SUFFIX": "inplace", + "DECLS": ["INPLACE"] + }, +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(NOT_INPLACE) + +fn update(src_offset: u32, dst_offset: u32, scale: f32) { + dst[dst_offset] = scale * src[src_offset]; +} @group(0) @binding(1) var dst: array; +@group(0) @binding(2) +var params: Params; + +#enddecl(NOT_INPLACE) + +#decl(INPLACE) + +fn update(src_offset: u32, dst_offset: u32, scale: f32) { + src[dst_offset] = scale * src[src_offset]; +} + +@group(0) @binding(1) +var params: Params; + +#enddecl(INPLACE) + +#end(DECLS) + +#define(SHADER) + struct Params { offset_src: u32, // in elements offset_dst: u32, // in elements @@ -23,11 +62,13 @@ struct Params { ne2: u32, ne3: u32, - eps: u32 + eps: f32 }; -@group(0) @binding(2) -var params: Params; +@group(0) @binding(0) +var src: array; + +DECLS override wg_size: u32; @compute @workgroup_size(wg_size) @@ -49,9 +90,9 @@ fn main(@builtin(global_invocation_id) gid: vec3) { for (var j: u32 = 0; j < params.ne0; j++) { sum += src[i_src_row + j] * src[i_src_row + j]; } - let eps = bitcast(params.eps); - let scale = 1.0/sqrt(sum/f32(params.ne0) + eps); + let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); for (var j: u32 = 0; j < params.ne0; j++) { - dst[i_dst_row + j] = scale * src[i_src_row + j]; + update(i_src_row + j, i_dst_row + j, scale); } } +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl deleted file mode 100644 index ae84f556d6..0000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +++ /dev/null @@ -1,48 +0,0 @@ -@group(0) @binding(0) -var a: array; - -struct Params { - offset: u32, // in elements - - // Strides (in elements) - stride1: u32, - stride2: u32, - stride3: u32, - - // Shape - ne0: u32, - ne1: u32, - ne2: u32, - ne3: u32, - - eps: u32 -}; - -@group(0) @binding(1) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne1 * params.ne2 * params.ne3) { - return; - } - - // one thread per row - var i = gid.x; - let i3 = i / (params.ne2 * params.ne1); - i = i % (params.ne2 * params.ne1); - let i2 = i / params.ne1; - let i1 = i % params.ne1; - let i_row = params.offset + i3 * params.stride3 + i2 * params.stride2 + i1 * params.stride1; - - var sum = 0.0f; - for (var j: u32 = 0; j < params.ne0; j++) { - sum += a[i_row + j] * a[i_row + j]; - } - let eps = bitcast(params.eps); - let scale = 1.0/sqrt(sum/f32(params.ne0) + eps); - for (var j: u32 = 0; j < params.ne0; j++) { - a[i_row + j] = scale * a[i_row + j]; - } -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl new file mode 100644 index 0000000000..9a6ff41128 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl @@ -0,0 +1,282 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] + }, + { + "SHADER_SUFFIX": "f32_inplace", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] + }, + { + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] + }, + { + "SHADER_SUFFIX": "f16_inplace", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] + }, + { + "SHADER_SUFFIX": "f32_ff", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] + }, + { + "SHADER_SUFFIX": "f32_ff_inplace", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] + }, + { + "SHADER_SUFFIX": "f16_ff", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] + }, + { + "SHADER_SUFFIX": "f16_ff_inplace", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(ROTATE) +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + dst[i_dst0] = {{TYPE}}(out0); + dst[i_dst1] = {{TYPE}}(out1); +} +#enddecl(ROTATE) + +#decl(ROTATE_INPLACE) +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + src0[i_dst0] = {{TYPE}}(out0); + src0[i_dst1] = {{TYPE}}(out1); +} +#enddecl(ROTATE_INPLACE) + +#decl(NO_FF_FUNC) +fn freq_factor(i: u32) -> f32 { + return 1.0f; +} +#enddecl(NO_FF_FUNC) + +#decl(FF_FUNC) +fn freq_factor(i: u32) -> f32 { + return src2[params.offset_src2 + i/2]; +} +#enddecl(FF_FUNC) + +#decl(NO_FF_BINDINGS) + +@group(0) @binding(2) +var dst: array<{{TYPE}}>; + +@group(0) @binding(3) +var params: Params; + +#enddecl(NO_FF_BINDINGS) + +#decl(NO_FF_BINDINGS_INPLACE) + +@group(0) @binding(2) +var params: Params; + +#enddecl(NO_FF_BINDINGS_INPLACE) + +#decl(FF_BINDINGS) + +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var dst: array<{{TYPE}}>; + +@group(0) @binding(4) +var params: Params; + +#enddecl(FF_BINDINGS) + +#decl(FF_BINDINGS_INPLACE) + +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var params: Params; + +#enddecl(FF_BINDINGS_INPLACE) + +#end(DECLS) + +#define(SHADER) + +enable f16; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_src2: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + n_threads: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + n_dims: u32, + mode: u32, + theta_scale: f32, + attn_factor: f32, + freq_scale: f32, + ext_factor: f32, + corr_dim0: f32, + corr_dim1: f32, + sections0: u32, + sections1: u32, + sections2: u32, + sections3: u32 +}; + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +@group(0) @binding(1) +var src1: array; + +DECLS + +fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 { + let y = (f32(i / 2) - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// returns vector of (cos_theta, sin_theta) +// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row +fn rope_yarn(theta_extrap: f32, i: u32) -> vec2 { + var mscale = params.attn_factor; + var theta = params.freq_scale * theta_extrap; + if (params.ext_factor != 0.0f) { + let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor; + theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix; + mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale); + } + return vec2(cos(theta) * mscale, sin(theta) * mscale); +} + +fn pair_base(i0: u32, div_2: bool) -> u32 { + if (div_2) { + return i0 / 2; + } else { + return i0; + } +} + +fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 { + if (is_vision) { + return params.n_dims; + } else if (is_neox || is_mrope) { + return params.n_dims / 2; + } else { + return 1; + } +} + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + // two elements per thread + if (gid.x >= params.n_threads) { + return; + } + + let is_neox = bool(params.mode & 2); + let is_mrope = bool(params.mode & 8); + let is_vision = params.mode == 24; + + var i = gid.x * 2; // start index for this thread + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + if (i0 >= params.n_dims && !is_vision) { + let i_src = i_src_row + i0; + let i_dst = i_dst_row + i0; + rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1])); + return; + } + + var theta_base_mult: u32 = 0; + var theta_scale_pwr: u32 = i0 / 2; + if (is_mrope) { + let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3; + let sec_w = params.sections1 + params.sections0; + let sec_e = params.sections2 + sec_w; + let sector = (i0 / 2) % sect_dims; + if (sector >= params.sections0 && sector < sec_w) { + theta_base_mult = 1; + if (is_vision) { + theta_scale_pwr = sector - params.sections0; + } + } else if (sector >= sec_w && sector < sec_e) { + theta_base_mult = 2; + if (is_vision) { + theta_scale_pwr = sector - sec_w; + } + } else if (sector >= sec_e) { + if (is_vision) { + theta_scale_pwr = sector - sec_e; + theta_scale_pwr = (i0 / 2) % sec_e; + } + theta_base_mult = 3; + } else if (is_vision) { + theta_scale_pwr = sector; + } + } + let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr)); + let thetas = rope_yarn(theta_base/freq_factor(i0), i0); + + let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision); + let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision); + + let x0 = f32(src0[i_src]); + let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]); + rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x); +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl new file mode 100644 index 0000000000..040e80dfea --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl @@ -0,0 +1,90 @@ +#define(VARIANTS) + +[ + { + "SHADER_NAME": "scale_f32", + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "scale_f32_inplace", + "DECLS": ["INPLACE"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(NOT_INPLACE) +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; + +fn store_scale(val: f32, offset: u32) { + dst[offset] = val; +} +#enddecl(NOT_INPLACE) + +#decl(INPLACE) +@group(0) @binding(1) +var params: Params; + +fn store_scale(val: f32, offset: u32) { + src[offset] = val; +} +#enddecl(INPLACE) + +#end(DECLS) + +#define(SHADER) + +struct Params { + offset_src: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + scale: f32, + bias: f32 +}; + +@group(0) @binding(0) +var src: array; + +DECLS + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0; + let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; + + store_scale(src[i_src] * params.scale + params.bias, i_dst); +} +#end(SHADER) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index fe36bab836..aecbdad5a3 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3687,6 +3687,7 @@ struct ggml_tensor * ggml_set_rows( result->op = GGML_OP_SET_ROWS; result->src[0] = b; result->src[1] = c; + result->src[2] = a; // note: order is weird due to legacy reasons (https://github.com/ggml-org/llama.cpp/pull/16063#discussion_r2385795931) return result; } @@ -3927,7 +3928,7 @@ static struct ggml_tensor * ggml_rope_impl( memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float)); - if (mrope_used) { + if (mrope_used && sections) { memcpy(params + 11, sections, sizeof(int32_t) * GGML_MROPE_SECTIONS); } else { memset(params + 11, 0, sizeof(int32_t) * GGML_MROPE_SECTIONS); diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 733d30cfa2..5e09de499e 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -978f6e1993f2eeb4e99b63d4e70b4401c0a2dae2 +72632094336524a9c809e129e8b1c52154543a5a diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2ae9abb446..63655bf651 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -675,10 +675,17 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_MINICPM: { + // Backward-compatible defaults for older MiniCPM GGUFs + hparams.f_embedding_scale = 12.0f; + hparams.f_residual_scale = 1.4f / sqrtf(float(hparams.n_layer)); + hparams.f_logit_scale = hparams.n_embd ? (256.0f / float(hparams.n_embd)) : 1.0f; + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + + // Optional KV reads, override defaults if present in newer GGUF exports + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /*required=*/false); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /*required=*/false); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /*required=*/false); // MiniCPM uses rope by default, unlike Granite which uses it as a switch hparams.rope_finetuned = true; @@ -4818,11 +4825,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); } } } @@ -11744,6 +11753,7 @@ struct llm_graph_context_mamba : public llm_graph_context { // TODO: skip computing output earlier for unused tokens y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + cb(y, "mamba2_y_add_d", il); y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // grouped RMS norm @@ -14698,6 +14708,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba { ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); + ggml_build_forward_expand(gf, inpL); auto * inp = build_inp_mem_hybrid(); @@ -14729,7 +14740,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba { // add residual cur = ggml_add(ctx0, cur, inpSA); - cb(cur, "block_out", il); + cb(cur, "nemotron_h_block_out", il); // input for next layer inpL = cur; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 8cb36661a0..da938af03b 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1772,7 +1772,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { const size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); -#ifdef IS_BIG_ENDIAN +#if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ // correct endiannes of data in precompiled_charsmap binary blob uint32_t * xcda_blob_size = (uint32_t *) &precompiled_charsmap[0]; *xcda_blob_size = __builtin_bswap32(*xcda_blob_size); diff --git a/src/unicode.h b/src/unicode.h index 0a5fa2a78c..5bd1362ff4 100644 --- a/src/unicode.h +++ b/src/unicode.h @@ -4,6 +4,7 @@ #include #include +// TODO: reimplement this structure in endian-independent way struct unicode_cpt_flags { enum { UNDEFINED = 0x0001, @@ -15,6 +16,10 @@ struct unicode_cpt_flags { SYMBOL = 0x0040, // regex: \p{S} CONTROL = 0x0080, // regex: \p{C} MASK_CATEGORIES = 0x00FF, + WHITESPACE = 0x0100, + LOWERCASE = 0x0200, + UPPERCASE = 0x0400, + NFD = 0x0800, }; // codepoint type @@ -34,11 +39,49 @@ struct unicode_cpt_flags { // decode from uint16 inline unicode_cpt_flags(const uint16_t flags = 0) { +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ *reinterpret_cast(this) = flags; +#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + is_undefined = (flags & UNDEFINED) ? 1 : 0; + is_number = (flags & NUMBER) ? 1 : 0; + is_letter = (flags & LETTER) ? 1 : 0; + is_separator = (flags & SEPARATOR) ? 1 : 0; + is_accent_mark = (flags & ACCENT_MARK) ? 1 : 0; + is_punctuation = (flags & PUNCTUATION) ? 1 : 0; + is_symbol = (flags & SYMBOL) ? 1 : 0; + is_control = (flags & CONTROL) ? 1 : 0; + is_whitespace = (flags & WHITESPACE) ? 1 : 0; + is_lowercase = (flags & LOWERCASE) ? 1 : 0; + is_uppercase = (flags & UPPERCASE) ? 1 : 0; + is_nfd = (flags & NFD) ? 1 : 0; +#else +#error Unexpected or undefined __BYTE_ORDER__ +#endif } inline uint16_t as_uint() const { +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ return *reinterpret_cast(this); +#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + uint16_t result = + is_undefined * UNDEFINED + + is_number * NUMBER + + is_letter * LETTER + + is_separator * SEPARATOR + + is_accent_mark * ACCENT_MARK + + is_punctuation * PUNCTUATION + + is_symbol * SYMBOL + + is_control * CONTROL + + is_whitespace * WHITESPACE + + is_lowercase * LOWERCASE + + is_uppercase * UPPERCASE + + is_nfd * NFD + ; + + return result; +#else +#error Unexpected or undefined __BYTE_ORDER__ +#endif } inline uint16_t category_flag() const { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3e9e082d93..d9cc5e933f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -185,7 +185,11 @@ llama_build_and_test(test-json-partial.cpp) llama_build_and_test(test-log.cpp) llama_build_and_test(test-regex-partial.cpp) -llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4 -t 2) +if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") + llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4 -t 2) +else() + llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-be.Q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4 -t 2) +endif() # this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135) if (NOT WIN32) diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index e2836ca481..a60ca12fe5 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -126,52 +126,35 @@ int main(void) { assert(params.cpuparams.n_threads == 1010); #endif // _WIN32 - if (common_has_curl()) { - printf("test-arg-parser: test curl-related functions\n\n"); - const char * GOOD_URL = "https://ggml.ai/"; - const char * BAD_URL = "https://www.google.com/404"; - const char * BIG_FILE = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin"; + printf("test-arg-parser: test curl-related functions\n\n"); + const char * GOOD_URL = "http://ggml.ai/"; + const char * BAD_URL = "http://ggml.ai/404"; - { - printf("test-arg-parser: test good URL\n\n"); - auto res = common_remote_get_content(GOOD_URL, {}); - assert(res.first == 200); - assert(res.second.size() > 0); - std::string str(res.second.data(), res.second.size()); - assert(str.find("llama.cpp") != std::string::npos); - } + { + printf("test-arg-parser: test good URL\n\n"); + auto res = common_remote_get_content(GOOD_URL, {}); + assert(res.first == 200); + assert(res.second.size() > 0); + std::string str(res.second.data(), res.second.size()); + assert(str.find("llama.cpp") != std::string::npos); + } - { - printf("test-arg-parser: test bad URL\n\n"); - auto res = common_remote_get_content(BAD_URL, {}); - assert(res.first == 404); - } + { + printf("test-arg-parser: test bad URL\n\n"); + auto res = common_remote_get_content(BAD_URL, {}); + assert(res.first == 404); + } - { - printf("test-arg-parser: test max size error\n"); - common_remote_params params; - params.max_size = 1; - try { - common_remote_get_content(GOOD_URL, params); - assert(false && "it should throw an error"); - } catch (std::exception & e) { - printf(" expected error: %s\n\n", e.what()); - } + { + printf("test-arg-parser: test max size error\n"); + common_remote_params params; + params.max_size = 1; + try { + common_remote_get_content(GOOD_URL, params); + assert(false && "it should throw an error"); + } catch (std::exception & e) { + printf(" expected error: %s\n\n", e.what()); } - - { - printf("test-arg-parser: test timeout error\n"); - common_remote_params params; - params.timeout = 1; - try { - common_remote_get_content(BIG_FILE, params); - assert(false && "it should throw an error"); - } catch (std::exception & e) { - printf(" expected error: %s\n\n", e.what()); - } - } - } else { - printf("test-arg-parser: no curl, skipping curl-related functions\n"); } printf("test-arg-parser: all tests OK\n\n"); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 8918452cb6..62d815cc26 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2140,6 +2140,27 @@ struct test_set_rows : public test_case { } } } + + double max_nmse_err() override { + if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_IQ4_NL || + type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1 || type == GGML_TYPE_Q8_0) { + // estimate what the max nmse error would be if one quantized value is + // off by one. The test values are distributed in [-1,1], so it'll be + // roughly (2.0 / 2^bits)^2, divided by the mean square value of the reference, + // which is roughly 0.25 times the number of elements. + double err_estimate = 1.0f/8.0f; + if (type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) { + err_estimate /= 2.0f; + } + if (type == GGML_TYPE_Q8_0) { + err_estimate /= 8.0f; + } + err_estimate *= err_estimate; + err_estimate /= 0.25f*float(ne[0] * r * ne[2]*nr23[0] * ne[3]*nr23[1]); + return err_estimate; + } + return 1e-7; + } }; // GGML_OP_ARGMAX @@ -2430,6 +2451,30 @@ struct test_cpy : public test_case { } double max_nmse_err() override { + if (type_src == type_dst) { + return 0.0; + } + if (type_dst == GGML_TYPE_Q4_0 || type_dst == GGML_TYPE_Q4_1 || type_dst == GGML_TYPE_IQ4_NL || + type_dst == GGML_TYPE_Q5_0 || type_dst == GGML_TYPE_Q5_1 || type_dst == GGML_TYPE_Q8_0) { + // estimate what the max nmse error would be if one quantized value is + // off by one. The test values are distributed in [-150,150], so it'll be + // roughly (150*2.0 / 2^bits)^2, divided by the mean square value of the reference, + // which is roughly 0.25*150^2 times the number of elements. + double err_estimate = 1.0f/8.0f * 150.0f; + if (type_dst == GGML_TYPE_IQ4_NL) { + // iq4_nl values are a bit more spread out + err_estimate *= 2.0f; + } + if (type_dst == GGML_TYPE_Q5_0 || type_dst == GGML_TYPE_Q5_1) { + err_estimate /= 2.0f; + } + if (type_dst == GGML_TYPE_Q8_0) { + err_estimate /= 8.0f; + } + err_estimate *= err_estimate; + err_estimate /= (150.0f*150.0f*0.25f)*float(ne[0] * ne[1] * ne[2] * ne[3]); + return err_estimate; + } return 1e-6; } @@ -2688,23 +2733,30 @@ struct test_scale : public test_case { const std::array ne; float scale; float bias; + bool inplace; std::string vars() override { - return VARS_TO_STR4(type, ne, scale, bias); + return VARS_TO_STR5(type, ne, scale, bias, inplace); } test_scale(ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 10, 10}, float scale = 2.0f, - float bias = 0.0f) - : type(type), ne(ne), scale(scale), bias(bias) {} + float bias = 0.0f, + bool inplace = false) + : type(type), ne(ne), scale(scale), bias(bias), inplace(inplace) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_param(a); ggml_set_name(a, "a"); - ggml_tensor * out = ggml_scale_bias(ctx, a, scale, bias); + ggml_tensor * out; + if (inplace) { + out = ggml_scale_bias_inplace(ctx, a, scale, bias); + } else { + out = ggml_scale_bias(ctx, a, scale, bias); + } ggml_set_name(out, "out"); return out; @@ -2861,16 +2913,18 @@ struct test_rms_norm : public test_case { const std::array ne; const bool v; // whether a is a non-contiguous view const float eps; + const bool inplace; // whether to do the operation inplace std::string vars() override { - return VARS_TO_STR4(type, ne, v, eps); + return VARS_TO_STR5(type, ne, v, eps, inplace); } test_rms_norm(ggml_type type = GGML_TYPE_F32, std::array ne = {64, 5, 4, 3}, bool v = false, - float eps = 1e-6f) - : type(type), ne(ne), v(v), eps(eps) {} + float eps = 1e-6f, + bool inplace = false) + : type(type), ne(ne), v(v), eps(eps), inplace(inplace) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); @@ -2882,7 +2936,12 @@ struct test_rms_norm : public test_case { ggml_set_name(a, "view of a"); } - ggml_tensor * out = ggml_rms_norm(ctx, a, eps); + ggml_tensor * out; + if (inplace) { + out = ggml_rms_norm_inplace(ctx, a, eps); + } else { + out = ggml_rms_norm(ctx, a, eps); + } ggml_set_name(out, "out"); return out; @@ -3787,17 +3846,18 @@ struct test_rope : public test_case { bool ff; int v; // view (1 : non-contiguous a) bool forward; + bool inplace; std::string vars() override { // forward can be inferred from the op, does not need to be printed - return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v); + return VARS_TO_STR11(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v, inplace); } test_rope(ggml_type type = GGML_TYPE_F32, std::array ne_a = {10, 5, 3, 1}, - int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, - float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true) - : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward) {} + int n_dims = 10, int mode = GGML_ROPE_TYPE_NORMAL, int n_ctx = 512, float fs = 1.0f, + float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true, bool inplace = false) + : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward), inplace(inplace) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a; @@ -3842,7 +3902,11 @@ struct test_rope : public test_case { GGML_ASSERT(n_dims/4 > 0); int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate if (forward) { - out = ggml_rope_multi (ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + if (inplace) { + out = ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } else { + out = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } } else { out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); } @@ -3850,14 +3914,22 @@ struct test_rope : public test_case { GGML_ASSERT(n_dims/3 > 0); int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0}; if (forward) { - out = ggml_rope_multi (ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + if (inplace) { + out = ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } else { + out = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } } else { out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); } } } else { if (forward) { - out = ggml_rope_ext (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + if (inplace) { + out = ggml_rope_ext_inplace(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } else { + out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } } else { out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); } @@ -5753,6 +5825,13 @@ static std::vector> make_test_cases_eval() { } } +#if 0 + // >4GB im2col destination. Too slow to run by default. + // Test cases taken from Wan2.1 T2V 1.3B. + test_cases.emplace_back(new test_im2col (GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {832, 480, 192, 4}, {3, 3, 192, 96}, 1, 1, 1, 1, 1, 1, true)); + test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {834, 482, 6, 96}, {3, 3,3, 9216}, 96, 1, 1, 1, 0, 0, 0, 1, 1, 1, false)); +#endif + // im2col 1D test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); @@ -6131,9 +6210,11 @@ static std::vector> make_test_cases_eval() { //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1}); } - // single in-place tests, especially important for WebGPU backend since kernels for in-place vs. not are different + // single inplace tests, especially important for WebGPU backend since kernels for inplace vs. not are different test_cases.emplace_back(new test_bin_bcast(ggml_add_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); test_cases.emplace_back(new test_bin_bcast(ggml_mul_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); + test_cases.emplace_back(new test_bin_bcast(ggml_sub_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); + test_cases.emplace_back(new test_bin_bcast(ggml_div_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); // fusion test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}, 2)); @@ -6148,6 +6229,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_add1()); test_cases.emplace_back(new test_scale()); test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f)); + test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test + test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {100, 10, 10, 10}, 2.0f, 1.0f)); test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f)); test_cases.emplace_back(new test_silu_back()); @@ -6159,6 +6242,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps)); } + + // in-place tests + test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, false, 1e-6f, true)); + for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) { test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false)); test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); @@ -6200,6 +6287,14 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4)); test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4)); +#if 0 + // > 4GB A matrix. Too slow to be enabled by default. + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 900000, 3, 2592, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 96, 2592, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 3, 2592, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 1, 2592, {1, 1}, {1, 1})); +#endif + for (ggml_type type_a : all_types) { for (int i = 1; i < 10; ++i) { test_cases.emplace_back(new test_mul_mat(type_a, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); @@ -6293,6 +6388,22 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 32, 32, { 1, 1}, {1, 1}, {0, 1, 2, 3}, true, 3)); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1})); +#if 0 + // test the mat-mat path for Metal + for (int k = 1; k < 512; ++k) { + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 127, k, {12,1}, {1,1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 127, k, {12,1}, {1,1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 77, k, {12,1}, {1,1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, k, {12,1}, {1,1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 128, k, {12,1}, {1,1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 128, k, {12,1}, {1,1})); + test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 50, 200, k)); + test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, true, 50, 200, k)); + test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F32, GGML_TYPE_F32, 16, 16, false, 50, 200, k)); + test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F32, GGML_TYPE_F32, 16, 16, true, 50, 200, k)); + } +#endif + for (auto bs2 : {1,3}) { for (auto bs : {1,2,4,8}) { for (auto nr : {1,4}) { @@ -6329,7 +6440,7 @@ static std::vector> make_test_cases_eval() { for (int n_mats : {4, 8}) { for (int n_used : {1, 2, 4}) { for (bool b : {false, true}) { - for (int n : {1, 4, 5, 32, 129}) { + for (int n : {1, 4, 5, 17, 32, 129}) { int m = 512; int k = 256; test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k)); @@ -6482,26 +6593,26 @@ static std::vector> make_test_cases_eval() { for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { for (bool ff : {false, true}) { // freq_factors for (float v : { 0, 1 }) { - test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B + test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 7B if (all) { - test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 13B - test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 30B - test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 65B + test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 13B + test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 30B + test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 65B } if (all) { - test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) - test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) - test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) + test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) + test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) + test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw)); - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); - test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) - test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) } if (all) { @@ -6512,7 +6623,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT) } - test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) + test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) } } @@ -6523,6 +6634,15 @@ static std::vector> make_test_cases_eval() { } } + // single inplace test per type/mode/ff + for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION}) { + for (bool ff : {false, true}) { + test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 0, true, true)); + } + } + } + for (int v : { 0, 1, 2, 3 }) { for (int dim : { 0, 1, 2, 3, }) { test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v)); @@ -6535,6 +6655,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order)); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // bailingmoe2 (group selection) } for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) { @@ -6733,7 +6854,7 @@ static std::vector> make_test_cases_perf() { } // qwen3-30b-a3b - for (int bs : {1, 4, 8, 512}) { + for (int bs : {1, 4, 8, 32, 64, 128, 512}) { for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) { for (ggml_type type_b : {GGML_TYPE_F32}) { test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1)); diff --git a/tests/test-tokenizers-repo.sh b/tests/test-tokenizers-repo.sh index 1158aebae0..94a3d05ba5 100755 --- a/tests/test-tokenizers-repo.sh +++ b/tests/test-tokenizers-repo.sh @@ -23,6 +23,13 @@ if [ -d $folder ] && [ -d $folder/.git ]; then (cd $folder; git pull) else git clone $repo $folder + + # byteswap models if on big endian + if [ "$(uname -m)" = s390x ]; then + for f in $folder/*/*.gguf; do + echo YES | python3 "$(dirname $0)/../gguf-py/gguf/scripts/gguf_convert_endian.py" $f big + done + fi fi shopt -s globstar diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 083fc0cf26..498e00e3a5 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -707,6 +707,10 @@ int main(int argc, char ** argv) { embd.push_back(id); + if (params.conversation_mode && !waiting_for_first_input && !llama_vocab_is_eog(vocab, id)) { + assistant_ss << common_token_to_piece(ctx, id, false); + } + // echo this to console input_echo = true; @@ -824,11 +828,7 @@ int main(int argc, char ** argv) { } } - // if current token is not EOG, we add it to current assistant message if (params.conversation_mode && !waiting_for_first_input) { - const auto id = common_sampler_last(smpl); - assistant_ss << common_token_to_piece(ctx, id, false); - if (!prompt.empty()) { prompt.clear(); is_interacting = false; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 61420193da..210ecc883f 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -3067,7 +3067,7 @@ struct image_manipulation { dst.buf.resize(3 * target_width * target_height); float Cc; - float C[5]; + float C[5] = {}; float d0, d2, d3, a0, a1, a2, a3; int i, j, k, jj; int x, y; diff --git a/tools/perplexity/perplexity.cpp b/tools/perplexity/perplexity.cpp index c22d187cd4..caf080e8d1 100644 --- a/tools/perplexity/perplexity.cpp +++ b/tools/perplexity/perplexity.cpp @@ -1931,11 +1931,13 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { LOG("Maximum KLD: %10.6f\n", kld_values.back()); LOG("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f)); LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f)); + LOG("95.0%% KLD: %10.6f\n", percentile(kld_values, 0.950f)); LOG("90.0%% KLD: %10.6f\n", percentile(kld_values, 0.900f)); LOG("Median KLD: %10.6f\n", kld_median); LOG("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f)); LOG(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f)); LOG(" 1.0%% KLD: %10.6f\n", percentile(kld_values, 0.010f)); + LOG(" 0.1%% KLD: %10.6f\n", percentile(kld_values, 0.001f)); LOG("Minimum KLD: %10.6f\n", kld_values.front()); LOG("\n"); diff --git a/tools/run/run.cpp b/tools/run/run.cpp index 772d66c921..b90a7253c4 100644 --- a/tools/run/run.cpp +++ b/tools/run/run.cpp @@ -9,6 +9,7 @@ #include #if defined(_WIN32) +# define WIN32_LEAN_AND_MEAN # ifndef NOMINMAX # define NOMINMAX # endif @@ -22,6 +23,8 @@ #if defined(LLAMA_USE_CURL) # include +#else +# include "http.h" #endif #include @@ -397,7 +400,6 @@ class File { # endif }; -#ifdef LLAMA_USE_CURL class HttpClient { public: int init(const std::string & url, const std::vector & headers, const std::string & output_file, @@ -428,6 +430,8 @@ class HttpClient { return 0; } +#ifdef LLAMA_USE_CURL + ~HttpClient() { if (chunk) { curl_slist_free_all(chunk); @@ -532,6 +536,117 @@ class HttpClient { return curl_easy_perform(curl); } +#else // LLAMA_USE_CURL is not defined + +#define curl_off_t long long // temporary hack + + private: + // this is a direct translation of the cURL download() above + int download(const std::string & url, const std::vector & headers_vec, const std::string & output_file, + const bool progress, std::string * response_str = nullptr) { + try { + auto [cli, url_parts] = common_http_client(url); + + httplib::Headers headers; + for (const auto & h : headers_vec) { + size_t pos = h.find(':'); + if (pos != std::string::npos) { + headers.emplace(h.substr(0, pos), h.substr(pos + 2)); + } + } + + File out; + if (!output_file.empty()) { + if (!out.open(output_file, "ab")) { + printe("Failed to open file for writing\n"); + return 1; + } + if (out.lock()) { + printe("Failed to exclusively lock file\n"); + return 1; + } + } + + size_t resume_offset = 0; + if (!output_file.empty() && std::filesystem::exists(output_file)) { + resume_offset = std::filesystem::file_size(output_file); + if (resume_offset > 0) { + headers.emplace("Range", "bytes=" + std::to_string(resume_offset) + "-"); + } + } + + progress_data data; + data.file_size = resume_offset; + + long long total_size = 0; + long long received_this_session = 0; + + auto response_handler = + [&](const httplib::Response & response) { + if (resume_offset > 0 && response.status != 206) { + printe("\nServer does not support resuming. Restarting download.\n"); + out.file = freopen(output_file.c_str(), "wb", out.file); + if (!out.file) { + return false; + } + data.file_size = 0; + } + if (progress) { + if (response.has_header("Content-Length")) { + total_size = std::stoll(response.get_header_value("Content-Length")); + } else if (response.has_header("Content-Range")) { + auto range = response.get_header_value("Content-Range"); + auto slash = range.find('/'); + if (slash != std::string::npos) { + total_size = std::stoll(range.substr(slash + 1)); + } + } + } + return true; + }; + + auto content_receiver = + [&](const char * chunk, size_t length) { + if (out.file && fwrite(chunk, 1, length, out.file) != length) { + return false; + } + if (response_str) { + response_str->append(chunk, length); + } + received_this_session += length; + + if (progress && total_size > 0) { + update_progress(&data, total_size, received_this_session, 0, 0); + } + return true; + }; + + auto res = cli.Get(url_parts.path, headers, response_handler, content_receiver); + + if (data.printed) { + printe("\n"); + } + + if (!res) { + auto err = res.error(); + printe("Fetching resource '%s' failed: %s\n", url.c_str(), httplib::to_string(err).c_str()); + return 1; + } + + if (res->status >= 400) { + printe("Fetching resource '%s' failed with status code: %d\n", url.c_str(), res->status); + return 1; + } + + } catch (const std::exception & e) { + printe("HTTP request failed: %s\n", e.what()); + return 1; + } + return 0; + } + +#endif // LLAMA_USE_CURL + static std::string human_readable_time(double seconds) { int hrs = static_cast(seconds) / 3600; int mins = (static_cast(seconds) % 3600) / 60; @@ -644,8 +759,8 @@ class HttpClient { str->append(static_cast(ptr), size * nmemb); return size * nmemb; } + }; -#endif class LlamaData { public: @@ -673,7 +788,6 @@ class LlamaData { } private: -#ifdef LLAMA_USE_CURL int download(const std::string & url, const std::string & output_file, const bool progress, const std::vector & headers = {}, std::string * response_str = nullptr) { HttpClient http; @@ -683,14 +797,6 @@ class LlamaData { return 0; } -#else - int download(const std::string &, const std::string &, const bool, const std::vector & = {}, - std::string * = nullptr) { - printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); - - return 1; - } -#endif // Helper function to handle model tag extraction and URL construction std::pair extract_model_and_tag(std::string & model, const std::string & base_url) { diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index 83b608c32a..06df3ee49d 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -1,7 +1,5 @@ set(TARGET llama-server) -option(LLAMA_SERVER_SSL "Build SSL support for the server" OFF) - include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}) if (MINGW) @@ -37,12 +35,6 @@ target_include_directories(${TARGET} PRIVATE ../mtmd) target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT}) -if (LLAMA_SERVER_SSL) - find_package(OpenSSL REQUIRED) - target_link_libraries(${TARGET} PRIVATE OpenSSL::SSL OpenSSL::Crypto) - target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_OPENSSL_SUPPORT) -endif() - if (WIN32) TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32) endif() diff --git a/tools/server/README.md b/tools/server/README.md index 73b4cc6f03..9f7ab229f7 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -271,7 +271,7 @@ For more details, please refer to [multimodal documentation](../../docs/multimod - Using `CMake`: ```bash - cmake -B build -DLLAMA_SERVER_SSL=ON + cmake -B build -DLLAMA_OPENSSL=ON cmake --build build --config Release -t llama-server ``` @@ -391,7 +391,7 @@ node index.js ## API Endpoints -### GET `/health`: Returns heath check result +### GET `/health`: Returns health check result This endpoint is public (no API key check). @@ -846,7 +846,7 @@ To use this endpoint with POST method, you need to start server with `--props` ### POST `/embeddings`: non-OpenAI-compatible embeddings API -This endpoint supports all poolings, including `--pooling none`. When the pooling is `none`, the responses will contain the *unnormalized* embeddings for *all* input tokens. For all other pooling types, only the pooled embeddings are returned, normalized using Euclidian norm. +This endpoint supports all poolings, including `--pooling none`. When the pooling is `none`, the responses will contain the *unnormalized* embeddings for *all* input tokens. For all other pooling types, only the pooled embeddings are returned, normalized using Euclidean norm. Note that the response format of this endpoint is different from `/v1/embeddings`. diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index c6d8258d7f..4f18a634ce 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 129801fe06..6062904a8c 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -5262,42 +5262,6 @@ int main(int argc, char ** argv) { svr->Get (params.api_prefix + "/slots", handle_slots); svr->Post(params.api_prefix + "/slots/:id_slot", handle_slots_action); - // SPA fallback route - serve index.html for any route that doesn't match API endpoints - // This enables client-side routing for dynamic routes like /chat/[id] - if (params.webui && params.public_path.empty()) { - // Only add fallback when using embedded static files - svr->Get(".*", [](const httplib::Request & req, httplib::Response & res) { - // Skip API routes - they should have been handled above - if (req.path.find("/v1/") != std::string::npos || - req.path.find("/health") != std::string::npos || - req.path.find("/metrics") != std::string::npos || - req.path.find("/props") != std::string::npos || - req.path.find("/models") != std::string::npos || - req.path.find("/api/tags") != std::string::npos || - req.path.find("/completions") != std::string::npos || - req.path.find("/chat/completions") != std::string::npos || - req.path.find("/embeddings") != std::string::npos || - req.path.find("/tokenize") != std::string::npos || - req.path.find("/detokenize") != std::string::npos || - req.path.find("/lora-adapters") != std::string::npos || - req.path.find("/slots") != std::string::npos) { - return false; // Let other handlers process API routes - } - - // Serve index.html for all other routes (SPA fallback) - if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { - res.set_content("Error: gzip is not supported by this browser", "text/plain"); - } else { - res.set_header("Content-Encoding", "gzip"); - // COEP and COOP headers, required by pyodide (python interpreter) - res.set_header("Cross-Origin-Embedder-Policy", "require-corp"); - res.set_header("Cross-Origin-Opener-Policy", "same-origin"); - res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); - } - return false; - }); - } - // // Start the server // diff --git a/tools/server/webui/package.json b/tools/server/webui/package.json index 30c7cd15c6..7bf21bf57c 100644 --- a/tools/server/webui/package.json +++ b/tools/server/webui/package.json @@ -4,7 +4,7 @@ "version": "1.0.0", "type": "module", "scripts": { - "dev": "vite dev --host 0.0.0.0 & storybook dev -p 6006 --ci", + "dev": "bash scripts/dev.sh", "build": "vite build && ./scripts/post-build.sh", "preview": "vite preview", "prepare": "svelte-kit sync || echo ''", @@ -20,7 +20,8 @@ "test:ui": "vitest --project=ui", "test:unit": "vitest", "storybook": "storybook dev -p 6006", - "build-storybook": "storybook build" + "build-storybook": "storybook build", + "cleanup": "rm -rf .svelte-kit build node_modules test-results" }, "devDependencies": { "@chromatic-com/storybook": "^4.0.1", diff --git a/tools/server/webui/scripts/dev.sh b/tools/server/webui/scripts/dev.sh new file mode 100644 index 0000000000..2bda8f22c8 --- /dev/null +++ b/tools/server/webui/scripts/dev.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +# Development script for llama.cpp webui +# +# This script starts the webui development servers (Storybook and Vite). +# Note: You need to start llama-server separately. +# +# Usage: +# bash scripts/dev.sh +# npm run dev + +cd ../../../ + +# Check and install git hooks if missing +check_and_install_hooks() { + local hooks_missing=false + + # Check for required hooks + if [ ! -f ".git/hooks/pre-commit" ] || [ ! -f ".git/hooks/pre-push" ] || [ ! -f ".git/hooks/post-push" ]; then + hooks_missing=true + fi + + if [ "$hooks_missing" = true ]; then + echo "🔧 Git hooks missing, installing them..." + cd tools/server/webui + if bash scripts/install-git-hooks.sh; then + echo "✅ Git hooks installed successfully" + else + echo "⚠️ Failed to install git hooks, continuing anyway..." + fi + cd ../../../ + else + echo "✅ Git hooks already installed" + fi +} + +# Install git hooks if needed +check_and_install_hooks + +# Cleanup function +cleanup() { + echo "🧹 Cleaning up..." + exit +} + +# Set up signal handlers +trap cleanup SIGINT SIGTERM + +echo "🚀 Starting development servers..." +echo "📝 Note: Make sure to start llama-server separately if needed" +cd tools/server/webui +storybook dev -p 6006 --ci & vite dev --host 0.0.0.0 & + +# Wait for all background processes +wait diff --git a/tools/server/webui/scripts/install-git-hooks.sh b/tools/server/webui/scripts/install-git-hooks.sh index a2a3ca76f5..d14e281389 100755 --- a/tools/server/webui/scripts/install-git-hooks.sh +++ b/tools/server/webui/scripts/install-git-hooks.sh @@ -1,14 +1,14 @@ #!/bin/bash -# Script to install pre-commit and post-commit hooks for webui -# Pre-commit: formats, lints, checks, and builds code, stashes unstaged changes -# Post-commit: automatically unstashes changes +# Script to install pre-commit and pre-push hooks for webui +# Pre-commit: formats code and runs checks +# Pre-push: builds the project, stashes unstaged changes REPO_ROOT=$(git rev-parse --show-toplevel) PRE_COMMIT_HOOK="$REPO_ROOT/.git/hooks/pre-commit" -POST_COMMIT_HOOK="$REPO_ROOT/.git/hooks/post-commit" +PRE_PUSH_HOOK="$REPO_ROOT/.git/hooks/pre-push" -echo "Installing pre-commit and post-commit hooks for webui..." +echo "Installing pre-commit and pre-push hooks for webui..." # Create the pre-commit hook cat > "$PRE_COMMIT_HOOK" << 'EOF' @@ -16,7 +16,7 @@ cat > "$PRE_COMMIT_HOOK" << 'EOF' # Check if there are any changes in the webui directory if git diff --cached --name-only | grep -q "^tools/server/webui/"; then - echo "Formatting webui code..." + echo "Formatting and checking webui code..." # Change to webui directory and run format cd tools/server/webui @@ -27,20 +27,12 @@ if git diff --cached --name-only | grep -q "^tools/server/webui/"; then exit 1 fi - # Stash any unstaged changes to avoid conflicts during format/build - echo "Stashing unstaged changes..." - git stash push --keep-index --include-untracked -m "Pre-commit hook: stashed unstaged changes" - STASH_CREATED=$? - # Run the format command npm run format # Check if format command succeeded if [ $? -ne 0 ]; then echo "Error: npm run format failed" - if [ $STASH_CREATED -eq 0 ]; then - echo "You can restore your unstaged changes with: git stash pop" - fi exit 1 fi @@ -50,9 +42,6 @@ if git diff --cached --name-only | grep -q "^tools/server/webui/"; then # Check if lint command succeeded if [ $? -ne 0 ]; then echo "Error: npm run lint failed" - if [ $STASH_CREATED -eq 0 ]; then - echo "You can restore your unstaged changes with: git stash pop" - fi exit 1 fi @@ -62,73 +51,151 @@ if git diff --cached --name-only | grep -q "^tools/server/webui/"; then # Check if check command succeeded if [ $? -ne 0 ]; then echo "Error: npm run check failed" - if [ $STASH_CREATED -eq 0 ]; then - echo "You can restore your unstaged changes with: git stash pop" - fi exit 1 fi - # Run the build command - npm run build - - # Check if build command succeeded - if [ $? -ne 0 ]; then - echo "Error: npm run build failed" - if [ $STASH_CREATED -eq 0 ]; then - echo "You can restore your unstaged changes with: git stash pop" - fi - exit 1 - fi - - # Go back to repo root to add build output + # Go back to repo root cd ../../.. - # Add the build output to staging area - git add tools/server/public/index.html.gz - - if [ $STASH_CREATED -eq 0 ]; then - echo "✅ Build completed. Your unstaged changes have been stashed." - echo "They will be automatically restored after the commit." - # Create a marker file to indicate stash was created by pre-commit hook - touch .git/WEBUI_STASH_MARKER - fi - - echo "Webui code formatted successfully" + echo "✅ Webui code formatted and checked successfully" fi exit 0 EOF -# Create the post-commit hook -cat > "$POST_COMMIT_HOOK" << 'EOF' +# Create the pre-push hook +cat > "$PRE_PUSH_HOOK" << 'EOF' #!/bin/bash -# Check if we have a stash marker from the pre-commit hook -if [ -f .git/WEBUI_STASH_MARKER ]; then - echo "Restoring your unstaged changes..." +# Check if there are any webui changes that need building +WEBUI_CHANGES=$(git diff --name-only @{push}..HEAD | grep "^tools/server/webui/" || true) + +if [ -n "$WEBUI_CHANGES" ]; then + echo "Webui changes detected, checking if build is up-to-date..." + + # Change to webui directory + cd tools/server/webui + + # Check if npm is available and package.json exists + if [ ! -f "package.json" ]; then + echo "Error: package.json not found in tools/server/webui" + exit 1 + fi + + # Check if build output exists and is newer than source files + BUILD_FILE="../public/index.html.gz" + NEEDS_BUILD=false + + if [ ! -f "$BUILD_FILE" ]; then + echo "Build output not found, building..." + NEEDS_BUILD=true + else + # Check if any source files are newer than the build output + if find src -newer "$BUILD_FILE" -type f | head -1 | grep -q .; then + echo "Source files are newer than build output, rebuilding..." + NEEDS_BUILD=true + fi + fi + + if [ "$NEEDS_BUILD" = true ]; then + echo "Building webui..." + + # Stash any unstaged changes to avoid conflicts during build + echo "Checking for unstaged changes..." + if ! git diff --quiet || ! git diff --cached --quiet --diff-filter=A; then + echo "Stashing unstaged changes..." + git stash push --include-untracked -m "Pre-push hook: stashed unstaged changes" + STASH_CREATED=$? + else + echo "No unstaged changes to stash" + STASH_CREATED=1 + fi + + # Run the build command + npm run build + + # Check if build command succeeded + if [ $? -ne 0 ]; then + echo "Error: npm run build failed" + if [ $STASH_CREATED -eq 0 ]; then + echo "You can restore your unstaged changes with: git stash pop" + fi + exit 1 + fi + + # Go back to repo root + cd ../../.. + + # Check if build output was created/updated + if [ -f "tools/server/public/index.html.gz" ]; then + # Add the build output and commit it + git add tools/server/public/index.html.gz + if ! git diff --cached --quiet; then + echo "Committing updated build output..." + git commit -m "chore: update webui build output" + echo "✅ Build output committed successfully" + else + echo "Build output unchanged" + fi + else + echo "Error: Build output not found after build" + if [ $STASH_CREATED -eq 0 ]; then + echo "You can restore your unstaged changes with: git stash pop" + fi + exit 1 + fi + + if [ $STASH_CREATED -eq 0 ]; then + echo "✅ Build completed. Your unstaged changes have been stashed." + echo "They will be automatically restored after the push." + # Create a marker file to indicate stash was created by pre-push hook + touch .git/WEBUI_PUSH_STASH_MARKER + fi + else + echo "✅ Build output is up-to-date" + fi + + echo "✅ Webui ready for push" +fi + +exit 0 +EOF + +# Create the post-push hook (for restoring stashed changes after push) +cat > "$REPO_ROOT/.git/hooks/post-push" << 'EOF' +#!/bin/bash + +# Check if we have a stash marker from the pre-push hook +if [ -f .git/WEBUI_PUSH_STASH_MARKER ]; then + echo "Restoring your unstaged changes after push..." git stash pop - rm -f .git/WEBUI_STASH_MARKER + rm -f .git/WEBUI_PUSH_STASH_MARKER echo "✅ Your unstaged changes have been restored." fi exit 0 EOF -# Make both hooks executable +# Make all hooks executable chmod +x "$PRE_COMMIT_HOOK" -chmod +x "$POST_COMMIT_HOOK" +chmod +x "$PRE_PUSH_HOOK" +chmod +x "$REPO_ROOT/.git/hooks/post-push" if [ $? -eq 0 ]; then - echo "✅ Pre-commit and post-commit hooks installed successfully!" - echo " Pre-commit: $PRE_COMMIT_HOOK" - echo " Post-commit: $POST_COMMIT_HOOK" + echo "✅ Git hooks installed successfully!" + echo " Pre-commit: $PRE_COMMIT_HOOK" + echo " Pre-push: $PRE_PUSH_HOOK" + echo " Post-push: $REPO_ROOT/.git/hooks/post-push" echo "" echo "The hooks will automatically:" - echo " • Format, lint, check, and build webui code before commits" - echo " • Stash unstaged changes during the process" - echo " • Restore your unstaged changes after the commit" + echo " • Format and check webui code before commits (pre-commit)" + echo " • Build webui code before pushes (pre-push)" + echo " • Stash unstaged changes during build process" + echo " • Restore your unstaged changes after the push" echo "" - echo "To test the hooks, make a change to a file in the webui directory and commit it." + echo "To test the hooks:" + echo " • Make a change to a file in the webui directory and commit it (triggers format/check)" + echo " • Push your commits to trigger the build process" else echo "❌ Failed to make hooks executable" exit 1 diff --git a/tools/server/webui/scripts/post-build.sh b/tools/server/webui/scripts/post-build.sh index ae5c1f30be..a49d6cc107 100755 --- a/tools/server/webui/scripts/post-build.sh +++ b/tools/server/webui/scripts/post-build.sh @@ -1,3 +1,3 @@ rm -rf ../public/_app; rm ../public/favicon.svg; -rm ../public/index.html; \ No newline at end of file +rm ../public/index.html; diff --git a/tools/server/webui/src/app.css b/tools/server/webui/src/app.css index a9ac80ab9c..2ca1536409 100644 --- a/tools/server/webui/src/app.css +++ b/tools/server/webui/src/app.css @@ -37,8 +37,9 @@ --sidebar-accent-foreground: oklch(0.205 0 0); --sidebar-border: oklch(0.922 0 0); --sidebar-ring: oklch(0.708 0 0); - --code-background: oklch(0.225 0 0); - --code-foreground: oklch(0.875 0 0); + --code-background: oklch(0.975 0 0); + --code-foreground: oklch(0.145 0 0); + --layer-popover: 1000000; } .dark { @@ -73,6 +74,8 @@ --sidebar-accent-foreground: oklch(0.985 0 0); --sidebar-border: oklch(1 0 0 / 10%); --sidebar-ring: oklch(0.556 0 0); + --code-background: oklch(0.225 0 0); + --code-foreground: oklch(0.875 0 0); } @theme inline { diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte index 287acac7a1..c16a3105cb 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte @@ -4,7 +4,6 @@ import ChatMessageBranchingControls from './ChatMessageBranchingControls.svelte'; interface Props { - message: DatabaseMessage; role: 'user' | 'assistant'; justify: 'start' | 'end'; actionsPosition: 'left' | 'right'; @@ -29,7 +28,6 @@ actionsPosition, deletionInfo, justify, - message, onCopy, onEdit, onConfirmDelete, @@ -49,26 +47,17 @@
-
- {new Date(message.timestamp).toLocaleTimeString(undefined, { - hour: '2-digit', - minute: '2-digit' - })} -
-
{#if siblingInfo && siblingInfo.totalSiblings > 1} {/if}
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte index 80a6bbaad4..ad3ffa3792 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte @@ -3,12 +3,14 @@ import { useProcessingState } from '$lib/hooks/use-processing-state.svelte'; import { isLoading } from '$lib/stores/chat.svelte'; import { fade } from 'svelte/transition'; - import { Check, X } from '@lucide/svelte'; + import { Check, Copy, Package, X } from '@lucide/svelte'; import { Button } from '$lib/components/ui/button'; import { Checkbox } from '$lib/components/ui/checkbox'; import { INPUT_CLASSES } from '$lib/constants/input-classes'; import ChatMessageActions from './ChatMessageActions.svelte'; import Label from '$lib/components/ui/label/label.svelte'; + import { config } from '$lib/stores/settings.svelte'; + import { copyToClipboard } from '$lib/utils/copy'; interface Props { class?: string; @@ -136,9 +138,25 @@
{/if} + {#if config().showModelInfo && message.model} + + + + Model used: + + + + {/if} + {#if message.timestamp && !isEditing} { + if (!messageElement || !message.content.trim()) return; + + if (message.content.includes('\n')) { + isMultiline = true; + return; + } + + const resizeObserver = new ResizeObserver((entries) => { + for (const entry of entries) { + const element = entry.target as HTMLElement; + const estimatedSingleLineHeight = 24; // Typical line height for text-md + + isMultiline = element.offsetHeight > estimatedSingleLineHeight * 1.5; + } + }); + + resizeObserver.observe(messageElement); + + return () => { + resizeObserver.disconnect(); + }; + });
{#if isEditing} @@ -92,10 +119,13 @@ {/if} {#if message.content.trim()} - -
+ + {message.content} -
+
{/if} @@ -105,7 +135,6 @@ actionsPosition="right" {deletionInfo} justify="end" - {message} {onConfirmDelete} {onCopy} {onDelete} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte index 17215b713a..666febf0d2 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte @@ -3,9 +3,11 @@ import { ChatForm, ChatScreenHeader, + ChatScreenWarning, ChatMessages, ChatProcessingInfo, EmptyFileAlertDialog, + ServerErrorSplash, ServerInfo, ServerLoadingSplash, ConfirmationDialog @@ -29,6 +31,7 @@ supportsVision, supportsAudio, serverLoading, + serverWarning, serverStore } from '$lib/stores/server.svelte'; import { contextService } from '$lib/services'; @@ -303,6 +306,10 @@ > + {#if serverWarning()} + + {/if} +
+{:else if serverStore.error && !serverStore.modelName} + {:else if serverStore.modelName}
+ {#if serverWarning()} + + {/if} +
+ import { AlertTriangle, RefreshCw } from '@lucide/svelte'; + import { serverLoading, serverStore } from '$lib/stores/server.svelte'; + import { fly } from 'svelte/transition'; + + interface Props { + class?: string; + } + + let { class: className = '' }: Props = $props(); + + function handleRefreshServer() { + serverStore.fetchServerProps(); + } + + +
+
+
+
+ +

+ Server `/props` endpoint not available - using cached data +

+
+ +
+
+
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsDialog.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsDialog.svelte index 0d9cca0589..d832abc2e2 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsDialog.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsDialog.svelte @@ -75,6 +75,11 @@ key: 'pdfAsImage', label: 'Parse PDF as image', type: 'checkbox' + }, + { + key: 'showModelInfo', + label: 'Show model information', + type: 'checkbox' } ] }, @@ -362,7 +367,8 @@
@@ -441,7 +447,7 @@
- +
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte index c1eeed3b32..e06399e0bc 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte @@ -13,10 +13,9 @@ localConfig: SettingsConfigType; onConfigChange: (key: string, value: string | boolean) => void; onThemeChange?: (theme: string) => void; - isMobile?: boolean; } - let { fields, localConfig, onConfigChange, onThemeChange, isMobile = false }: Props = $props(); + let { fields, localConfig, onConfigChange, onThemeChange }: Props = $props(); {#each fields as field (field.key)} @@ -28,10 +27,10 @@ onConfigChange(field.key, e.currentTarget.value)} - placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] || 'none'}`} - class={isMobile ? 'w-full' : 'max-w-md'} + placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] ?? 'none'}`} + class="w-full md:max-w-md" /> {#if field.help || SETTING_CONFIG_INFO[field.key]}

@@ -45,10 +44,10 @@